diff --git a/tinygrad/codegen/symbolic.py b/tinygrad/codegen/symbolic.py index 9cd5ca7623..dec7be249b 100644 --- a/tinygrad/codegen/symbolic.py +++ b/tinygrad/codegen/symbolic.py @@ -213,8 +213,8 @@ symbolic = symbolic_simple+PatternMatcher([ ((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False)) 0 else None), # ** move add/mul consts to end (NOTE: this is still happening before constant folding) ** - (UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), - (UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), + ((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), + ((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), # *** rules from symbolic *** # unrolled arange div folding (UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),