From 124d2f82273befaed4d2e40cb8a5bdc51d228090 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 1 Jun 2026 14:42:02 -0700 Subject: [PATCH] anon addrspace from new renderer (#16461) * anon addrspace from new renderer * use max_numel in python renderer * add sizes to ptrs in tests * more * correct fix --- test/backend/test_renderer_failures.py | 16 +++---- test/backend/test_uops.py | 16 +++---- test/helpers.py | 4 +- test/null/test_transcendental_helpers.py | 4 +- test/null/test_uop_graph.py | 58 ++++++++++++------------ test/null/test_uop_symbolic.py | 4 +- test/null/test_uops.py | 32 ++++++------- tinygrad/dtype.py | 2 +- tinygrad/runtime/ops_python.py | 14 +++--- tinygrad/uop/ops.py | 3 +- tinygrad/uop/spec.py | 5 +- tinygrad/viz/serve.py | 2 +- 12 files changed, 81 insertions(+), 79 deletions(-) diff --git a/test/backend/test_renderer_failures.py b/test/backend/test_renderer_failures.py index 5b7945c922..3839f2bab7 100644 --- a/test/backend/test_renderer_failures.py +++ b/test/backend/test_renderer_failures.py @@ -22,30 +22,30 @@ def _test_uop_result(inputs:list[Tensor], sink:UOp, local_size=None): def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp): dtype = alu_src_uops[0].dtype - a = UOp.param(0, dtype.ptr()) - b = UOp.param(1, dtype.ptr()) + a = UOp.param(0, dtype.ptr(1)) + b = UOp.param(1, dtype.ptr(1)) idx = UOp.const(dtypes.int, 0) - ld = b.index(idx) + ld = b.index(idx, ptr=True).load() alu = ld.alu(alu_op, *alu_src_uops) - store = UOp.store(a.index(idx), alu) + store = UOp.store(a.index(idx, ptr=True), alu) return _test_uop_result([Tensor([input_val])], UOp(Ops.SINK, dtypes.void, (store,), arg=KernelInfo()))[0] class TestRendererFailures(unittest.TestCase): @unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer") def test_gated_store_with_alu(self): - a = UOp.param(0, dtypes.int.ptr()) + a = UOp.param(0, dtypes.int.ptr(4)) gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0) - gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1))) + gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu), ptr=True), UOp.const(dtypes.int, 1))) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,), arg=KernelInfo()) ret = _test_uop_result([], sink, local_size=[4, 1, 1])[0] np.testing.assert_equal(ret, [0, 1, 1, 1]) @unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer") def test_gated_store_with_alu_2d(self): - a = UOp.param(0, dtypes.int.ptr()) + a = UOp.param(0, dtypes.int.ptr(8)) gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0) gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0) - gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1))) + gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1), ptr=True), UOp.const(dtypes.int, 1))) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,), arg=KernelInfo()) ret = _test_uop_result([], sink, local_size=[4, 2, 1])[0] np.testing.assert_equal(ret, [0, 0, 0, 0, 0, 1, 1, 1]) diff --git a/test/backend/test_uops.py b/test/backend/test_uops.py index ce559b89be..e1be405f1c 100644 --- a/test/backend/test_uops.py +++ b/test/backend/test_uops.py @@ -27,8 +27,8 @@ def uop(uops:list[UOp], op:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg: def _test_single_value(vals, op, dts): uops = [] output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1] - buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(), (), 0) - buf_loads = [uop(uops, Ops.PARAM, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)] + buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(1), (), 0) + buf_loads = [uop(uops, Ops.PARAM, dtype.ptr(1), (), i+1) for i,dtype in enumerate(dts)] loads = (buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0)) for i, dtype in enumerate(dts)) alu = uop(uops, op, output_dtype, loads) out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True), alu)) @@ -42,7 +42,7 @@ def _test_single_value(vals, op, dts): def _test_single_value_const(vals, op, dts): uops = [] output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1] - buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(), (), 0) + buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(1), (), 0) loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts)) alu = uop(uops, op, output_dtype, loads) out = buf_store[UOp.const(dtypes.int32, 0)].store(alu) @@ -54,7 +54,7 @@ def _test_single_value_const(vals, op, dts): def _test_uops_result(output_dtype, uops, res): # uops = [] - buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(), (), 0) + buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(1), (), 0) # res = output_fn(uops) out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), res)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() @@ -221,8 +221,8 @@ class TestLocalAccess(unittest.TestCase): @unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "This only tests assembly backends") class TestAssembly(unittest.TestCase): def test_bitshift_left(self): - g1 = UOp.param(0, dtypes.int32.ptr()) - out = UOp.param(1, dtypes.int32.ptr()) + g1 = UOp.param(0, dtypes.int32.ptr(3)) + out = UOp.param(1, dtypes.int32.ptr(2)) c1 = UOp.const(dtypes.int, 2) c2 = UOp.const(dtypes.int, 3) l1 = g1.index(c1) @@ -249,7 +249,7 @@ class TestAssembly(unittest.TestCase): self.assertGreaterEqual(len([x.op for x in uops if x.op is Ops.MULACC]), 4) def test_mulacc_shl(self): - g1 = UOp.param(0, dtypes.int32.ptr()) + g1 = UOp.param(0, dtypes.int32.ptr(2)) c1 = UOp.const(dtypes.int, 0) c2 = UOp.const(dtypes.int, 1) expr = g1.index(c1) * UOp.const(dtypes.int, 4096) + g1.index(c2) @@ -258,7 +258,7 @@ class TestAssembly(unittest.TestCase): self.assertIn(Ops.MULACC, [x.op for x in uops]) def test_use_cmpeq(self): - g = UOp.param(0, dtypes.uint32.ptr()) + g = UOp.param(0, dtypes.uint32.ptr(8)) c = UOp.const(dtypes.uint, 7) comp = g.index(c).ne(c).ne(True) uops = to_uops_list([comp], ren=Device[Device.DEFAULT].renderer) diff --git a/test/helpers.py b/test/helpers.py index c93bd451fa..c5cb0b1238 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -82,8 +82,8 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None, vals:tuple for buf_dt, data in inputs or []: bufs.append(buf:=allocator.alloc(len(data) * buf_dt.itemsize)) allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + (buf_dt.fmt or ""), *data))) - g = UOp.param(0, uop.dtype.ptr()) - prg = to_program(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(arg=KernelInfo()), PythonRenderer(Target("PYTHON"))) + g = UOp.param(0, uop.dtype.ptr(1)) + prg = to_program(UOp.store(g.index(UOp.const(dtypes.int, 0), ptr=True), uop).sink(arg=KernelInfo()), PythonRenderer(Target("PYTHON"))) prog = PythonProgram("run", PythonCompiler().compile(prg.src[3].arg)) prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs, vals=vals) return out_buf.cast(uop.dtype.fmt or "").tolist()[0] diff --git a/test/null/test_transcendental_helpers.py b/test/null/test_transcendental_helpers.py index b6e95ee932..107371a13f 100644 --- a/test/null/test_transcendental_helpers.py +++ b/test/null/test_transcendental_helpers.py @@ -10,8 +10,8 @@ class TestTranscendentalFunctions(unittest.TestCase): def test_payne_hanek_reduction(self): # TODO: Test constant input when constant folding is fixed (or maybe test both variants) # Load input value from a buffer to prevent constant folding - input_buf = UOp.param(1, dtypes.double.ptr()) - loaded_value = input_buf.index(UOp.const(dtypes.int, 0)) + input_buf = UOp.param(1, dtypes.double.ptr(1)) + loaded_value = input_buf.index(UOp.const(dtypes.int, 0), ptr=True).load() def eval_payne_hanek_reduction(v:float) -> tuple[float, int]: return tuple(eval_uop(u, [(dtypes.float64, [v])]) for u in payne_hanek_reduction(loaded_value)) diff --git a/test/null/test_uop_graph.py b/test/null/test_uop_graph.py index 2fda5df0f6..6f38064ca7 100644 --- a/test/null/test_uop_graph.py +++ b/test/null/test_uop_graph.py @@ -380,22 +380,22 @@ class TestUOpGraph(unittest.TestCase): self.assertEqual(uops[-2], wmma) # -2 to skip SINK def test_cast_alu_fold(self): - d0 = UOp.param(0, dtypes.bool.ptr()) - d1 = UOp.param(1, dtypes.int.ptr()) + d0 = UOp.param(0, dtypes.bool.ptr(1)) + d1 = UOp.param(1, dtypes.int.ptr(1)) idx = UOp.const(dtypes.int, 0) ld = d1.index(idx) alu = (ld<1).cast(dtypes.bool) - out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu)) + out = UOp(Ops.STORE, dtypes.void, (d0.index(idx, ptr=True), alu)) uops = to_uops_list([out]) self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0) def test_double_cast_fold(self): - d0 = UOp.param(0, dtypes.float.ptr()) - d1 = UOp.param(1, dtypes.int.ptr()) + d0 = UOp.param(0, dtypes.float.ptr(1)) + d1 = UOp.param(1, dtypes.int.ptr(1)) idx = UOp.const(dtypes.int, 0) ld = d1.index(idx) alu = ld.cast(dtypes.float).cast(dtypes.float) - out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu)) + out = UOp(Ops.STORE, dtypes.void, (d0.index(idx, ptr=True), alu)) uops = to_uops_list([out]) self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1) @@ -414,7 +414,7 @@ class TestUOpGraph(unittest.TestCase): def test_bitcast_to_same_dtype_fold(self): for dt in dtypes.ints + dtypes.floats + (dtypes.bool,): - d0 = UOp.param(0, dt.ptr()) + d0 = UOp.param(0, dt.ptr(1)) v = d0.index(UOp.const(dtypes.int, 0)) uops = to_uops_list([v.bitcast(dt)]) self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST and x.dtype is dt]), 0, f"dtype = {dt}") @@ -427,18 +427,18 @@ class TestUOpGraph(unittest.TestCase): def test_where_on_gated_load_fold(self): ridx0 = UOp.range(100, 0) - d0 = UOp.param(0, dtypes.long.ptr()) + d0 = UOp.param(0, dtypes.long.ptr(100)) ld = d0.index(ridx0.valid(ridx0<50)) w = (ridx0<50).where(ld, 5) - out = UOp.param(1, dtypes.long.ptr()) - uops = to_uops_list([out.index(ridx0).store(w)]) + out = UOp.param(1, dtypes.long.ptr(100)) + uops = to_uops_list([out.index(ridx0, ptr=True).store(w)]) for u in uops: assert u.op is not Ops.WHERE if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg==5 def test_where_on_gated_load_folds_swapped_branches(self): ridx0 = UOp.range(100, 0) - d0 = UOp.param(0, dtypes.long.ptr()) + d0 = UOp.param(0, dtypes.long.ptr(100)) ld = d0.index(ridx0.valid((ridx0<50).logical_not())) w = (ridx0<50).where(5, ld) uops = to_uops_list([w]) @@ -448,40 +448,40 @@ class TestUOpGraph(unittest.TestCase): def test_where_on_gated_load_with_cast(self): ridx0 = UOp.range(100, 0) - d0 = UOp.param(0, dtypes.int.ptr()) + d0 = UOp.param(0, dtypes.int.ptr(100)) gate_idx = ridx0.valid((ridx0<50)) ld = d0.index(gate_idx).cast(dtypes.float) w = (ridx0<50).where(ld, 5.0) - out = UOp.param(1, dtypes.float.ptr()) - uops = to_uops_list([out.index(ridx0).store(w)]) + out = UOp.param(1, dtypes.float.ptr(100)) + uops = to_uops_list([out.index(ridx0, ptr=True).store(w)]) for u in uops: assert u.op is not Ops.WHERE if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg == 5 def test_where_on_casted_gated_load_extra_cond(self): ridx0 = UOp.range(100, 0) - d0 = UOp.param(0, dtypes.float.ptr()) + d0 = UOp.param(0, dtypes.float.ptr(100)) ld = d0.index(ridx0.valid(ridx0<50)) w = ((ridx0<50) & (ridx0>30)).where(ld, UOp.const(dtypes.float, 0)).cast(dtypes.half) - out = UOp.param(1, dtypes.half.ptr()) - uops = to_uops_list([out.index(ridx0).store(w)]) + out = UOp.param(1, dtypes.half.ptr(100)) + uops = to_uops_list([out.index(ridx0, ptr=True).store(w)]) for u in uops: assert u.op is not Ops.WHERE def test_where_on_casted_gated_load_extra_cond_swapped(self): ridx0 = UOp.range(100, 0) - d0 = UOp.param(0, dtypes.float.ptr()) + d0 = UOp.param(0, dtypes.float.ptr(100)) ld = d0.index(ridx0.valid(ridx0<50)) w = ((ridx0<50) & (ridx0>30)).where(UOp.const(dtypes.float, 0), ld).cast(dtypes.half) - out = UOp.param(1, dtypes.half.ptr()) - uops = to_uops_list([out.index(ridx0).store(w)]) + out = UOp.param(1, dtypes.half.ptr(100)) + uops = to_uops_list([out.index(ridx0, ptr=True).store(w)]) for u in uops: assert u.op is not Ops.WHERE def test_where_in_store_becomes_gate(self): ridx0 = UOp.range(100, 0) - d0 = UOp.param(0, dtypes.long.ptr()) - idx = d0.index(ridx0) + d0 = UOp.param(0, dtypes.long.ptr(100)) + idx = d0.index(ridx0, ptr=True) ld = idx.load() val = (ridx0<50).where(5, ld) st = idx.store(val).end(ridx0) @@ -529,33 +529,33 @@ class TestUOpGraph(unittest.TestCase): self.assertNotEqual(u.dtype, dtypes.long) def test_fold_gated_load(self): - glbl0 = UOp.param(0, dtypes.int.ptr()) - glbl1 = UOp.param(1, dtypes.int.ptr()) - glbl2 = UOp.param(2, dtypes.int.ptr()) + glbl0 = UOp.param(0, dtypes.int.ptr(1)) + glbl1 = UOp.param(1, dtypes.int.ptr(1)) + glbl2 = UOp.param(2, dtypes.int.ptr(1)) idx = UOp.const(dtypes.int, 0) ld0 = glbl1.index(UOp.invalid()) ld1 = glbl2.index(idx.valid(UOp.const(dtypes.bool, True))) - uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))]) + uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx, ptr=True), ld1+ld0))]) ld0 = uops[-2].src[-1] # -2 to skip SINK # the gate and invalid value are deleted from ld1 self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int)) def test_fold_gated_load_local(self): - glbl0 = UOp.param(0, dtypes.int.ptr()) + glbl0 = UOp.param(0, dtypes.int.ptr(16)) smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp") lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0") st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx, ptr=True), glbl0.index(lidx, ptr=True).load())) barrier = UOp(Ops.BARRIER, dtypes.void, (st, )) ld0 = smem.after(barrier).index(UOp.invalid()) ld1 = smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True))) - uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))]) + uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx, ptr=True), ld1+ld0))]) ld0 = uops[-2].src[-1] # -2 to skip SINK # the gate and invalid value are deleted from ld1 self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True)) def test_fold_gated_store(self): - glbl = UOp.param(0, dtypes.int.ptr()) + glbl = UOp.param(0, dtypes.int.ptr(1)) idx0 = UOp.const(dtypes.int, 0) idx1 = UOp.const(dtypes.int, 0) val = UOp.const(dtypes.int, 42) diff --git a/test/null/test_uop_symbolic.py b/test/null/test_uop_symbolic.py index 884732dc71..b3fa4862df 100644 --- a/test/null/test_uop_symbolic.py +++ b/test/null/test_uop_symbolic.py @@ -951,8 +951,8 @@ class TestSymbolic(unittest.TestCase): expr = cond.where(a, b).cast(dtypes.half) # TODO: copied from render, render does not support cast - glbl = UOp.param(0, dtypes.int.ptr()) - uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink()) + glbl = UOp.param(0, dtypes.int.ptr(1)) + uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0), ptr=True), expr)).sink()) rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1] self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half))) diff --git a/test/null/test_uops.py b/test/null/test_uops.py index 18ff80efc2..f7196cc07e 100644 --- a/test/null/test_uops.py +++ b/test/null/test_uops.py @@ -110,10 +110,10 @@ class TestExecALU(unittest.TestCase): class TestGatedStoreRewrite(unittest.TestCase): def test_tiny_gate_store(self): - gmem = UOp.param(0, dtypes.float.ptr()) + gmem = UOp.param(0, dtypes.float.ptr(8)) gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0') gate = gidx0 1 else [src_values[1]]): + for j,val in enumerate(src_values[1] if u.max_numel() > 1 else [src_values[1]]): for (m,o),v,g in zip(src_values[0], val, store_gate): if g: _store(m, o+j, v, src_dtypes[1].scalar()) i += 1 continue if u.op is Ops.AFTER: values[u] = src_values[0] elif u.op in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}: - assert isinstance(u.dtype, PtrDType), u.dtype storage_fmt = storage_fmt_for_dtype(u.dtype.base.scalar()) if storage_fmt is None: raise RuntimeError(f"dtype={u.dtype} is not supported") if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e" if u.op is Ops.DEFINE_REG: # REGs are per thread - values[u] = [memoryview(bytearray(u.dtype.size*u.dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)] + values[u] = [memoryview(bytearray(u.max_numel()*u.dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)] else: - buf = memoryview(bytearray(u.dtype.size*u.dtype.itemsize)) if u.op is not Ops.PARAM else pbufs.pop(0) + buf = memoryview(bytearray(u.max_numel()*u.dtype.itemsize)) if u.op is not Ops.PARAM else pbufs.pop(0) values[u] = [buf.cast(storage_fmt)] * warp_size elif u.op is Ops.DEFINE_VAR: values[u] = [pvals.pop(0)] * warp_size @@ -129,9 +128,10 @@ class PythonProgram: elif u.op is Ops.CAST: values[u] = [truncate.get(u.dtype, lambda dt: dt)(u.dtype.const(x)) for x in src_values[0]] elif u.op is Ops.LOAD: - if u.dtype.count > 1: - values[u] = [load([src_values[k][j] if k != 0 and src_dtypes[k].count > 1 else src_values[k] \ - for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(u.dtype.count)] + if (load_sz := u.max_numel()) > 1: + # buf and gate are not vecs + values[u] = [load([src_values[k] if k in [0,2] else src_values[k][j] \ + for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(load_sz)] else: values[u] = load(src_values, 0, u.dtype) elif u.op is Ops.GEP: values[u] = src_values[0][get_single_element(u.arg)] diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 0d211ab016..b5aa11b486 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -758,8 +758,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if self.op is Ops.BUFFER: return AddrSpace.GLOBAL if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL if self.op is Ops.DEFINE_REG: return AddrSpace.REG - # LOAD brings things into registers - if self.op is Ops.LOAD: return AddrSpace.REG + if self.op is Ops.LOAD: return AddrSpace.ANON # LOAD brings things into anonymous registers if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP}: return self.src[0].addrspace if self.op in GroupOp.Movement: return self.src[0].addrspace diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 16745dd0fa..87f37a0169 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -13,7 +13,10 @@ def validate_index(uidx:UOp, gate:UOp|None=None): if idx.op is Ops.CONST and idx.arg is Invalid: return True if gate is None: gate = UOp.const(dtypes.bool, True) # TODO: check for overflow - if not CHECK_OOB or isinstance(buf.dtype, ImageDType) or (sz := buf.ptrdtype.size) == -1: return True + if not CHECK_OOB or isinstance(buf.dtype, ImageDType): return True + + # buffer size + sz = buf.max_numel() # We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask if 0<=idx.vmin and idx.vmax