mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
Fix calculation of vmin and vmax in multiplication when one src is negative and the other src has negative min and positive max (#7333)
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -38,6 +38,13 @@ class TestVminVmaxProperties(unittest.TestCase):
|
||||
self.assertEqual(uop.vmin, -15)
|
||||
self.assertEqual(uop.vmax, -6)
|
||||
|
||||
def test_vmin_vmax_with_negative_multiplication2(self):
|
||||
# vmin and vmax when multiplying by a negative number
|
||||
x = UOp.variable('x', -2, 5)
|
||||
uop = x * -3
|
||||
self.assertEqual(uop.vmin, -15)
|
||||
self.assertEqual(uop.vmax, 6)
|
||||
|
||||
def test_vmin_vmax_nested_min_max(self):
|
||||
# vmin and vmax with nested min/max operations
|
||||
x = UOp.variable('x', 0, 10)
|
||||
|
||||
@@ -380,14 +380,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is UOps.ALU and not dtypes.is_float(self.dtype):
|
||||
s0,s1,s2 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(3)]
|
||||
if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
|
||||
if self.arg is BinaryOps.MUL:
|
||||
# both are non-positive
|
||||
if (s0.vmax <= 0 and s1.vmax <= 0): return s0.vmax*s1.vmax, s0.vmin*s1.vmin
|
||||
# at least one is non-negative
|
||||
if (s0.vmin >= 0 or s1.vmin >= 0):
|
||||
Lmin, Lmax = (s0.vmin, s0.vmax) if s1.vmin >= 0 else (s0.vmax, s0.vmin)
|
||||
Rmin, Rmax = (s1.vmin, s1.vmax) if s0.vmin >= 0 else (s1.vmax, s1.vmin)
|
||||
return Lmin*Rmin, Lmax*Rmax
|
||||
if self.arg is BinaryOps.MUL: return min(vals:=(s0.vmin*s1.vmin, s0.vmin*s1.vmax, s0.vmax*s1.vmin, s0.vmax*s1.vmax)), max(vals)
|
||||
if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1
|
||||
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
|
||||
if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg
|
||||
|
||||
Reference in New Issue
Block a user