diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index cc643c4994..e0bbdad2d8 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -193,6 +193,17 @@ class TestSymbolic(unittest.TestCase): def test_sum_div_no_factor(self): self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)") + def test_mod_min_max(self): + self.helper_test_variable(Variable("x", 0, 10)%Variable("y", 1, 10), 0, 9, "(x%y)") + self.helper_test_variable(Variable("x", -10, 0)%Variable("y", 1, 10), -9, 0, "(((x*-1)%y)*-1)") + self.helper_test_variable(Variable("x", 0, 10)%Variable("y", -10, -1), 0, 9, "(x%y)") + self.helper_test_variable(Variable("x", -10, 0)%Variable("y", -10, -1), -9, 0, "(((x*-1)%y)*-1)") + self.helper_test_variable(Variable("x", -10, 10)%Variable("y", -10, -1), -9, 9, "(x%y)") + + # test _min_max directly without the rewrite taking out the sign + self.assertEqual((Variable("x", -10, 0)%Variable("y", -10, -1))._min_max, (-9, 0)) + self.assertEqual((Variable("x", -10, 0)%Variable("y", 1, 10))._min_max, (-9, 0)) + def test_mod_factor(self): self.helper_test_variable(usum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "((b%2)*50)") diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 25d8192e8e..accc7ac163 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -466,7 +466,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2] if self.op is Ops.MOD: if s1_vmin > 0: return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), 0) if s0_vmax <= 0 else (-(s1_vmax-1), s1_vmax-1) - if s1_vmax < 0: return (0, -s1_vmax-1) if s0_vmin >= 0 else (-(-s1_vmax-1), 0) if s0_vmax <= 0 else (-(-s1_vmax-1), -s1_vmax-1) + if s1_vmax < 0: return (0, -s1_vmin-1) if s0_vmin >= 0 else (-(-s1_vmin-1), 0) if s0_vmax <= 0 else (-(-s1_vmin-1), -s1_vmin-1) if self.op is Ops.IDIV: assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int) if (c:=s1_vmin) == s1_vmax: # s1 is a const