var_vals prereq for deleting LBScheduleItem [run_process_replay] (#6511)

This commit is contained in:
qazal
2024-09-14 17:00:30 +08:00
committed by GitHub
parent 9188245677
commit 4ffb722d4e
4 changed files with 22 additions and 20 deletions

View File

@@ -18,7 +18,7 @@ def fuzz_schedule(outs:List[LazyBuffer]):
for combination in itertools.product(*ctx_vars.values()):
for var, val in zip(ctx_vars, combination): var.value = val
ctx_var_values = dict(zip([v.key for v in ctx_vars], combination))
graph, in_degree = _graph_schedule(outs)
graph, in_degree, _ = _graph_schedule(outs)
for ts in find_all_toposorts(graph, in_degree): unique_ts[ts] = ctx_var_values
toposorts = list(unique_ts.items())
if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow"))

View File

@@ -17,7 +17,7 @@ def process_replay(outs:List[LazyBuffer], graph:DefaultDict[LBScheduleItem, List
if not os.path.isfile(fp):
shutil.copyfile(fetch(f"https://raw.githubusercontent.com/tinygrad/tinygrad/{ref_schedule}/tinygrad/engine/schedule.py", allow_caching=False), fp)
# create the reference graph
ref_graph, ref_in_degree = importlib.import_module("test.external.process_replay.master_schedule")._graph_schedule(outs, set())
ref_graph, ref_in_degree = importlib.import_module("test.external.process_replay.master_schedule")._graph_schedule(outs)
# compare
diff_schedule([(ref_graph, ref_in_degree), (graph, in_degree)])

View File

@@ -18,8 +18,8 @@ class TestDiffSchedule(unittest.TestCase):
X = Tensor.randn(10, 10).realize()
idxs = Tensor([0, 2]).realize()
xt = cast(LazyBuffer, X[idxs].lazydata)
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([xt])
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([xt])
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree, _ = _graph_schedule([xt])
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree, _ = _graph_schedule([xt])
# 1 arange LazyBuffer folds, 1 arange child's kernel changes
changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)])
self.assertEqual(changed, 1)
@@ -30,15 +30,15 @@ class TestDiffSchedule(unittest.TestCase):
for _ in range(2):
X = Tensor.randn(10, 10).realize()
xt = cast(LazyBuffer, X[idxs].lazydata)
with Context(FUSE_ARANGE=0): schedules.append(_graph_schedule([xt]))
with Context(FUSE_ARANGE=1): schedules.append(_graph_schedule([xt]))
with Context(FUSE_ARANGE=0): schedules.append(_graph_schedule([xt])[:-1])
with Context(FUSE_ARANGE=1): schedules.append(_graph_schedule([xt])[:-1])
changed = diff_schedule(schedules)
self.assertEqual(changed, 1)
def test_no_diff(self):
a = cast(LazyBuffer, (Tensor([1])+Tensor([2])).lazydata)
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([a])
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([a])
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree, _ = _graph_schedule([a])
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree, _ = _graph_schedule([a])
changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)])
self.assertEqual(changed, 0)
@@ -49,8 +49,8 @@ class TestDiffSchedule(unittest.TestCase):
c1(img).relu().mean().backward()
assert img.grad is not None and c1.weight.grad is not None
outs = [cast(LazyBuffer, img.grad.lazydata), cast(LazyBuffer, c1.weight.grad.lazydata)]
with Context(FUSE_CONV_BW=0): ref_graph, ref_in_degree = _graph_schedule(outs)
with Context(FUSE_CONV_BW=1): compare_graph, compare_in_degree = _graph_schedule(outs)
with Context(FUSE_CONV_BW=0): ref_graph, ref_in_degree, _ = _graph_schedule(outs)
with Context(FUSE_CONV_BW=1): compare_graph, compare_in_degree, _ = _graph_schedule(outs)
changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)])
# 1 reduceop folds, its child reduceop changes
self.assertEqual(changed, 1)

View File

@@ -38,7 +38,6 @@ class LBScheduleItem:
ast: UOp
outputs: List[LazyBuffer]
inputs: List[LazyBuffer]
var_vals: Dict[Variable, int] = field(default_factory=dict)
metadata: List[Metadata] = field(default_factory=list)
def __hash__(self):
"""The unique identifier of a schedule item in the toposort."""
@@ -159,10 +158,10 @@ reduceop_fusor = PatternMatcher([
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> Tuple[LBScheduleItem, Dict[Variable, int]]:
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
if (out:=outs[0]).op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs]), {}
# create the stores
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
@@ -185,7 +184,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
from tinygrad.engine.graph import graph_uop
graph_uop(sink)
raise e
return LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))
return LBScheduleItem(sink, outs, list(inputs), dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])), var_vals
# *** DAG creation: decide which LazyBuffers should realize ***
@@ -360,11 +359,16 @@ def _get_output_groups(outs:List[LazyBuffer]) -> \
SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = []
def _graph_schedule(outs:List[LazyBuffer]) -> \
Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], # this is the graph
DefaultDict[LBScheduleItem, int]]: # this is the in-degree of the graph
DefaultDict[LBScheduleItem, int], # this is the in-degree of the graph
Dict[Variable, int]]: # this has all the var values of the schedule
"""create a graph for realizing the outputs"""
output_groups, realizes, assign_targets = _get_output_groups(outs)
# preschedule all buffers in realizes
prescheduled = [_lower_lazybuffer(group, realizes) for group in output_groups.values()]
prescheduled: List[LBScheduleItem] = []
var_vals: Dict[Variable, int] = {}
for group in output_groups.values():
prescheduled.append((ret:=_lower_lazybuffer(group, realizes))[0])
var_vals = merge_dicts([var_vals, ret[1]])
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)
@@ -388,26 +392,24 @@ def _graph_schedule(outs:List[LazyBuffer]) -> \
with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
if len(SCHEDULES) == 0: atexit.register(_save)
SCHEDULES.append((graph, in_degree))
return graph, in_degree
return graph, in_degree, var_vals
# *** DAG ordering: breadth first search ***
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
graph, in_degree = _graph_schedule(outs)
graph, in_degree, var_vals = _graph_schedule(outs)
if getenv("RUN_PROCESS_REPLAY") and getenv("COMPARE_SCHEDULE", 1):
# NOTE: process relpay needs PYTHONPATH=., remove this once it just pickles LazyBuffers
with contextlib.suppress(Exception): importlib.import_module("test.external.process_replay.diff_schedule").process_replay(outs, graph, in_degree)
queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0)
schedule: List[ScheduleItem] = []
var_vals: Dict[Variable, int] = {}
kernel_number = GlobalCounters.kernel_count
while queue:
lsi = queue.popleft()
if GRAPH:
kernel_number += 1
for out in lsi.outputs: realized_lazybuffer(out, kernel_number)
var_vals = merge_dicts([var_vals, lsi.var_vals])
for out in lsi.outputs: del out.srcs # can only schedule once
schedule.append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata))
for x in graph[lsi]: