Fix calculation of vmin and vmax in multiplication when one src is negative and the other src has negative min and positive max (#7333)

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Sieds Lykles
2024-10-28 21:01:46 +01:00
committed by GitHub
parent 603fcc96f2
commit 75dcd98e79
2 changed files with 8 additions and 8 deletions

View File

@@ -38,6 +38,13 @@ class TestVminVmaxProperties(unittest.TestCase):
self.assertEqual(uop.vmin, -15)
self.assertEqual(uop.vmax, -6)
def test_vmin_vmax_with_negative_multiplication2(self):
# vmin and vmax when multiplying by a negative number
x = UOp.variable('x', -2, 5)
uop = x * -3
self.assertEqual(uop.vmin, -15)
self.assertEqual(uop.vmax, 6)
def test_vmin_vmax_nested_min_max(self):
# vmin and vmax with nested min/max operations
x = UOp.variable('x', 0, 10)

View File

@@ -380,14 +380,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is UOps.ALU and not dtypes.is_float(self.dtype):
s0,s1,s2 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(3)]
if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
if self.arg is BinaryOps.MUL:
# both are non-positive
if (s0.vmax <= 0 and s1.vmax <= 0): return s0.vmax*s1.vmax, s0.vmin*s1.vmin
# at least one is non-negative
if (s0.vmin >= 0 or s1.vmin >= 0):
Lmin, Lmax = (s0.vmin, s0.vmax) if s1.vmin >= 0 else (s0.vmax, s0.vmin)
Rmin, Rmax = (s1.vmin, s1.vmax) if s0.vmin >= 0 else (s1.vmax, s1.vmin)
return Lmin*Rmin, Lmax*Rmax
if self.arg is BinaryOps.MUL: return min(vals:=(s0.vmin*s1.vmin, s0.vmin*s1.vmax, s0.vmax*s1.vmin, s0.vmax*s1.vmax)), max(vals)
if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg