mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
kernel op cleanups + use ScheduleItem [pr] (#9009)
This commit is contained in:
@@ -249,7 +249,7 @@ break_sched = PatternMatcher([
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
|
||||
])
|
||||
|
||||
# **** convert Kernel to a ScheduleItem (for legacy reasons)
|
||||
# **** ScheduleItem creation (TODO: replace ScheduleItem with the KERNEL UOp)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
@@ -267,18 +267,6 @@ class ScheduleItem:
|
||||
@functools.cached_property
|
||||
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
|
||||
|
||||
def kernel_to_si(k:UOp) -> ScheduleItem:
|
||||
assert k.op is Ops.KERNEL and isinstance(k.metadata, tuple), f"must be KERNEL {k}"
|
||||
return ScheduleItem(k.arg.ast, tuple(u.buf_uop.buffer for u in k.src), k.metadata)
|
||||
|
||||
# **** Kernel creation
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Kernel:
|
||||
ast: UOp
|
||||
metadata: tuple[Metadata, ...]
|
||||
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KernelContext:
|
||||
var_vals: dict[Variable, int]
|
||||
@@ -380,14 +368,14 @@ def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
|
||||
return var
|
||||
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
|
||||
|
||||
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> UOp:
|
||||
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
|
||||
# unbind_vars + push views to edges
|
||||
sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=ctx.var_vals), view_right)
|
||||
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
|
||||
ast = graph_rewrite(sink, to_si, si_ctx:=KernelContext(ctx.var_vals))
|
||||
# NOTE: we only add the metadata for fused tensors
|
||||
metadata = tuple(dedup(m for x in pre.toposort if x.op is not Ops.BUFFER and (m:=ctx.ops_metadata.get(x)) is not None))
|
||||
return UOp(Ops.KERNEL, src=tuple(si_ctx.bufs), arg=Kernel(ast, metadata))
|
||||
return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), metadata)
|
||||
|
||||
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
@@ -395,8 +383,13 @@ if CAPTURE_PROCESS_REPLAY:
|
||||
def save_process_replay():
|
||||
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Kernel:
|
||||
ast: UOp
|
||||
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op}>"
|
||||
|
||||
# NOTE: realizes become ASSIGN(BUFFER, KERNEL) in the schedule graph
|
||||
def init_kernel(ctx:dict[UOp, UOp], u:UOp): return u.buf_uop.assign(UOp(Ops.KERNEL, src=u.src, arg=Kernel(ctx[u.buf_uop].sink(), ())))
|
||||
def init_kernel(ctx:dict[UOp, UOp], u:UOp): return u.buf_uop.assign(UOp(Ops.KERNEL, src=u.src, arg=Kernel(ctx[u.buf_uop].sink())))
|
||||
def is_kernel(u:UOp) -> bool: return u.op is Ops.ASSIGN and u.src[1].op is Ops.KERNEL
|
||||
|
||||
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.BUFFER}
|
||||
@@ -502,7 +495,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
schedule: list[ScheduleItem] = []
|
||||
while queue:
|
||||
u = queue.popleft()
|
||||
schedule.append(kernel_to_si(schedule_uop(u.src[1].arg.ast, ctx)))
|
||||
schedule.append(schedule_uop(u.src[1].arg.ast, ctx))
|
||||
for x in children.get(u, []):
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
@@ -500,7 +500,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return ret
|
||||
def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
|
||||
@property
|
||||
def metadata(self) -> tuple[Metadata, ...]|Metadata|None: return self.arg.metadata if self.op is Ops.KERNEL else all_metadata.get(self, None)
|
||||
def metadata(self) -> Metadata|None: return all_metadata.get(self, None)
|
||||
|
||||
# *** uop movement ops ***
|
||||
|
||||
|
||||
Reference in New Issue
Block a user