kernel op cleanups + use ScheduleItem [pr] (#9009)

This commit is contained in:
qazal
2025-02-10 17:54:30 +01:00
committed by GitHub
parent 25fa5e4d5f
commit c52cd2b437
2 changed files with 11 additions and 18 deletions

View File

@@ -249,7 +249,7 @@ break_sched = PatternMatcher([
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
])
# **** convert Kernel to a ScheduleItem (for legacy reasons)
# **** ScheduleItem creation (TODO: replace ScheduleItem with the KERNEL UOp)
@dataclass(frozen=True)
class ScheduleItem:
@@ -267,18 +267,6 @@ class ScheduleItem:
@functools.cached_property
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
def kernel_to_si(k:UOp) -> ScheduleItem:
assert k.op is Ops.KERNEL and isinstance(k.metadata, tuple), f"must be KERNEL {k}"
return ScheduleItem(k.arg.ast, tuple(u.buf_uop.buffer for u in k.src), k.metadata)
# **** Kernel creation
@dataclass(frozen=True)
class Kernel:
ast: UOp
metadata: tuple[Metadata, ...]
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
@dataclass(frozen=True)
class KernelContext:
var_vals: dict[Variable, int]
@@ -380,14 +368,14 @@ def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
return var
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> UOp:
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
# unbind_vars + push views to edges
sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=ctx.var_vals), view_right)
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
ast = graph_rewrite(sink, to_si, si_ctx:=KernelContext(ctx.var_vals))
# NOTE: we only add the metadata for fused tensors
metadata = tuple(dedup(m for x in pre.toposort if x.op is not Ops.BUFFER and (m:=ctx.ops_metadata.get(x)) is not None))
return UOp(Ops.KERNEL, src=tuple(si_ctx.bufs), arg=Kernel(ast, metadata))
return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), metadata)
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
if CAPTURE_PROCESS_REPLAY:
@@ -395,8 +383,13 @@ if CAPTURE_PROCESS_REPLAY:
def save_process_replay():
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
@dataclass(frozen=True)
class Kernel:
ast: UOp
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op}>"
# NOTE: realizes become ASSIGN(BUFFER, KERNEL) in the schedule graph
def init_kernel(ctx:dict[UOp, UOp], u:UOp): return u.buf_uop.assign(UOp(Ops.KERNEL, src=u.src, arg=Kernel(ctx[u.buf_uop].sink(), ())))
def init_kernel(ctx:dict[UOp, UOp], u:UOp): return u.buf_uop.assign(UOp(Ops.KERNEL, src=u.src, arg=Kernel(ctx[u.buf_uop].sink())))
def is_kernel(u:UOp) -> bool: return u.op is Ops.ASSIGN and u.src[1].op is Ops.KERNEL
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.BUFFER}
@@ -502,7 +495,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
schedule: list[ScheduleItem] = []
while queue:
u = queue.popleft()
schedule.append(kernel_to_si(schedule_uop(u.src[1].arg.ast, ctx)))
schedule.append(schedule_uop(u.src[1].arg.ast, ctx))
for x in children.get(u, []):
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)

View File

@@ -500,7 +500,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ret
def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
@property
def metadata(self) -> tuple[Metadata, ...]|Metadata|None: return self.arg.metadata if self.op is Ops.KERNEL else all_metadata.get(self, None)
def metadata(self) -> Metadata|None: return all_metadata.get(self, None)
# *** uop movement ops ***