diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index b488ddac9a..34abdd356b 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -169,11 +169,11 @@ class TestGraphRewrite(unittest.TestCase): def test_commutative_work(self): a = UOp.variable('a', 0, 1) b = UOp.variable('b', 0, 1) - self.assertIs(a+b, b+a) + self.assertIs((a+b).simplify(), (b+a).simplify()) def test_consts_go_last_right_away(self): a = UOp.variable('a', 0, 1) - tst = 2+a + tst = (2+a).simplify() self.assertIs(tst.src[0], a) self.assertIs(tst.src[1], a.const_like(2)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e3646fbfaf..4967ef0e64 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -192,8 +192,6 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s class UOpMetaClass(type): ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary() def __call__(cls, op:UOps, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None): - if op is UOps.ALU and arg in COMMUTATIVE and (src[0] is not src[1] and src[1].tuplize < src[0].tuplize and src[0].op is not UOps.NOOP): - src = src[::-1] if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg) return ret @@ -954,7 +952,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: for candidate in candidates: # if every branch in candidate gives the same simplified uop, we can rewrite the uop - newuops = [graph_rewrite(uop.substitute({X:newX}), symbolic_flat).substitute({newX:X}) for X,newX in candidate] + newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate] if uop.op is UOps.VECTORIZE and len(uop.src) == 2: if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1])) if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1])) @@ -1001,6 +999,9 @@ symbolic = PatternMatcher([ # ** constant folding ** (UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))), lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))), + # ** COMMUTATIVE flipping ** + *[(UPat(UOps.ALU, arg=cc, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[0] is not x.src[1] \ + and x.src[1].tuplize < x.src[0].tuplize else None) for cc in COMMUTATIVE], # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly (UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y), (UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y),