mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 07:27:43 +08:00
fix valid in indexing tests (#16087)
This commit is contained in:
@@ -518,17 +518,18 @@ class TestUnfoldableImage(unittest.TestCase):
|
||||
|
||||
class TestDropTrueGate(unittest.TestCase):
|
||||
def test_drop_true_gate_on_index(self):
|
||||
# test that INDEX with a constant True gate gets simplified to drop the gate
|
||||
# test that INDEX with a constant True valid gets simplified to drop the valid
|
||||
from tinygrad.codegen.late.devectorizer import load_store_indexing
|
||||
from tinygrad.uop.ops import graph_rewrite
|
||||
from tinygrad.uop.symbolic import sym
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
|
||||
idx = UOp.const(dtypes.weakint, 0)
|
||||
true_gate = UOp.const(dtypes.bool, True)
|
||||
index_with_gate = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx, true_gate))
|
||||
index_with_gate = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx.valid(true_gate)))
|
||||
# apply the optimization
|
||||
result = graph_rewrite(index_with_gate, load_store_indexing)
|
||||
# the True gate should be dropped (INDEX should only have 2 sources)
|
||||
self.assertEqual(len(result.src), 2, "True gate should be dropped from INDEX")
|
||||
result = graph_rewrite(index_with_gate, sym+load_store_indexing)
|
||||
# the True valid should be dropped (INDEX should only have 2 sources)
|
||||
self.assertEqual(len(result.src), 2, "True valid should be dropped from INDEX")
|
||||
|
||||
class TestRangeShrink(unittest.TestCase):
|
||||
def get_ranges(self, sink):
|
||||
|
||||
@@ -756,7 +756,7 @@ class TestLoadStoreFolding(unittest.TestCase):
|
||||
buf = UOp(Ops.PARAM, dtypes.float.vec(4).ptr(), (), 0)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
gate = UOp.const(dtypes.bool, True)
|
||||
gated_index = buf.index(idx, gate)
|
||||
gated_index = buf.index(idx.valid(gate))
|
||||
gep = gated_index.gep(0)
|
||||
alt = UOp.const(dtypes.float, 42.0)
|
||||
gated_load = gep.load(alt)
|
||||
|
||||
@@ -149,8 +149,8 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
||||
idx = gidx0*UOp.const(dtypes.int, 2)
|
||||
gate = gidx0<UOp.const(dtypes.int, 1)
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx, gate))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx, gate))
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gate)))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx.valid(gate)))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
|
||||
uops = to_uops_list(stores)
|
||||
|
||||
@@ -154,7 +154,7 @@ class TestValidateOOB(unittest.TestCase):
|
||||
|
||||
gate = (gidx<400) & (lidx<8)
|
||||
|
||||
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx, lidx<8), UOp.const(dtypes.uint, 1)))
|
||||
local_store = sbuf.index(lidx.valid(lidx<8)).store(UOp.const(dtypes.uint, 1))
|
||||
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,))
|
||||
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
|
||||
|
||||
Reference in New Issue
Block a user