From 1d89c018fa6ddcc8cf76227f5fbe53fe9f9fbfcb Mon Sep 17 00:00:00 2001 From: Paul Gustafson Date: Sun, 26 Nov 2023 13:05:04 -0800 Subject: [PATCH] Add isinstance check before gcd call in SumNode.__lt__ (#2450) * Add isinstance check before gcd call * Delete blank lines * Fix unit test typo * Delete blank lines again --------- Co-authored-by: Paul Gustafson --- test/unit/test_symbolic.py | 9 ++++++++- tinygrad/shape/symbolic.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index 4d393f8a3a..83a3460085 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -394,6 +394,14 @@ class TestSymbolicSymbolicOps(unittest.TestCase): assert a < 3 and (a < 3).min == 0 and (a < 3).max == 1 assert a > 3 and (a > 3).min == 0 and (a > 3).max == 1 + def test_sumnode_mulnode_lt(self): + a = Variable("a", 1, 2) + b = Variable("b", 1, 2) + c = Variable("c", 1, 2) + x = SumNode([MulNode(a, b), c]) + assert isinstance((x < 3), Node) and (x < 3) == 0 + assert isinstance((x < 4), LtNode) and (x < 4).min == 0 and (x < 4).max == 1 + def test_num_node_mul_node(self): a = Variable("a", 1, 5) b = NumNode(2) * a @@ -448,6 +456,5 @@ class TestSymbolicSymbolicOps(unittest.TestCase): c = b.substitute({a: NumNode(1)}) assert c == NumNode(2) - if __name__ == '__main__': unittest.main() diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 3530c6aa98..d3b84cce6b 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -290,7 +290,7 @@ class SumNode(RedNode): if muls: # NOTE: gcd in python 3.8 takes exactly 2 args mul_gcd = b - for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell x.b is int here + for x in muls: mul_gcd = gcd(mul_gcd, x.b) if isinstance(x.b, int) else 1 all_others = Variable.sum(others) if all_others.min >= 0 and all_others.max < mul_gcd: lhs, b = Variable.sum([mul//mul_gcd for mul in muls]), b//mul_gcd