mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
remove Kernel.membufs [pr] (#11200)
This commit is contained in:
@@ -15,8 +15,8 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
def helper_test_lin(lin: Kernel, opts, failed_platforms, validate_device, rtol=1e-2, atol=1e-2):
|
||||
if any(b.dtype.base == dtypes.half for b in lin.membufs) and not is_dtype_supported(dtypes.half): return
|
||||
if any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs) and not is_dtype_supported(dtypes.bfloat16): return
|
||||
if any(b.dtype.base == dtypes.half for b in lin.bufs) and not is_dtype_supported(dtypes.half): return
|
||||
if any(b.dtype.base == dtypes.bfloat16 for b in lin.bufs) and not is_dtype_supported(dtypes.bfloat16): return
|
||||
|
||||
try:
|
||||
lin.apply_opts(opts)
|
||||
|
||||
2
test/external/fuzz_linearizer.py
vendored
2
test/external/fuzz_linearizer.py
vendored
@@ -114,7 +114,7 @@ def run_linearizer(lin: Kernel, rawbufs=None, var_vals=None) -> tuple[str, Any]:
|
||||
|
||||
def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
|
||||
# TODO: for bfloat16 it compiles linearizer, but it does not run because numpy cannot generate bf16 buffer.
|
||||
has_bf16 = any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs)
|
||||
has_bf16 = any(b.dtype.base == dtypes.bfloat16 for b in lin.bufs)
|
||||
|
||||
# TODO: raise specific fuzzing errors instead of str, and propagate the error message
|
||||
try:
|
||||
|
||||
2
test/external/speed_compare_amd_am.py
vendored
2
test/external/speed_compare_amd_am.py
vendored
@@ -77,7 +77,7 @@ if __name__ == "__main__":
|
||||
with run_amd():
|
||||
amdlin = ast_str_to_lin(ast, opts=amddev.renderer)
|
||||
amdlin.apply_opts(hand_coded_optimizations(amdlin))
|
||||
has_bf16 = any(b.dtype == dtypes.bfloat16 for b in amdlin.membufs)
|
||||
has_bf16 = any(b.dtype == dtypes.bfloat16 for b in amdlin.bufs)
|
||||
|
||||
amd_prg = CompiledRunner(amdlin.to_program())
|
||||
amdbufs = bufs_from_lin(amdlin)
|
||||
|
||||
2
test/external/speed_compare_cuda_nv.py
vendored
2
test/external/speed_compare_cuda_nv.py
vendored
@@ -23,7 +23,7 @@ if __name__ == "__main__":
|
||||
# cuda compile
|
||||
culin = ast_str_to_lin(ast, opts=cudev.renderer)
|
||||
culin.apply_opts(hand_coded_optimizations(culin))
|
||||
has_bf16 = any(b.dtype == dtypes.bfloat16 for b in culin.membufs)
|
||||
has_bf16 = any(b.dtype == dtypes.bfloat16 for b in culin.bufs)
|
||||
|
||||
cuda_prg = CompiledRunner(culin.to_program())
|
||||
cubufs = bufs_from_lin(culin)
|
||||
|
||||
@@ -15,8 +15,8 @@ class TestSearchUtil(unittest.TestCase):
|
||||
def test_bufs_from_lin(self):
|
||||
a = Tensor([1,2,3,4]).realize()
|
||||
si = (a+1).schedule()[0]
|
||||
rawbufs = bufs_from_lin(lin:=Kernel(si.ast))
|
||||
assert len(rawbufs) == len(lin.membufs) == 2
|
||||
rawbufs = bufs_from_lin(Kernel(si.ast))
|
||||
assert len(rawbufs) == 2
|
||||
assert all(r is not None for r in rawbufs)
|
||||
assert all(isinstance(r, Buffer) for r in rawbufs)
|
||||
assert all(r.size > 0 for r in rawbufs)
|
||||
@@ -25,8 +25,8 @@ class TestSearchUtil(unittest.TestCase):
|
||||
a = Tensor.randn(4, 4).realize()
|
||||
b = a+a[0]
|
||||
si = b.schedule()[0]
|
||||
rawbufs = bufs_from_lin(k:=Kernel(si.ast))
|
||||
assert len(rawbufs) == len(k.membufs) == 2
|
||||
rawbufs = bufs_from_lin(Kernel(si.ast))
|
||||
assert len(rawbufs) == 2
|
||||
assert all(r is not None for r in rawbufs)
|
||||
assert all(isinstance(r, Buffer) for r in rawbufs)
|
||||
assert all(r.size > 0 for r in rawbufs)
|
||||
|
||||
@@ -207,8 +207,10 @@ class Kernel:
|
||||
first_reduce = [resolve(x!=y) for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True)
|
||||
|
||||
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
|
||||
if isinstance(self.membufs[0].dtype, ImageDType):
|
||||
base_shape = self.membufs[0].dtype.shape
|
||||
# TODO: remove membufs
|
||||
membufs = dedup([x.src[0].base for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
|
||||
if isinstance(membufs[0].base.dtype, ImageDType):
|
||||
base_shape = membufs[0].base.dtype.shape
|
||||
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
|
||||
special_strides: tuple[sint, ...] = tuple()
|
||||
for i,g in enumerate(shape_idx_groups):
|
||||
@@ -504,9 +506,6 @@ class Kernel:
|
||||
|
||||
# TODO: update the tests and delete these methods
|
||||
|
||||
@property
|
||||
def membufs(self) -> list[UOp]: return dedup([x.src[0].base for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
|
||||
|
||||
def to_program(self, name_override:Optional[str]=None) -> ProgramSpec:
|
||||
from tinygrad.engine.realize import get_program
|
||||
ret = get_program(self.get_optimized_ast(name_override), self.opts)
|
||||
|
||||
Reference in New Issue
Block a user