fix valid in indexing tests (#16087)

This commit is contained in:
George Hotz
2026-05-07 14:11:28 -07:00
committed by GitHub
parent 4d1a9dca41
commit b796bbae87
4 changed files with 10 additions and 9 deletions

View File

@@ -518,17 +518,18 @@ class TestUnfoldableImage(unittest.TestCase):
class TestDropTrueGate(unittest.TestCase):
def test_drop_true_gate_on_index(self):
# test that INDEX with a constant True gate gets simplified to drop the gate
# test that INDEX with a constant True valid gets simplified to drop the valid
from tinygrad.codegen.late.devectorizer import load_store_indexing
from tinygrad.uop.ops import graph_rewrite
from tinygrad.uop.symbolic import sym
buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
idx = UOp.const(dtypes.weakint, 0)
true_gate = UOp.const(dtypes.bool, True)
index_with_gate = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx, true_gate))
index_with_gate = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx.valid(true_gate)))
# apply the optimization
result = graph_rewrite(index_with_gate, load_store_indexing)
# the True gate should be dropped (INDEX should only have 2 sources)
self.assertEqual(len(result.src), 2, "True gate should be dropped from INDEX")
result = graph_rewrite(index_with_gate, sym+load_store_indexing)
# the True valid should be dropped (INDEX should only have 2 sources)
self.assertEqual(len(result.src), 2, "True valid should be dropped from INDEX")
class TestRangeShrink(unittest.TestCase):
def get_ranges(self, sink):

View File

@@ -756,7 +756,7 @@ class TestLoadStoreFolding(unittest.TestCase):
buf = UOp(Ops.PARAM, dtypes.float.vec(4).ptr(), (), 0)
idx = UOp.const(dtypes.int, 0)
gate = UOp.const(dtypes.bool, True)
gated_index = buf.index(idx, gate)
gated_index = buf.index(idx.valid(gate))
gep = gated_index.gep(0)
alt = UOp.const(dtypes.float, 42.0)
gated_load = gep.load(alt)

View File

@@ -149,8 +149,8 @@ class TestGatedStoreRewrite(unittest.TestCase):
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
idx = gidx0*UOp.const(dtypes.int, 2)
gate = gidx0<UOp.const(dtypes.int, 1)
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx, gate))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx, gate))
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gate)))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx.valid(gate)))
val = UOp.const(dtypes.float, 42.0)
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
uops = to_uops_list(stores)

View File

@@ -154,7 +154,7 @@ class TestValidateOOB(unittest.TestCase):
gate = (gidx<400) & (lidx<8)
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx, lidx<8), UOp.const(dtypes.uint, 1)))
local_store = sbuf.index(lidx.valid(lidx<8)).store(UOp.const(dtypes.uint, 1))
barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,))
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))