diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 3c008a26e4..1176192df0 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -294,7 +294,7 @@ view_left = merge_views+PatternMatcher([ (UPat(Ops.VIEW, src=(UPat(GroupOp.UnsafePad, name="e"),), name="view"), lambda e,view: e.contiguous().view(view.st) if any(v.mask is not None for v in view.st.views) else None), # view before elementwise ops - (UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST}, name="e"),), name="view"), + (UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND}, name="e"),), name="view"), lambda e,view: e.replace(src=tuple(s.view(s.st+view.st) if s.op is Ops.VIEW else s.view(view.st) for s in e.src))), # if there's ones added after reduce, put this before the reduce (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones), @@ -344,19 +344,6 @@ view_right = merge_views+PatternMatcher([ lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] == r2.arg[0] else None), ]) -# **** unbind variables - -def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp): - st = unwrap(x.st).simplify() - if any(x.op is Ops.BIND for x in st.vars()): - st, var_vals = st.unbind() - ctx[0].update(var_vals) - return x.replace(arg=st) if st != x.st else None - -def unbind_variable(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], var:UOp, val:UOp): - ctx[0][var.replace(src=())] = val.arg - return var - # **** fix kernel AST add_buffer_ops = PatternMatcher([ @@ -391,9 +378,6 @@ fix_kernel_ops = PatternMatcher([ # remove CONTIGUOUS/DEVICE from kernel AST (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())), - # BIND in shapetracker becomes DEFINE_VAR - (UPat(Ops.VIEW, name="x"), unbind_shapetracker), - (UPat(Ops.BIND, src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable), # no ImageDType after load (UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), # if this kernel also assigns to the loaded buffer, ensure we can index it correctly @@ -437,10 +421,6 @@ pm_fuse = PatternMatcher([ (UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))), ]) -def get_name(ret:tuple[dict[UOp, UOp], dict[Variable, int]]) -> str: - kcount = len({u.base.src[1] for u in ret[0].values() if u.base.op is Ops.ASSIGN}) - return f"Schedule {pluralize('Kernel', kcount)}"+(f" (with_{pluralize('Var', len(ret[1]))})" if ret[1] else "") - PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {} if CAPTURE_PROCESS_REPLAY: import atexit @@ -448,8 +428,8 @@ if CAPTURE_PROCESS_REPLAY: def save_process_replay(): for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True) -@track_rewrites(name_fxn=get_name) -def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]: +@track_rewrites(name_fxn=lambda ret: f"Schedule {pluralize('Kernel', len({u.base.src[1] for u in ret.values() if u.base.op is Ops.ASSIGN}))}") +def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]: # merge_views + simplify tensor_map = graph_rewrite_map(big_sink, merge_views+sym+reorder_view+replace_contiguous+pm_fuse, ctx={}) @@ -506,4 +486,4 @@ def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]: asts = dedup(u.arg.ast for u in sched_sink.toposort if u.op is Ops.KERNEL) PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, asts)) - return becomes_map, var_vals + return becomes_map diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 5d8d9a89d3..ee59ac2bae 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from collections import deque -from tinygrad.ops import UOp, Variable, Ops, buffers +from tinygrad.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers from tinygrad.device import Buffer from tinygrad.helpers import Metadata, DEBUG, unwrap from tinygrad.engine.grouper import get_becomes_map @@ -13,10 +13,29 @@ class ScheduleItem: bufs: tuple[Buffer, ...] metadata: tuple[Metadata, ...] = () +# **** unbind Variables + +def unbind_view(ctx:dict[Variable, int], x:UOp): + st = unwrap(x.st).simplify() + if any(x.op is Ops.BIND for x in st.vars()): + st, var_vals = st.unbind() + ctx.update(var_vals) + return x.replace(arg=st) if st != x.st else None + +def unbind_bind(ctx:dict[Variable, int], x:UOp): + var, val = x.unbind() + ctx[var.replace(src=())] = val + return var + +pm_unbind = PatternMatcher([ + (UPat(Ops.VIEW, name="x"), unbind_view), + (UPat(Ops.BIND, name="x"), unbind_bind), +]) + # **** schedule linearizer def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: - becomes_map, var_vals = get_becomes_map(big_sink) + becomes_map = get_becomes_map(big_sink) sched_sink = becomes_map.pop(big_sink) # bfs toposort @@ -32,12 +51,13 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va queue = deque(k for k,v in in_degree.items() if v == 0) schedule: list[ScheduleItem] = [] + var_vals: dict[Variable, int] = {} while queue: u = queue.popleft() # map the BUFFER UOp to a subbuffer if it's a BUFFER_VIEW if (k:=u.src[1]).arg.ast.op is Ops.BUFFER_VIEW: buffers[k.src[0]] = (base:=k.src[1].buf_uop.buffer).view(k.size, k.arg.ast.dtype, k.arg.ast.arg[1]*base.dtype.itemsize) - schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata)) + schedule.append(ScheduleItem(graph_rewrite(k.arg.ast, pm_unbind, ctx=var_vals), tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata)) for x in children.get(u, []): in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x)