diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index bbb082725b..d9188fbaab 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -188,7 +188,6 @@ class TestSymbolic(unittest.TestCase): def test_sum_div_min_max(self): self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)") - @unittest.expectedFailure def test_sum_div_factor(self): self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))") diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index b60300eacc..b061950305 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -34,7 +34,7 @@ def float4_expand_load(load, buf, ex, idx=UOp.const(dtypes.int, 0), idx2=None, i if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None if idx2 is not None: idx = idx + idx2 if idx3 is not None: idx = idx + idx3 - if not idx.divides(len(ex.src)): return None + if idx.divides(len(ex.src)) is None: return None if load.dtype.scalar() != load.dtype: return None # how does this happen? vec_load = UOp(UOps.LOAD, load.dtype.vec(len(ex.src)), (buf, idx)) @@ -46,7 +46,7 @@ def float4_contract_store(buf, ex, var, store_allow_any_len, idx=UOp.const(dtype if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None if idx2 is not None: idx = idx + idx2 if idx3 is not None: idx = idx + idx3 - if not idx.divides(len(ex.src)): return None + if idx.divides(len(ex.src)) is None: return None new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), ((ex.arg[0][0],len(ex.src)),)) return UOp(UOps.STORE, None, (buf, idx, new_var) + store_allow_any_len.src[3:]) @@ -251,6 +251,7 @@ constant_folder = PatternMatcher([ # *** rules from symbolic *** # div folding (NOp.var('x') // NOp.cvar('c'), lambda x,c: x.const(x.vmin.arg//c.arg) if c.arg > 0 and x.vmin.arg//c.arg == x.vmax.arg//c.arg else None), + (NOp.var('x') // NOp.cvar('c'), lambda x,c: d if c.arg > 0 and (d:=x.divides(c.arg)) is not None else None), # mod folding (NOp.var('x') % NOp.cvar('c'), lambda x,c: x if 0 <= x.vmin.arg <= x.vmax.arg < c.arg else None), # mod reduction diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index fea7b9ffb9..322d13a426 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -86,13 +86,14 @@ class UOp: @property # parents with self def sparents(self) -> Set[UOp]: return set([self]).union(self.parents) def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR]) - def divides(self, v): - if self.op is UOps.CONST: - return self.arg%v == 0 + def divides(self, v) -> Optional[UOp]: + if self.op is UOps.CONST: return self.const(self.arg//v) if self.arg%v == 0 else None if self.op is UOps.ALU: - if self.arg is BinaryOps.ADD: return all(x.divides(v) for x in self.src) - if self.arg is BinaryOps.MUL: return any(x.divides(v) for x in self.src) - return False # generic false if we aren't sure + if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None + if self.arg is BinaryOps.MUL: + if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1] + if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1 + return None # generic None if we aren't sure @functools.cached_property def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None else self.const(dtypes.min(cast(DType, self.dtype))) @functools.cached_property