From 75dcd98e7975c55fe4b8636e41c9b142e9daec1f Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Mon, 28 Oct 2024 21:01:46 +0100 Subject: [PATCH] Fix calculation of vmin and vmax in multiplication when one src is negative and the other src has negative min and positive max (#7333) Co-authored-by: chenyu --- test/unit/test_uop_vmin_vmax.py | 7 +++++++ tinygrad/ops.py | 9 +-------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/unit/test_uop_vmin_vmax.py b/test/unit/test_uop_vmin_vmax.py index fbac2895ae..442ed6bef5 100644 --- a/test/unit/test_uop_vmin_vmax.py +++ b/test/unit/test_uop_vmin_vmax.py @@ -38,6 +38,13 @@ class TestVminVmaxProperties(unittest.TestCase): self.assertEqual(uop.vmin, -15) self.assertEqual(uop.vmax, -6) + def test_vmin_vmax_with_negative_multiplication2(self): + # vmin and vmax when multiplying by a negative number + x = UOp.variable('x', -2, 5) + uop = x * -3 + self.assertEqual(uop.vmin, -15) + self.assertEqual(uop.vmax, 6) + def test_vmin_vmax_nested_min_max(self): # vmin and vmax with nested min/max operations x = UOp.variable('x', 0, 10) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index dd9bae658b..bc267799ee 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -380,14 +380,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is UOps.ALU and not dtypes.is_float(self.dtype): s0,s1,s2 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(3)] if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax - if self.arg is BinaryOps.MUL: - # both are non-positive - if (s0.vmax <= 0 and s1.vmax <= 0): return s0.vmax*s1.vmax, s0.vmin*s1.vmin - # at least one is non-negative - if (s0.vmin >= 0 or s1.vmin >= 0): - Lmin, Lmax = (s0.vmin, s0.vmax) if s1.vmin >= 0 else (s0.vmax, s0.vmin) - Rmin, Rmax = (s1.vmin, s1.vmax) if s0.vmin >= 0 else (s1.vmax, s1.vmin) - return Lmin*Rmin, Lmax*Rmax + if self.arg is BinaryOps.MUL: return min(vals:=(s0.vmin*s1.vmin, s0.vmin*s1.vmax, s0.vmax*s1.vmin, s0.vmax*s1.vmax)), max(vals) if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1 if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST: if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg