diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ab74958206..cbd8aa98b9 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -226,15 +226,24 @@ create_kernels = merge_views+PatternMatcher([ (UPat(Ops.KERNEL, name="x"), append_to_kernel), ]) -# **** convert Kernel to a ScheduleItem (for legacy reasons) +# **** fix kernel AST -@dataclass(frozen=True) -class ScheduleItem: - ast: UOp - bufs: tuple[Buffer, ...] - metadata: tuple[Metadata, ...] +# ** create buffer ops + enumerate buffers -# **** Kernel creation +def load_buf(ctx:list[UOp], x:UOp): + if x not in ctx: ctx.append(x) + return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop())) + +add_buffer_ops = PatternMatcher([ + # LOAD + (UPat(Ops.BUFFER, name="x"), load_buf), + # STORE (except for COPY/BUFFER_VIEW) + (UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x), + (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)), + lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()), +]) + +# ** push views to buffer ops def apply_swizzle(u:UOp) -> UOp: with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left) @@ -288,13 +297,22 @@ view_right = merge_views+PatternMatcher([ (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) -def _append_st_vars(ctx:dict[Variable, int], x:UOp) -> UOp|None: +# ** unbind variables + +def unbind_shapetracker(ctx:dict[Variable, int], x:UOp) -> UOp|None: st = unwrap(x.st).simplify() if any(x.op is Ops.BIND for x in st.vars()): st, var_vals = st.unbind() ctx.update(var_vals) return st.to_uop() if st != x.st else None +def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp): + ctx[var.replace(src=())] = val.arg + return var +unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),]) + +# ** fix_kernel_ops + def check_load_st(glbl:UOp, view:UOp): if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine @@ -307,7 +325,7 @@ def check_load_st(glbl:UOp, view:UOp): fix_kernel_ops = PatternMatcher([ # BIND in shapetracker becomes DEFINE_VAR - (UPat(Ops.VIEW, name="x"), _append_st_vars), + (UPat(Ops.VIEW, name="x"), unbind_shapetracker), # remove CONTIGUOUS/ASSIGN/DEVICE (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), (UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x), @@ -318,23 +336,12 @@ fix_kernel_ops = PatternMatcher([ (UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st), ]) -def load_buf(ctx:list[UOp], x:UOp): - if x not in ctx: ctx.append(x) - return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop())) - -add_buffer_ops = PatternMatcher([ - # LOAD - (UPat(Ops.BUFFER, name="x"), load_buf), - # STORE (except for COPY/BUFFER_VIEW) - (UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x), - (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)), - lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()), -]) - -def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp): - ctx[var.replace(src=())] = val.arg - return var -unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),]) +# TODO: replace this with the KERNEL UOp +@dataclass(frozen=True) +class ScheduleItem: + ast: UOp + bufs: tuple[Buffer, ...] + metadata: tuple[Metadata, ...] def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem: assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN"