From 73caa5dd1bb5dd31267613fa4447e1df7f736ae8 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 12 Jul 2025 14:48:47 -0400 Subject: [PATCH] remove Kernel.membufs [pr] (#11200) --- test/external/external_test_hcq_fuzz_failures.py | 4 ++-- test/external/fuzz_linearizer.py | 2 +- test/external/speed_compare_amd_am.py | 2 +- test/external/speed_compare_cuda_nv.py | 2 +- test/unit/test_search.py | 8 ++++---- tinygrad/opt/kernel.py | 9 ++++----- 6 files changed, 13 insertions(+), 14 deletions(-) diff --git a/test/external/external_test_hcq_fuzz_failures.py b/test/external/external_test_hcq_fuzz_failures.py index 2ac55c93c7..641461c67f 100644 --- a/test/external/external_test_hcq_fuzz_failures.py +++ b/test/external/external_test_hcq_fuzz_failures.py @@ -15,8 +15,8 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View def helper_test_lin(lin: Kernel, opts, failed_platforms, validate_device, rtol=1e-2, atol=1e-2): - if any(b.dtype.base == dtypes.half for b in lin.membufs) and not is_dtype_supported(dtypes.half): return - if any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs) and not is_dtype_supported(dtypes.bfloat16): return + if any(b.dtype.base == dtypes.half for b in lin.bufs) and not is_dtype_supported(dtypes.half): return + if any(b.dtype.base == dtypes.bfloat16 for b in lin.bufs) and not is_dtype_supported(dtypes.bfloat16): return try: lin.apply_opts(opts) diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index c879c30184..77da3c8449 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -114,7 +114,7 @@ def run_linearizer(lin: Kernel, rawbufs=None, var_vals=None) -> tuple[str, Any]: def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2): # TODO: for bfloat16 it compiles linearizer, but it does not run because numpy cannot generate bf16 buffer. - has_bf16 = any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs) + has_bf16 = any(b.dtype.base == dtypes.bfloat16 for b in lin.bufs) # TODO: raise specific fuzzing errors instead of str, and propagate the error message try: diff --git a/test/external/speed_compare_amd_am.py b/test/external/speed_compare_amd_am.py index 3a422d4ca0..2821d2388d 100644 --- a/test/external/speed_compare_amd_am.py +++ b/test/external/speed_compare_amd_am.py @@ -77,7 +77,7 @@ if __name__ == "__main__": with run_amd(): amdlin = ast_str_to_lin(ast, opts=amddev.renderer) amdlin.apply_opts(hand_coded_optimizations(amdlin)) - has_bf16 = any(b.dtype == dtypes.bfloat16 for b in amdlin.membufs) + has_bf16 = any(b.dtype == dtypes.bfloat16 for b in amdlin.bufs) amd_prg = CompiledRunner(amdlin.to_program()) amdbufs = bufs_from_lin(amdlin) diff --git a/test/external/speed_compare_cuda_nv.py b/test/external/speed_compare_cuda_nv.py index f9185a94bf..6fc557bc63 100644 --- a/test/external/speed_compare_cuda_nv.py +++ b/test/external/speed_compare_cuda_nv.py @@ -23,7 +23,7 @@ if __name__ == "__main__": # cuda compile culin = ast_str_to_lin(ast, opts=cudev.renderer) culin.apply_opts(hand_coded_optimizations(culin)) - has_bf16 = any(b.dtype == dtypes.bfloat16 for b in culin.membufs) + has_bf16 = any(b.dtype == dtypes.bfloat16 for b in culin.bufs) cuda_prg = CompiledRunner(culin.to_program()) cubufs = bufs_from_lin(culin) diff --git a/test/unit/test_search.py b/test/unit/test_search.py index cb9eb3d3c6..8d179754b5 100644 --- a/test/unit/test_search.py +++ b/test/unit/test_search.py @@ -15,8 +15,8 @@ class TestSearchUtil(unittest.TestCase): def test_bufs_from_lin(self): a = Tensor([1,2,3,4]).realize() si = (a+1).schedule()[0] - rawbufs = bufs_from_lin(lin:=Kernel(si.ast)) - assert len(rawbufs) == len(lin.membufs) == 2 + rawbufs = bufs_from_lin(Kernel(si.ast)) + assert len(rawbufs) == 2 assert all(r is not None for r in rawbufs) assert all(isinstance(r, Buffer) for r in rawbufs) assert all(r.size > 0 for r in rawbufs) @@ -25,8 +25,8 @@ class TestSearchUtil(unittest.TestCase): a = Tensor.randn(4, 4).realize() b = a+a[0] si = b.schedule()[0] - rawbufs = bufs_from_lin(k:=Kernel(si.ast)) - assert len(rawbufs) == len(k.membufs) == 2 + rawbufs = bufs_from_lin(Kernel(si.ast)) + assert len(rawbufs) == 2 assert all(r is not None for r in rawbufs) assert all(isinstance(r, Buffer) for r in rawbufs) assert all(r.size > 0 for r in rawbufs) diff --git a/tinygrad/opt/kernel.py b/tinygrad/opt/kernel.py index ab7f75827c..a067c4c0e5 100644 --- a/tinygrad/opt/kernel.py +++ b/tinygrad/opt/kernel.py @@ -207,8 +207,10 @@ class Kernel: first_reduce = [resolve(x!=y) for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True) # if it's an image, insert fake strides such that this fusion doesn't happen across image axes - if isinstance(self.membufs[0].dtype, ImageDType): - base_shape = self.membufs[0].dtype.shape + # TODO: remove membufs + membufs = dedup([x.src[0].base for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}]) + if isinstance(membufs[0].base.dtype, ImageDType): + base_shape = membufs[0].base.dtype.shape if shape_idx_groups := get_contraction(self.output_shape, base_shape): special_strides: tuple[sint, ...] = tuple() for i,g in enumerate(shape_idx_groups): @@ -504,9 +506,6 @@ class Kernel: # TODO: update the tests and delete these methods - @property - def membufs(self) -> list[UOp]: return dedup([x.src[0].base for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}]) - def to_program(self, name_override:Optional[str]=None) -> ProgramSpec: from tinygrad.engine.realize import get_program ret = get_program(self.get_optimized_ast(name_override), self.opts)