From 0362dbbbe8befb293dc5300e1aebd8b5fbbcb6b4 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 23 Sep 2024 03:04:57 -0400 Subject: [PATCH] relax idx simplification given valid (#6669) apply to kernels in op 0.9.7. if a valid has a complicated expr, we cannot drop valid but it's possible to simplify idx given valid --- test/unit/test_image_valid.py | 15 +++++++++++++++ tinygrad/codegen/uopgraph.py | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/test/unit/test_image_valid.py b/test/unit/test_image_valid.py index 61e0b9ed33..5a428c9e4f 100644 --- a/test/unit/test_image_valid.py +++ b/test/unit/test_image_valid.py @@ -192,5 +192,20 @@ class TestValidSimplification(unittest.TestCase): self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu5*8))%64),(alu5//8)))), "((idx0<256)?read_imagef(data0, smp, (int2)((((((idx0//8)%32)//4)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))") + def test_simplify5(self): + # openpilot 0.9.7, chunk replacement to simplify + shape = (10, 384, 4) + idx0 = Special("idx0", 16) + idx1 = Special("idx1", 24) + alu0 = idx0*4 + alu1 = (idx1*256)+alu0 + alu2 = idx1//3 + alu3 = ((alu1+1)%768) + idx = ((idx0+((((alu3//640)+alu2)%8)*16)+128),((alu3//64)%10)) + valid = alu3.lt(640) + + self.assertEqual(render(shape, valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)), + "((alu0<640)?read_imagef(data0, smp, (int2)((idx0+((idx1//3)*16)+128),(alu0//64))):(float4)(0.0f,0.0f,0.0f,0.0f))") + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 32e82d3827..5f70891adc 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -228,7 +228,7 @@ def simplify_valid_image_load(load:UOp, buf:UOp): if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1])) if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0])) - elif is_irreducible(uop): + else: new = UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1]) newidx = replace_uop(graph_rewrite(replace_uop(idx, uop, new), constant_folder), new, uop) if newidx.key != idx.key: idx = newidx