mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
unbind variables when creating ScheduleItems [pr] (#9846)
This commit is contained in:
@@ -294,7 +294,7 @@ view_left = merge_views+PatternMatcher([
|
||||
(UPat(Ops.VIEW, src=(UPat(GroupOp.UnsafePad, name="e"),), name="view"),
|
||||
lambda e,view: e.contiguous().view(view.st) if any(v.mask is not None for v in view.st.views) else None),
|
||||
# view before elementwise ops
|
||||
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST}, name="e"),), name="view"),
|
||||
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND}, name="e"),), name="view"),
|
||||
lambda e,view: e.replace(src=tuple(s.view(s.st+view.st) if s.op is Ops.VIEW else s.view(view.st) for s in e.src))),
|
||||
# if there's ones added after reduce, put this before the reduce
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
|
||||
@@ -344,19 +344,6 @@ view_right = merge_views+PatternMatcher([
|
||||
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] == r2.arg[0] else None),
|
||||
])
|
||||
|
||||
# **** unbind variables
|
||||
|
||||
def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp):
|
||||
st = unwrap(x.st).simplify()
|
||||
if any(x.op is Ops.BIND for x in st.vars()):
|
||||
st, var_vals = st.unbind()
|
||||
ctx[0].update(var_vals)
|
||||
return x.replace(arg=st) if st != x.st else None
|
||||
|
||||
def unbind_variable(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], var:UOp, val:UOp):
|
||||
ctx[0][var.replace(src=())] = val.arg
|
||||
return var
|
||||
|
||||
# **** fix kernel AST
|
||||
|
||||
add_buffer_ops = PatternMatcher([
|
||||
@@ -391,9 +378,6 @@ fix_kernel_ops = PatternMatcher([
|
||||
# remove CONTIGUOUS/DEVICE from kernel AST
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
||||
# BIND in shapetracker becomes DEFINE_VAR
|
||||
(UPat(Ops.VIEW, name="x"), unbind_shapetracker),
|
||||
(UPat(Ops.BIND, src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),
|
||||
# no ImageDType after load
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
||||
@@ -437,10 +421,6 @@ pm_fuse = PatternMatcher([
|
||||
(UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))),
|
||||
])
|
||||
|
||||
def get_name(ret:tuple[dict[UOp, UOp], dict[Variable, int]]) -> str:
|
||||
kcount = len({u.base.src[1] for u in ret[0].values() if u.base.op is Ops.ASSIGN})
|
||||
return f"Schedule {pluralize('Kernel', kcount)}"+(f" (with_{pluralize('Var', len(ret[1]))})" if ret[1] else "")
|
||||
|
||||
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
import atexit
|
||||
@@ -448,8 +428,8 @@ if CAPTURE_PROCESS_REPLAY:
|
||||
def save_process_replay():
|
||||
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
|
||||
|
||||
@track_rewrites(name_fxn=get_name)
|
||||
def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
|
||||
@track_rewrites(name_fxn=lambda ret: f"Schedule {pluralize('Kernel', len({u.base.src[1] for u in ret.values() if u.base.op is Ops.ASSIGN}))}")
|
||||
def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
# merge_views + simplify
|
||||
tensor_map = graph_rewrite_map(big_sink, merge_views+sym+reorder_view+replace_contiguous+pm_fuse, ctx={})
|
||||
|
||||
@@ -506,4 +486,4 @@ def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
|
||||
asts = dedup(u.arg.ast for u in sched_sink.toposort if u.op is Ops.KERNEL)
|
||||
PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, asts))
|
||||
|
||||
return becomes_map, var_vals
|
||||
return becomes_map
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from collections import deque
|
||||
from tinygrad.ops import UOp, Variable, Ops, buffers
|
||||
from tinygrad.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import Metadata, DEBUG, unwrap
|
||||
from tinygrad.engine.grouper import get_becomes_map
|
||||
@@ -13,10 +13,29 @@ class ScheduleItem:
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
|
||||
# **** unbind Variables
|
||||
|
||||
def unbind_view(ctx:dict[Variable, int], x:UOp):
|
||||
st = unwrap(x.st).simplify()
|
||||
if any(x.op is Ops.BIND for x in st.vars()):
|
||||
st, var_vals = st.unbind()
|
||||
ctx.update(var_vals)
|
||||
return x.replace(arg=st) if st != x.st else None
|
||||
|
||||
def unbind_bind(ctx:dict[Variable, int], x:UOp):
|
||||
var, val = x.unbind()
|
||||
ctx[var.replace(src=())] = val
|
||||
return var
|
||||
|
||||
pm_unbind = PatternMatcher([
|
||||
(UPat(Ops.VIEW, name="x"), unbind_view),
|
||||
(UPat(Ops.BIND, name="x"), unbind_bind),
|
||||
])
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
becomes_map, var_vals = get_becomes_map(big_sink)
|
||||
becomes_map = get_becomes_map(big_sink)
|
||||
sched_sink = becomes_map.pop(big_sink)
|
||||
|
||||
# bfs toposort
|
||||
@@ -32,12 +51,13 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
|
||||
queue = deque(k for k,v in in_degree.items() if v == 0)
|
||||
schedule: list[ScheduleItem] = []
|
||||
var_vals: dict[Variable, int] = {}
|
||||
while queue:
|
||||
u = queue.popleft()
|
||||
# map the BUFFER UOp to a subbuffer if it's a BUFFER_VIEW
|
||||
if (k:=u.src[1]).arg.ast.op is Ops.BUFFER_VIEW:
|
||||
buffers[k.src[0]] = (base:=k.src[1].buf_uop.buffer).view(k.size, k.arg.ast.dtype, k.arg.ast.arg[1]*base.dtype.itemsize)
|
||||
schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
|
||||
schedule.append(ScheduleItem(graph_rewrite(k.arg.ast, pm_unbind, ctx=var_vals), tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
|
||||
for x in children.get(u, []):
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
Reference in New Issue
Block a user