UOp simple div folding (#5740)

made UOp.divides return the Optional[quotient] and used it for simple div folding
This commit is contained in:
chenyu
2024-07-26 17:14:32 -04:00
committed by GitHub
parent 671259417f
commit dc7483ee6f
3 changed files with 10 additions and 9 deletions

View File

@@ -188,7 +188,6 @@ class TestSymbolic(unittest.TestCase):
def test_sum_div_min_max(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
@unittest.expectedFailure
def test_sum_div_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))")

View File

@@ -34,7 +34,7 @@ def float4_expand_load(load, buf, ex, idx=UOp.const(dtypes.int, 0), idx2=None, i
if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None
if idx2 is not None: idx = idx + idx2
if idx3 is not None: idx = idx + idx3
if not idx.divides(len(ex.src)): return None
if idx.divides(len(ex.src)) is None: return None
if load.dtype.scalar() != load.dtype: return None # how does this happen?
vec_load = UOp(UOps.LOAD, load.dtype.vec(len(ex.src)), (buf, idx))
@@ -46,7 +46,7 @@ def float4_contract_store(buf, ex, var, store_allow_any_len, idx=UOp.const(dtype
if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None
if idx2 is not None: idx = idx + idx2
if idx3 is not None: idx = idx + idx3
if not idx.divides(len(ex.src)): return None
if idx.divides(len(ex.src)) is None: return None
new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), ((ex.arg[0][0],len(ex.src)),))
return UOp(UOps.STORE, None, (buf, idx, new_var) + store_allow_any_len.src[3:])
@@ -251,6 +251,7 @@ constant_folder = PatternMatcher([
# *** rules from symbolic ***
# div folding
(NOp.var('x') // NOp.cvar('c'), lambda x,c: x.const(x.vmin.arg//c.arg) if c.arg > 0 and x.vmin.arg//c.arg == x.vmax.arg//c.arg else None),
(NOp.var('x') // NOp.cvar('c'), lambda x,c: d if c.arg > 0 and (d:=x.divides(c.arg)) is not None else None),
# mod folding
(NOp.var('x') % NOp.cvar('c'), lambda x,c: x if 0 <= x.vmin.arg <= x.vmax.arg < c.arg else None),
# mod reduction

View File

@@ -86,13 +86,14 @@ class UOp:
@property # parents with self
def sparents(self) -> Set[UOp]: return set([self]).union(self.parents)
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
def divides(self, v):
if self.op is UOps.CONST:
return self.arg%v == 0
def divides(self, v) -> Optional[UOp]:
if self.op is UOps.CONST: return self.const(self.arg//v) if self.arg%v == 0 else None
if self.op is UOps.ALU:
if self.arg is BinaryOps.ADD: return all(x.divides(v) for x in self.src)
if self.arg is BinaryOps.MUL: return any(x.divides(v) for x in self.src)
return False # generic false if we aren't sure
if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
if self.arg is BinaryOps.MUL:
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
return None # generic None if we aren't sure
@functools.cached_property
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None else self.const(dtypes.min(cast(DType, self.dtype)))
@functools.cached_property