order flipper as *normal* rewrite rule (#7300)

* instant isn't actually used [pr]

* order flipper as *normal* rewrite rule

* fix inf loop

* need simplify now
This commit is contained in:
George Hotz
2024-10-25 20:28:30 +07:00
committed by GitHub
parent 3c31497f55
commit aadf688aeb
2 changed files with 6 additions and 5 deletions

View File

@@ -169,11 +169,11 @@ class TestGraphRewrite(unittest.TestCase):
def test_commutative_work(self):
a = UOp.variable('a', 0, 1)
b = UOp.variable('b', 0, 1)
self.assertIs(a+b, b+a)
self.assertIs((a+b).simplify(), (b+a).simplify())
def test_consts_go_last_right_away(self):
a = UOp.variable('a', 0, 1)
tst = 2+a
tst = (2+a).simplify()
self.assertIs(tst.src[0], a)
self.assertIs(tst.src[1], a.const_like(2))

View File

@@ -192,8 +192,6 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
class UOpMetaClass(type):
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
def __call__(cls, op:UOps, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
if op is UOps.ALU and arg in COMMUTATIVE and (src[0] is not src[1] and src[1].tuplize < src[0].tuplize and src[0].op is not UOps.NOOP):
src = src[::-1]
if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg)
return ret
@@ -954,7 +952,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
for candidate in candidates:
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
newuops = [graph_rewrite(uop.substitute({X:newX}), symbolic_flat).substitute({newX:X}) for X,newX in candidate]
newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
if uop.op is UOps.VECTORIZE and len(uop.src) == 2:
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
@@ -1001,6 +999,9 @@ symbolic = PatternMatcher([
# ** constant folding **
(UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))),
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))),
# ** COMMUTATIVE flipping **
*[(UPat(UOps.ALU, arg=cc, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[0] is not x.src[1] \
and x.src[1].tuplize < x.src[0].tuplize else None) for cc in COMMUTATIVE],
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y),