mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
11
test/external/process_replay/process_replay.py
vendored
11
test/external/process_replay/process_replay.py
vendored
@@ -2,11 +2,11 @@
|
||||
# compare kernels created by HEAD against master
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
|
||||
from typing import Callable, cast
|
||||
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
|
||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm, dedup
|
||||
from tinygrad.engine.grouper import get_becomes_map
|
||||
from tinygrad.codegen.kernel import Kernel, Opt
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.ops import UOp
|
||||
from tinygrad.ops import UOp, Ops
|
||||
|
||||
# *** process replay settings
|
||||
|
||||
@@ -34,8 +34,9 @@ class ProcessReplayWarning(Warning): pass
|
||||
# *** recreators
|
||||
|
||||
def recreate_sched(big_sink:UOp) -> list[UOp]:
|
||||
sched, _, __ = create_schedule_with_vars(big_sink)
|
||||
return [x.ast for x in sched]
|
||||
sched_sink = get_becomes_map(big_sink)[0][big_sink]
|
||||
return dedup(u.src[1].arg.ast for u in sched_sink.toposort if u.op is Ops.ASSIGN)
|
||||
|
||||
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str, _) -> str:
|
||||
k = Kernel(ast, opts=opts)
|
||||
for opt in applied_opts: k.apply_opt(opt)
|
||||
|
||||
@@ -4,8 +4,8 @@ from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, grap
|
||||
from tinygrad.ops import can_pad, sint, track_rewrites
|
||||
from tinygrad.codegen.lowerer import get_contraction
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, flatten, getenv, pluralize
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, flatten, getenv, pluralize, ContextVar, Context, diskcache_put
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP, CAPTURE_PROCESS_REPLAY
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
@@ -439,6 +439,13 @@ def get_name(ret:tuple[dict[UOp, UOp], dict[Variable, int]]) -> str:
|
||||
kcount = len({u.src[1] for u in ret[0].values() if u.op is Ops.ASSIGN})
|
||||
return f"Schedule {pluralize('Kernel', kcount)}"+(f" (with_{pluralize('Var', len(ret[1]))})" if ret[1] else "")
|
||||
|
||||
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
import atexit
|
||||
@atexit.register
|
||||
def save_process_replay():
|
||||
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
|
||||
|
||||
@track_rewrites(name_fxn=get_name)
|
||||
def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
|
||||
# merge_views + simplify
|
||||
@@ -487,4 +494,12 @@ def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
|
||||
var_vals: dict[Variable, int] = {}
|
||||
sched_sink = graph_rewrite(sched_sink, create_ast, ctx=var_vals, bottom_up=True)
|
||||
becomes_map[big_sink] = sched_sink
|
||||
|
||||
# capture process replay
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
with Context(PICKLE_BUFFERS=0):
|
||||
import pickle
|
||||
asts = dedup(u.arg.ast for u in sched_sink.toposort if u.op is Ops.KERNEL)
|
||||
PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, asts))
|
||||
|
||||
return becomes_map, var_vals
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import atexit, pickle
|
||||
from dataclasses import dataclass
|
||||
from collections import deque
|
||||
from tinygrad.ops import UOp, Variable, Ops, buffers
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import Metadata, CAPTURE_PROCESS_REPLAY, DEBUG, Context, ContextVar, diskcache_put, unwrap
|
||||
from tinygrad.helpers import Metadata, DEBUG, unwrap
|
||||
from tinygrad.engine.grouper import get_becomes_map
|
||||
|
||||
# **** ScheduleItem return type
|
||||
@@ -14,12 +13,6 @@ class ScheduleItem:
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
|
||||
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
@atexit.register
|
||||
def save_process_replay():
|
||||
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
@@ -53,10 +46,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
if len(schedule) != len(in_degree): raise RuntimeError(f"created {len(in_degree)} kernels but only scheduled {len(schedule)}")
|
||||
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
||||
|
||||
# capture process replay
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule]))
|
||||
|
||||
# map ASSIGN to BUFFER after ScheduleItems are constructed
|
||||
for k,v in becomes_map.items():
|
||||
if v.base.op is Ops.ASSIGN:
|
||||
|
||||
Reference in New Issue
Block a user