From aadf688aebbea4eb07fdbb221de238cf413a569b Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 25 Oct 2024 20:28:30 +0700 Subject: [PATCH] order flipper as *normal* rewrite rule (#7300) * instant isn't actually used [pr] * order flipper as *normal* rewrite rule * fix inf loop * need simplify now --- test/test_uop_graph.py | 4 ++-- tinygrad/ops.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index b488ddac9a..34abdd356b 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -169,11 +169,11 @@ class TestGraphRewrite(unittest.TestCase): def test_commutative_work(self): a = UOp.variable('a', 0, 1) b = UOp.variable('b', 0, 1) - self.assertIs(a+b, b+a) + self.assertIs((a+b).simplify(), (b+a).simplify()) def test_consts_go_last_right_away(self): a = UOp.variable('a', 0, 1) - tst = 2+a + tst = (2+a).simplify() self.assertIs(tst.src[0], a) self.assertIs(tst.src[1], a.const_like(2)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e3646fbfaf..4967ef0e64 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -192,8 +192,6 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s class UOpMetaClass(type): ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary() def __call__(cls, op:UOps, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None): - if op is UOps.ALU and arg in COMMUTATIVE and (src[0] is not src[1] and src[1].tuplize < src[0].tuplize and src[0].op is not UOps.NOOP): - src = src[::-1] if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg) return ret @@ -954,7 +952,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: for candidate in candidates: # if every branch in candidate gives the same simplified uop, we can rewrite the uop - newuops = [graph_rewrite(uop.substitute({X:newX}), symbolic_flat).substitute({newX:X}) for X,newX in candidate] + newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate] if uop.op is UOps.VECTORIZE and len(uop.src) == 2: if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1])) if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1])) @@ -1001,6 +999,9 @@ symbolic = PatternMatcher([ # ** constant folding ** (UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))), lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))), + # ** COMMUTATIVE flipping ** + *[(UPat(UOps.ALU, arg=cc, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[0] is not x.src[1] \ + and x.src[1].tuplize < x.src[0].tuplize else None) for cc in COMMUTATIVE], # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly (UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y), (UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y),