From b796bbae87bb628e656c7723cbb18be50ddcd919 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 7 May 2026 14:11:28 -0700 Subject: [PATCH] fix valid in indexing tests (#16087) --- test/null/test_simplify_valid_idx.py | 11 ++++++----- test/null/test_uop_graph.py | 2 +- test/null/test_uops.py | 4 ++-- test/null/test_validate_oob.py | 2 +- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/test/null/test_simplify_valid_idx.py b/test/null/test_simplify_valid_idx.py index 3d1bbd3ed4..1fd781651a 100644 --- a/test/null/test_simplify_valid_idx.py +++ b/test/null/test_simplify_valid_idx.py @@ -518,17 +518,18 @@ class TestUnfoldableImage(unittest.TestCase): class TestDropTrueGate(unittest.TestCase): def test_drop_true_gate_on_index(self): - # test that INDEX with a constant True gate gets simplified to drop the gate + # test that INDEX with a constant True valid gets simplified to drop the valid from tinygrad.codegen.late.devectorizer import load_store_indexing from tinygrad.uop.ops import graph_rewrite + from tinygrad.uop.symbolic import sym buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0) idx = UOp.const(dtypes.weakint, 0) true_gate = UOp.const(dtypes.bool, True) - index_with_gate = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx, true_gate)) + index_with_gate = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx.valid(true_gate))) # apply the optimization - result = graph_rewrite(index_with_gate, load_store_indexing) - # the True gate should be dropped (INDEX should only have 2 sources) - self.assertEqual(len(result.src), 2, "True gate should be dropped from INDEX") + result = graph_rewrite(index_with_gate, sym+load_store_indexing) + # the True valid should be dropped (INDEX should only have 2 sources) + self.assertEqual(len(result.src), 2, "True valid should be dropped from INDEX") class TestRangeShrink(unittest.TestCase): def get_ranges(self, sink): diff --git a/test/null/test_uop_graph.py b/test/null/test_uop_graph.py index ea84737d98..687467b3a6 100644 --- a/test/null/test_uop_graph.py +++ b/test/null/test_uop_graph.py @@ -756,7 +756,7 @@ class TestLoadStoreFolding(unittest.TestCase): buf = UOp(Ops.PARAM, dtypes.float.vec(4).ptr(), (), 0) idx = UOp.const(dtypes.int, 0) gate = UOp.const(dtypes.bool, True) - gated_index = buf.index(idx, gate) + gated_index = buf.index(idx.valid(gate)) gep = gated_index.gep(0) alt = UOp.const(dtypes.float, 42.0) gated_load = gep.load(alt) diff --git a/test/null/test_uops.py b/test/null/test_uops.py index 33eff659ed..86f3e40b01 100644 --- a/test/null/test_uops.py +++ b/test/null/test_uops.py @@ -149,8 +149,8 @@ class TestGatedStoreRewrite(unittest.TestCase): gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0') idx = gidx0*UOp.const(dtypes.int, 2) gate = gidx0