anon addrspace from new renderer (#16461)

* anon addrspace from new renderer

* use max_numel in python renderer

* add sizes to ptrs in tests

* more

* correct fix
This commit is contained in:
George Hotz
2026-06-01 14:42:02 -07:00
committed by GitHub
parent 517eea5985
commit 124d2f8227
12 changed files with 81 additions and 79 deletions

View File

@@ -22,30 +22,30 @@ def _test_uop_result(inputs:list[Tensor], sink:UOp, local_size=None):
def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
dtype = alu_src_uops[0].dtype
a = UOp.param(0, dtype.ptr())
b = UOp.param(1, dtype.ptr())
a = UOp.param(0, dtype.ptr(1))
b = UOp.param(1, dtype.ptr(1))
idx = UOp.const(dtypes.int, 0)
ld = b.index(idx)
ld = b.index(idx, ptr=True).load()
alu = ld.alu(alu_op, *alu_src_uops)
store = UOp.store(a.index(idx), alu)
store = UOp.store(a.index(idx, ptr=True), alu)
return _test_uop_result([Tensor([input_val])], UOp(Ops.SINK, dtypes.void, (store,), arg=KernelInfo()))[0]
class TestRendererFailures(unittest.TestCase):
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
def test_gated_store_with_alu(self):
a = UOp.param(0, dtypes.int.ptr())
a = UOp.param(0, dtypes.int.ptr(4))
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1)))
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu), ptr=True), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,), arg=KernelInfo())
ret = _test_uop_result([], sink, local_size=[4, 1, 1])[0]
np.testing.assert_equal(ret, [0, 1, 1, 1])
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
def test_gated_store_with_alu_2d(self):
a = UOp.param(0, dtypes.int.ptr())
a = UOp.param(0, dtypes.int.ptr(8))
gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1)))
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1), ptr=True), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,), arg=KernelInfo())
ret = _test_uop_result([], sink, local_size=[4, 2, 1])[0]
np.testing.assert_equal(ret, [0, 0, 0, 0, 0, 1, 1, 1])

View File

@@ -27,8 +27,8 @@ def uop(uops:list[UOp], op:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:
def _test_single_value(vals, op, dts):
uops = []
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(), (), 0)
buf_loads = [uop(uops, Ops.PARAM, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(1), (), 0)
buf_loads = [uop(uops, Ops.PARAM, dtype.ptr(1), (), i+1) for i,dtype in enumerate(dts)]
loads = (buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0)) for i, dtype in enumerate(dts))
alu = uop(uops, op, output_dtype, loads)
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True), alu))
@@ -42,7 +42,7 @@ def _test_single_value(vals, op, dts):
def _test_single_value_const(vals, op, dts):
uops = []
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(), (), 0)
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(1), (), 0)
loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
alu = uop(uops, op, output_dtype, loads)
out = buf_store[UOp.const(dtypes.int32, 0)].store(alu)
@@ -54,7 +54,7 @@ def _test_single_value_const(vals, op, dts):
def _test_uops_result(output_dtype, uops, res):
# uops = []
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(), (), 0)
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(1), (), 0)
# res = output_fn(uops)
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), res))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
@@ -221,8 +221,8 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "This only tests assembly backends")
class TestAssembly(unittest.TestCase):
def test_bitshift_left(self):
g1 = UOp.param(0, dtypes.int32.ptr())
out = UOp.param(1, dtypes.int32.ptr())
g1 = UOp.param(0, dtypes.int32.ptr(3))
out = UOp.param(1, dtypes.int32.ptr(2))
c1 = UOp.const(dtypes.int, 2)
c2 = UOp.const(dtypes.int, 3)
l1 = g1.index(c1)
@@ -249,7 +249,7 @@ class TestAssembly(unittest.TestCase):
self.assertGreaterEqual(len([x.op for x in uops if x.op is Ops.MULACC]), 4)
def test_mulacc_shl(self):
g1 = UOp.param(0, dtypes.int32.ptr())
g1 = UOp.param(0, dtypes.int32.ptr(2))
c1 = UOp.const(dtypes.int, 0)
c2 = UOp.const(dtypes.int, 1)
expr = g1.index(c1) * UOp.const(dtypes.int, 4096) + g1.index(c2)
@@ -258,7 +258,7 @@ class TestAssembly(unittest.TestCase):
self.assertIn(Ops.MULACC, [x.op for x in uops])
def test_use_cmpeq(self):
g = UOp.param(0, dtypes.uint32.ptr())
g = UOp.param(0, dtypes.uint32.ptr(8))
c = UOp.const(dtypes.uint, 7)
comp = g.index(c).ne(c).ne(True)
uops = to_uops_list([comp], ren=Device[Device.DEFAULT].renderer)

View File

@@ -82,8 +82,8 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None, vals:tuple
for buf_dt, data in inputs or []:
bufs.append(buf:=allocator.alloc(len(data) * buf_dt.itemsize))
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + (buf_dt.fmt or ""), *data)))
g = UOp.param(0, uop.dtype.ptr())
prg = to_program(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(arg=KernelInfo()), PythonRenderer(Target("PYTHON")))
g = UOp.param(0, uop.dtype.ptr(1))
prg = to_program(UOp.store(g.index(UOp.const(dtypes.int, 0), ptr=True), uop).sink(arg=KernelInfo()), PythonRenderer(Target("PYTHON")))
prog = PythonProgram("run", PythonCompiler().compile(prg.src[3].arg))
prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs, vals=vals)
return out_buf.cast(uop.dtype.fmt or "").tolist()[0]

View File

@@ -10,8 +10,8 @@ class TestTranscendentalFunctions(unittest.TestCase):
def test_payne_hanek_reduction(self):
# TODO: Test constant input when constant folding is fixed (or maybe test both variants)
# Load input value from a buffer to prevent constant folding
input_buf = UOp.param(1, dtypes.double.ptr())
loaded_value = input_buf.index(UOp.const(dtypes.int, 0))
input_buf = UOp.param(1, dtypes.double.ptr(1))
loaded_value = input_buf.index(UOp.const(dtypes.int, 0), ptr=True).load()
def eval_payne_hanek_reduction(v:float) -> tuple[float, int]:
return tuple(eval_uop(u, [(dtypes.float64, [v])]) for u in payne_hanek_reduction(loaded_value))

View File

@@ -380,22 +380,22 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
def test_cast_alu_fold(self):
d0 = UOp.param(0, dtypes.bool.ptr())
d1 = UOp.param(1, dtypes.int.ptr())
d0 = UOp.param(0, dtypes.bool.ptr(1))
d1 = UOp.param(1, dtypes.int.ptr(1))
idx = UOp.const(dtypes.int, 0)
ld = d1.index(idx)
alu = (ld<1).cast(dtypes.bool)
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx, ptr=True), alu))
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
def test_double_cast_fold(self):
d0 = UOp.param(0, dtypes.float.ptr())
d1 = UOp.param(1, dtypes.int.ptr())
d0 = UOp.param(0, dtypes.float.ptr(1))
d1 = UOp.param(1, dtypes.int.ptr(1))
idx = UOp.const(dtypes.int, 0)
ld = d1.index(idx)
alu = ld.cast(dtypes.float).cast(dtypes.float)
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx, ptr=True), alu))
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1)
@@ -414,7 +414,7 @@ class TestUOpGraph(unittest.TestCase):
def test_bitcast_to_same_dtype_fold(self):
for dt in dtypes.ints + dtypes.floats + (dtypes.bool,):
d0 = UOp.param(0, dt.ptr())
d0 = UOp.param(0, dt.ptr(1))
v = d0.index(UOp.const(dtypes.int, 0))
uops = to_uops_list([v.bitcast(dt)])
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST and x.dtype is dt]), 0, f"dtype = {dt}")
@@ -427,18 +427,18 @@ class TestUOpGraph(unittest.TestCase):
def test_where_on_gated_load_fold(self):
ridx0 = UOp.range(100, 0)
d0 = UOp.param(0, dtypes.long.ptr())
d0 = UOp.param(0, dtypes.long.ptr(100))
ld = d0.index(ridx0.valid(ridx0<50))
w = (ridx0<50).where(ld, 5)
out = UOp.param(1, dtypes.long.ptr())
uops = to_uops_list([out.index(ridx0).store(w)])
out = UOp.param(1, dtypes.long.ptr(100))
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
for u in uops:
assert u.op is not Ops.WHERE
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg==5
def test_where_on_gated_load_folds_swapped_branches(self):
ridx0 = UOp.range(100, 0)
d0 = UOp.param(0, dtypes.long.ptr())
d0 = UOp.param(0, dtypes.long.ptr(100))
ld = d0.index(ridx0.valid((ridx0<50).logical_not()))
w = (ridx0<50).where(5, ld)
uops = to_uops_list([w])
@@ -448,40 +448,40 @@ class TestUOpGraph(unittest.TestCase):
def test_where_on_gated_load_with_cast(self):
ridx0 = UOp.range(100, 0)
d0 = UOp.param(0, dtypes.int.ptr())
d0 = UOp.param(0, dtypes.int.ptr(100))
gate_idx = ridx0.valid((ridx0<50))
ld = d0.index(gate_idx).cast(dtypes.float)
w = (ridx0<50).where(ld, 5.0)
out = UOp.param(1, dtypes.float.ptr())
uops = to_uops_list([out.index(ridx0).store(w)])
out = UOp.param(1, dtypes.float.ptr(100))
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
for u in uops:
assert u.op is not Ops.WHERE
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg == 5
def test_where_on_casted_gated_load_extra_cond(self):
ridx0 = UOp.range(100, 0)
d0 = UOp.param(0, dtypes.float.ptr())
d0 = UOp.param(0, dtypes.float.ptr(100))
ld = d0.index(ridx0.valid(ridx0<50))
w = ((ridx0<50) & (ridx0>30)).where(ld, UOp.const(dtypes.float, 0)).cast(dtypes.half)
out = UOp.param(1, dtypes.half.ptr())
uops = to_uops_list([out.index(ridx0).store(w)])
out = UOp.param(1, dtypes.half.ptr(100))
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
for u in uops:
assert u.op is not Ops.WHERE
def test_where_on_casted_gated_load_extra_cond_swapped(self):
ridx0 = UOp.range(100, 0)
d0 = UOp.param(0, dtypes.float.ptr())
d0 = UOp.param(0, dtypes.float.ptr(100))
ld = d0.index(ridx0.valid(ridx0<50))
w = ((ridx0<50) & (ridx0>30)).where(UOp.const(dtypes.float, 0), ld).cast(dtypes.half)
out = UOp.param(1, dtypes.half.ptr())
uops = to_uops_list([out.index(ridx0).store(w)])
out = UOp.param(1, dtypes.half.ptr(100))
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
for u in uops:
assert u.op is not Ops.WHERE
def test_where_in_store_becomes_gate(self):
ridx0 = UOp.range(100, 0)
d0 = UOp.param(0, dtypes.long.ptr())
idx = d0.index(ridx0)
d0 = UOp.param(0, dtypes.long.ptr(100))
idx = d0.index(ridx0, ptr=True)
ld = idx.load()
val = (ridx0<50).where(5, ld)
st = idx.store(val).end(ridx0)
@@ -529,33 +529,33 @@ class TestUOpGraph(unittest.TestCase):
self.assertNotEqual(u.dtype, dtypes.long)
def test_fold_gated_load(self):
glbl0 = UOp.param(0, dtypes.int.ptr())
glbl1 = UOp.param(1, dtypes.int.ptr())
glbl2 = UOp.param(2, dtypes.int.ptr())
glbl0 = UOp.param(0, dtypes.int.ptr(1))
glbl1 = UOp.param(1, dtypes.int.ptr(1))
glbl2 = UOp.param(2, dtypes.int.ptr(1))
idx = UOp.const(dtypes.int, 0)
ld0 = glbl1.index(UOp.invalid())
ld1 = glbl2.index(idx.valid(UOp.const(dtypes.bool, True)))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx, ptr=True), ld1+ld0))])
ld0 = uops[-2].src[-1] # -2 to skip SINK
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
def test_fold_gated_load_local(self):
glbl0 = UOp.param(0, dtypes.int.ptr())
glbl0 = UOp.param(0, dtypes.int.ptr(16))
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp")
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx, ptr=True), glbl0.index(lidx, ptr=True).load()))
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
ld0 = smem.after(barrier).index(UOp.invalid())
ld1 = smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True)))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx, ptr=True), ld1+ld0))])
ld0 = uops[-2].src[-1] # -2 to skip SINK
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
def test_fold_gated_store(self):
glbl = UOp.param(0, dtypes.int.ptr())
glbl = UOp.param(0, dtypes.int.ptr(1))
idx0 = UOp.const(dtypes.int, 0)
idx1 = UOp.const(dtypes.int, 0)
val = UOp.const(dtypes.int, 42)

View File

@@ -951,8 +951,8 @@ class TestSymbolic(unittest.TestCase):
expr = cond.where(a, b).cast(dtypes.half)
# TODO: copied from render, render does not support cast
glbl = UOp.param(0, dtypes.int.ptr())
uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
glbl = UOp.param(0, dtypes.int.ptr(1))
uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0), ptr=True), expr)).sink())
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))

View File

@@ -110,10 +110,10 @@ class TestExecALU(unittest.TestCase):
class TestGatedStoreRewrite(unittest.TestCase):
def test_tiny_gate_store(self):
gmem = UOp.param(0, dtypes.float.ptr())
gmem = UOp.param(0, dtypes.float.ptr(8))
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
gate = gidx0<UOp.const(dtypes.int, 1)
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, (gidx0 * UOp.const(dtypes.int, 2)).valid(gate)))
idx = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem, (gidx0 * UOp.const(dtypes.int, 2)).valid(gate)))
val = UOp.const(dtypes.float, 42.0)
store = UOp(Ops.STORE, dtypes.void, (idx, val))
uops = to_uops_list([store])
@@ -126,12 +126,12 @@ class TestGatedStoreRewrite(unittest.TestCase):
self.assertEqual(len(gated_uops[-1].src), 2)
def test_gate_some_stores(self):
gmem0 = UOp.param(0, dtypes.float.ptr())
gmem1 = UOp.param(1, dtypes.float.ptr())
gmem0 = UOp.param(0, dtypes.float.ptr(8))
gmem1 = UOp.param(1, dtypes.float.ptr(8))
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
idx = gidx0 * UOp.const(dtypes.int, 2)
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gidx0<UOp.const(dtypes.int, 1))))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem0, idx.valid(gidx0<UOp.const(dtypes.int, 1))))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem1, idx))
val = UOp.const(dtypes.float, 42.0)
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
uops = to_uops_list(stores)
@@ -146,13 +146,13 @@ class TestGatedStoreRewrite(unittest.TestCase):
# scaled down version of TestLinearizerDumb.test_unmerged_ifs
@unittest.skip("we don't merge ifs anymore")
def test_merge_ifs_alt(self):
gmem0 = UOp.param(0, dtypes.float.ptr())
gmem1 = UOp.param(1, dtypes.float.ptr())
gmem0 = UOp.param(0, dtypes.float.ptr(8))
gmem1 = UOp.param(1, dtypes.float.ptr(8))
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
idx = gidx0*UOp.const(dtypes.int, 2)
gate = gidx0<UOp.const(dtypes.int, 1)
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gate)))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx.valid(gate)))
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem0, idx.valid(gate)))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem1, idx.valid(gate)))
val = UOp.const(dtypes.float, 42.0)
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
uops = to_uops_list(stores)
@@ -170,7 +170,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
class TestFastIdiv(unittest.TestCase):
def test_division_power_of_two(self):
for dt in (dtypes.int32, dtypes.uint32):
g = UOp.param(0, dt.ptr())
g = UOp.param(0, dt.ptr(3))
c = UOp.const(dt, 2)
l = g.index(c)
a = UOp(Ops.CDIV, dt, (l, c))
@@ -183,7 +183,7 @@ class TestFastIdiv(unittest.TestCase):
def test_floormod_power_of_two(self):
# FLOORMOD by a power of two lowers to AND (correct floor mod for any sign in two's complement)
for dt in (dtypes.int32, dtypes.uint32):
g = UOp.param(0, dt.ptr())
g = UOp.param(0, dt.ptr(9))
c = UOp.const(dt, 8)
a = UOp(Ops.FLOORMOD, dt, (g.index(c), c))
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
@@ -195,7 +195,7 @@ class TestFastIdiv(unittest.TestCase):
def test_floordiv_power_of_two_uint(self):
# uint FLOORDIV by a power of two lowers to a shift, leaving no IDIV/FLOORDIV in the kernel
for dt in (dtypes.uint32, dtypes.uint64):
g = UOp.param(0, dt.ptr())
g = UOp.param(0, dt.ptr(3))
c = UOp.const(dt, 2)
a = UOp(Ops.FLOORDIV, dt, (g.index(c), c))
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
@@ -207,7 +207,7 @@ class TestFastIdiv(unittest.TestCase):
@Context(DISABLE_FAST_IDIV=0)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't support long")
def test_fast_idiv_and_mod(self):
g = UOp.param(0, dtypes.uint32.ptr())
g = UOp.param(0, dtypes.uint32.ptr(4))
c = UOp.const(dtypes.uint, 3)
l = g.index(c)
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
@@ -242,7 +242,7 @@ class TestFastIdiv(unittest.TestCase):
@unittest.expectedFailure
def test_fast_idiv_overflow(self):
# This will be possible with a slightly different method for fast_idiv
g = UOp.param(0, dtypes.uint32.ptr())
g = UOp.param(0, dtypes.uint32.ptr(8))
c = UOp.const(dtypes.uint, 7)
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
@@ -253,7 +253,7 @@ class TestFastIdiv(unittest.TestCase):
self.assertNotIn(Ops.CDIV, ops)
def test_disable_fast_idiv(self):
g = UOp.param(0, dtypes.uint32.ptr())
g = UOp.param(0, dtypes.uint32.ptr(4))
c = UOp.const(dtypes.uint, 3)
l = g.index(c)
a = UOp(Ops.CDIV, dtypes.uint, (l, c))

View File

@@ -51,7 +51,7 @@ class DTypeMetaClass(type):
class AddrSpace(Enum):
def __repr__(self): return str(self)
GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702
GLOBAL = auto(); LOCAL = auto(); REG = auto(); ANON = auto() # noqa: E702
@dataclass(frozen=True, eq=False)
class DType(metaclass=DTypeMetaClass):

View File

@@ -79,22 +79,21 @@ class PythonProgram:
if u.op is Ops.STORE:
assert len(src_values) == 2, f"STORE must be lowered to 2 srcs, got {len(src_values)}"
store_gate = exec_masks[-1]
for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]):
for j,val in enumerate(src_values[1] if u.max_numel() > 1 else [src_values[1]]):
for (m,o),v,g in zip(src_values[0], val, store_gate):
if g: _store(m, o+j, v, src_dtypes[1].scalar())
i += 1
continue
if u.op is Ops.AFTER: values[u] = src_values[0]
elif u.op in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
assert isinstance(u.dtype, PtrDType), u.dtype
storage_fmt = storage_fmt_for_dtype(u.dtype.base.scalar())
if storage_fmt is None: raise RuntimeError(f"dtype={u.dtype} is not supported")
if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e"
if u.op is Ops.DEFINE_REG:
# REGs are per thread
values[u] = [memoryview(bytearray(u.dtype.size*u.dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
values[u] = [memoryview(bytearray(u.max_numel()*u.dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
else:
buf = memoryview(bytearray(u.dtype.size*u.dtype.itemsize)) if u.op is not Ops.PARAM else pbufs.pop(0)
buf = memoryview(bytearray(u.max_numel()*u.dtype.itemsize)) if u.op is not Ops.PARAM else pbufs.pop(0)
values[u] = [buf.cast(storage_fmt)] * warp_size
elif u.op is Ops.DEFINE_VAR:
values[u] = [pvals.pop(0)] * warp_size
@@ -129,9 +128,10 @@ class PythonProgram:
elif u.op is Ops.CAST:
values[u] = [truncate.get(u.dtype, lambda dt: dt)(u.dtype.const(x)) for x in src_values[0]]
elif u.op is Ops.LOAD:
if u.dtype.count > 1:
values[u] = [load([src_values[k][j] if k != 0 and src_dtypes[k].count > 1 else src_values[k] \
for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(u.dtype.count)]
if (load_sz := u.max_numel()) > 1:
# buf and gate are not vecs
values[u] = [load([src_values[k] if k in [0,2] else src_values[k][j] \
for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(load_sz)]
else:
values[u] = load(src_values, 0, u.dtype)
elif u.op is Ops.GEP: values[u] = src_values[0][get_single_element(u.arg)]

View File

@@ -758,8 +758,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op is Ops.BUFFER: return AddrSpace.GLOBAL
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
# LOAD brings things into registers
if self.op is Ops.LOAD: return AddrSpace.REG
if self.op is Ops.LOAD: return AddrSpace.ANON # LOAD brings things into anonymous registers
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP}:
return self.src[0].addrspace
if self.op in GroupOp.Movement: return self.src[0].addrspace

View File

@@ -13,7 +13,10 @@ def validate_index(uidx:UOp, gate:UOp|None=None):
if idx.op is Ops.CONST and idx.arg is Invalid: return True
if gate is None: gate = UOp.const(dtypes.bool, True)
# TODO: check for overflow
if not CHECK_OOB or isinstance(buf.dtype, ImageDType) or (sz := buf.ptrdtype.size) == -1: return True
if not CHECK_OOB or isinstance(buf.dtype, ImageDType): return True
# buffer size
sz = buf.max_numel()
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
if 0<=idx.vmin and idx.vmax<sz: return True

View File

@@ -57,7 +57,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
Ops.STAGE: "#AC640D", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
addrspace_colors = {AddrSpace.REG:"#e68181", AddrSpace.LOCAL:"#e7c86a", AddrSpace.GLOBAL:"#75bd7b"}
addrspace_colors = {AddrSpace.ANON:"#e68181", AddrSpace.REG: "#ff40a0", AddrSpace.LOCAL:"#e7c86a", AddrSpace.GLOBAL:"#75bd7b"}
# VIZ API