move process replay to grouper (#9830)

* simpler

* sched
This commit is contained in:
qazal
2025-04-10 18:27:42 +08:00
committed by GitHub
parent c8f47c1d07
commit 16afe04f45
3 changed files with 24 additions and 19 deletions

View File

@@ -2,11 +2,11 @@
# compare kernels created by HEAD against master
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
from typing import Callable, cast
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
from tinygrad.engine.schedule import create_schedule_with_vars
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm, dedup
from tinygrad.engine.grouper import get_becomes_map
from tinygrad.codegen.kernel import Kernel, Opt
from tinygrad.renderer import Renderer
from tinygrad.ops import UOp
from tinygrad.ops import UOp, Ops
# *** process replay settings
@@ -34,8 +34,9 @@ class ProcessReplayWarning(Warning): pass
# *** recreators
def recreate_sched(big_sink:UOp) -> list[UOp]:
sched, _, __ = create_schedule_with_vars(big_sink)
return [x.ast for x in sched]
sched_sink = get_becomes_map(big_sink)[0][big_sink]
return dedup(u.src[1].arg.ast for u in sched_sink.toposort if u.op is Ops.ASSIGN)
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str, _) -> str:
k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)

View File

@@ -4,8 +4,8 @@ from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, grap
from tinygrad.ops import can_pad, sint, track_rewrites
from tinygrad.codegen.lowerer import get_contraction
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, flatten, getenv, pluralize
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, flatten, getenv, pluralize, ContextVar, Context, diskcache_put
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP, CAPTURE_PROCESS_REPLAY
from tinygrad.dtype import ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
@@ -439,6 +439,13 @@ def get_name(ret:tuple[dict[UOp, UOp], dict[Variable, int]]) -> str:
kcount = len({u.src[1] for u in ret[0].values() if u.op is Ops.ASSIGN})
return f"Schedule {pluralize('Kernel', kcount)}"+(f" (with_{pluralize('Var', len(ret[1]))})" if ret[1] else "")
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
if CAPTURE_PROCESS_REPLAY:
import atexit
@atexit.register
def save_process_replay():
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
@track_rewrites(name_fxn=get_name)
def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
# merge_views + simplify
@@ -487,4 +494,12 @@ def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
var_vals: dict[Variable, int] = {}
sched_sink = graph_rewrite(sched_sink, create_ast, ctx=var_vals, bottom_up=True)
becomes_map[big_sink] = sched_sink
# capture process replay
if CAPTURE_PROCESS_REPLAY:
with Context(PICKLE_BUFFERS=0):
import pickle
asts = dedup(u.arg.ast for u in sched_sink.toposort if u.op is Ops.KERNEL)
PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, asts))
return becomes_map, var_vals

View File

@@ -1,9 +1,8 @@
import atexit, pickle
from dataclasses import dataclass
from collections import deque
from tinygrad.ops import UOp, Variable, Ops, buffers
from tinygrad.device import Buffer
from tinygrad.helpers import Metadata, CAPTURE_PROCESS_REPLAY, DEBUG, Context, ContextVar, diskcache_put, unwrap
from tinygrad.helpers import Metadata, DEBUG, unwrap
from tinygrad.engine.grouper import get_becomes_map
# **** ScheduleItem return type
@@ -14,12 +13,6 @@ class ScheduleItem:
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...] = ()
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
if CAPTURE_PROCESS_REPLAY:
@atexit.register
def save_process_replay():
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
# **** schedule linearizer
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
@@ -53,10 +46,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
if len(schedule) != len(in_degree): raise RuntimeError(f"created {len(in_degree)} kernels but only scheduled {len(schedule)}")
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
# capture process replay
if CAPTURE_PROCESS_REPLAY:
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule]))
# map ASSIGN to BUFFER after ScheduleItems are constructed
for k,v in becomes_map.items():
if v.base.op is Ops.ASSIGN: