diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index c87441b21a..7d08905464 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -2,9 +2,17 @@ import unittest from typing import Tuple from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, is_increasing -from tinygrad.dtype import dtypes +from tinygrad.dtype import dtypes, PtrDType from tinygrad.ops import UOp, UOps, BinaryOps +def get_gated_load_uop(valid:UOp, idx:UOp): + return UOp(UOps.LOAD, dtypes.float, ( + UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0), + idx, + UOp.const(dtypes.float, 0.0), + valid + )) + def get_load_image_uop(image_shape:Tuple[int, ...], valid:UOp, idx:Tuple[UOp, UOp]): return UOp(UOps.LOAD, dtypes.float.vec(4), ( UOp(UOps.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0), @@ -20,7 +28,7 @@ def render(uop:UOp) -> str: code_for_op = {**OpenCLRenderer().code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"} fxn = TestRenderer().render("", uops) # print(fxn) - return fxn.split("float4 val0 = ")[1].split(";")[0] + return fxn.split("val0 = ")[1].split(";")[0] def Special(expr, nmax): return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax)) def Variable(expr, nmin, nmax): return UOp.define_var(expr, dtypes.int, nmin, nmax) @@ -51,7 +59,51 @@ class TestHelpers(unittest.TestCase): class TestValidIdxSimplification(unittest.TestCase): def test_conv_backward(self): - pass + # DEBUG=4 python3 test/test_ops.py TestOps.test_simple_conv2d + gidx0 = Special("gidx0", 3) + gidx1 = Special("gidx1", 3) + lidx0 = Special("lidx0", 4) + lidx1 = Special("lidx1", 3) + lidx2 = Special("lidx2", 3) + ridx0 = Range(0, 4) + alu0 = gidx0*3 + alu1 = (alu0+lidx2) + alu2 = (gidx1*3) + alu3 = (alu1+7) + alu4 = (alu1+8) + alu5 = (alu1+9) + alu6 = ((gidx0+9)//10) + alu7 = (alu3%10) + alu8 = (alu4%10) + alu9 = (alu5%10) + alu10 = (gidx1+(ridx0*3)) + alu11 = (ridx0*9) + alu12 = (alu2+lidx1+alu11) + alu13 = ((alu6+alu2+lidx1+alu11)%10) + alu14 = (alu12%10) + alu15 = (((((alu10//10)+lidx0)%4)*441)+(((alu12//10)%3)*3)+(alu14*63)) + alu16 = alu12.lt(30) + alu17 = alu16&(alu14.lt(7)) + + # TODO: simplify these + val0 = get_gated_load_uop(alu17&(alu9.lt(7)), alu15+(alu5//10)+(alu9*9)) + self.assertEqual(render(val0), + "((((alu2<30)&(alu3<7))&(alu1<7))?data0[(((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((alu2//10)%3)*3)+(alu3*63)+(alu0//10)+(alu1*9)]:0.0f)") + + val1 = get_gated_load_uop( + ((alu16&gidx0.lt(1))&alu13.lt(7))&alu7.lt(7), + ((((((((((lidx1*10)+gidx0)//3)+3)//10)+alu10)//10)+lidx0)%4)*441)+((((alu6+alu12)//10)%3)*3)+(alu13*63)+(((alu3//10)+2)%3)+(alu7*9) + ) + self.assertEqual(render(val1), + "(((((alu5<30)&(gidx0<1))&(alu6<7))&(alu3<7))?data0[((((((((((lidx1*10)+gidx0)//3)+3)//10)+gidx1+(ridx0*3))//10)+lidx0)%4)*441)+((((alu2+alu5)//10)%3)*3)+(alu6*63)+(((alu1//10)+2)%3)+(alu3*9)]:0.0f)") # noqa: E501 + + val2 = get_gated_load_uop(alu17&alu1.lt(7), alu15+(gidx0*27)+(lidx2*9)) + self.assertEqual(render(val2), + "((((alu0<30)&(alu1<7))&(((gidx0*3)+lidx2)<7))?data0[(((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((alu0//10)%3)*3)+(alu1*63)+(gidx0*27)+(lidx2*9)]:0.0f)") # noqa: E501 + + val3 = get_gated_load_uop(alu17&alu8.lt(7), (alu4//10)+alu15+(alu8*9)+1) + self.assertEqual(render(val3), + "((((alu2<30)&(alu3<7))&(alu1<7))?data0[(alu0//10)+(((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((alu2//10)%3)*3)+(alu3*63)+(alu1*9)+1]:0.0f)") class TestImageSimplification(unittest.TestCase): def test_idx_gt_c(self):