mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
order flipper as *normal* rewrite rule (#7300)
* instant isn't actually used [pr] * order flipper as *normal* rewrite rule * fix inf loop * need simplify now
This commit is contained in:
@@ -169,11 +169,11 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
def test_commutative_work(self):
|
||||
a = UOp.variable('a', 0, 1)
|
||||
b = UOp.variable('b', 0, 1)
|
||||
self.assertIs(a+b, b+a)
|
||||
self.assertIs((a+b).simplify(), (b+a).simplify())
|
||||
|
||||
def test_consts_go_last_right_away(self):
|
||||
a = UOp.variable('a', 0, 1)
|
||||
tst = 2+a
|
||||
tst = (2+a).simplify()
|
||||
self.assertIs(tst.src[0], a)
|
||||
self.assertIs(tst.src[1], a.const_like(2))
|
||||
|
||||
|
||||
@@ -192,8 +192,6 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
|
||||
class UOpMetaClass(type):
|
||||
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
|
||||
def __call__(cls, op:UOps, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
if op is UOps.ALU and arg in COMMUTATIVE and (src[0] is not src[1] and src[1].tuplize < src[0].tuplize and src[0].op is not UOps.NOOP):
|
||||
src = src[::-1]
|
||||
if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
|
||||
UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg)
|
||||
return ret
|
||||
@@ -954,7 +952,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
|
||||
for candidate in candidates:
|
||||
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
|
||||
newuops = [graph_rewrite(uop.substitute({X:newX}), symbolic_flat).substitute({newX:X}) for X,newX in candidate]
|
||||
newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
|
||||
if uop.op is UOps.VECTORIZE and len(uop.src) == 2:
|
||||
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
|
||||
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
|
||||
@@ -1001,6 +999,9 @@ symbolic = PatternMatcher([
|
||||
# ** constant folding **
|
||||
(UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))),
|
||||
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))),
|
||||
# ** COMMUTATIVE flipping **
|
||||
*[(UPat(UOps.ALU, arg=cc, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[0] is not x.src[1] \
|
||||
and x.src[1].tuplize < x.src[0].tuplize else None) for cc in COMMUTATIVE],
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
|
||||
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y),
|
||||
|
||||
Reference in New Issue
Block a user