From bfc68d195343e9eb025685b5d3afd9e31fa8bb7d Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 13 Mar 2025 09:46:25 +0800 Subject: [PATCH] add gep rules to simplify (#9419) * add gep rules to simplify * ws * flipped direction --- tinygrad/codegen/devectorizer.py | 18 ++++-------------- tinygrad/codegen/symbolic.py | 22 +++++++++++----------- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/tinygrad/codegen/devectorizer.py b/tinygrad/codegen/devectorizer.py index 9782040509..a0a9d4e87c 100644 --- a/tinygrad/codegen/devectorizer.py +++ b/tinygrad/codegen/devectorizer.py @@ -11,26 +11,16 @@ from tinygrad.renderer import Renderer # ***** load/store grouping ***** -def fancy_gep(vec:UOp, i:int): - # if there's a vectorized ADD here, expand through it - if vec.op is Ops.ADD: - if vec.src[0].op is Ops.VECTORIZE and vec.src[1].op is Ops.VCONST: return vec.src[0].gep(i) + vec.src[1].gep(i) - if vec.src[1].op is Ops.VECTORIZE and vec.src[0].op is Ops.VCONST: return vec.src[1].gep(i) + vec.src[0].gep(i) - # if there's a vectorized AND here, expand through it - if vec.op is Ops.AND: - if vec.src[0].op is Ops.VECTORIZE and vec.src[1].op is Ops.VCONST: return vec.src[0].gep(i) & vec.src[1].gep(i) - if vec.src[1].op is Ops.VECTORIZE and vec.src[0].op is Ops.VCONST: return vec.src[1].gep(i) & vec.src[0].gep(i) - return vec.gep(i) - def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None): # first, extract all the relevant offsets offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict) for i in range(vec.dtype.count): - idx = fancy_gep(vec, i) + idx = vec.gep(i).simplify() if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg + elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg else: root_src, arg = idx, 0 - if mask is not None: root_src = (fancy_gep(mask, i), root_src) + if mask is not None: root_src = (mask.gep(i).simplify(), root_src) offsets_rootsrc[root_src].setdefault(arg, []).append(i) # the buf.dtype is always a pointer @@ -44,7 +34,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None): grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])] for grp in grouped_offsets: # get the index offset for this element. using [0] is okay, because they are the same - oidx = fancy_gep(vec, offsets[grp[0]][0]) + oidx = vec.gep(offsets[grp[0]][0]) lidx = UOp(Ops.INDEX, buf.dtype, (buf, oidx, rootsrc[0]) if mask is not None else (buf, oidx)) if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local)) # set the idxs of the output diff --git a/tinygrad/codegen/symbolic.py b/tinygrad/codegen/symbolic.py index dec7be249b..82fb566ebc 100644 --- a/tinygrad/codegen/symbolic.py +++ b/tinygrad/codegen/symbolic.py @@ -230,6 +230,17 @@ symbolic = symbolic_simple+PatternMatcher([ # ** mod ** # mod folding (UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)), + # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST + (UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'), + lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))), + (UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"), + lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]), + (UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)), + (UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))), + # push all GEPs through ALUs (fix arange stuff) + (UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'), + lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \ + if not isinstance(gep.dtype, PtrDType) else None), ]) symbolic_flat = symbolic+PatternMatcher([ @@ -399,17 +410,6 @@ sym = symbolic_flat+PatternMatcher([ # VECTORIZE void is SINK (UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b), (UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)), - # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST - (UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'), - lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))), - (UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"), - lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]), - (UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)), - (UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))), - # push all GEPs through ALUs (fix arange stuff) - (UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'), - lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \ - if not isinstance(gep.dtype, PtrDType) else None), # push some GEPs through WMMAs (UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma), # CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)