mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
anon addrspace from new renderer (#16461)
* anon addrspace from new renderer * use max_numel in python renderer * add sizes to ptrs in tests * more * correct fix
This commit is contained in:
@@ -22,30 +22,30 @@ def _test_uop_result(inputs:list[Tensor], sink:UOp, local_size=None):
|
||||
|
||||
def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
|
||||
dtype = alu_src_uops[0].dtype
|
||||
a = UOp.param(0, dtype.ptr())
|
||||
b = UOp.param(1, dtype.ptr())
|
||||
a = UOp.param(0, dtype.ptr(1))
|
||||
b = UOp.param(1, dtype.ptr(1))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = b.index(idx)
|
||||
ld = b.index(idx, ptr=True).load()
|
||||
alu = ld.alu(alu_op, *alu_src_uops)
|
||||
store = UOp.store(a.index(idx), alu)
|
||||
store = UOp.store(a.index(idx, ptr=True), alu)
|
||||
return _test_uop_result([Tensor([input_val])], UOp(Ops.SINK, dtypes.void, (store,), arg=KernelInfo()))[0]
|
||||
|
||||
class TestRendererFailures(unittest.TestCase):
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
|
||||
def test_gated_store_with_alu(self):
|
||||
a = UOp.param(0, dtypes.int.ptr())
|
||||
a = UOp.param(0, dtypes.int.ptr(4))
|
||||
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1)))
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu), ptr=True), UOp.const(dtypes.int, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,), arg=KernelInfo())
|
||||
ret = _test_uop_result([], sink, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
|
||||
def test_gated_store_with_alu_2d(self):
|
||||
a = UOp.param(0, dtypes.int.ptr())
|
||||
a = UOp.param(0, dtypes.int.ptr(8))
|
||||
gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
|
||||
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1)))
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1), ptr=True), UOp.const(dtypes.int, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,), arg=KernelInfo())
|
||||
ret = _test_uop_result([], sink, local_size=[4, 2, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 0, 0, 0, 0, 1, 1, 1])
|
||||
|
||||
@@ -27,8 +27,8 @@ def uop(uops:list[UOp], op:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:
|
||||
def _test_single_value(vals, op, dts):
|
||||
uops = []
|
||||
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
|
||||
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(), (), 0)
|
||||
buf_loads = [uop(uops, Ops.PARAM, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
|
||||
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(1), (), 0)
|
||||
buf_loads = [uop(uops, Ops.PARAM, dtype.ptr(1), (), i+1) for i,dtype in enumerate(dts)]
|
||||
loads = (buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0)) for i, dtype in enumerate(dts))
|
||||
alu = uop(uops, op, output_dtype, loads)
|
||||
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True), alu))
|
||||
@@ -42,7 +42,7 @@ def _test_single_value(vals, op, dts):
|
||||
def _test_single_value_const(vals, op, dts):
|
||||
uops = []
|
||||
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
|
||||
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(), (), 0)
|
||||
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(1), (), 0)
|
||||
loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
|
||||
alu = uop(uops, op, output_dtype, loads)
|
||||
out = buf_store[UOp.const(dtypes.int32, 0)].store(alu)
|
||||
@@ -54,7 +54,7 @@ def _test_single_value_const(vals, op, dts):
|
||||
|
||||
def _test_uops_result(output_dtype, uops, res):
|
||||
# uops = []
|
||||
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(), (), 0)
|
||||
buf_store = uop(uops, Ops.PARAM, output_dtype.ptr(1), (), 0)
|
||||
# res = output_fn(uops)
|
||||
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), res))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
|
||||
@@ -221,8 +221,8 @@ class TestLocalAccess(unittest.TestCase):
|
||||
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "This only tests assembly backends")
|
||||
class TestAssembly(unittest.TestCase):
|
||||
def test_bitshift_left(self):
|
||||
g1 = UOp.param(0, dtypes.int32.ptr())
|
||||
out = UOp.param(1, dtypes.int32.ptr())
|
||||
g1 = UOp.param(0, dtypes.int32.ptr(3))
|
||||
out = UOp.param(1, dtypes.int32.ptr(2))
|
||||
c1 = UOp.const(dtypes.int, 2)
|
||||
c2 = UOp.const(dtypes.int, 3)
|
||||
l1 = g1.index(c1)
|
||||
@@ -249,7 +249,7 @@ class TestAssembly(unittest.TestCase):
|
||||
self.assertGreaterEqual(len([x.op for x in uops if x.op is Ops.MULACC]), 4)
|
||||
|
||||
def test_mulacc_shl(self):
|
||||
g1 = UOp.param(0, dtypes.int32.ptr())
|
||||
g1 = UOp.param(0, dtypes.int32.ptr(2))
|
||||
c1 = UOp.const(dtypes.int, 0)
|
||||
c2 = UOp.const(dtypes.int, 1)
|
||||
expr = g1.index(c1) * UOp.const(dtypes.int, 4096) + g1.index(c2)
|
||||
@@ -258,7 +258,7 @@ class TestAssembly(unittest.TestCase):
|
||||
self.assertIn(Ops.MULACC, [x.op for x in uops])
|
||||
|
||||
def test_use_cmpeq(self):
|
||||
g = UOp.param(0, dtypes.uint32.ptr())
|
||||
g = UOp.param(0, dtypes.uint32.ptr(8))
|
||||
c = UOp.const(dtypes.uint, 7)
|
||||
comp = g.index(c).ne(c).ne(True)
|
||||
uops = to_uops_list([comp], ren=Device[Device.DEFAULT].renderer)
|
||||
|
||||
@@ -82,8 +82,8 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None, vals:tuple
|
||||
for buf_dt, data in inputs or []:
|
||||
bufs.append(buf:=allocator.alloc(len(data) * buf_dt.itemsize))
|
||||
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + (buf_dt.fmt or ""), *data)))
|
||||
g = UOp.param(0, uop.dtype.ptr())
|
||||
prg = to_program(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(arg=KernelInfo()), PythonRenderer(Target("PYTHON")))
|
||||
g = UOp.param(0, uop.dtype.ptr(1))
|
||||
prg = to_program(UOp.store(g.index(UOp.const(dtypes.int, 0), ptr=True), uop).sink(arg=KernelInfo()), PythonRenderer(Target("PYTHON")))
|
||||
prog = PythonProgram("run", PythonCompiler().compile(prg.src[3].arg))
|
||||
prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs, vals=vals)
|
||||
return out_buf.cast(uop.dtype.fmt or "").tolist()[0]
|
||||
|
||||
@@ -10,8 +10,8 @@ class TestTranscendentalFunctions(unittest.TestCase):
|
||||
def test_payne_hanek_reduction(self):
|
||||
# TODO: Test constant input when constant folding is fixed (or maybe test both variants)
|
||||
# Load input value from a buffer to prevent constant folding
|
||||
input_buf = UOp.param(1, dtypes.double.ptr())
|
||||
loaded_value = input_buf.index(UOp.const(dtypes.int, 0))
|
||||
input_buf = UOp.param(1, dtypes.double.ptr(1))
|
||||
loaded_value = input_buf.index(UOp.const(dtypes.int, 0), ptr=True).load()
|
||||
def eval_payne_hanek_reduction(v:float) -> tuple[float, int]:
|
||||
return tuple(eval_uop(u, [(dtypes.float64, [v])]) for u in payne_hanek_reduction(loaded_value))
|
||||
|
||||
|
||||
@@ -380,22 +380,22 @@ class TestUOpGraph(unittest.TestCase):
|
||||
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
|
||||
|
||||
def test_cast_alu_fold(self):
|
||||
d0 = UOp.param(0, dtypes.bool.ptr())
|
||||
d1 = UOp.param(1, dtypes.int.ptr())
|
||||
d0 = UOp.param(0, dtypes.bool.ptr(1))
|
||||
d1 = UOp.param(1, dtypes.int.ptr(1))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = d1.index(idx)
|
||||
alu = (ld<1).cast(dtypes.bool)
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx, ptr=True), alu))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
|
||||
|
||||
def test_double_cast_fold(self):
|
||||
d0 = UOp.param(0, dtypes.float.ptr())
|
||||
d1 = UOp.param(1, dtypes.int.ptr())
|
||||
d0 = UOp.param(0, dtypes.float.ptr(1))
|
||||
d1 = UOp.param(1, dtypes.int.ptr(1))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = d1.index(idx)
|
||||
alu = ld.cast(dtypes.float).cast(dtypes.float)
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx, ptr=True), alu))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1)
|
||||
|
||||
@@ -414,7 +414,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
def test_bitcast_to_same_dtype_fold(self):
|
||||
for dt in dtypes.ints + dtypes.floats + (dtypes.bool,):
|
||||
d0 = UOp.param(0, dt.ptr())
|
||||
d0 = UOp.param(0, dt.ptr(1))
|
||||
v = d0.index(UOp.const(dtypes.int, 0))
|
||||
uops = to_uops_list([v.bitcast(dt)])
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST and x.dtype is dt]), 0, f"dtype = {dt}")
|
||||
@@ -427,18 +427,18 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
def test_where_on_gated_load_fold(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp.param(0, dtypes.long.ptr())
|
||||
d0 = UOp.param(0, dtypes.long.ptr(100))
|
||||
ld = d0.index(ridx0.valid(ridx0<50))
|
||||
w = (ridx0<50).where(ld, 5)
|
||||
out = UOp.param(1, dtypes.long.ptr())
|
||||
uops = to_uops_list([out.index(ridx0).store(w)])
|
||||
out = UOp.param(1, dtypes.long.ptr(100))
|
||||
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg==5
|
||||
|
||||
def test_where_on_gated_load_folds_swapped_branches(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp.param(0, dtypes.long.ptr())
|
||||
d0 = UOp.param(0, dtypes.long.ptr(100))
|
||||
ld = d0.index(ridx0.valid((ridx0<50).logical_not()))
|
||||
w = (ridx0<50).where(5, ld)
|
||||
uops = to_uops_list([w])
|
||||
@@ -448,40 +448,40 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
def test_where_on_gated_load_with_cast(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp.param(0, dtypes.int.ptr())
|
||||
d0 = UOp.param(0, dtypes.int.ptr(100))
|
||||
gate_idx = ridx0.valid((ridx0<50))
|
||||
ld = d0.index(gate_idx).cast(dtypes.float)
|
||||
w = (ridx0<50).where(ld, 5.0)
|
||||
out = UOp.param(1, dtypes.float.ptr())
|
||||
uops = to_uops_list([out.index(ridx0).store(w)])
|
||||
out = UOp.param(1, dtypes.float.ptr(100))
|
||||
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg == 5
|
||||
|
||||
def test_where_on_casted_gated_load_extra_cond(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp.param(0, dtypes.float.ptr())
|
||||
d0 = UOp.param(0, dtypes.float.ptr(100))
|
||||
ld = d0.index(ridx0.valid(ridx0<50))
|
||||
w = ((ridx0<50) & (ridx0>30)).where(ld, UOp.const(dtypes.float, 0)).cast(dtypes.half)
|
||||
out = UOp.param(1, dtypes.half.ptr())
|
||||
uops = to_uops_list([out.index(ridx0).store(w)])
|
||||
out = UOp.param(1, dtypes.half.ptr(100))
|
||||
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
|
||||
def test_where_on_casted_gated_load_extra_cond_swapped(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp.param(0, dtypes.float.ptr())
|
||||
d0 = UOp.param(0, dtypes.float.ptr(100))
|
||||
ld = d0.index(ridx0.valid(ridx0<50))
|
||||
w = ((ridx0<50) & (ridx0>30)).where(UOp.const(dtypes.float, 0), ld).cast(dtypes.half)
|
||||
out = UOp.param(1, dtypes.half.ptr())
|
||||
uops = to_uops_list([out.index(ridx0).store(w)])
|
||||
out = UOp.param(1, dtypes.half.ptr(100))
|
||||
uops = to_uops_list([out.index(ridx0, ptr=True).store(w)])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
|
||||
def test_where_in_store_becomes_gate(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp.param(0, dtypes.long.ptr())
|
||||
idx = d0.index(ridx0)
|
||||
d0 = UOp.param(0, dtypes.long.ptr(100))
|
||||
idx = d0.index(ridx0, ptr=True)
|
||||
ld = idx.load()
|
||||
val = (ridx0<50).where(5, ld)
|
||||
st = idx.store(val).end(ridx0)
|
||||
@@ -529,33 +529,33 @@ class TestUOpGraph(unittest.TestCase):
|
||||
self.assertNotEqual(u.dtype, dtypes.long)
|
||||
|
||||
def test_fold_gated_load(self):
|
||||
glbl0 = UOp.param(0, dtypes.int.ptr())
|
||||
glbl1 = UOp.param(1, dtypes.int.ptr())
|
||||
glbl2 = UOp.param(2, dtypes.int.ptr())
|
||||
glbl0 = UOp.param(0, dtypes.int.ptr(1))
|
||||
glbl1 = UOp.param(1, dtypes.int.ptr(1))
|
||||
glbl2 = UOp.param(2, dtypes.int.ptr(1))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld0 = glbl1.index(UOp.invalid())
|
||||
ld1 = glbl2.index(idx.valid(UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx, ptr=True), ld1+ld0))])
|
||||
ld0 = uops[-2].src[-1] # -2 to skip SINK
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp.param(0, dtypes.int.ptr())
|
||||
glbl0 = UOp.param(0, dtypes.int.ptr(16))
|
||||
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp")
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
|
||||
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx, ptr=True), glbl0.index(lidx, ptr=True).load()))
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
|
||||
ld0 = smem.after(barrier).index(UOp.invalid())
|
||||
ld1 = smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx, ptr=True), ld1+ld0))])
|
||||
|
||||
ld0 = uops[-2].src[-1] # -2 to skip SINK
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
|
||||
|
||||
def test_fold_gated_store(self):
|
||||
glbl = UOp.param(0, dtypes.int.ptr())
|
||||
glbl = UOp.param(0, dtypes.int.ptr(1))
|
||||
idx0 = UOp.const(dtypes.int, 0)
|
||||
idx1 = UOp.const(dtypes.int, 0)
|
||||
val = UOp.const(dtypes.int, 42)
|
||||
|
||||
@@ -951,8 +951,8 @@ class TestSymbolic(unittest.TestCase):
|
||||
expr = cond.where(a, b).cast(dtypes.half)
|
||||
|
||||
# TODO: copied from render, render does not support cast
|
||||
glbl = UOp.param(0, dtypes.int.ptr())
|
||||
uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
|
||||
glbl = UOp.param(0, dtypes.int.ptr(1))
|
||||
uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0), ptr=True), expr)).sink())
|
||||
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
|
||||
|
||||
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))
|
||||
|
||||
@@ -110,10 +110,10 @@ class TestExecALU(unittest.TestCase):
|
||||
|
||||
class TestGatedStoreRewrite(unittest.TestCase):
|
||||
def test_tiny_gate_store(self):
|
||||
gmem = UOp.param(0, dtypes.float.ptr())
|
||||
gmem = UOp.param(0, dtypes.float.ptr(8))
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
||||
gate = gidx0<UOp.const(dtypes.int, 1)
|
||||
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, (gidx0 * UOp.const(dtypes.int, 2)).valid(gate)))
|
||||
idx = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem, (gidx0 * UOp.const(dtypes.int, 2)).valid(gate)))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
store = UOp(Ops.STORE, dtypes.void, (idx, val))
|
||||
uops = to_uops_list([store])
|
||||
@@ -126,12 +126,12 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
self.assertEqual(len(gated_uops[-1].src), 2)
|
||||
|
||||
def test_gate_some_stores(self):
|
||||
gmem0 = UOp.param(0, dtypes.float.ptr())
|
||||
gmem1 = UOp.param(1, dtypes.float.ptr())
|
||||
gmem0 = UOp.param(0, dtypes.float.ptr(8))
|
||||
gmem1 = UOp.param(1, dtypes.float.ptr(8))
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
||||
idx = gidx0 * UOp.const(dtypes.int, 2)
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gidx0<UOp.const(dtypes.int, 1))))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem0, idx.valid(gidx0<UOp.const(dtypes.int, 1))))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem1, idx))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
|
||||
uops = to_uops_list(stores)
|
||||
@@ -146,13 +146,13 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
# scaled down version of TestLinearizerDumb.test_unmerged_ifs
|
||||
@unittest.skip("we don't merge ifs anymore")
|
||||
def test_merge_ifs_alt(self):
|
||||
gmem0 = UOp.param(0, dtypes.float.ptr())
|
||||
gmem1 = UOp.param(1, dtypes.float.ptr())
|
||||
gmem0 = UOp.param(0, dtypes.float.ptr(8))
|
||||
gmem1 = UOp.param(1, dtypes.float.ptr(8))
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
||||
idx = gidx0*UOp.const(dtypes.int, 2)
|
||||
gate = gidx0<UOp.const(dtypes.int, 1)
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gate)))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx.valid(gate)))
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem0, idx.valid(gate)))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(8), (gmem1, idx.valid(gate)))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
|
||||
uops = to_uops_list(stores)
|
||||
@@ -170,7 +170,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
class TestFastIdiv(unittest.TestCase):
|
||||
def test_division_power_of_two(self):
|
||||
for dt in (dtypes.int32, dtypes.uint32):
|
||||
g = UOp.param(0, dt.ptr())
|
||||
g = UOp.param(0, dt.ptr(3))
|
||||
c = UOp.const(dt, 2)
|
||||
l = g.index(c)
|
||||
a = UOp(Ops.CDIV, dt, (l, c))
|
||||
@@ -183,7 +183,7 @@ class TestFastIdiv(unittest.TestCase):
|
||||
def test_floormod_power_of_two(self):
|
||||
# FLOORMOD by a power of two lowers to AND (correct floor mod for any sign in two's complement)
|
||||
for dt in (dtypes.int32, dtypes.uint32):
|
||||
g = UOp.param(0, dt.ptr())
|
||||
g = UOp.param(0, dt.ptr(9))
|
||||
c = UOp.const(dt, 8)
|
||||
a = UOp(Ops.FLOORMOD, dt, (g.index(c), c))
|
||||
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
|
||||
@@ -195,7 +195,7 @@ class TestFastIdiv(unittest.TestCase):
|
||||
def test_floordiv_power_of_two_uint(self):
|
||||
# uint FLOORDIV by a power of two lowers to a shift, leaving no IDIV/FLOORDIV in the kernel
|
||||
for dt in (dtypes.uint32, dtypes.uint64):
|
||||
g = UOp.param(0, dt.ptr())
|
||||
g = UOp.param(0, dt.ptr(3))
|
||||
c = UOp.const(dt, 2)
|
||||
a = UOp(Ops.FLOORDIV, dt, (g.index(c), c))
|
||||
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
|
||||
@@ -207,7 +207,7 @@ class TestFastIdiv(unittest.TestCase):
|
||||
@Context(DISABLE_FAST_IDIV=0)
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't support long")
|
||||
def test_fast_idiv_and_mod(self):
|
||||
g = UOp.param(0, dtypes.uint32.ptr())
|
||||
g = UOp.param(0, dtypes.uint32.ptr(4))
|
||||
c = UOp.const(dtypes.uint, 3)
|
||||
l = g.index(c)
|
||||
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
|
||||
@@ -242,7 +242,7 @@ class TestFastIdiv(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
def test_fast_idiv_overflow(self):
|
||||
# This will be possible with a slightly different method for fast_idiv
|
||||
g = UOp.param(0, dtypes.uint32.ptr())
|
||||
g = UOp.param(0, dtypes.uint32.ptr(8))
|
||||
c = UOp.const(dtypes.uint, 7)
|
||||
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
|
||||
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
|
||||
@@ -253,7 +253,7 @@ class TestFastIdiv(unittest.TestCase):
|
||||
self.assertNotIn(Ops.CDIV, ops)
|
||||
|
||||
def test_disable_fast_idiv(self):
|
||||
g = UOp.param(0, dtypes.uint32.ptr())
|
||||
g = UOp.param(0, dtypes.uint32.ptr(4))
|
||||
c = UOp.const(dtypes.uint, 3)
|
||||
l = g.index(c)
|
||||
a = UOp(Ops.CDIV, dtypes.uint, (l, c))
|
||||
|
||||
@@ -51,7 +51,7 @@ class DTypeMetaClass(type):
|
||||
|
||||
class AddrSpace(Enum):
|
||||
def __repr__(self): return str(self)
|
||||
GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702
|
||||
GLOBAL = auto(); LOCAL = auto(); REG = auto(); ANON = auto() # noqa: E702
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class DType(metaclass=DTypeMetaClass):
|
||||
|
||||
@@ -79,22 +79,21 @@ class PythonProgram:
|
||||
if u.op is Ops.STORE:
|
||||
assert len(src_values) == 2, f"STORE must be lowered to 2 srcs, got {len(src_values)}"
|
||||
store_gate = exec_masks[-1]
|
||||
for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]):
|
||||
for j,val in enumerate(src_values[1] if u.max_numel() > 1 else [src_values[1]]):
|
||||
for (m,o),v,g in zip(src_values[0], val, store_gate):
|
||||
if g: _store(m, o+j, v, src_dtypes[1].scalar())
|
||||
i += 1
|
||||
continue
|
||||
if u.op is Ops.AFTER: values[u] = src_values[0]
|
||||
elif u.op in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
|
||||
assert isinstance(u.dtype, PtrDType), u.dtype
|
||||
storage_fmt = storage_fmt_for_dtype(u.dtype.base.scalar())
|
||||
if storage_fmt is None: raise RuntimeError(f"dtype={u.dtype} is not supported")
|
||||
if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e"
|
||||
if u.op is Ops.DEFINE_REG:
|
||||
# REGs are per thread
|
||||
values[u] = [memoryview(bytearray(u.dtype.size*u.dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
|
||||
values[u] = [memoryview(bytearray(u.max_numel()*u.dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
|
||||
else:
|
||||
buf = memoryview(bytearray(u.dtype.size*u.dtype.itemsize)) if u.op is not Ops.PARAM else pbufs.pop(0)
|
||||
buf = memoryview(bytearray(u.max_numel()*u.dtype.itemsize)) if u.op is not Ops.PARAM else pbufs.pop(0)
|
||||
values[u] = [buf.cast(storage_fmt)] * warp_size
|
||||
elif u.op is Ops.DEFINE_VAR:
|
||||
values[u] = [pvals.pop(0)] * warp_size
|
||||
@@ -129,9 +128,10 @@ class PythonProgram:
|
||||
elif u.op is Ops.CAST:
|
||||
values[u] = [truncate.get(u.dtype, lambda dt: dt)(u.dtype.const(x)) for x in src_values[0]]
|
||||
elif u.op is Ops.LOAD:
|
||||
if u.dtype.count > 1:
|
||||
values[u] = [load([src_values[k][j] if k != 0 and src_dtypes[k].count > 1 else src_values[k] \
|
||||
for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(u.dtype.count)]
|
||||
if (load_sz := u.max_numel()) > 1:
|
||||
# buf and gate are not vecs
|
||||
values[u] = [load([src_values[k] if k in [0,2] else src_values[k][j] \
|
||||
for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(load_sz)]
|
||||
else:
|
||||
values[u] = load(src_values, 0, u.dtype)
|
||||
elif u.op is Ops.GEP: values[u] = src_values[0][get_single_element(u.arg)]
|
||||
|
||||
@@ -758,8 +758,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.BUFFER: return AddrSpace.GLOBAL
|
||||
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
|
||||
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
|
||||
# LOAD brings things into registers
|
||||
if self.op is Ops.LOAD: return AddrSpace.REG
|
||||
if self.op is Ops.LOAD: return AddrSpace.ANON # LOAD brings things into anonymous registers
|
||||
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP}:
|
||||
return self.src[0].addrspace
|
||||
if self.op in GroupOp.Movement: return self.src[0].addrspace
|
||||
|
||||
@@ -13,7 +13,10 @@ def validate_index(uidx:UOp, gate:UOp|None=None):
|
||||
if idx.op is Ops.CONST and idx.arg is Invalid: return True
|
||||
if gate is None: gate = UOp.const(dtypes.bool, True)
|
||||
# TODO: check for overflow
|
||||
if not CHECK_OOB or isinstance(buf.dtype, ImageDType) or (sz := buf.ptrdtype.size) == -1: return True
|
||||
if not CHECK_OOB or isinstance(buf.dtype, ImageDType): return True
|
||||
|
||||
# buffer size
|
||||
sz = buf.max_numel()
|
||||
|
||||
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
|
||||
if 0<=idx.vmin and idx.vmax<sz: return True
|
||||
|
||||
@@ -57,7 +57,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
Ops.STAGE: "#AC640D", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
|
||||
|
||||
addrspace_colors = {AddrSpace.REG:"#e68181", AddrSpace.LOCAL:"#e7c86a", AddrSpace.GLOBAL:"#75bd7b"}
|
||||
addrspace_colors = {AddrSpace.ANON:"#e68181", AddrSpace.REG: "#ff40a0", AddrSpace.LOCAL:"#e7c86a", AddrSpace.GLOBAL:"#75bd7b"}
|
||||
|
||||
# VIZ API
|
||||
|
||||
|
||||
Reference in New Issue
Block a user