mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
fix divmod recombine for floordiv (#16062)
This commit is contained in:
@@ -745,6 +745,11 @@ class TestSymbolic(unittest.TestCase):
|
||||
a = Variable("a", 0, 10)
|
||||
self.helper_test_variable((25*a+3)%10 + ((25*a+3)//10)*10, 3, 253, "((a*25)+3)")
|
||||
|
||||
def test_div_mod_recombine_negative_div_unsound(self):
|
||||
# ((b//d)%div)*mul + (b//(d*div))*(div*mul) only equals (b//d)*mul when div>0
|
||||
b = Variable("b", -100, 100)
|
||||
self.helper_test_variable(((b//(-3))%(-2)) + (b//6)*(-2), -33, 34, "(b//6*-2+b//-3%-2)")
|
||||
|
||||
def test_mod_nest_by_factor(self):
|
||||
# (a*f+b) % (f*k) = (a%k)*f + b when 0<=b<f — mirrors nest_div_by_factor for FLOORMOD
|
||||
gidx0 = Variable("gidx0", 0, 15)
|
||||
|
||||
@@ -38,8 +38,8 @@ def fold_add_divmod_recombine(x:UOp) -> UOp|None:
|
||||
q, exact = v.src[0], False
|
||||
# (base%div)*mul + (base//div)*(div*mul) -> base*mul
|
||||
if q.op is Ops.FLOORDIV and q.src[1].op is Ops.CONST and q.src[1].arg == div: exact = q.src[0] is base
|
||||
# ((base//d)%div)*mul + (base//(d*div))*(div*mul) -> (base//d)*mul
|
||||
if not exact and base.op is Ops.FLOORDIV and base.src[1].op is Ops.CONST:
|
||||
# ((base//d)%div)*mul + (base//(d*div))*(div*mul) -> (base//d)*mul if div>0
|
||||
if not exact and div > 0 and base.op is Ops.FLOORDIV and base.src[1].op is Ops.CONST:
|
||||
exact = q.op is Ops.FLOORDIV and q.src[1].op is Ops.CONST and q.src[0] is base.src[0] and q.src[1].arg == base.src[1].arg*div
|
||||
if exact: return (base*mul).usum(*[t for k,t in enumerate(terms) if k not in (i,j)])
|
||||
# ((base//div)%d)*div + base%div -> base%(div*d)
|
||||
|
||||
Reference in New Issue
Block a user