mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
relax idx simplification given valid (#6669)
apply to kernels in op 0.9.7. if a valid has a complicated expr, we cannot drop valid but it's possible to simplify idx given valid
This commit is contained in:
@@ -192,5 +192,20 @@ class TestValidSimplification(unittest.TestCase):
|
||||
self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu5*8))%64),(alu5//8)))),
|
||||
"((idx0<256)?read_imagef(data0, smp, (int2)((((((idx0//8)%32)//4)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
|
||||
def test_simplify5(self):
|
||||
# openpilot 0.9.7, chunk replacement to simplify
|
||||
shape = (10, 384, 4)
|
||||
idx0 = Special("idx0", 16)
|
||||
idx1 = Special("idx1", 24)
|
||||
alu0 = idx0*4
|
||||
alu1 = (idx1*256)+alu0
|
||||
alu2 = idx1//3
|
||||
alu3 = ((alu1+1)%768)
|
||||
idx = ((idx0+((((alu3//640)+alu2)%8)*16)+128),((alu3//64)%10))
|
||||
valid = alu3.lt(640)
|
||||
|
||||
self.assertEqual(render(shape, valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)),
|
||||
"((alu0<640)?read_imagef(data0, smp, (int2)((idx0+((idx1//3)*16)+128),(alu0//64))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -228,7 +228,7 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
|
||||
if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1]))
|
||||
if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0]))
|
||||
|
||||
elif is_irreducible(uop):
|
||||
else:
|
||||
new = UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1])
|
||||
newidx = replace_uop(graph_rewrite(replace_uop(idx, uop, new), constant_folder), new, uop)
|
||||
if newidx.key != idx.key: idx = newidx
|
||||
|
||||
Reference in New Issue
Block a user