From c52cd2b437ed7e018c07aa7a0542ec72ea9df69f Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 10 Feb 2025 17:54:30 +0100 Subject: [PATCH] kernel op cleanups + use ScheduleItem [pr] (#9009) --- tinygrad/engine/schedule.py | 27 ++++++++++----------------- tinygrad/ops.py | 2 +- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index b017d8a7de..1526ed1bee 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -249,7 +249,7 @@ break_sched = PatternMatcher([ (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse), ]) -# **** convert Kernel to a ScheduleItem (for legacy reasons) +# **** ScheduleItem creation (TODO: replace ScheduleItem with the KERNEL UOp) @dataclass(frozen=True) class ScheduleItem: @@ -267,18 +267,6 @@ class ScheduleItem: @functools.cached_property def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,) -def kernel_to_si(k:UOp) -> ScheduleItem: - assert k.op is Ops.KERNEL and isinstance(k.metadata, tuple), f"must be KERNEL {k}" - return ScheduleItem(k.arg.ast, tuple(u.buf_uop.buffer for u in k.src), k.metadata) - -# **** Kernel creation - -@dataclass(frozen=True) -class Kernel: - ast: UOp - metadata: tuple[Metadata, ...] - def __repr__(self): return f"" - @dataclass(frozen=True) class KernelContext: var_vals: dict[Variable, int] @@ -380,14 +368,14 @@ def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp): return var unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),]) -def schedule_uop(pre:UOp, ctx:ScheduleContext) -> UOp: +def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem: # unbind_vars + push views to edges sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=ctx.var_vals), view_right) # remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL ast = graph_rewrite(sink, to_si, si_ctx:=KernelContext(ctx.var_vals)) # NOTE: we only add the metadata for fused tensors metadata = tuple(dedup(m for x in pre.toposort if x.op is not Ops.BUFFER and (m:=ctx.ops_metadata.get(x)) is not None)) - return UOp(Ops.KERNEL, src=tuple(si_ctx.bufs), arg=Kernel(ast, metadata)) + return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), metadata) PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {} if CAPTURE_PROCESS_REPLAY: @@ -395,8 +383,13 @@ if CAPTURE_PROCESS_REPLAY: def save_process_replay(): for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True) +@dataclass(frozen=True) +class Kernel: + ast: UOp + def __repr__(self): return f"" + # NOTE: realizes become ASSIGN(BUFFER, KERNEL) in the schedule graph -def init_kernel(ctx:dict[UOp, UOp], u:UOp): return u.buf_uop.assign(UOp(Ops.KERNEL, src=u.src, arg=Kernel(ctx[u.buf_uop].sink(), ()))) +def init_kernel(ctx:dict[UOp, UOp], u:UOp): return u.buf_uop.assign(UOp(Ops.KERNEL, src=u.src, arg=Kernel(ctx[u.buf_uop].sink()))) def is_kernel(u:UOp) -> bool: return u.op is Ops.ASSIGN and u.src[1].op is Ops.KERNEL DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.BUFFER} @@ -502,7 +495,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va schedule: list[ScheduleItem] = [] while queue: u = queue.popleft() - schedule.append(kernel_to_si(schedule_uop(u.src[1].arg.ast, ctx))) + schedule.append(schedule_uop(u.src[1].arg.ast, ctx)) for x in children.get(u, []): in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index f90058964f..4794169207 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -500,7 +500,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return ret def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True) @property - def metadata(self) -> tuple[Metadata, ...]|Metadata|None: return self.arg.metadata if self.op is Ops.KERNEL else all_metadata.get(self, None) + def metadata(self) -> Metadata|None: return all_metadata.get(self, None) # *** uop movement ops ***