diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index 749d245d2f..a73348ad38 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -122,7 +122,8 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Variable("a", 0, 7) // 2, 0, 3, "(a//2)") def test_div_neg_min_max(self): - self.helper_test_variable(Variable("a", 0, 7) // -2, -3, 0, "((a//2)*-1)") + self.helper_test_variable(Variable("a", 0, 7) // -2, -4, 0, "((((a*-1)+8)//2)+-4)") + self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "((((a*-1)+6)//2)+-3)") def test_sum_div_min_max(self): self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)") diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 1e44edb874..62924a1880 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -56,7 +56,7 @@ class Node: if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node raise RuntimeError(f"not supported: {self} // {b}") assert b != 0 - if b < 0: return (self//-b)*-1 + if b < 0: return (self*-1)//-b if b == 1: return self # the numerator of div is not allowed to be negative