From c398f2467cfd41ccae4436acc98fe0d423f03078 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 28 Oct 2024 17:52:01 -0400 Subject: [PATCH] test uop mul min/max do not have nan in 0*inf (#7340) --- test/unit/test_uop_vmin_vmax.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/test/unit/test_uop_vmin_vmax.py b/test/unit/test_uop_vmin_vmax.py index 442ed6bef5..222f4f31a5 100644 --- a/test/unit/test_uop_vmin_vmax.py +++ b/test/unit/test_uop_vmin_vmax.py @@ -1,5 +1,6 @@ -import unittest -from tinygrad.ops import UOp, dtypes +import unittest, math +from tinygrad.ops import UOp, UOps +from tinygrad.dtype import dtypes class TestVminVmaxProperties(unittest.TestCase): def test_vmin_vmax_constant(self): @@ -31,6 +32,15 @@ class TestVminVmaxProperties(unittest.TestCase): self.assertEqual(uop.vmin, -6) self.assertEqual(uop.vmax, 8) + def test_vmin_vmax_multiplication_0_inf(self): + # vmin and vmax for multiplication with a variable + x = UOp.const(dtypes.float, 0.0) + y = UOp.load(UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0), UOp.const(dtypes.int, 0), dtype=dtypes.float) + uop = x * y + # TODO: these should be 0, but definitely should not be nan + self.assertEqual(uop.vmin, -math.inf) + self.assertEqual(uop.vmax, math.inf) + def test_vmin_vmax_with_negative_multiplication(self): # vmin and vmax when multiplying by a negative number x = UOp.variable('x', 2, 5)