mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
UOp simple div folding (#5740)
made UOp.divides return the Optional[quotient] and used it for simple div folding
This commit is contained in:
@@ -188,7 +188,6 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_sum_div_min_max(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_sum_div_factor(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))")
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def float4_expand_load(load, buf, ex, idx=UOp.const(dtypes.int, 0), idx2=None, i
|
||||
if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None
|
||||
if idx2 is not None: idx = idx + idx2
|
||||
if idx3 is not None: idx = idx + idx3
|
||||
if not idx.divides(len(ex.src)): return None
|
||||
if idx.divides(len(ex.src)) is None: return None
|
||||
|
||||
if load.dtype.scalar() != load.dtype: return None # how does this happen?
|
||||
vec_load = UOp(UOps.LOAD, load.dtype.vec(len(ex.src)), (buf, idx))
|
||||
@@ -46,7 +46,7 @@ def float4_contract_store(buf, ex, var, store_allow_any_len, idx=UOp.const(dtype
|
||||
if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None
|
||||
if idx2 is not None: idx = idx + idx2
|
||||
if idx3 is not None: idx = idx + idx3
|
||||
if not idx.divides(len(ex.src)): return None
|
||||
if idx.divides(len(ex.src)) is None: return None
|
||||
|
||||
new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), ((ex.arg[0][0],len(ex.src)),))
|
||||
return UOp(UOps.STORE, None, (buf, idx, new_var) + store_allow_any_len.src[3:])
|
||||
@@ -251,6 +251,7 @@ constant_folder = PatternMatcher([
|
||||
# *** rules from symbolic ***
|
||||
# div folding
|
||||
(NOp.var('x') // NOp.cvar('c'), lambda x,c: x.const(x.vmin.arg//c.arg) if c.arg > 0 and x.vmin.arg//c.arg == x.vmax.arg//c.arg else None),
|
||||
(NOp.var('x') // NOp.cvar('c'), lambda x,c: d if c.arg > 0 and (d:=x.divides(c.arg)) is not None else None),
|
||||
# mod folding
|
||||
(NOp.var('x') % NOp.cvar('c'), lambda x,c: x if 0 <= x.vmin.arg <= x.vmax.arg < c.arg else None),
|
||||
# mod reduction
|
||||
|
||||
@@ -86,13 +86,14 @@ class UOp:
|
||||
@property # parents with self
|
||||
def sparents(self) -> Set[UOp]: return set([self]).union(self.parents)
|
||||
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
||||
def divides(self, v):
|
||||
if self.op is UOps.CONST:
|
||||
return self.arg%v == 0
|
||||
def divides(self, v) -> Optional[UOp]:
|
||||
if self.op is UOps.CONST: return self.const(self.arg//v) if self.arg%v == 0 else None
|
||||
if self.op is UOps.ALU:
|
||||
if self.arg is BinaryOps.ADD: return all(x.divides(v) for x in self.src)
|
||||
if self.arg is BinaryOps.MUL: return any(x.divides(v) for x in self.src)
|
||||
return False # generic false if we aren't sure
|
||||
if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
|
||||
if self.arg is BinaryOps.MUL:
|
||||
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
|
||||
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
||||
return None # generic None if we aren't sure
|
||||
@functools.cached_property
|
||||
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None else self.const(dtypes.min(cast(DType, self.dtype)))
|
||||
@functools.cached_property
|
||||
|
||||
Reference in New Issue
Block a user