mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
fix min_max and add test (#10952)
This commit is contained in:
@@ -193,6 +193,17 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_sum_div_no_factor(self):
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
|
||||
|
||||
def test_mod_min_max(self):
|
||||
self.helper_test_variable(Variable("x", 0, 10)%Variable("y", 1, 10), 0, 9, "(x%y)")
|
||||
self.helper_test_variable(Variable("x", -10, 0)%Variable("y", 1, 10), -9, 0, "(((x*-1)%y)*-1)")
|
||||
self.helper_test_variable(Variable("x", 0, 10)%Variable("y", -10, -1), 0, 9, "(x%y)")
|
||||
self.helper_test_variable(Variable("x", -10, 0)%Variable("y", -10, -1), -9, 0, "(((x*-1)%y)*-1)")
|
||||
self.helper_test_variable(Variable("x", -10, 10)%Variable("y", -10, -1), -9, 9, "(x%y)")
|
||||
|
||||
# test _min_max directly without the rewrite taking out the sign
|
||||
self.assertEqual((Variable("x", -10, 0)%Variable("y", -10, -1))._min_max, (-9, 0))
|
||||
self.assertEqual((Variable("x", -10, 0)%Variable("y", 1, 10))._min_max, (-9, 0))
|
||||
|
||||
def test_mod_factor(self):
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "((b%2)*50)")
|
||||
|
||||
|
||||
@@ -466,7 +466,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2]
|
||||
if self.op is Ops.MOD:
|
||||
if s1_vmin > 0: return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), 0) if s0_vmax <= 0 else (-(s1_vmax-1), s1_vmax-1)
|
||||
if s1_vmax < 0: return (0, -s1_vmax-1) if s0_vmin >= 0 else (-(-s1_vmax-1), 0) if s0_vmax <= 0 else (-(-s1_vmax-1), -s1_vmax-1)
|
||||
if s1_vmax < 0: return (0, -s1_vmin-1) if s0_vmin >= 0 else (-(-s1_vmin-1), 0) if s0_vmax <= 0 else (-(-s1_vmin-1), -s1_vmin-1)
|
||||
if self.op is Ops.IDIV:
|
||||
assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int)
|
||||
if (c:=s1_vmin) == s1_vmax: # s1 is a const
|
||||
|
||||
Reference in New Issue
Block a user