diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 64f9cfddf9..0f182a5665 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -2,10 +2,6 @@ import unittest, pickle from typing import Tuple -# TODO: fix all the @unittest.expectedFailure - -# *** fake symobilc uops *** - from tinygrad.dtype import dtypes, ConstType from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.uopgraph import full_graph_rewrite, sym @@ -20,15 +16,9 @@ def render(self) -> Tuple[str, ConstType, ConstType]: rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1] return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax -def NumNode(val): return UOp.const(dtypes.int, val) -class Node: - @staticmethod - def sum(ops): return functools.reduce(lambda x,y: x+y, ops) - @staticmethod - def ands(ops): return functools.reduce(lambda x,y: x*y, ops) - def __floordiv__(a,b,unk): return a//b -def SumNode(x): return Node.sum(x) -def MulNode(x, y): return x*y +def uconst(val): return UOp.const(dtypes.int, val) +def usum(ops): return functools.reduce(lambda x,y: x+y, ops) +def uand(ops): return functools.reduce(lambda x,y: x*y, ops) # *** leave tests the same @@ -74,8 +64,8 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(expr, 0, 1, "(idx<128)") def test_lt_divides_and(self): - expr = Node.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512, - (Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512]) + expr = uand([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512, + (Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512]) self.helper_test_variable(expr, 0, 1, "((idx1<128)&(idx2<128))") def test_lt_factors(self): @@ -113,15 +103,9 @@ class TestSymbolic(unittest.TestCase): def test_add_1(self): self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(a+1)") - def test_add_num_1(self): - self.helper_test_variable(Variable("a", 0, 8)+NumNode(1), 1, 9, "(a+1)") - def test_sub_1(self): self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(a+-1)") - def test_sub_num_1(self): - self.helper_test_variable(Variable("a", 0, 8)-NumNode(1), -1, 7, "(a+-1)") - def test_add_self(self): a = Variable("a", 0, 8) self.helper_test_variable(a+a, 0, 16, "(a*2)") @@ -165,31 +149,31 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "(a//-2)") def test_sum_div_remove(self): - self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 20, 0, 0, "0") + self.helper_test_variable(usum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 20, 0, 0, "0") def test_sum_div_min_max(self): - self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)") + self.helper_test_variable(usum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)") def test_sum_div_mod_factor(self): - self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))") - self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) % 2, 0, 0, "0") + self.helper_test_variable(usum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))") + self.helper_test_variable(usum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) % 2, 0, 0, "0") def test_sum_div_some_factor(self): - self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, ("(((a*5)//2)+(b*2))", "((b*2)+((a*5)//2))")) + self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, ("(((a*5)//2)+(b*2))", "((b*2)+((a*5)//2))")) def test_sum_div_trim_const(self): self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "(((a+b)+1)//4)") def test_sum_div_some_partial_factor(self): - self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)") - self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)") + self.helper_test_variable(usum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)") + self.helper_test_variable(usum([uconst(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)") self.helper_test_variable((Variable("a", 0, 7)*30+20)//20, 1, 11, "(((a*3)//2)+1)") def test_sum_div_no_factor(self): - self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)") + self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)") def test_mod_factor(self): - self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "((b%2)*50)") + self.helper_test_variable(usum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "((b%2)*50)") def test_mod_to_sub(self): self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, "(a+-1)") @@ -219,16 +203,16 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((3+Variable("a",4,5))%4, 0, 3, "((a*-3)+15)") def test_sum_div_const(self): - self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a") + self.helper_test_variable(usum([Variable("a", 0, 7)*4, uconst(3)]) // 4, 0, 7, "a") def test_sum_div_const_big(self): - self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)") + self.helper_test_variable(usum([Variable("a", 0, 7)*4, uconst(3)]) // 16, 0, 1, "(a//4)") def test_sum_lt_fold(self): - self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)") - self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1, + self.helper_test_variable(usum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)") + self.helper_test_variable(usum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1, ("(((a*4)+b)<16)", "((b+(a*4))<16)")) - self.helper_test_variable(Node.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]) < (4 * 67), 0, 1, "(a<23)") + self.helper_test_variable(usum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]) < (4 * 67), 0, 1, "(a<23)") def test_mul_mod_large(self): self.helper_test_variable((Variable("a", 0, 20)*10)%9, 0, 8, "(a%9)") @@ -260,14 +244,14 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((a//2+1)//2, 0, 31, "((a+2)//4)") def test_distribute_mul(self): - self.helper_test_variable(Node.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))") + self.helper_test_variable(usum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))") self.helper_test_variable((1+Variable("a", 0, 3))*(-2)+12, 4, 10, "((a*-2)+10)") def test_mod_mul_sum(self): - self.helper_test_variable(Node.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, ("(b+a)", "(a+b)")) + self.helper_test_variable(usum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, ("(b+a)", "(a+b)")) def test_sum_0(self): - self.helper_test_variable(Node.sum([Variable("a", 0, 7)]), 0, 7, "a") + self.helper_test_variable(usum([Variable("a", 0, 7)]), 0, 7, "a") def test_mod_remove(self): self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a") @@ -303,26 +287,26 @@ class TestSymbolic(unittest.TestCase): "((((a*3)+(b*2))+(c*4))<1)") def test_and_fold(self): - self.helper_test_variable(Node.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0") + self.helper_test_variable(uand([uconst(0), Variable("a", 0, 1)]), 0, 0, "0") def test_and_remove(self): - self.helper_test_variable(Node.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a") + self.helper_test_variable(uand([uconst(1), Variable("a", 0, 1)]), 0, 1, "a") def test_mod_factor_negative(self): - self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)") - self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)") + self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)") + self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)") def test_sum_combine_num(self): - self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, "(a+6)") + self.helper_test_variable(usum([uconst(29), Variable("a", 0, 10), uconst(-23)]), 6, 16, "(a+6)") def test_sum_num_hoisted_and_factors_cancel_out(self): - self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1") + self.helper_test_variable(usum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1") def test_div_cancel(self): - self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(b+-1)") + self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(b+-1)") def test_mod_cancel(self): - self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)") + self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)") def test_mul_div(self): self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a") @@ -345,7 +329,7 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)") def test_sum_div_partial_remove(self): - self.helper_test_variable(Node.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0") + self.helper_test_variable(usum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0") @unittest.expectedFailure def test_div_numerator_negative(self): @@ -541,7 +525,7 @@ class TestSymbolicNumeric(unittest.TestCase): MIN, MAX = 0, 10 # one number for i in range(MIN, MAX): - v = graph_rewrite(f(NumNode(i)), sym) + v = graph_rewrite(f(uconst(i)), sym) self.assertEqual(v.vmin, v.vmax) self.assertEqual(v.vmin, f(i)) for kmin in range(MIN, MAX): @@ -565,16 +549,16 @@ class TestSymbolicNumeric(unittest.TestCase): class TestSymbolicVars(unittest.TestCase): def test_simple(self): - z = NumNode(0) + z = uconst(0) a = Variable("a", 0, 10) b = Variable("b", 0, 10) c = Variable("c", 0, 10) assert z.vars() == z.vars() == set() print(a.vars()) assert a.vars() == a.vars() == {a} - m = MulNode(a, 3) + m = a * 3 assert m.vars() == {a} - s = SumNode([a, b, c]) + s = usum([a, b, c]) assert s.vars() == {a, b, c} def test_compound(self): @@ -608,29 +592,19 @@ class TestSymInfer(unittest.TestCase): assert sym_infer(a*b+c, var_vals) == 10 """ -@unittest.skip("not supported on uops yet") -class TestSymRender(unittest.TestCase): - def test_sym_render(self): - a = Variable("a", 1, 8) - b = Variable("b", 1, 10) - assert sym_render(a) == "a" - assert sym_render(1) == "1" - assert sym_render(a+1) == "(1+a)" - assert sym_render(a*b) == "(a*b)" - @unittest.skip("not supported on uops yet") class TestSymbolicSymbolicOps(unittest.TestCase): def test_node_divmod_node(self): i = Variable("i", 1, 10) idx0 = Variable("idx0", 0, i*3-1) - assert NumNode(0) // (Variable("i", 1, 10)*128) == 0 - assert NumNode(0) % (Variable("i", 1, 10)*128) == 0 - assert NumNode(127) // (Variable("i", 1, 10)*128) == 0 - assert NumNode(127) % (Variable("i", 1, 10)*128) == 127 + assert uconst(0) // (Variable("i", 1, 10)*128) == 0 + assert uconst(0) % (Variable("i", 1, 10)*128) == 0 + assert uconst(127) // (Variable("i", 1, 10)*128) == 0 + assert uconst(127) % (Variable("i", 1, 10)*128) == 127 assert 127 // (Variable("i", 1, 10)*128) == 0 assert 127 % (Variable("i", 1, 10)*128) == 127 - assert NumNode(128) // (Variable("i", 1, 10)*128 + 128) == 0 - assert NumNode(128) % (Variable("i", 1, 10)*128 + 128) == 128 + assert uconst(128) // (Variable("i", 1, 10)*128 + 128) == 0 + assert uconst(128) % (Variable("i", 1, 10)*128 + 128) == 128 assert 128 // (Variable("i", 1, 10)*128 + 128) == 0 assert 128 % (Variable("i", 1, 10)*128 + 128) == 128 assert 0 // (Variable("i", 1, 10)*128) == 0 @@ -639,10 +613,10 @@ class TestSymbolicSymbolicOps(unittest.TestCase): assert idx0 % (i*3) == idx0 assert i // i == 1 assert i % i == 0 - assert 128 // NumNode(4) == 32 - assert 128 % NumNode(4) == 0 - assert NumNode(128) // NumNode(4) == 32 - assert NumNode(128) % NumNode(4) == 0 + assert 128 // uconst(4) == 32 + assert 128 % uconst(4) == 0 + assert uconst(128) // uconst(4) == 32 + assert uconst(128) % uconst(4) == 0 def test_mulnode_divmod_node(self): i = Variable("i", 1, 10) @@ -667,12 +641,12 @@ class TestSymbolicSymbolicOps(unittest.TestCase): # assert (i*128+128)*2 // (i*128+128) == 2 # assert (i*128+128)*2 % (i*128+128) == 0 - def test_sumnode_div_numnode_no_factoring(self): + def test_sumnode_div_uconst_no_factoring(self): gid = Variable("gid", 0, 1023) lid = Variable("lid", 0, 3) - expr_before_div = NumNode(-1019)-4*lid-gid - unfactored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), False) - factored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), True) + expr_before_div = uconst(-1019)-4*lid-gid + unfactored_expr = Node.__floordiv__(expr_before_div, uconst(-16), False) + factored_expr = Node.__floordiv__(expr_before_div, uconst(-16), True) self.assertEqual(unfactored_expr.render(), "(((lid*4)+1019+gid)//16)") self.assertEqual(factored_expr.render(), "(((((3+gid)//4)+2+lid)//4)+63)") @@ -698,21 +672,21 @@ class TestSymbolicSymbolicOps(unittest.TestCase): def test_num_node_mul_node(self): a = Variable("a", 1, 5) - b = NumNode(2) * a + b = uconst(2) * a assert b == a * 2 assert isinstance(b, MulNode) - b = NumNode(1) * a + b = uconst(1) * a assert b == a assert isinstance(b, Variable) - b = NumNode(0) * a + b = uconst(0) * a assert b == 0 - assert isinstance(b, NumNode) + assert isinstance(b, uconst) def test_substitute(self): a = Variable("idx0", 1, 3) b = a + 1 - c = b.substitute({a: NumNode(1)}) - assert c == NumNode(2) + c = b.substitute({a: uconst(1)}) + assert c == uconst(2) """ class TestSymbolicRealWorld(unittest.TestCase):