relax idx simplification given valid (#6669)

apply to kernels in op 0.9.7.
if a valid has a complicated expr, we cannot drop valid but it's possible to simplify idx given valid
This commit is contained in:
chenyu
2024-09-23 03:04:57 -04:00
committed by GitHub
parent 7ca9ffa494
commit 0362dbbbe8
2 changed files with 16 additions and 1 deletions

View File

@@ -192,5 +192,20 @@ class TestValidSimplification(unittest.TestCase):
self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu5*8))%64),(alu5//8)))),
"((idx0<256)?read_imagef(data0, smp, (int2)((((((idx0//8)%32)//4)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))")
def test_simplify5(self):
# openpilot 0.9.7, chunk replacement to simplify
shape = (10, 384, 4)
idx0 = Special("idx0", 16)
idx1 = Special("idx1", 24)
alu0 = idx0*4
alu1 = (idx1*256)+alu0
alu2 = idx1//3
alu3 = ((alu1+1)%768)
idx = ((idx0+((((alu3//640)+alu2)%8)*16)+128),((alu3//64)%10))
valid = alu3.lt(640)
self.assertEqual(render(shape, valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)),
"((alu0<640)?read_imagef(data0, smp, (int2)((idx0+((idx1//3)*16)+128),(alu0//64))):(float4)(0.0f,0.0f,0.0f,0.0f))")
if __name__ == '__main__':
unittest.main()

View File

@@ -228,7 +228,7 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1]))
if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0]))
elif is_irreducible(uop):
else:
new = UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1])
newidx = replace_uop(graph_rewrite(replace_uop(idx, uop, new), constant_folder), new, uop)
if newidx.key != idx.key: idx = newidx