order copies

This commit is contained in:
George Hotz
2025-04-03 14:43:02 +08:00
parent 2006afabf3
commit 3f62c8693b

View File

@@ -1,4 +1,5 @@
import sys, atexit, pickle
import sys, atexit, pickle, heapq
from typing import cast, Any
from collections import defaultdict, deque
from dataclasses import dataclass
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
@@ -462,17 +463,31 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
children.setdefault(s, []).append(u)
in_degree[u] += 1
queue = deque(k for k,v in in_degree.items() if v == 0)
CPU_DEVICE = {"CPU", "LLVM", "PYTHON", "NPY"}
queue:list[tuple[tuple[str, ...], int, UOp]] = []
ii = 0
def push(u:UOp):
nonlocal ii
assert u.op is Ops.ASSIGN
devices = tuple(dedup([cast(str, x.device).split(":")[0] for x in u.src[1].src]))
priority = 0
if devices[0] in CPU_DEVICE: priority = 10 # copy out
if devices[-1] in CPU_DEVICE: priority = -10 # copy in
heapq.heappush(cast(Any, queue), (priority, ii, u))
ii += 1
for k,v in in_degree.items():
if v == 0: push(k)
schedule: list[ScheduleItem] = []
var_vals: dict[Variable, int] = {}
while queue:
u = queue.popleft()
_,_,u = heapq.heappop(queue)
# TODO: move this to create_kernels
k = fix_kernel_ast(u.src[1], var_vals)
schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
for x in children.get(u, []):
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
if in_degree[x] == 0: push(x)
# confirm everything was scheduled correctly
if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")