unbind variables when creating ScheduleItems [pr] (#9846)

This commit is contained in:
qazal
2025-04-11 15:23:53 +08:00
committed by GitHub
parent 6896197978
commit cbc5e7ed45
2 changed files with 27 additions and 27 deletions

View File

@@ -294,7 +294,7 @@ view_left = merge_views+PatternMatcher([
(UPat(Ops.VIEW, src=(UPat(GroupOp.UnsafePad, name="e"),), name="view"),
lambda e,view: e.contiguous().view(view.st) if any(v.mask is not None for v in view.st.views) else None),
# view before elementwise ops
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST}, name="e"),), name="view"),
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND}, name="e"),), name="view"),
lambda e,view: e.replace(src=tuple(s.view(s.st+view.st) if s.op is Ops.VIEW else s.view(view.st) for s in e.src))),
# if there's ones added after reduce, put this before the reduce
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
@@ -344,19 +344,6 @@ view_right = merge_views+PatternMatcher([
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] == r2.arg[0] else None),
])
# **** unbind variables
def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp):
st = unwrap(x.st).simplify()
if any(x.op is Ops.BIND for x in st.vars()):
st, var_vals = st.unbind()
ctx[0].update(var_vals)
return x.replace(arg=st) if st != x.st else None
def unbind_variable(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], var:UOp, val:UOp):
ctx[0][var.replace(src=())] = val.arg
return var
# **** fix kernel AST
add_buffer_ops = PatternMatcher([
@@ -391,9 +378,6 @@ fix_kernel_ops = PatternMatcher([
# remove CONTIGUOUS/DEVICE from kernel AST
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
# BIND in shapetracker becomes DEFINE_VAR
(UPat(Ops.VIEW, name="x"), unbind_shapetracker),
(UPat(Ops.BIND, src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),
# no ImageDType after load
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
@@ -437,10 +421,6 @@ pm_fuse = PatternMatcher([
(UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))),
])
def get_name(ret:tuple[dict[UOp, UOp], dict[Variable, int]]) -> str:
kcount = len({u.base.src[1] for u in ret[0].values() if u.base.op is Ops.ASSIGN})
return f"Schedule {pluralize('Kernel', kcount)}"+(f" (with_{pluralize('Var', len(ret[1]))})" if ret[1] else "")
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
if CAPTURE_PROCESS_REPLAY:
import atexit
@@ -448,8 +428,8 @@ if CAPTURE_PROCESS_REPLAY:
def save_process_replay():
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
@track_rewrites(name_fxn=get_name)
def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
@track_rewrites(name_fxn=lambda ret: f"Schedule {pluralize('Kernel', len({u.base.src[1] for u in ret.values() if u.base.op is Ops.ASSIGN}))}")
def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
# merge_views + simplify
tensor_map = graph_rewrite_map(big_sink, merge_views+sym+reorder_view+replace_contiguous+pm_fuse, ctx={})
@@ -506,4 +486,4 @@ def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
asts = dedup(u.arg.ast for u in sched_sink.toposort if u.op is Ops.KERNEL)
PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, asts))
return becomes_map, var_vals
return becomes_map

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from collections import deque
from tinygrad.ops import UOp, Variable, Ops, buffers
from tinygrad.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers
from tinygrad.device import Buffer
from tinygrad.helpers import Metadata, DEBUG, unwrap
from tinygrad.engine.grouper import get_becomes_map
@@ -13,10 +13,29 @@ class ScheduleItem:
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...] = ()
# **** unbind Variables
def unbind_view(ctx:dict[Variable, int], x:UOp):
st = unwrap(x.st).simplify()
if any(x.op is Ops.BIND for x in st.vars()):
st, var_vals = st.unbind()
ctx.update(var_vals)
return x.replace(arg=st) if st != x.st else None
def unbind_bind(ctx:dict[Variable, int], x:UOp):
var, val = x.unbind()
ctx[var.replace(src=())] = val
return var
pm_unbind = PatternMatcher([
(UPat(Ops.VIEW, name="x"), unbind_view),
(UPat(Ops.BIND, name="x"), unbind_bind),
])
# **** schedule linearizer
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
becomes_map, var_vals = get_becomes_map(big_sink)
becomes_map = get_becomes_map(big_sink)
sched_sink = becomes_map.pop(big_sink)
# bfs toposort
@@ -32,12 +51,13 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
queue = deque(k for k,v in in_degree.items() if v == 0)
schedule: list[ScheduleItem] = []
var_vals: dict[Variable, int] = {}
while queue:
u = queue.popleft()
# map the BUFFER UOp to a subbuffer if it's a BUFFER_VIEW
if (k:=u.src[1]).arg.ast.op is Ops.BUFFER_VIEW:
buffers[k.src[0]] = (base:=k.src[1].buf_uop.buffer).view(k.size, k.arg.ast.dtype, k.arg.ast.arg[1]*base.dtype.itemsize)
schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
schedule.append(ScheduleItem(graph_rewrite(k.arg.ast, pm_unbind, ctx=var_vals), tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
for x in children.get(u, []):
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)