conv backward tests in test_simplify_valid_idx (#6727)

the backward idx is pretty ugly now
This commit is contained in:
chenyu
2024-09-25 02:51:07 -04:00
committed by GitHub
parent 6c69fec1ef
commit ff25bfb1b0

View File

@@ -2,9 +2,17 @@ import unittest
from typing import Tuple
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, is_increasing
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.ops import UOp, UOps, BinaryOps
def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(UOps.LOAD, dtypes.float, (
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0),
idx,
UOp.const(dtypes.float, 0.0),
valid
))
def get_load_image_uop(image_shape:Tuple[int, ...], valid:UOp, idx:Tuple[UOp, UOp]):
return UOp(UOps.LOAD, dtypes.float.vec(4), (
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0),
@@ -20,7 +28,7 @@ def render(uop:UOp) -> str:
code_for_op = {**OpenCLRenderer().code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"}
fxn = TestRenderer().render("", uops)
# print(fxn)
return fxn.split("float4 val0 = ")[1].split(";")[0]
return fxn.split("val0 = ")[1].split(";")[0]
def Special(expr, nmax): return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax))
def Variable(expr, nmin, nmax): return UOp.define_var(expr, dtypes.int, nmin, nmax)
@@ -51,7 +59,51 @@ class TestHelpers(unittest.TestCase):
class TestValidIdxSimplification(unittest.TestCase):
def test_conv_backward(self):
pass
# DEBUG=4 python3 test/test_ops.py TestOps.test_simple_conv2d
gidx0 = Special("gidx0", 3)
gidx1 = Special("gidx1", 3)
lidx0 = Special("lidx0", 4)
lidx1 = Special("lidx1", 3)
lidx2 = Special("lidx2", 3)
ridx0 = Range(0, 4)
alu0 = gidx0*3
alu1 = (alu0+lidx2)
alu2 = (gidx1*3)
alu3 = (alu1+7)
alu4 = (alu1+8)
alu5 = (alu1+9)
alu6 = ((gidx0+9)//10)
alu7 = (alu3%10)
alu8 = (alu4%10)
alu9 = (alu5%10)
alu10 = (gidx1+(ridx0*3))
alu11 = (ridx0*9)
alu12 = (alu2+lidx1+alu11)
alu13 = ((alu6+alu2+lidx1+alu11)%10)
alu14 = (alu12%10)
alu15 = (((((alu10//10)+lidx0)%4)*441)+(((alu12//10)%3)*3)+(alu14*63))
alu16 = alu12.lt(30)
alu17 = alu16&(alu14.lt(7))
# TODO: simplify these
val0 = get_gated_load_uop(alu17&(alu9.lt(7)), alu15+(alu5//10)+(alu9*9))
self.assertEqual(render(val0),
"((((alu2<30)&(alu3<7))&(alu1<7))?data0[(((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((alu2//10)%3)*3)+(alu3*63)+(alu0//10)+(alu1*9)]:0.0f)")
val1 = get_gated_load_uop(
((alu16&gidx0.lt(1))&alu13.lt(7))&alu7.lt(7),
((((((((((lidx1*10)+gidx0)//3)+3)//10)+alu10)//10)+lidx0)%4)*441)+((((alu6+alu12)//10)%3)*3)+(alu13*63)+(((alu3//10)+2)%3)+(alu7*9)
)
self.assertEqual(render(val1),
"(((((alu5<30)&(gidx0<1))&(alu6<7))&(alu3<7))?data0[((((((((((lidx1*10)+gidx0)//3)+3)//10)+gidx1+(ridx0*3))//10)+lidx0)%4)*441)+((((alu2+alu5)//10)%3)*3)+(alu6*63)+(((alu1//10)+2)%3)+(alu3*9)]:0.0f)") # noqa: E501
val2 = get_gated_load_uop(alu17&alu1.lt(7), alu15+(gidx0*27)+(lidx2*9))
self.assertEqual(render(val2),
"((((alu0<30)&(alu1<7))&(((gidx0*3)+lidx2)<7))?data0[(((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((alu0//10)%3)*3)+(alu1*63)+(gidx0*27)+(lidx2*9)]:0.0f)") # noqa: E501
val3 = get_gated_load_uop(alu17&alu8.lt(7), (alu4//10)+alu15+(alu8*9)+1)
self.assertEqual(render(val3),
"((((alu2<30)&(alu3<7))&(alu1<7))?data0[(alu0//10)+(((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((alu2//10)%3)*3)+(alu3*63)+(alu1*9)+1]:0.0f)")
class TestImageSimplification(unittest.TestCase):
def test_idx_gt_c(self):