From 0082300a596690e6333abe90d8a2093e7cb1d970 Mon Sep 17 00:00:00 2001 From: Patrick Tsai <5304405+patosai@users.noreply.github.com> Date: Sun, 3 Mar 2024 11:40:52 -0500 Subject: [PATCH] Fix symbolic negative floordiv (#3594) Co-authored-by: Patrick Tsai --- test/unit/test_symbolic.py | 3 ++- tinygrad/shape/symbolic.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index 749d245d2f..a73348ad38 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -122,7 +122,8 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Variable("a", 0, 7) // 2, 0, 3, "(a//2)") def test_div_neg_min_max(self): - self.helper_test_variable(Variable("a", 0, 7) // -2, -3, 0, "((a//2)*-1)") + self.helper_test_variable(Variable("a", 0, 7) // -2, -4, 0, "((((a*-1)+8)//2)+-4)") + self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "((((a*-1)+6)//2)+-3)") def test_sum_div_min_max(self): self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)") diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 1e44edb874..62924a1880 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -56,7 +56,7 @@ class Node: if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node raise RuntimeError(f"not supported: {self} // {b}") assert b != 0 - if b < 0: return (self//-b)*-1 + if b < 0: return (self*-1)//-b if b == 1: return self # the numerator of div is not allowed to be negative