Files
onepilot/tinygrad_repo/tinygrad/codegen/late/expander.py
T
firestar5683 d0e1db6766 StarPilot
2026-03-22 03:15:05 -05:00

151 lines
7.5 KiB
Python

# this converts a lowerer program into a vectorized program
import functools, itertools
from tinygrad.dtype import dtypes, PtrDType, AddrSpace
from tinygrad.helpers import dedup, flatten, all_same, prod, partition
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, AxisType, range_start
from tinygrad.schedule.rangeify import BufferizeOpts
def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
idx, mul = 0, 1
for axis,m in args[::-1]:
idx += rpk[axis] * mul
mul *= m
return idx
def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
@functools.cache
def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
def do_expand(root:UOp):
expands = [x for x in root.src if x.op is Ops.UNROLL]
if len(expands) == 0: return None
# NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else ()
if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0:
# if there's only one expand arg, it's okay to use it (optimization)
expand_args = expands[0].arg
else:
# otherwise, we sort them and GEP
expand_args = tuple(x for x in sorted(dedup(flatten(expands_args))) if x[0] not in exclude_args)
expand_sz = prod([x[1] for x in expand_args])
new_srcs = []
for i,src in enumerate(root.src):
if src.op is Ops.UNROLL:
if expand_args == src.arg:
# just remove the expand
new_srcs.append(src.src[0])
else:
lst = _swizzle_args(expand_args, src.arg, exclude_args)
# if the base dtype is > 1, put those at the end
if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst])
new_srcs.append(src.src[0].gep(tuple(lst)))
else:
# non-UNROLL input
if root.op in range_start and i >= range_start[root.op]:
# for any range args of STORE/REDUCE, pass them through
new_srcs.append(src)
elif root.op is Ops.INDEX and i >= 1 and not isinstance(root.dtype, PtrDType):
new_srcs.append(src)
elif src.dtype.count > 1:
# put any input dtype > 1 grouped together
new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
else:
# repeat the arg
new_srcs.append(src.broadcast(expand_sz))
new_arg = root.arg
if root.op is Ops.GEP:
assert root.dtype.count == 1
# is this right?
new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
return UOp(Ops.UNROLL, root.dtype, (nsrc,), expand_args)
def do_contract(con:UOp):
ex = con.src[0]
# CONTRACT without UNROLL repeats the element VECTORIZED
if ex.op is not Ops.UNROLL: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count)
# CONTRACT may remove several axes from UNROLL
assert con.dtype == dtypes.void or con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
idxs = []
for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
def end_unrolls(u:UOp):
unrolls, src = partition(u.src[1:], lambda x: x.op is Ops.UNROLL)
if not len(unrolls): return None
ret = UOp(Ops.CONTRACT, dtypes.void, (u.src[0],), sum([x.arg for x in unrolls], start=()))
return u.replace(src=(ret,)+tuple(src))
expander = PatternMatcher([
# push broadcast through AFTER/END
(UPat.var("x").broadcast(name="b").after(name="a", allow_any_len=True), lambda x,b,a: x.after(*a.src[1:]).broadcast(len(b.src))),
(UPat.var("x").broadcast(name="b").end(name="a", allow_any_len=True), lambda x,b,a: x.end(*a.src[1:]).broadcast(len(b.src))),
# END on UNROLL ends the UNROLL
(UPat(Ops.END, name="u"), end_unrolls),
# BUFFERIZE puts UNROLLs for ranges as contract
(UPat(Ops.BUFFERIZE, src=(UPat(Ops.UNROLL), UPat(Ops.UNROLL)), name="x"),
lambda x: x.replace(src=tuple(UOp(Ops.CONTRACT, dtype=s.dtype.vec(x.src[1].src[0].dtype.count), src=(s,), arg=x.src[1].arg) for s in x.src))),
# double expand
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
Ops.VECTORIZE, Ops.REDUCE, Ops.END, Ops.AFTER), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# empty UNROLL is NOOP
(UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
])
# ****
def fix_reduce_unroll(x:UOp):
reduce_range, reduce_expand = partition(x.src[1:], lambda y: y.op is Ops.RANGE)
if len(reduce_expand) == 0: return None
reduce_expand = [x for x in reduce_expand if x.op is not Ops.CONST]
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand}"
ret = x.src[0]
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis), tag=1)
# REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group
return x.replace(src=(ret,)+tuple(reduce_range))
def fix_store_unroll(x:UOp):
store_expand, store_range = partition(x.src[2:], lambda y: y.op is Ops.UNROLL)
if len(store_expand) == 0: return None
return UOp(Ops.CONTRACT, dtypes.void, (x.replace(src=x.src[:2]+tuple(store_range)),), tuple(flatten(x.arg for x in store_expand)), tag=1)
def fix_group_for_reduce(x:UOp):
reduce_gfr, reduce_r = partition(x.src[1:], lambda u: u.op is Ops.RANGE and u.arg[1] == AxisType.GROUP_REDUCE)
if len(reduce_gfr) == 0: return None
# NOTE: if there's other locals here, we need them in the buffer too
upstream_locals = [u for u in x.toposort() if u.op is Ops.RANGE and u.arg[1] == AxisType.LOCAL]
# do only the non grouped reduces early
ret = x.replace(src=(x.src[0],)+tuple(reduce_r))
reduce_loop = [x.replace(arg=(x.arg[0]+100, AxisType.REDUCE)) for x in reduce_gfr]
buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=BufferizeOpts(reduce_gfr[0].arg[0], AddrSpace.LOCAL)).index(*upstream_locals, *reduce_loop)
# do the final reduce (if/barrier are added in gpudims step)
return buf.reduce(*reduce_loop, arg=x.arg)
pm_pre_expander = PatternMatcher([
# rewrite UPCAST/UNROLL range to something to be expanded
(UPat(Ops.RANGE, name="r"),
lambda r: UOp(Ops.UNROLL, r.dtype, (UOp.const(r.dtype.vec(s:=r.vmax+1), tuple(range(s))),), ((r.arg[0],s),)) \
if r.arg[1] in {AxisType.UNROLL, AxisType.UPCAST} else None),
# fix REDUCEs with UNROLLs
(UPat(Ops.REDUCE, name="x"), fix_reduce_unroll),
(UPat(Ops.STORE, name="x"), fix_store_unroll),
])
pm_group_for_reduce = PatternMatcher([
# fix group for reduce
(UPat(Ops.REDUCE, name="x"), fix_group_for_reduce),
])