Files
tinygrad/tinygrad/codegen/devectorizer.py
George Hotz cac8bcf8b5 use Ops.REDUCE (#9721)
* decrease bert python time [pr]

* order copies

* Revert "order copies"

This reverts commit 3f62c8693b.

* rewrite count

* Ops.REDUCE

* acc first in the add chain

* Fix tensor core acc

* arange patterns look good

* fix multireduce gate

* reduce rewrite rule

* bump that to 15 minutes

* multiwmma isn't fusing

* gep through wmma is gep pushing

* bump that timeout too, it's all env setup

* add failing test
2025-04-04 10:14:34 +08:00

336 lines
17 KiB
Python

from typing import Optional, Any, Callable, cast
import functools, operator, itertools
from collections import defaultdict
from dataclasses import dataclass
from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve, graph_rewrite, GroupOp, identity_element
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, gep_pushing
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
from tinygrad.renderer import Renderer
# ***** load/store grouping *****
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
# generate the individual indexes
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}")
# extract all the relevant offsets
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
for i in range(vec.dtype.count):
idx: Any = midx.src[i].src[1]
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
# the buf.dtype is always a pointer
ptrdtype = cast(PtrDType, buf.dtype)
# then rewrite everything we can into groups
ret = []
idxs: list[int|None] = [None]*vec.dtype.count
global_offset = 0
for offsets in offsets_rootsrc.values():
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for grp in grouped_offsets:
# get the index offset for this element. using [0] is okay, because they are the same
lidx = midx.src[offsets[grp[0]][0]]
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local))
# set the idxs of the output
for i,g in enumerate(grp):
for oo in offsets[g]: idxs[oo] = global_offset+i
# add this lidx to the CAT
ret.append(lidx)
global_offset += len(grp)
assert None not in idxs, f"some idxs are missing {idxs}"
# this base thing is for image, we want the CAT to be a normal pointer
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret))
return post_cat.gep(tuple(cast(list[int], idxs)))
def cat_after_store(cat:UOp, data:UOp):
# TODO: this is written in many places
offset = 0
ret = []
for s in cat.src:
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count)))))
offset += s.dtype.count
return UOp.sink(ret[0], *ret[1:])
def gep_on_store(gep:UOp, st:UOp):
# NOTE: we need to invert the gep here, but it may be an expanding gep
# fake argsort. TODO: handle duplicates
a = {}
for i,x in enumerate(gep.arg): a[x] = i
new_arg = tuple(x[1] for x in sorted(a.items()))
return UOp(Ops.STORE, src=(gep.src[0], st.gep(new_arg)))
load_store_folding = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"))), expand_index),
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"),
UPat.var("mask"))), expand_index),
# GEP after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
# GEP on data of STORE
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st"))), gep_on_store),
# put PTRCAT after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
# put PTRCAT after STORE
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data"))), cat_after_store),
])
# ***** image load valid simplification *****
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0)
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
# wait for it to be image indexed before running simplification
if start_idx.dtype.count != 2: return None
# can drop valid if idx is out of bound when valid is False
drop_stmt = []
for stmt in split_uop(valid, Ops.AND):
X, is_upper_bound, c = parse_valid(stmt)
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)):
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx)
testidx = testidx.simplify()
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
drop_stmt.append(stmt)
continue
# if X <= c, check if it's out of bound when X = c+1
# if X >= c, check if it's out of bound when X = c-1
test_value = c + 1 if is_upper_bound else c - 1
for i,b in zip(idx.src, (buf.dtype.shape[1], buf.dtype.shape[0])):
if i.is_increasing():
rw = i.substitute({X:X.const_like(test_value)}).simplify()
if rw.vmin >= b or rw.vmax < 0:
drop_stmt.append(stmt)
break
if not drop_stmt and idx is start_idx: return None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None
return buf.index(idx, new_valid)
# ***** optional patterns *****
powers_of_two = {2**i:i for i in range(64)}
@functools.cache
def get_late_rewrite_patterns(ops, force_transcendental=False):
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
# rewrite SQRT to xpow 0.5
if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)]
if Ops.SHR in ops:
# no reason to check x>=0 for uints
pat += [(UPat.var("x", dtypes.uints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)]
pat += [(UPat.var("x", dtypes.sints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) and resolve(x>=0,False) else None)]
if Ops.NEG in ops:
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
return PatternMatcher(pat)
# *** correct load/store ***
def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
# this splits loads and stores into multiple chunks
# if there's only one element to load/store, no splitting needed
if (sz:=ls.src[0].dtype.count) == 1: return None
buf = idx.src[0]
# determine fold lengths
lengths = []
must_divide = True
if ctx is not None and ctx.device == "DSP":
lengths = [128,64,32,16,8,4]
must_divide = False
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])
lengths.append(1) # worst case, it's not folded
# split based on the fold lengths
global_offset = 0
ret = []
ptrdtype = cast(PtrDType, buf.dtype)
while global_offset < sz:
oidx = (idx.src[1] + global_offset).simplify()
# with 1 at the end of the lengths list, this will always hit
for fold_length in lengths:
if global_offset+fold_length > sz: continue
if must_divide and oidx.divides(fold_length) is None: continue
lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None)
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local))
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
global_offset += fold_length
break
# if it wasn't split, we return None. otherwise we CAT them
return UOp(Ops.CAT, ls.dtype, tuple(ret)) if len(ret) > 1 else None
def image_fixup(ls:UOp):
# normal image load or store, with the CAST from expand_index
if ls.src[0].op is Ops.CAST and isinstance(image_dtype:=ls.src[0].src[0].dtype, ImageDType):
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
idx = ls.src[0].src[0]
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
return ls.replace(src=(idx,)+ls.src[1:])
# this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores
if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].dtype != dtypes.int.vec(2):
assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it"
idx = ls.src[0]
id4 = idx.src[1] % 4
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan')))
return None
correct_load_store = PatternMatcher([
# split LOAD/STORE
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.INDEX, name="idx").cast(),), name="ls", allow_any_len=True), split_load_store),
# image indexing, including unfoldable images
(UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup),
])
# *** uop expander ***
# TODO: there's a lot shared with gep_through_wmma here
def no_vectorized_wmma(wmma:UOp):
out_sz = prod(x[1] for x in wmma.arg[6][-1])
if wmma.dtype.count == out_sz: return None
tsrcs = []
for s,sz in zip(wmma.src, wmma.arg[6]):
ssz = prod(x[1] for x in sz)
tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
wmmas = [UOp(Ops.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
def no_vectorized_alu(alu):
if alu.dtype.vcount == 1: return None
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
return UOp(Ops.VECTORIZE, alu.dtype, alus)
def no_vectorized_acc(acc:UOp):
if acc.dtype.count == 1: return None
alus = tuple(UOp(acc.op, acc.dtype.scalar(),
tuple(s.gep(i) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
return UOp(Ops.VECTORIZE, acc.dtype, alus)
devectorize = PatternMatcher([
# no ALU on vectorized dtypes
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN), name="alu"), no_vectorized_alu),
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
(UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
])
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
if store_gate not in [gate.src[0] for gate in val.toposort if gate.op is Ops.IF]: return None
# remove the gate from the index
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
load_store_indexing = PatternMatcher([
# simplify valid
(UPat(Ops.AND, name="valid"), simplify_valid),
# image load valid idx simplification
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
# index True is just Index
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)),
# delete_redundant_gates (after expand)
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
UPat.var("val"))), delete_redundant_gates),
])
pm_render = PatternMatcher([
# for rendering, we use explicit VECTORIZE
(UPat(Ops.CONST, name='c'),
lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
# give any loads that are masked an alt value
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None),
# gate any stores that aren't gated with ifs
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store"),
lambda store,idx: UOp(Ops.STORE, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),))),
])
# *** Ops.REDUCE -> Ops.DEFINE_ACC+Ops.ASSIGN ***
@dataclass
class ReduceContext:
acc_num: int = 0
def reduce_to_acc(ctx:ReduceContext, red:UOp):
inp, reduce_range = red.src[0], red.src[1:]
# if this has a horizontal reduction component, do that first
if inp.dtype != red.dtype:
# NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7]
horizontal_amount = inp.dtype.count//red.dtype.count
lst = [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)]
else:
lst = [inp]
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
# if we have a range
if len(reduce_range) != 0:
acc = UOp(Ops.DEFINE_ACC, red.dtype, (red.const_like(identity_element(red.arg, red.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
lst = [acc] + lst # put acc as the first element
ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
return acc.assign(ret) if len(reduce_range) != 0 else ret
pm_reduce = PatternMatcher([
# REDUCE -> DEFINE_ACC+ASSIGN
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
# tensor core built in accumulate
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
])
# *** uop graph ***
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
# remove reduce
sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
# devectorize is optional
if DEVECTORIZE >= 2: sink = graph_rewrite(sink, sym+load_store_folding+load_store_indexing, ctx=opts)
elif DEVECTORIZE: sink = graph_rewrite(sink, sym+devectorize+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
else: sink = graph_rewrite(sink, sym+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
# optional pre matcher
if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)
# final rules for the renderer (without sym)
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)
return sink