diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index c4e6b8ffbf..f513c4e2d5 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -212,11 +212,9 @@ class TestSymbolic(unittest.TestCase): # This is mod reduction self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, {"(-1+a)", "(a+(-1))"}) - @unittest.expectedFailure def test_sum_div_const(self): self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a") - @unittest.expectedFailure def test_sum_div_const_big(self): self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)") @@ -311,7 +309,6 @@ class TestSymbolic(unittest.TestCase): def test_mul_div_factor_div(self): self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)") - @unittest.expectedFailure def test_sum_div_partial_remove(self): self.helper_test_variable(Node.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0") diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 8aa0795666..ae4653bccf 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -255,10 +255,14 @@ constant_folder = PatternMatcher([ (UOp.var('x') % UOp.cvar('c'), lambda x,c: x if 0 <= x.vmin.arg <= x.vmax.arg < c.arg else None), # mod reduction (UOp.var('x') % UOp.cvar('c'), lambda x,c: (x-(x.vmin.arg//c.arg)*c.arg)%c if 0 < c.arg <= x.vmin.arg else None), - # mul -> (sum) -> mod + # mul -> mod ((UOp.cvar('c0')*UOp.var('x')) % UOp.cvar('c1'), lambda x,c0,c1: x*(c0.arg%c1.arg)%c1 if c0.arg >= c1.arg > 0 else None), + # mul -> add -> mod (((UOp.cvar('c0')*UOp.var('x'))+UOp.var('x2')) % UOp.cvar('c1'), lambda x,x2,c0,c1: x2%c1 if (r:=c0.arg%c1.arg) == 0 else (x*r+x2)%c1 if c0.arg >= c1.arg > 0 else None), + # mul -> add -> div + (((UOp.cvar('c0')*UOp.var('x'))+UOp.var('x2')) // UOp.cvar('c1'), lambda x,x2,c0,c1:\ + x*(c0.arg//g)//(c1.arg//g) if c0.arg > 0 and c1.arg > 0 and (g:=math.gcd(c0.arg,c1.arg)) > 1 and g > x2.vmax.arg and x2.vmin.arg >= 0 else None), # mod mod ((UOp.var('x') % UOp.cvar('c0')) % UOp.cvar('c1'), lambda x,c0,c1: x % c0 if 0 < c0.arg < c1.arg else x % c1 if c0.arg % c1.arg == 0 else None), (((UOp.var('x') * UOp.cvar('c0')) % UOp.cvar('c1')) % UOp.cvar('c0'), lambda x,c0,c1: x.const(0)),