remove Kernel.membufs [pr] (#11200)

This commit is contained in:
chenyu
2025-07-12 14:48:47 -04:00
committed by GitHub
parent 5ce278b245
commit 73caa5dd1b
6 changed files with 13 additions and 14 deletions

View File

@@ -15,8 +15,8 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
def helper_test_lin(lin: Kernel, opts, failed_platforms, validate_device, rtol=1e-2, atol=1e-2):
if any(b.dtype.base == dtypes.half for b in lin.membufs) and not is_dtype_supported(dtypes.half): return
if any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs) and not is_dtype_supported(dtypes.bfloat16): return
if any(b.dtype.base == dtypes.half for b in lin.bufs) and not is_dtype_supported(dtypes.half): return
if any(b.dtype.base == dtypes.bfloat16 for b in lin.bufs) and not is_dtype_supported(dtypes.bfloat16): return
try:
lin.apply_opts(opts)

View File

@@ -114,7 +114,7 @@ def run_linearizer(lin: Kernel, rawbufs=None, var_vals=None) -> tuple[str, Any]:
def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
# TODO: for bfloat16 it compiles linearizer, but it does not run because numpy cannot generate bf16 buffer.
has_bf16 = any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs)
has_bf16 = any(b.dtype.base == dtypes.bfloat16 for b in lin.bufs)
# TODO: raise specific fuzzing errors instead of str, and propagate the error message
try:

View File

@@ -77,7 +77,7 @@ if __name__ == "__main__":
with run_amd():
amdlin = ast_str_to_lin(ast, opts=amddev.renderer)
amdlin.apply_opts(hand_coded_optimizations(amdlin))
has_bf16 = any(b.dtype == dtypes.bfloat16 for b in amdlin.membufs)
has_bf16 = any(b.dtype == dtypes.bfloat16 for b in amdlin.bufs)
amd_prg = CompiledRunner(amdlin.to_program())
amdbufs = bufs_from_lin(amdlin)

View File

@@ -23,7 +23,7 @@ if __name__ == "__main__":
# cuda compile
culin = ast_str_to_lin(ast, opts=cudev.renderer)
culin.apply_opts(hand_coded_optimizations(culin))
has_bf16 = any(b.dtype == dtypes.bfloat16 for b in culin.membufs)
has_bf16 = any(b.dtype == dtypes.bfloat16 for b in culin.bufs)
cuda_prg = CompiledRunner(culin.to_program())
cubufs = bufs_from_lin(culin)

View File

@@ -15,8 +15,8 @@ class TestSearchUtil(unittest.TestCase):
def test_bufs_from_lin(self):
a = Tensor([1,2,3,4]).realize()
si = (a+1).schedule()[0]
rawbufs = bufs_from_lin(lin:=Kernel(si.ast))
assert len(rawbufs) == len(lin.membufs) == 2
rawbufs = bufs_from_lin(Kernel(si.ast))
assert len(rawbufs) == 2
assert all(r is not None for r in rawbufs)
assert all(isinstance(r, Buffer) for r in rawbufs)
assert all(r.size > 0 for r in rawbufs)
@@ -25,8 +25,8 @@ class TestSearchUtil(unittest.TestCase):
a = Tensor.randn(4, 4).realize()
b = a+a[0]
si = b.schedule()[0]
rawbufs = bufs_from_lin(k:=Kernel(si.ast))
assert len(rawbufs) == len(k.membufs) == 2
rawbufs = bufs_from_lin(Kernel(si.ast))
assert len(rawbufs) == 2
assert all(r is not None for r in rawbufs)
assert all(isinstance(r, Buffer) for r in rawbufs)
assert all(r.size > 0 for r in rawbufs)

View File

@@ -207,8 +207,10 @@ class Kernel:
first_reduce = [resolve(x!=y) for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True)
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
if isinstance(self.membufs[0].dtype, ImageDType):
base_shape = self.membufs[0].dtype.shape
# TODO: remove membufs
membufs = dedup([x.src[0].base for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
if isinstance(membufs[0].base.dtype, ImageDType):
base_shape = membufs[0].base.dtype.shape
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
special_strides: tuple[sint, ...] = tuple()
for i,g in enumerate(shape_idx_groups):
@@ -504,9 +506,6 @@ class Kernel:
# TODO: update the tests and delete these methods
@property
def membufs(self) -> list[UOp]: return dedup([x.src[0].base for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
def to_program(self, name_override:Optional[str]=None) -> ProgramSpec:
from tinygrad.engine.realize import get_program
ret = get_program(self.get_optimized_ast(name_override), self.opts)