diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8437e7e64f..c0871d1f31 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -417,7 +417,7 @@ jobs: llvm: 'true' - name: Test openpilot model kernel count and gate usage run: | - ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1486 ALLOWED_GATED_READ_IMAGE=17 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 + ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1486 ALLOWED_GATED_READ_IMAGE=18 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 - name: Test openpilot CL compile fp16 run: FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 - name: Test openpilot CL compile fp32 (test correctness) diff --git a/test/external/fuzz_symbolic.py b/test/external/fuzz_symbolic.py index 641d5eed6e..e97840d0cf 100644 --- a/test/external/fuzz_symbolic.py +++ b/test/external/fuzz_symbolic.py @@ -7,7 +7,7 @@ import z3 from tinygrad import Variable, dtypes from tinygrad.uop.ops import UOp from tinygrad.uop.validate import uops_to_z3 -from tinygrad.helpers import DEBUG, Context +from tinygrad.helpers import DEBUG seed = int(sys.argv[1]) if len(sys.argv) > 1 else random.randint(0, 100) print(f"Seed: {seed}", flush=True) @@ -56,8 +56,7 @@ if __name__ == "__main__": v = [u1,u2,u3] expr = random_int_expr(6) - with Context(CORRECT_DIVMOD_FOLDING=1): - simplified_expr = expr.simplify() + simplified_expr = expr.simplify() solver = z3.Solver(ctx=z3.Context()) solver.set(timeout=5000) # some expressions take very long verify, but its very unlikely they actually return sat @@ -74,10 +73,9 @@ if __name__ == "__main__": m = solver.model() n1, n2, n3 = m[v1], m[v2], m[v3] u1_val, u2_val, u3_val = u1.const_like(n1.as_long()), u2.const_like(n2.as_long()), u3.const_like(n3.as_long()) - with Context(CORRECT_DIVMOD_FOLDING=1): - num = expr.simplify().substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify() - rn = expr.substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify() - if num==rn: print("z3 found a mismatch but the expressions are equal!!") + num = expr.simplify().substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify() + rn = expr.substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify() + if num==rn: print("z3 found a mismatch but the expressions are equal!!") assert False, f"mismatched {expr.render()} at v1={m[v1]}; v2={m[v2]}; v3={m[v3]} = {num} != {rn}\n" +\ "Reproduce with:\n" +\ f"v1=Variable(\"{u1.arg[0]}\", {u1.arg[1]}, {u1.arg[2]})\n" +\ diff --git a/test/external/fuzz_symbolic_symbolic_div.py b/test/external/fuzz_symbolic_symbolic_div.py index 38b5a5a638..e0d72266fa 100644 --- a/test/external/fuzz_symbolic_symbolic_div.py +++ b/test/external/fuzz_symbolic_symbolic_div.py @@ -2,7 +2,7 @@ import random, sys import z3 from tinygrad.uop.ops import UOp, Ops from tinygrad.uop.validate import uops_to_z3 -from tinygrad.helpers import DEBUG, Context, colored +from tinygrad.helpers import DEBUG, colored seed = int(sys.argv[1]) if len(sys.argv) > 1 else random.randint(0, 100) print(f"Seed: {seed}", flush=True) @@ -36,8 +36,7 @@ if __name__ == "__main__": variable_names += [f"r{i}" for i in range(num_ranges)] expr = get_random_expr(ranges, factors) - with Context(CORRECT_DIVMOD_FOLDING=1): - simplified_expr = expr.simplify() + simplified_expr = expr.simplify() if DEBUG>=1: print(expr.render(simplify=False), " --> ", simplified_expr.render(simplify=False)) diff --git a/test/null/test_simplify_valid_idx.py b/test/null/test_simplify_valid_idx.py index 2174ce5d7c..3d1bbd3ed4 100644 --- a/test/null/test_simplify_valid_idx.py +++ b/test/null/test_simplify_valid_idx.py @@ -1,12 +1,18 @@ import unittest, itertools +from tinygrad.codegen.late.devectorizer import load_store_indexing from tinygrad.dtype import dtypes -from tinygrad.uop.ops import UOp, Ops -from tinygrad.uop.symbolic import simplify_valid +from tinygrad.uop.ops import UOp, Ops, graph_rewrite +from tinygrad.uop.symbolic import simplify_valid, sym, pm_move_where_on_load from tinygrad.helpers import Context from test.helpers import full_rewrite from test.null.test_uop_symbolic import check_uop_against_string +# symbolic-only idx + valid simplification (no late lowering of FLOORDIV/FLOORMOD) +def simplify_valid_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move_where_on_load, name="simplify_valid_idx") +# image-aware idx + valid simplification: adds the codegen-layer matcher that drops provably in-bounds gates +def simplify_image_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move_where_on_load+load_store_indexing, name="simplify_image_idx") + def get_gated_load_uop(valid:UOp, idx:UOp): return UOp(Ops.LOAD, dtypes.float, ( UOp(Ops.PARAM, dtypes.float.ptr(), arg=0).index(idx.valid(valid), ptr=True), @@ -47,11 +53,10 @@ class TestHelpers(unittest.TestCase): class TestValidIdxSimplification(unittest.TestCase): def check(self, load, sidx, svalid, extra=()): - with Context(NOOPT=1, SPEC=0): - load = full_rewrite(UOp.sink(load, *extra)).src[0] - idx, valid = load.src[0].src[1], load.src[0].src[2] - check_uop_against_string(self, idx, sidx) - check_uop_against_string(self, valid, svalid) + load = simplify_valid_idx(UOp.sink(load, *extra)).src[0] + off = load.src[0].src[1] + check_uop_against_string(self, off.get_idx(), sidx) + check_uop_against_string(self, off.get_valid(), svalid) def test_cumsum(self): gidx0 = Special("gidx0", 5) @@ -216,18 +221,18 @@ class TestValidIdxSimplification(unittest.TestCase): class TestImageSimplification(unittest.TestCase): def check(self, load, svalid, sidx0, sidx1): - with Context(NOOPT=1, SPEC=0): - load = full_rewrite(load.sink()).src[0] - idx = load.src[0].src[1] + load = simplify_image_idx(load.sink()).src[0] + off = load.src[0].src[1] + idx = off.get_idx() self.assertEqual(idx.op, Ops.STACK) self.assertEqual(len(idx.src), 2) idx0, idx1 = idx.src[0], idx.src[1] check_uop_against_string(self, idx0, sidx0) check_uop_against_string(self, idx1, sidx1) if svalid is not None: - check_uop_against_string(self, load.src[0].src[2], svalid) + check_uop_against_string(self, off.get_valid(), svalid) else: - self.assertEqual(len(load.src[0].src), 2, "svalid is None but load still has a valid") + self.assertEqual(off.get_valid(), UOp.const(dtypes.bool, True), "svalid is None but valid is not True") def test_idx_gt_c(self): # (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid @@ -447,12 +452,12 @@ class TestImageSimplification(unittest.TestCase): load = get_load_image_uop((32, 1024, 4), valid, (alu0, alu1)) self.check(load, None, "(lidx1*128+gidx0//2+144)", "(lidx0*2+r0+-3)") - # TODO: this is the same idx as above, but simplifying idx too early makes it hard to drop the valid + # same idx, written without the inline simplification of the inner div/mod alu0 = ((gidx0*2+lidx1*512+(lidx0*8192+r0*4096)+-11711)//4%1024) alu1 = (lidx0*2+r0+-3) valid = ((lidx1<7)&((((lidx0*2+r0)<3)!=1)&((lidx0*2+r0)<35))) load = get_load_image_uop((32, 1024, 4), valid, (alu0, alu1)) - self.check(load, "(lidx1<7)", "((gidx0*2+lidx1*512+(lidx0*8192+r0*4096)+-11711)//4%1024)", "(lidx0*2+r0+-3)") + self.check(load, None, "(lidx1*128+gidx0//2+144)", "(lidx0*2+r0+-3)") def test_simplify8(self): # from openpilot compile3, kernel r_4_16_8_16_4_4_3_3n1 diff --git a/test/null/test_symbolic_failures.py b/test/null/test_symbolic_failures.py index 8587bf2659..9fc6b1cf57 100644 --- a/test/null/test_symbolic_failures.py +++ b/test/null/test_symbolic_failures.py @@ -1,16 +1,8 @@ import unittest from tinygrad import Variable -from tinygrad.helpers import Context class TestFuzzFailure(unittest.TestCase): - def setUp(self): - self.context = Context(CORRECT_DIVMOD_FOLDING=1) - self.context.__enter__() - - def tearDown(self): - self.context.__exit__(None, None, None) - def test_fuzz_failure1(self): v1=Variable('v1', 0, 8) v2=Variable('v2', 0, 2) diff --git a/test/null/test_uop_symbolic.py b/test/null/test_uop_symbolic.py index f159a34483..395c48ba00 100644 --- a/test/null/test_uop_symbolic.py +++ b/test/null/test_uop_symbolic.py @@ -3,7 +3,6 @@ import unittest, pickle, functools, math import z3 from tinygrad.dtype import dtypes, ConstType, DType, Invalid -from tinygrad.helpers import Context from test.helpers import get_uops from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid, pm_move_where_on_load @@ -181,8 +180,8 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Variable("a", 0, 8)*1, 0, 8, "a") def test_mul_neg_1(self): - self.helper_test_variable((Variable("a", 0, 2)*-1)//3, 0, 0, "0") - self.helper_test_variable((Variable("a", 2, 7)*-1)//3, -2, 0, "((a//3)*-1)") + self.helper_test_variable((Variable("a", 0, 2)*-1)//3, -1, 0, "((a*-1)//3)") + self.helper_test_variable((Variable("a", 2, 7)*-1)//3, -3, -1, "((a*-1)//3)") def test_mul_2(self): self.helper_test_variable(Variable("a", 0, 8)*2, 0, 16, "(a*2)") @@ -203,8 +202,8 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Variable("a", 0, 7) // 20, 0, 0, "0") def test_div_neg_min_max(self): - self.helper_test_variable(Variable("a", 1, 7) // -2, -3, 0, "((a//2)*-1)") - self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "((a//2)*-1)") + self.helper_test_variable(Variable("a", 1, 7) // -2, -4, -1, "(a//-2)") + self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "(a//-2)") def test_div_mod_zero(self): with self.assertRaises(ZeroDivisionError): @@ -238,14 +237,14 @@ class TestSymbolic(unittest.TestCase): def test_mod_min_max(self): self.helper_test_variable(Variable("x", 0, 10)%Variable("y", 1, 10), 0, 9, "(x%y)") - self.helper_test_variable(Variable("x", -10, 0)%Variable("y", 1, 10), -9, 0, "(((x*-1)%y)*-1)") - self.helper_test_variable(Variable("x", 0, 10)%Variable("y", -10, -1), 0, 9, "(x%(y*-1))") - self.helper_test_variable(Variable("x", -10, 0)%Variable("y", -10, -1), -9, 0, "(((x*-1)%(y*-1))*-1)") - self.helper_test_variable(Variable("x", -10, 10)%Variable("y", -10, -1), -9, 9, "(x%(y*-1))") + self.helper_test_variable(Variable("x", -10, 0)%Variable("y", 1, 10), 0, 9, "(x%y)") + self.helper_test_variable(Variable("x", 0, 10)%Variable("y", -10, -1), -9, 0, "(x%y)") + self.helper_test_variable(Variable("x", -10, 0)%Variable("y", -10, -1), -9, 0, "(x%y)") + self.helper_test_variable(Variable("x", -10, 10)%Variable("y", -10, -1), -9, 0, "(x%y)") - # test _min_max directly without the rewrite taking out the sign + # test _min_max directly: floor mod with positive divisor is in [0, c-1]; with negative divisor in [c+1, 0] self.assertEqual((Variable("x", -10, 0)%Variable("y", -10, -1))._min_max, (-9, 0)) - self.assertEqual((Variable("x", -10, 0)%Variable("y", 1, 10))._min_max, (-9, 0)) + self.assertEqual((Variable("x", -10, 0)%Variable("y", 1, 10))._min_max, (0, 9)) def test_range_div_its_symbolic_bound(self): a = Variable("a", 1, 10, dtypes.weakint) @@ -262,12 +261,12 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Variable("a", 0, 6) // 2, 0, 3, "(a//2)") self.helper_test_variable(Variable("x", 0, 10)//Variable("y", 1, 10), 0, 10, "(x//y)") - self.helper_test_variable(Variable("x", -10, 0)//Variable("y", 1, 10), -10, 0, "(((x*-1)//y)*-1)") - self.helper_test_variable(Variable("x", 0, 10)//Variable("y", -10, -1), -10, 0, "((x//(y*-1))*-1)") - self.helper_test_variable(Variable("x", -10, 0)//Variable("y", -10, -1), 0, 10, "((x*-1)//(y*-1))") + self.helper_test_variable(Variable("x", -10, 0)//Variable("y", 1, 10), -10, 0, "(x//y)") + self.helper_test_variable(Variable("x", 0, 10)//Variable("y", -10, -1), -10, 0, "(x//y)") + self.helper_test_variable(Variable("x", -10, 0)//Variable("y", -10, -1), 0, 10, "(x//y)") self.helper_test_variable(Variable("x", -10, 10)//Variable("y", 1, 10), -10, 10, "(x//y)") - self.helper_test_variable(Variable("x", -10, 10)//Variable("y", -10, -1), -10, 10, "((x//(y*-1))*-1)") + self.helper_test_variable(Variable("x", -10, 10)//Variable("y", -10, -1), -10, 10, "(x//y)") def test_mod_factor(self): self.helper_test_variable(usum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "((b%2)*50)") @@ -334,12 +333,12 @@ class TestSymbolic(unittest.TestCase): def test_mod_mod_wrong_sign(self): v1=Variable("v1", 0, 128) v3=Variable("v3", 0, 7) - self.helper_test_variable((((((v1%2)*2)+((v3+-1)%5))+-2)%5), -3, 4, "(v1%2*2+(v3+-1)%5+-2)") + self.helper_test_variable((((((v1%2)*2)+((v3+-1)%5))+-2)%5), 0, 4, "((v3+v1%2*2+-3)%5)") def test_mod_mod_wrong_sign2(self): v2=Variable("v2", 0, 8) v3=Variable("v3", 0, 4) - self.helper_test_variable((((((v3+3)%7)+(v2+-2))%7)%7), -2, 6, "(((v2+((v3+3)%7))+-2)%7)") + self.helper_test_variable((((((v3+3)%7)+(v2+-2))%7)%7), 0, 6, "((v2+v3+1)%7)") def test_mul_mul(self): self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)") @@ -357,21 +356,21 @@ class TestSymbolic(unittest.TestCase): def test_div_const_div(self): a = Variable("a", 0, 124) self.helper_test_variable((a//2+1)//2, 0, 31, "((a+2)//4)") - self.helper_test_variable(((-a)//2-1)//2, -31, 0, "(((a+2)//4)*-1)") - self.helper_test_variable(((-a)//2+10)//2, -26, 5, "((((a//2)*-1)+10)//2)") + self.helper_test_variable(((-a)//2-1)//2, -32, -1, "((a*-1+2)//4+-1)") + self.helper_test_variable(((-a)//2+10)//2, -26, 5, "(a*-1//4+5)") def test_div_const_div_wrong_sign(self): a = Variable("a", 0, 124) - self.helper_test_variable(((a-10)//2+10)//2, 2, 33, "((((a+-10)//2)+10)//2)") + self.helper_test_variable(((a-10)//2+10)//2, 2, 33, "((a+2)//4+2)") def test_div_const_div_wrong_sign_divisor(self): a = Variable("a", 0, 124) - self.helper_test_variable(((a+10)//-2+10)//-4, -1, 14, "(((((a//2)*-1)+5)//4)*-1)") + self.helper_test_variable(((a+10)//-2+10)//-4, -2, 14, "(((a+10)//-2+10)//-4)") def test_neg_mod(self): a = Variable("a", 0, 124) - self.helper_test_variable((-a)%4, -3, 0, "((a%4)*-1)") - self.helper_test_variable(a%-4, 0, 3, "(a%4)") + self.helper_test_variable((-a)%4, 0, 3, "(a*-1%4)") + self.helper_test_variable(a%-4, -3, 0, "(a%-4)") def test_distribute_mul(self): self.helper_test_variable(usum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))") @@ -387,11 +386,11 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a") def test_big_mod(self): - self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)") - self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(((a*-1)%10)*-1)") - self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)") + self.helper_test_variable(Variable("a", -20, 20)%10, 0, 9, "(a%10)") + self.helper_test_variable(Variable("a", -20, 0)%10, 0, 9, "(a%10)") + self.helper_test_variable(Variable("a", -20, 1)%10, 0, 9, "(a%10)") self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)") - self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)") + self.helper_test_variable(Variable("a", -1, 20)%10, 0, 9, "(a%10)") def test_ge_remove(self): self.helper_test_variable(Variable("a", 0, 6) >= 25, 0, 0, "False") @@ -439,8 +438,8 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(c & c.logical_not(), False, False, "False") def test_mod_factor_negative(self): - self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, -27, 27, "(((a+(b*28))+-29)%28)") - self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, -27, 27, "(((a+(b*28))+-29)%28)") + self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+b*28+-29)%28)") + self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+b*28+-29)%28)") def test_sum_combine_num(self): self.helper_test_variable(usum([uconst(29), Variable("a", 0, 10), uconst(-23)]), 6, 16, "(a+6)") @@ -448,22 +447,12 @@ class TestSymbolic(unittest.TestCase): def test_sum_num_hoisted_and_factors_cancel_out(self): self.helper_test_variable(usum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1") - @unittest.expectedFailure # only correct for floordiv, not truncdiv def test_div_cancel(self): self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(b+-1)") - def test_div_cancel_correct(self): - with Context(CORRECT_DIVMOD_FOLDING=1): - self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(((a+(b*20))+-20)//20)") - - @unittest.expectedFailure # only correct for floordiv, not truncdiv def test_mod_cancel(self): self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)") - def test_mod_cancel_correct(self): - with Context(CORRECT_DIVMOD_FOLDING=1): - self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, -38, 38, "((((a+(b*20))+-20)%20)*2)") - def test_mul_div(self): self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a") @@ -475,22 +464,22 @@ class TestSymbolic(unittest.TestCase): lidx1 = UOp.variable("lidx1", 0, 1) ridx1005 = UOp.variable("ridx1005", 0, 2) ridx1006 = UOp.variable("ridx1006", 0, 2) - self.helper_test_variable((lidx1+((gidx1*18)+(ridx1005*18)+(lidx0*162))+(gidx0*2)+(ridx1006*2)+-40)//18, -2, 20, - "(((((lidx1+(((gidx1*18)+(ridx1005*18))+(lidx0*162)))+(gidx0*2))+(ridx1006*2))+-40)//18)") + self.helper_test_variable((lidx1+((gidx1*18)+(ridx1005*18)+(lidx0*162))+(gidx0*2)+(ridx1006*2)+-40)//18, -3, 20, + "(gidx1+ridx1005+lidx0*9+(gidx0+ridx1006+7)//9+-3)") def test_add_div(self): # careful about the lower bounds and upper bounds - self.helper_test_variable((Variable("a", 0, 5)-2)//4, 0, 0, "0") - self.helper_test_variable((Variable("a", 0, 5)-1)//4, 0, 1, "((a+-1)//4)") + self.helper_test_variable((Variable("a", 0, 5)-2)//4, -1, 0, "((a+2)//4+-1)") + self.helper_test_variable((Variable("a", 0, 5)-1)//4, -1, 1, "((a+3)//4+-1)") self.helper_test_variable((Variable("a", 0, 5))//4, 0, 1, "(a//4)") self.helper_test_variable((Variable("a", 0, 5)+1)//4, 0, 1, "((a+1)//4)") self.helper_test_variable((Variable("a", 0, 5)+2)//4, 0, 1, "((a+2)//4)") self.helper_test_variable((Variable("a", 0, 5)+3)//4, 0, 2, "((a+3)//4)") - self.helper_test_variable((Variable("a", 0, 5)+4)//4, 1, 2, "((a//4)+1)") - self.helper_test_variable((Variable("a", 0, 5)+5)//4, 1, 2, "(((a+1)//4)+1)") + self.helper_test_variable((Variable("a", 0, 5)+4)//4, 1, 2, "(a//4+1)") + self.helper_test_variable((Variable("a", 0, 5)+5)//4, 1, 2, "((a+1)//4+1)") def test_div_neg_rem(self): - self.helper_test_variable((-Variable("a", 0, 255)+256)//2, 0, 128, "((((a+1)//2)*-1)+128)") + self.helper_test_variable((-Variable("a", 0, 255)+256)//2, 0, 128, "(a*-1//2+128)") def test_mul_div_factor_mul(self): self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)") @@ -502,7 +491,7 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)") def test_mul_div_factor_div_neg(self): - self.helper_test_variable((Variable("a", 0, 10)*-4+4)//8, -4, 0, "(((a*-1)+1)//2)") + self.helper_test_variable((Variable("a", 0, 10)*-4+4)//8, -5, 0, "((a*-1+1)//2)") def test_div_symbolic_const_gcd(self): a = Variable("a", -10, 10) @@ -520,8 +509,8 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((d1*a*d2*b*d1)//(d1*d2), -1000, 1000, "(a*(b*d1))", test_z3=False) self.helper_test_variable((d1*a + b*d1)//(d1), -20, 20, "(a+b)", test_z3=False) self.helper_test_variable((d1*a + b*d1 + c*d1)//(d1), -30, 30, "(c+(a+b))", test_z3=False) - self.helper_test_variable((3*a*d1 + 9*b*d1)//(3*d1*d2), -40, 40, "(((a+(b*3))//(d2*-1))*-1)", test_z3=False) - self.helper_test_variable((3*a*d1 + 9*b*d1+3)//(3*d1*d2), -401, 399, "(((((a*d1)+((b*d1)*3))+1)//((d1*d2)*-1))*-1)", test_z3=False) + self.helper_test_variable((3*a*d1 + 9*b*d1)//(3*d1*d2), -40, 40, "((a+b*3)//d2)", test_z3=False) + self.helper_test_variable((3*a*d1 + 9*b*d1+3)//(3*d1*d2), -401, 399, "((a*d1+b*d1*3+1)//(d1*d2))", test_z3=False) def test_symbolic_factor_remainder_div(self): a = Variable("a", 0, 10) @@ -532,7 +521,7 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((d*a*20+b*d*5+10)//(5*d), 0, 52, "((b+(a*4))+(2//d))") def test_mod_gcd_factor_neg(self): - self.helper_test_variable((Variable("a", 0, 10)*-4+4)%8, -4, 4, "((((a*-1)+1)%2)*4)") + self.helper_test_variable((Variable("a", 0, 10)*-4+4)%8, 0, 4, "((a*-1+1)%2*4)") def test_mod_gcd_fold_neg(self): self.helper_test_variable((Variable("a", 0, 10)*-8+20)%4, 0, 0, "0") @@ -540,22 +529,21 @@ class TestSymbolic(unittest.TestCase): def test_sum_div_partial_remove(self): self.helper_test_variable(usum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0") - def test_cdiv_const_evaluation(self): - self.helper_test_variable((Variable("a", 0, 2)-12)//8, -1, -1, "-1") - self.helper_test_variable((-Variable("a", 0, 2))//7, 0, 0, "0") + def test_floordiv_const_evaluation(self): + self.helper_test_variable((Variable("a", 0, 2)-12)//8, -2, -2, "-2") + self.helper_test_variable((-Variable("a", 0, 2))//7, -1, 0, "(a*-1//7)") - def test_cmod_const_evaluation(self): - self.helper_test_variable((Variable("a", 1, 1)*-3)%8, -3, -3, "-3") - self.helper_test_variable((-Variable("a", 10, 10))%7, -3, -3, "-3") + def test_floormod_const_evaluation(self): + self.helper_test_variable((Variable("a", 1, 1)*-3)%8, 5, 5, "5") + self.helper_test_variable((-Variable("a", 10, 10))%7, 4, 4, "4") def test_div_numerator_negative(self): - with Context(CORRECT_DIVMOD_FOLDING=1): - self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -8, 0, "(((idx*10)//11)*-1)") + self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "(idx*-1)") def test_nest_div_negative_factor(self): ridx0=Variable("ridx0", 0, 9) ridx1=Variable("ridx1", 0, 6) - self.helper_test_variable(((((ridx0*-7)+ridx1)+63)//35), 0, 1, "(((ridx0//5)*-1)+1)") + self.helper_test_variable(((((ridx0*-7)+ridx1)+63)//35), 0, 1, "((ridx1+ridx0*-7+28)//35+1)") def test_div_into_mod(self): self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)") @@ -568,11 +556,11 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(x%12//4*4 + x%4 + x//12*12, 0, 23, "x") def test_div_neg_cancel(self): - self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 1, 26, "((idx//4)+1)") - self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((idx+3)//4)") - self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, 0, 25, "((idx+2)//4)") - self.helper_test_variable((-Variable("idx", 0, 100))//2, -50, 0, "((idx//2)*-1)") - self.helper_test_variable(Variable("idx", 0, 100)//-2, -50, 0, "((idx//2)*-1)") + self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 0, 25, "((idx*-1+199)//-4+50)") + self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((idx*-1+200)//-4+50)") + self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, -1, 24, "((idx*-1+201)//-4+50)") + self.helper_test_variable((-Variable("idx", 0, 100))//2, -50, 0, "(idx*-1//2)") + self.helper_test_variable(Variable("idx", 0, 100)//-2, -50, 0, "(idx//-2)") def test_sum_div_big_const(self): gidx0 = Variable("gidx0", 0, 24) @@ -647,22 +635,22 @@ class TestSymbolic(unittest.TestCase): def test_div_neg_all_range(self): gidx = Variable("gidx", 0, 124) lidx = Variable("lidx", 0, 7) - self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 1, 250, "(((gidx*2)+(lidx//4))+1)") - self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 250, "((gidx*2)+((lidx+3)//4))") - self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "((gidx*2)+((lidx+2)//4))") - self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "((gidx*2)+((lidx+1)//4))") + self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 0, 250, "((gidx*-8+lidx*-1+999)//-4+250)") + self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 249, "((gidx*-8+lidx*-1+1000)//-4+250)") + self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, -1, 249, "((gidx*-8+lidx*-1+1001)//-4+250)") + self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, -1, 249, "((gidx*-8+lidx*-1+1002)//-4+250)") def test_div_neg_then_neg(self): # taken from arange opts lidx0 = Variable("lidx0", 0, 7) lidx1 = Variable("lidx1", 0, 7) alu2 = -lidx0-lidx1 - self.helper_test_variable((((alu2+14)//(-32))+4), 4, 4, "4") - self.helper_test_variable(-(((alu2+14)//(-32))+4), -4, -4, "-4") - self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, "(((lidx0+lidx1)+25)//32)") - self.helper_test_variable((((alu2+142)//(-32))+4), 0, 0, "0") - self.helper_test_variable((((alu2+150)//(-32))+4), 0, 0, "0") - self.helper_test_variable((((alu2+158)//(-32))+4), 0, 0, "0") + self.helper_test_variable((((alu2+14)//(-32))+4), 3, 4, "((lidx0*-1+lidx1*-1+14)//-32+4)") + self.helper_test_variable(-(((alu2+14)//(-32))+4), -4, -3, "((lidx0*-1+lidx1*-1+14)//-32*-1+-4)") + self.helper_test_variable((((alu2+134)//(-32))+4), -1, 0, "((lidx0*-1+lidx1*-1+134)//-32+4)") + self.helper_test_variable((((alu2+142)//(-32))+4), -1, 0, "((lidx0*-1+lidx1*-1+142)//-32+4)") + self.helper_test_variable((((alu2+150)//(-32))+4), -1, -1, "-1") + self.helper_test_variable((((alu2+158)//(-32))+4), -1, -1, "-1") def test_div_mod_recombine(self): gidx = Variable("gidx", 0, 124) @@ -696,7 +684,7 @@ class TestSymbolic(unittest.TestCase): # negative variable range xn = Variable("x", -1000, 1000) self.helper_test_variable(xn//3%224*3 + xn%3 + xn//672*672, -1000, 1000, "x") - self.helper_test_variable(xn//3%7*3 + xn//21*21, -999, 999, "(x//3*3)") + self.helper_test_variable(xn//3%7*3 + xn//21*21, -1002, 999, "(x//3*3)") # should NOT simplify: a*c1 != b (3*224 != 600) self.helper_test_variable(gidx//3%224*3 + gidx//600*600, 0, 150669, "(gidx//600*600+gidx//3%224*3)") # should NOT simplify: c1*c2 != c3 (224*3 != 700) @@ -709,7 +697,7 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((30 * b + 1) % 18 + ((30 * b + 1) // 18) * 18, 1, 3001, "((b*30)+1)") def test_div_partial_quotient(self): - # IDIV should extract partial quotients when const_factor > divisor, matching what MOD already does + # FLOORDIV should extract partial quotients when const_factor > divisor, matching what FLOORMOD already does # (f*x+c)//d -> (f%d*x+c)//d + (f//d)*x when f >= d b = Variable("b", 0, 100) self.helper_test_variable((31*b+1)//18, 0, 172, "(((b*13)+1)//18+b)") @@ -730,8 +718,7 @@ class TestSymbolic(unittest.TestCase): def test_div_by_factor_tie_break(self): a = Variable("a", 0, 1) b = Variable("b", 0, 1) - with Context(CORRECT_DIVMOD_FOLDING=1): - self.helper_test_variable((a*2+b*3+2)//6, 0, 1, "((a+b+1)//3)") + self.helper_test_variable((a*2+b*3+2)//6, 0, 1, "((a+b+1)//3)") def test_div_mod_recombine_large_coeff(self): # recombine must work even when coeff > divisor: both mod and div reduce the coeff the same way @@ -741,7 +728,7 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((25*a+3)%10 + ((25*a+3)//10)*10, 3, 253, "((a*25)+3)") def test_mod_nest_by_factor(self): - # (a*f+b) % (f*k) = (a%k)*f + b when 0<=b x0 idx = Variable("idx", 0, 24) self.helper_test_variable((idx//4<3), 0, 1, "(idx<12)") - self.helper_test_variable(((idx-20)//4<-3), 0, 1, "(idx<5)") - self.helper_test_variable(((idx-10)//4<0), 0, 1, "(idx<7)") - self.helper_test_variable((idx//-4<-3), 0, 1, "(((idx//4)*-1)<-3)") + self.helper_test_variable(((idx-20)//4<-3), 0, 1, "(idx<8)") + self.helper_test_variable(((idx-10)//4<0), 0, 1, "(idx<10)") + self.helper_test_variable((idx//-4<-3), 0, 1, "((idx//-4)<-3)") + + def test_nested_div_mod_negative_inner_divisor(self): + # (x % (k*c)) // c -> (x // c) % k requires k>0; (x % (k*c)) % c -> x % c is unconditional for c>0 + a = Variable("a", 0, 100) + self.helper_test_variable((a % -8) // 2, -4, 0, "(a%-8//2)") + self.helper_test_variable((a % -8) % 2, 0, 1, "(a%2)") + + def test_floordiv_lt_negative_c(self): + # x//d0 + idx = Variable("idx", -20, 20) + self.helper_test_variable((idx//4 < 0), 0, 1, "(idx<0)") + self.helper_test_variable((idx//4 < -1), 0, 1, "(idx<-4)") + self.helper_test_variable((idx//4 < -2), 0, 1, "(idx<-8)") def test_simplex_lt(self): a = Variable("a", 0, 3) @@ -981,10 +982,10 @@ class TestSymbolic(unittest.TestCase): self.assertIn((a.cast(dtypes.long)*b.cast(dtypes.long)).render(), "(long)((a*b))") def test_nested_mod_negative_range(self): - # (x%(k*c))%c = x%c holds for cmod regardless of signs since sign(x%(k*c)) = sign(x) + # (x%(k*c))%c = x%c for positive c x = Variable("x", 0, 1575) - self.helper_test_variable(((x + (-1064)) % 512) % 4, -3, 3, "((x+-1064)%4)") - self.helper_test_variable(((x + (-1064)) % 512) % 128, -127, 127, "((x+-1064)%128)") + self.helper_test_variable(((x + (-1064)) % 512) % 4, 0, 3, "((x+-1064)%4)") + self.helper_test_variable(((x + (-1064)) % 512) % 128, 0, 127, "((x+-1064)%128)") class TestSymbolicNumeric(unittest.TestCase): def helper_test_numeric(self, f): @@ -1062,12 +1063,13 @@ class TestSymInfer(unittest.TestCase): assert sym_infer(a+b+c, var_vals) == 9 assert sym_infer(a*b, var_vals) == 6 assert sym_infer(a*b+c, var_vals) == 10 - def test_sym_infer_cdiv_cmod(self): + def test_sym_infer_floordiv_floormod(self): a = Variable("a", -1000, 1) b = Variable("b", -1000, 1) var_vals = {a.expr: 1, b.expr: -1000} - assert sym_infer(a%b, var_vals) == 1 - assert sym_infer(a//b, var_vals) == 0 + # floor: 1 % -1000 = -999, 1 // -1000 = -1 + assert sym_infer(a%b, var_vals) == -999 + assert sym_infer(a//b, var_vals) == -1 def test_sym_infer_with_bitcast(self): a = Variable("a", 1, 10, dtypes.int) expr = ((a.bitcast(dtypes.uint) << UOp.const(dtypes.uint, 1)).bitcast(dtypes.int) + 2) @@ -1286,7 +1288,8 @@ class TestGatedUopGivenValid(unittest.TestCase): idx:UOp = (r0 < 3).where((r0 + uconst(-1)) // uconst(3), UOp.invalid()) idx = graph_rewrite(idx, pm_simplify_valid) - self.assertEqual(idx, (r0 < 3).where(uconst(0), UOp.invalid())) + # (r0-1)//3 = (r0+2)//3 - 1 (constant offset split) + self.assertEqual(idx, (r0 < 3).where((r0 + uconst(2)) // uconst(3) + uconst(-1), UOp.invalid())) def test_invalid_gate_simplifies_vectorize(self): r0 = Variable("r0", 0, 2) @@ -1295,8 +1298,8 @@ class TestGatedUopGivenValid(unittest.TestCase): idx1 = r0 % uconst(3) idx:UOp = (r0 < 3).where(UOp(Ops.STACK, dtypes.weakint.vec(2), (idx0, idx1)), UOp.invalid()) idx = graph_rewrite(idx, pm_simplify_valid) - # NOTE: independent simplification: (r0-1)//3 -> 0, r0%3 -> r0 when r0 in [0,2] - expected_vec = UOp(Ops.STACK, dtypes.weakint.vec(2), (uconst(0), r0)) + # independent simplification: (r0-1)//3 -> (r0+2)//3 - 1, and r0%3 -> r0 when r0 in [0,2] + expected_vec = UOp(Ops.STACK, dtypes.weakint.vec(2), ((r0 + uconst(2)) // uconst(3) + uconst(-1), r0)) self.assertEqual(idx, (r0 < 3).where(expected_vec, UOp.invalid())) class TestRangeSplitting(unittest.TestCase): @@ -1335,8 +1338,8 @@ class TestBounds(unittest.TestCase): alu0 = gidx0 * -1 assert alu0.vmin == -2559 and alu0.vmax == 0 assert (alu0+2559).vmin == 0 and (alu0+2559).vmax == 2559 - assert ((alu0+2559)//-4).vmin == -639 and ((alu0+2559)//-4).vmax == 0 - assert (((alu0+2559)//-4)*(-1)).vmin == 0 and (((alu0+2559)//-4)*(-1)).vmax == 639 + assert ((alu0+2559)//-4).vmin == -640 and ((alu0+2559)//-4).vmax == 0 + assert (((alu0+2559)//-4)*(-1)).vmin == 0 and (((alu0+2559)//-4)*(-1)).vmax == 640 class TestFuzzFailure(unittest.TestCase): def test_fuzz_failure1(self): diff --git a/test/null/test_uop_vmin_vmax.py b/test/null/test_uop_vmin_vmax.py index 583326c2b2..0e4c1ed9b9 100644 --- a/test/null/test_uop_vmin_vmax.py +++ b/test/null/test_uop_vmin_vmax.py @@ -173,17 +173,15 @@ class TestVminVmaxDivMod(unittest.TestCase): self.assertEqual(uop.vmax, 10) def test_vmin_vmax_division_negative(self): - # vmin and vmax for division of a variable by a negative constant - # always positive + # floor division of a variable by a negative constant x = UOp.variable('x', 10, 20) uop = x // -2 self.assertEqual(uop.vmin, -10) self.assertEqual(uop.vmax, -5) uop = x // -3 - self.assertEqual(uop.vmin, -6) - self.assertEqual(uop.vmax, -3) + self.assertEqual(uop.vmin, -7) + self.assertEqual(uop.vmax, -4) - # always negative x = UOp.variable('x', -20, -10) uop = x // -2 self.assertEqual(uop.vmin, 5) @@ -193,7 +191,6 @@ class TestVminVmaxDivMod(unittest.TestCase): self.assertEqual(uop.vmax, 6) def test_vmin_vmax_floordiv_floormod(self): - # FLOORDIV/FLOORMOD ranges differ from IDIV/MOD when the dividend can be negative x = UOp.variable('x', -7, 7) floordiv = x.alu(Ops.FLOORDIV, x.const_like(3)) self.assertEqual(floordiv.vmin, -3) @@ -212,32 +209,42 @@ class TestVminVmaxDivMod(unittest.TestCase): self.assertEqual(uop.vmin, -5) self.assertEqual(uop.vmax, 5) uop = x // -3 - self.assertEqual(uop.vmin, -3) + self.assertEqual(uop.vmin, -4) self.assertEqual(uop.vmax, 3) + def test_vmin_vmax_floordiv_floormod_empty_range(self): + # empty numerator range (vmin > vmax, e.g. RANGE with end=0) short-circuits to (0, 0) + rng = UOp.range(0, 0) + self.assertEqual(rng.vmin, 0) + self.assertEqual(rng.vmax, -1) + self.assertEqual((rng // 4).vmin, 0) + self.assertEqual((rng // 4).vmax, 0) + self.assertEqual((rng % 4).vmin, 0) + self.assertEqual((rng % 4).vmax, 0) + def test_vmin_vmax_div_symbolic(self): x = UOp.variable('x', 1, 10) y = UOp.variable('y', 3, 5) self.assertEqual((x//y).vmin, 0) self.assertEqual((x//y).vmax, 3) - self.assertEqual(((-x)//y).vmin, -3) - self.assertEqual(((-x)//y).vmax, 0) - self.assertEqual((x//(-y)).vmin, -3) - self.assertEqual((x//(-y)).vmax, 0) + self.assertEqual(((-x)//y).vmin, -4) + self.assertEqual(((-x)//y).vmax, -1) + self.assertEqual((x//(-y)).vmin, -4) + self.assertEqual((x//(-y)).vmax, -1) self.assertEqual(((-x)//(-y)).vmin, 0) self.assertEqual(((-x)//(-y)).vmax, 3) self.assertEqual((100//y).vmin, 20) self.assertEqual((100//y).vmax, 33) - self.assertEqual(((-100)//y).vmin, -33) + self.assertEqual(((-100)//y).vmin, -34) self.assertEqual(((-100)//y).vmax, -20) - self.assertEqual((100//(-y)).vmin, -33) + self.assertEqual((100//(-y)).vmin, -34) self.assertEqual((100//(-y)).vmax, -20) self.assertEqual(((-100)//(-y)).vmin, 20) self.assertEqual(((-100)//(-y)).vmax, 33) def test_vmin_vmax_mod_positive(self): - # vmin and vmax for modulo of a variable by a positive constant + # floor mod with positive divisor: result in [0, c-1] regardless of dividend sign positive = UOp.variable('positive', 10, 20) uop = positive % 3 self.assertEqual(uop.vmin, 0) @@ -245,20 +252,20 @@ class TestVminVmaxDivMod(unittest.TestCase): negative = UOp.variable('negative', -20, -10) uop = negative % 3 - self.assertEqual(uop.vmin, -2) - self.assertEqual(uop.vmax, 0) + self.assertEqual(uop.vmin, 0) + self.assertEqual(uop.vmax, 2) mixed = UOp.variable('mixed', -20, 20) uop = mixed % 3 - self.assertEqual(uop.vmin, -2) + self.assertEqual(uop.vmin, 0) self.assertEqual(uop.vmax, 2) def test_vmin_vmax_mod_negative(self): - # vmin and vmax for modulo of a variable by a negative constant + # floor mod with negative divisor: result in [c+1, 0] regardless of dividend sign positive = UOp.variable('positive', 10, 20) uop = positive % -3 - self.assertEqual(uop.vmin, 0) - self.assertEqual(uop.vmax, 2) + self.assertEqual(uop.vmin, -2) + self.assertEqual(uop.vmax, 0) negative = UOp.variable('negative', -20, -10) uop = negative % -3 @@ -268,7 +275,7 @@ class TestVminVmaxDivMod(unittest.TestCase): mixed = UOp.variable('mixed', -20, 20) uop = mixed % -3 self.assertEqual(uop.vmin, -2) - self.assertEqual(uop.vmax, 2) + self.assertEqual(uop.vmax, 0) class TestVminVmaxVConst(unittest.TestCase): def test_vmin_vmax_vconst_single_element(self): diff --git a/test/null/test_uops.py b/test/null/test_uops.py index 28e0784ff5..3c5043bfb5 100644 --- a/test/null/test_uops.py +++ b/test/null/test_uops.py @@ -177,6 +177,18 @@ class TestFastIdiv(unittest.TestCase): self.assertIn(Ops.SHR, ops, f"For dtype={dt} divison by power of two did not simplify to shift") self.assertNotIn(Ops.IDIV, ops, f"For dtype={dt} divison by power of two did not simplify to shift") + def test_floordiv_power_of_two_uint(self): + # uint FLOORDIV by a power of two lowers to a shift, leaving no IDIV/FLOORDIV in the kernel + for dt in (dtypes.uint32, dtypes.uint64): + g = UOp(Ops.PARAM, dt.ptr(), (), 0) + c = UOp.const(dt, 2) + a = UOp(Ops.FLOORDIV, dt, (g.index(c), c)) + uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer) + ops = [x.op for x in uops] + self.assertIn(Ops.SHR, ops, f"For dtype={dt} FLOORDIV by power of two did not simplify to shift") + self.assertNotIn(Ops.IDIV, ops, f"For dtype={dt} FLOORDIV by power of two did not simplify to shift") + self.assertNotIn(Ops.FLOORDIV, ops, f"For dtype={dt} FLOORDIV survived past late rewrite") + @unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't support long") def test_fast_idiv_and_mod(self): g = UOp(Ops.PARAM, dtypes.uint32.ptr(), (), 0) diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index 145071c312..29ca29f403 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -17,7 +17,8 @@ pm_flatten_range = PatternMatcher([ (UPat((Ops.REDUCE, Ops.END), name="r"), flatten_range), ]) -def count_divmod(x:UOp) -> int: return sum(u.op in {Ops.IDIV, Ops.MOD} for u in x.backward_slice) +# index/range arithmetic uses FLOORDIV/FLOORMOD prior to late rewrite +def count_divmod(x:UOp) -> int: return sum(u.op in {Ops.FLOORDIV, Ops.FLOORMOD} for u in x.backward_slice) def simplify_merge_adjacent(u:UOp) -> UOp|None: reduce_ranges = [x.ranges for x in u.backward_slice_with_self if x.op is Ops.REDUCE] # on END we only want to merge adjacent ranges, on REDUCE we want to try all combinations diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index bbe84e5602..a699eba15f 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -241,7 +241,7 @@ SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), Contex RING, ALL2ALL, ALLREDUCE_CAST = ContextVar("RING", 1), ContextVar("ALL2ALL", 0), ContextVar("ALLREDUCE_CAST", 1) CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1) VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0) -CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0) +FUSE_OPTIM = ContextVar("FUSE_OPTIM", 0) ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0) MAX_KERNEL_BUFFERS = ContextVar("MAX_KERNEL_BUFFERS", 0) EMULATED_DTYPES = ContextVar("EMULATED_DTYPES", "") diff --git a/tinygrad/mixin/elementwise.py b/tinygrad/mixin/elementwise.py index 1f55b0afd5..6a7ec473de 100644 --- a/tinygrad/mixin/elementwise.py +++ b/tinygrad/mixin/elementwise.py @@ -181,7 +181,7 @@ class ElementwiseMixin(DTypeMixin, CreationMixin): return self._binop(Ops.IDIV, x, reverse) def mod(self, x: Self | ConstType, reverse: bool = False) -> Self: - return self._binop(Ops.MOD, x, reverse) + return self._binop(Ops.FLOORMOD, x, reverse) def div(self, x: Self | ConstType, reverse: bool = False) -> Self: lhs, rhs = self._broadcasted(x, reverse) @@ -206,7 +206,7 @@ class ElementwiseMixin(DTypeMixin, CreationMixin): return self.div(x) def __floordiv__(self, x: Self | ConstType) -> Self: - return self.idiv(x) # TODO: idiv is trunc div, not floordiv + return self._binop(Ops.FLOORDIV, x, False) def __mod__(self, x: Self | ConstType) -> Self: return self.mod(x) @@ -233,7 +233,7 @@ class ElementwiseMixin(DTypeMixin, CreationMixin): return self.div(x, True) def __rfloordiv__(self, x: Self | ConstType) -> Self: - return self.idiv(x, True) + return self._binop(Ops.FLOORDIV, x, True) def __rand__(self, x: Self | ConstType) -> Self: return self.bitwise_and(x, True) diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index 435f1cfcba..37d3cc7759 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -290,8 +290,10 @@ def fast_idiv(target: Target, x: UOp, d: int, dont_cast=False) -> UOp|None: if m*vmin >= x.dtype.min and m*vmax <= x.dtype.max: return ((x*m) >> s) if is_unsigned else ((x*m) >> s) + (x<0).where(x.ufix(1), 0) # before we try casting to a larger dtype (slow), we see if there are powers of two in d we can shift to make x smaller + # use explicit Ops.IDIV (trunc) since the recursion assumes trunc semantics throughout if (largest_factor_of_two_in_d := (d & -d)) > 1: - if (ret:=fast_idiv(target, x//largest_factor_of_two_in_d, d//largest_factor_of_two_in_d, dont_cast=True)) is not None: return ret + if (ret:=fast_idiv(target, x.alu(Ops.IDIV, x.const_like(largest_factor_of_two_in_d)), + d//largest_factor_of_two_in_d, dont_cast=True)) is not None: return ret if dont_cast: return None # promo_lattice needs to return an unsigned type if the type is unsigned if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and is_dtype_supported(next_dtype, target): @@ -459,22 +461,30 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> Pa if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32)) # MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends) if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0]))) - # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1) - # TODO: drop the x.vmin>=0 guard once UOp `%` lowers to FLOORMOD instead of MOD + # rewrite FLOORMOD to AND on power-of-2 const: x % (2**y) -> x & (2**y-1) (correct floor mod for any sign in two's complement) if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), - lambda x,c: x & (c.arg-1) if c.arg in powers_of_two and x.vmin >= 0 else None)] + lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)] if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(), lambda x,y: (x | y).logical_not())] # rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y) if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)] if Ops.SHR in ops: - # no reason to check x<0 for uints - pat += [(UPat.var("x", dtypes.uints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] - pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where( - c-1, 0)) >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] # (x+(x<0).where(c-1, 0)) >> v + # uint floor==trunc, so safe for both ops + pat += [(UPat((Ops.IDIV, Ops.FLOORDIV), src=(UPat.var("x", dtypes.uints), UPat.cvar("c"))), + lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] + # signed FLOORDIV by 2**v -> (x + (x<0 ? c-1 : 0)) >> v + # signed IDIV (trunc) by 2**v -> (x + (x<0 ? c-1 : 0)) >> v; only correct for trunc, so match raw Ops.IDIV + pat += [(UPat(Ops.IDIV, src=(UPat.var("x", dtypes.ints), UPat.cvar("c"))), + lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where(c-1, 0)) >> v + if (v:=powers_of_two.get(c.arg, 0)) else None)] if not disable_fast_idiv: - pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d", vec=False), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))] - pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))] + # fast_idiv handles non-pow2: only fire on non-negative inputs (signed magic-mul is unreliable for x<0) + pat += [(UPat(Ops.IDIV, src=(UPat.var("x", dtypes.ints), UPat.cvar("d", vec=False))), + lambda ctx, x, d: fast_idiv(ctx, x, d.arg) if x.vmin >= 0 or x.dtype in dtypes.uints else None)] + # rewrite raw MOD -> x - d*IDIV(x,d) so fast_idiv can pick up the IDIV. only on non-negative inputs; + # avoids disturbing floormod_to_mod's general-path output (which uses a trunc Ops.MOD as an implementation detail) + pat += [(UPat(Ops.MOD, src=(UPat.var("x", dtypes.ints), UPat.var("d"))), + lambda x, d: x - d * x.alu(Ops.IDIV, d) if x.vmin >= 0 or x.dtype in dtypes.uints else None)] if Ops.NEG in ops: pat += [(UPat.var('x')*-1, lambda ctx,x: x.alu(Ops.NEG))] if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda ctx,x,y: x.alu(Ops.SUB, y))] diff --git a/tinygrad/uop/divandmod.py b/tinygrad/uop/divandmod.py index 92d697ae5c..f8012195d2 100644 --- a/tinygrad/uop/divandmod.py +++ b/tinygrad/uop/divandmod.py @@ -1,19 +1,19 @@ import functools, itertools, math from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp from tinygrad.dtype import dtypes -from tinygrad.helpers import cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap +from tinygrad.helpers import floordiv, floormod, unwrap # NOTE: this cache is only on index UOps @functools.cache -def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: +def fold_divmod_general(d: UOp) -> UOp|None: x, y = d.src # cancel_divmod: simple cancel div/mod case when the range of the numerator lies within a single denominator interval x_min, x_max, y_min, y_max = x.vmin, x.vmax, y.vmin, y.vmax assert isinstance(x_min, int) and isinstance(x_max, int) and isinstance(y_min, int) and isinstance(y_max, int) - if y_min==y_max==0: raise ZeroDivisionError(f"{'Division' if d.op is Ops.IDIV else 'Mod'} by zero trying to rewrite {x.alu(d.op, y)}") - if y_min*y_max > 0 and (qv:=cdiv(x_min,y_min)) == cdiv(x_min,y_max) == cdiv(x_max,y_min) == cdiv(x_max,y_max): - return x - qv*y if d.op is Ops.MOD else d.const_like(qv) + if y_min==y_max==0: raise ZeroDivisionError(f"{'Division' if d.op is Ops.FLOORDIV else 'Mod'} by zero trying to rewrite {x.alu(d.op, y)}") + if y_min*y_max > 0 and (qv:=floordiv(x_min,y_min)) == floordiv(x_min,y_max) == floordiv(x_max,y_min) == floordiv(x_max,y_max): + return x - qv*y if d.op is Ops.FLOORMOD else d.const_like(qv) # split uops for the rest of the processing x_peeled, const = x.pop_const() @@ -22,19 +22,20 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: # ** Constant Denominator Rules ** # these rules strictly require y to be a scalar constant > 0 if y.op is Ops.CONST and (c := y.arg) > 0: - # nested_div_mod: (x%(k*c))//c -> (x//c)%k, and (x%(k*c))%c -> x%c - if x.op is Ops.MOD and (k := x.src[1].divides(c)) is not None: - return x.src[0] // y % k if d.op is Ops.IDIV else x.src[0] % y + # nested_div_mod: (x%(k*c))//c -> (x//c)%k (requires k>0), and (x%(k*c))%c -> x%c + if x.op is Ops.FLOORMOD and (k := x.src[1].divides(c)) is not None: + if d.op is Ops.FLOORMOD: return x.src[0] % y + if k > 0: return x.src[0] // y % k - # remove_nested_mod in sum: (a%4 + b)%2 -> (a+b)%2, requires non-negative sums - if d.op is Ops.MOD and x.vmin >= 0: + # remove_nested_mod in sum: (a%4 + b)%2 -> (a+b)%2 + if d.op is Ops.FLOORMOD: new_xs, changed = [], False for u in uops_no_const: - if u.op is Ops.MOD and u.src[1].divides(c) is not None: + if u.op is Ops.FLOORMOD and u.src[1].divides(c) is not None: u = u.src[0] changed = True new_xs.append(u) - if changed and (new_x:=(UOp.usum(*new_xs) + const)).vmin >= 0: return new_x % y + if changed: return (UOp.usum(*new_xs) + const) % y # Shared decomposition for folding rules decomp = [(u.divides(f:=u.const_factor()),f) for u in uops_no_const] @@ -42,32 +43,31 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: # fold_binary_numerator: fold if expression has one non-constant term that takes on two values if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1: - y1 = (cmod if d.op is Ops.MOD else cdiv)(factors[0]*v.vmin+const, c) - y2 = (cmod if d.op is Ops.MOD else cdiv)(factors[0]*v.vmax+const, c) + y1 = (floormod if d.op is Ops.FLOORMOD else floordiv)(factors[0]*v.vmin+const, c) + y2 = (floormod if d.op is Ops.FLOORMOD else floordiv)(factors[0]*v.vmax+const, c) return (y2-y1)*(v-v.vmin) + y1 # fold_divmod_congruence: fold if a is congruent to an expression whose range is between 0 and c - if not (x.vmin<0 and correct_divmod_folding): - # when f%c == c//2, abs(r) == abs(r-c) is a tie, try both signs since either may fit in one period - rem_choices = [(r, r-c) if (r:=f%c)*2 == c else (min(r, r-c, key=abs),) for f in factors] - for rems in itertools.product(*rem_choices): - if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c==rem.vmax//c: - if d.op is Ops.MOD: return rem - rem.vmin//c*c - return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + const//c + rem.vmin//c + # when f%c == c//2, abs(r) == abs(r-c) is a tie, try both signs since either may fit in one period + rem_choices = [(r, r-c) if (r:=f%c)*2 == c else (min(r, r-c, key=abs),) for f in factors] + for rems in itertools.product(*rem_choices): + if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c==rem.vmax//c: + if d.op is Ops.FLOORMOD: return rem - rem.vmin//c*c + return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + const//c + rem.vmin//c # gcd_with_remainder: factor out common gcd from numerator if x.vmin >= 0 and (g:=math.gcd(*factors, c)) > 1: new_x = unwrap(x_peeled.divides(g)).simplify() + (const//g)%(c//g) if new_x.vmin >= 0: - if d.op is Ops.MOD: return new_x % (c//g) * g + const%g + if d.op is Ops.FLOORMOD: return new_x % (c//g) * g + const%g return new_x // (c//g) + const//c # nest_by_factor: x//c -> (x//f)//(c//f), x%c -> (x//f%(c//f))*f + b where b=x%f if x.vmin >= 0: results = [] for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}: - if (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0: - if d.op is Ops.IDIV: + if (newxs := fold_divmod_general(x//div)) is not None and newxs.vmin >= 0: + if d.op is Ops.FLOORDIV: results.append((len(newxs.backward_slice), newxs // (c // div))) else: b_parts = [f%div*t for f, t in zip(factors, terms) if f%div] @@ -86,7 +86,7 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: gcd = UOp.gcd(*all_uops, y).simplify() if not (gcd.op is Ops.CONST and gcd.arg==1): ret = unwrap(x.divide_exact(gcd)).alu(d.op, unwrap(y.divide_exact(gcd))) - return ret*gcd if d.op is Ops.MOD else ret + return ret*gcd if d.op is Ops.FLOORMOD else ret # factor_remainder: (d*x+y)//d -> x+y//d if y.vmin<0 or x.vmin<0: return None @@ -95,29 +95,22 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: if (q:=u.divide_exact(y)) is not None: quo.append(q) elif y.op is Ops.CONST and (c:=u.const_factor())%y.arg!=c: rem.append(u.divides(c)*(c%y.arg)) - quo.append(u.divides(c)*(c//y.arg) if d.op is Ops.IDIV else u.const_like(0)) + quo.append(u.divides(c)*(c//y.arg) if d.op is Ops.FLOORDIV else u.const_like(0)) else: rem.append(u) if not quo: return None new_x = sum(rem)+x.const_like(0) if new_x.vmin<0: return None - return new_x%y if d.op is Ops.MOD else new_x//y+sum(quo) + return new_x%y if d.op is Ops.FLOORMOD else new_x//y+sum(quo) div_and_mod_symbolic = PatternMatcher([ # ** 1. Fast Inline Rules ** - ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d) - if c.vmin>0 and d.vmin>0 and x.vmin>=0 and a.vmin>=0 else None), # (x//c+a)//d -> (x+a*c)//(c*d) - (UPat.var("x", dtypes.weakint) // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None), - (UPat.var("x", dtypes.weakint) // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <= 0 else None), - ((UPat.var("x", dtypes.weakint)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), - lambda x,c,n,d: ((x+c.arg%d.arg)//d + c.arg//d.arg) if c.arg%d.arg!=c.arg and x.vmin>=0 and n.vmin>=0 and d.arg>0 else None), - ((UPat.var("x", dtypes.weakint)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), - lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None), + # (x//c+a)//d -> (x+a*c)//(c*d) for c>0, d>0 + ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d) if c.vmin>0 and d.vmin>0 else None), + # (x+c)//d -> (x+c%d)//d + c//d for d>0 (split out the multiple of d in the constant) + ((UPat.var("x", dtypes.weakint)+UPat.cvar("c", vec=False))//UPat.cvar("d", vec=False), + lambda x,c,d: (x+c.arg%d.arg)//d + c.arg//d.arg if c.arg%d.arg!=c.arg and d.arg>0 else None), # ** 2. Slow Rules ** - (UPat((Ops.IDIV, Ops.MOD), dtypes.weakint, name="d"), lambda d: fold_divmod_general(d, bool(CORRECT_DIVMOD_FOLDING))), - - # NOTE: these have to go at the bottom or TestSymbolicOps.test_var loops - (UPat.var("x", dtypes.weakint) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None), - (UPat.var("x", dtypes.weakint) % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None), -]) \ No newline at end of file + (UPat((Ops.FLOORDIV, Ops.FLOORMOD), dtypes.weakint, name="d"), lambda d: fold_divmod_general(d)), +]) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 042879772c..8a1ad13bec 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -869,9 +869,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return min(vals:=(cdiv(s0_vmin, s1_vmin), cdiv(s0_vmin, s1_vmax), cdiv(s0_vmax, s1_vmin), cdiv(s0_vmax, s1_vmax))), max(vals) if self.op is Ops.FLOORDIV: assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int) + if s0_vmin > s0_vmax: return 0, 0 # numerator range is empty (e.g. RANGE with end=0) if s1_vmin*s1_vmax>0: return min(vals:=(s0_vmin//s1_vmin, s0_vmin//s1_vmax, s0_vmax//s1_vmin, s0_vmax//s1_vmax)), max(vals) if self.op is Ops.FLOORMOD: assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int) + if s0_vmin > s0_vmax: return 0, 0 # numerator range is empty (e.g. RANGE with end=0) if (c:=s1_vmin) == s1_vmax > 0: return (s0_vmin%c, s0_vmax%c) if s0_vmin//c == s0_vmax//c else (0, c-1) if (c:=s1_vmin) == s1_vmax < 0: return (s0_vmin%c, s0_vmax%c) if s0_vmin//c == s0_vmax//c else (c+1, 0) if s1_vmin > 0: return (0, s1_vmax-1) @@ -906,7 +908,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): # TODO: sanitize varnames, or don't use naked eval while staying fast ret = _render_with_splits(list(sself.toposort()), renderer_infer, {sself}) lines = [f" {k}={v}" for k,v in ret.items() if k != "ast"] + [f" return {ret['ast']}"] - ns: dict[str, Any] = {"max": max, "cdiv": cdiv, "cmod": cmod, "bitcast": bitcast, "dtypes": dtypes} + ns: dict[str, Any] = {"max": max, "cdiv": cdiv, "cmod": cmod, "floordiv": floordiv, "floormod": floormod, "bitcast": bitcast, "dtypes": dtypes} exec(f"def _f({','.join(varnames)}):\n"+'\n'.join(lines), ns) # pylint: disable=exec-used return ns["_f"], varnames diff --git a/tinygrad/uop/render.py b/tinygrad/uop/render.py index fc37e0e0d0..7b3c58458c 100644 --- a/tinygrad/uop/render.py +++ b/tinygrad/uop/render.py @@ -23,10 +23,10 @@ def print_uops(uops:list[UOp]): print(f"{i:4d} {str(u.op):20s}: {multirange_str(u.ranges, color=True, pad=10)} {str(u.dtype):40s} " f"{str(formatted_srcs):32s} {u.arg}") # for debug -syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>", +syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.FLOORDIV: "//", Ops.FLOORMOD: "%", Ops.SHL: "<<", Ops.SHR: ">>", Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"} # comparison operators are not in here because they are chained in python, not left-associative -precedence = {Ops.MUL:1, Ops.IDIV:1, Ops.MOD:1, Ops.ADD:2, Ops.SUB:2, Ops.SHL:3, Ops.SHR:3, Ops.AND:4, Ops.XOR:5, Ops.OR:6} +precedence = {Ops.MUL:1, Ops.FLOORDIV:1, Ops.FLOORMOD:1, Ops.ADD:2, Ops.SUB:2, Ops.SHL:3, Ops.SHR:3, Ops.AND:4, Ops.XOR:5, Ops.OR:6} def strip_binary_parens(x:UOp, left:str, right:str, code_for_op) -> str: if x.op not in precedence: return code_for_op(left, right) return code_for_op(strip_parens(left) if precedence.get(x.src[0].op,99)<=precedence[x.op] else left, strip_parens(right) if @@ -46,6 +46,8 @@ renderer = PatternMatcher([ (UPat(Ops.MAX, name="x"), lambda ctx,x: f"max({ctx[x.src[0]]}, {ctx[x.src[1]]})"), (UPat(Ops.MULACC, name="x"), lambda ctx,x: f"({ctx[x.src[0]]}*{ctx[x.src[1]]}+{ctx[x.src[2]]})"), (UPat(Ops.WHERE, name="x"), lambda ctx,x: f"({ctx[x.src[1]]} if {ctx[x.src[0]]} else {ctx[x.src[2]]})"), + (UPat(Ops.IDIV, name="x"), lambda ctx,x: f"cdiv({ctx[x.src[0]]}, {ctx[x.src[1]]})"), + (UPat(Ops.MOD, name="x"), lambda ctx,x: f"cmod({ctx[x.src[0]]}, {ctx[x.src[1]]})"), (UPat(set(syms.keys()), name="x"), lambda ctx,x: strip_binary_parens(x, ctx[x.src[0]], ctx[x.src[1]], lambda a,b: f"({a}{syms[x.op]}{b})")), (UPat((Ops.INDEX, Ops.BUFFERIZE), name="x"), lambda x, ctx: ''.join([f"[{strip_parens(ctx[y])}]" for y in x.src[1:]])), (UPat(Ops.STACK, name="x"), @@ -56,6 +58,8 @@ renderer = PatternMatcher([ renderer_infer = PatternMatcher([ (UPat(Ops.MOD, name="x"), lambda ctx,x: f"cmod({ctx[x.src[0]]}, {ctx[x.src[1]]})"), (UPat(Ops.IDIV, name="x"), lambda ctx,x: f"cdiv({ctx[x.src[0]]}, {ctx[x.src[1]]})"), + (UPat(Ops.FLOORMOD, name="x"), lambda ctx,x: f"floormod({ctx[x.src[0]]}, {ctx[x.src[1]]})"), + (UPat(Ops.FLOORDIV, name="x"), lambda ctx,x: f"floordiv({ctx[x.src[0]]}, {ctx[x.src[1]]})"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast({ctx[x.src[0]]}, {x.src[0].dtype!r}, {x.dtype!r})"), ]) + renderer @@ -99,13 +103,16 @@ pm_pyrender_extra = PatternMatcher([ # TODO: movement ops simplify stuff, this can break SPEC=2 #(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"), # NOTE: CMPNE doesn't work cause there's no __rne__ + # explicit trunc ops: `//` and `%` parse as FLOORDIV/FLOORMOD, so render IDIV/MOD via their named methods + (UPat(Ops.IDIV, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.idiv({ctx[x.src[1]]})"), + (UPat(Ops.MOD, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.alu(Ops.MOD, {ctx[x.src[1]]})"), # NOTE: only match CONSTs without UNIQUE (len(src)==1), unique_const needs explicit rendering - (UPat(set(syms.keys())-{Ops.SUB, Ops.CMPNE}, src=(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="y"), UPat(name="z")), name="x"), + (UPat(set(syms.keys())-{Ops.SUB, Ops.CMPNE, Ops.IDIV, Ops.MOD}, src=(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="y"), UPat(name="z")), name="x"), lambda ctx,x,y,z: strip_binary_parens(x, str(y.arg), ctx[z], lambda a,b: f"({a}{syms[x.op]}{b})")), # NOTE: sub doesn't work cause it's written as add/mul - (UPat(set(syms.keys())-{Ops.SUB}, src=(UPat(name="y"), UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="z")), name="x"), lambda ctx,x,y,z: - strip_binary_parens(x, ctx[y], str(z.arg), lambda a,b: f"({a}{syms[x.op]}{b})")), - (UPat(set(syms.keys())-{Ops.SUB}, name="x"), lambda ctx,x: + (UPat(set(syms.keys())-{Ops.SUB, Ops.IDIV, Ops.MOD}, src=(UPat(name="y"), UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="z")), name="x"), + lambda ctx,x,y,z: strip_binary_parens(x, ctx[y], str(z.arg), lambda a,b: f"({a}{syms[x.op]}{b})")), + (UPat(set(syms.keys())-{Ops.SUB, Ops.IDIV, Ops.MOD}, name="x"), lambda ctx,x: strip_binary_parens(x, ctx[x.src[0]], ctx[x.src[1]], lambda a,b: f"({a}{syms[x.op]}{b})")), (UPat(sugar, src=(), name="x"), lambda x: f"UOp.{x.op.name.lower()}("+', '.join(([f'arg={repr(x.arg)}'] if x.arg is not None else []))+")"), (UPat(sugar, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}("+', '.join([ctx[y] for y in x.src[1:]] + \ diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index beb8346695..105b5759c1 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -28,8 +28,8 @@ invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat) def fold_add_divmod_recombine(x:UOp) -> UOp|None: terms = list(x.split_uop(Ops.ADD)) for i,u in enumerate(terms): - if u.op is Ops.MOD and u.src[1].op is Ops.CONST: base, div, mul = u.src[0], u.src[1].arg, 1 - elif u.op is Ops.MUL and u.src[1].op is Ops.CONST and (m:=u.src[0]).op is Ops.MOD and m.src[1].op is Ops.CONST: + if u.op is Ops.FLOORMOD and u.src[1].op is Ops.CONST: base, div, mul = u.src[0], u.src[1].arg, 1 + elif u.op is Ops.MUL and u.src[1].op is Ops.CONST and (m:=u.src[0]).op is Ops.FLOORMOD and m.src[1].op is Ops.CONST: base, div, mul = m.src[0], m.src[1].arg, u.src[1].arg else: continue for j,v in enumerate(terms): @@ -37,13 +37,13 @@ def fold_add_divmod_recombine(x:UOp) -> UOp|None: if v.op is not Ops.MUL or v.src[1].op is not Ops.CONST or v.src[1].arg != div*mul: continue q, exact = v.src[0], False # (base%div)*mul + (base//div)*(div*mul) -> base*mul - if q.op is Ops.IDIV and q.src[1].op is Ops.CONST and q.src[1].arg == div: exact = q.src[0] is base + if q.op is Ops.FLOORDIV and q.src[1].op is Ops.CONST and q.src[1].arg == div: exact = q.src[0] is base # ((base//d)%div)*mul + (base//(d*div))*(div*mul) -> (base//d)*mul - if not exact and base.op is Ops.IDIV and base.src[1].op is Ops.CONST: - exact = q.op is Ops.IDIV and q.src[1].op is Ops.CONST and q.src[0] is base.src[0] and q.src[1].arg == base.src[1].arg*div + if not exact and base.op is Ops.FLOORDIV and base.src[1].op is Ops.CONST: + exact = q.op is Ops.FLOORDIV and q.src[1].op is Ops.CONST and q.src[0] is base.src[0] and q.src[1].arg == base.src[1].arg*div if exact: return (base*mul).usum(*[t for k,t in enumerate(terms) if k not in (i,j)]) # ((base//div)%d)*div + base%div -> base%(div*d) - if mul == 1 and div > 0 and q.op is Ops.MOD and q.src[1].op is Ops.CONST and (d:=q.src[1].arg) > 0 and q.src[0].op is Ops.IDIV: + if mul == 1 and div > 0 and q.op is Ops.FLOORMOD and q.src[1].op is Ops.CONST and (d:=q.src[1].arg) > 0 and q.src[0].op is Ops.FLOORDIV: if q.src[0].src[0] is base and q.src[0].src[1].op is Ops.CONST and q.src[0].src[1].arg == div: return (base % (div*d)).usum(*[t for k,t in enumerate(terms) if k not in (i,j)]) return None @@ -244,7 +244,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ ((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \ lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None), # ALU/variable min==max -> CONST - (UPat({Ops.CMPLT, Ops.CMPNE, Ops.IDIV, Ops.MOD, Ops.DEFINE_VAR, Ops.BIND, Ops.SPECIAL}, name="x"), + (UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.DEFINE_VAR, Ops.BIND, Ops.SPECIAL}, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), (UPat(Ops.RANGE, src=(UPat(Ops.CONST,)), name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), # max folding @@ -263,9 +263,9 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ # c0*x x0 ((UPat.var("x", dtype=dtypes.weakint)//UPat.cvar("d", vec=False)) 0 else x<(c.arg*d.arg-(d.arg-1))) if d.arg > 0 else None), + lambda x,d,c: x<(c.arg*d.arg) if d.arg > 0 else None), # ** move add/mul consts to end (NOTE: this is still happening before constant folding) ** ((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), ((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), @@ -408,7 +408,7 @@ pm_move_where_on_load = PatternMatcher([ def gated_given_valid(cond:UOp, x:UOp, i:UOp) -> UOp|None: if x.dtype.scalar() is not dtypes.weakint: return None # Skip if x contains DIV/MOD AND IMAGE mode is enabled -> image index e.g. openpilot - if IMAGE.value > 0 and x.op_in_backward_slice_with_self(Ops.IDIV, Ops.MOD): return None + if IMAGE.value > 0 and x.op_in_backward_slice_with_self(Ops.IDIV, Ops.MOD, Ops.FLOORDIV, Ops.FLOORMOD): return None return cond.where(uop_given_valid(cond, x, try_simplex=False), i) # TODO: this is O(number of WHERE * number of node)