mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
order copies
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import sys, atexit, pickle
|
||||
import sys, atexit, pickle, heapq
|
||||
from typing import cast, Any
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
|
||||
@@ -462,17 +463,31 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
children.setdefault(s, []).append(u)
|
||||
in_degree[u] += 1
|
||||
|
||||
queue = deque(k for k,v in in_degree.items() if v == 0)
|
||||
CPU_DEVICE = {"CPU", "LLVM", "PYTHON", "NPY"}
|
||||
queue:list[tuple[tuple[str, ...], int, UOp]] = []
|
||||
ii = 0
|
||||
def push(u:UOp):
|
||||
nonlocal ii
|
||||
assert u.op is Ops.ASSIGN
|
||||
devices = tuple(dedup([cast(str, x.device).split(":")[0] for x in u.src[1].src]))
|
||||
priority = 0
|
||||
if devices[0] in CPU_DEVICE: priority = 10 # copy out
|
||||
if devices[-1] in CPU_DEVICE: priority = -10 # copy in
|
||||
heapq.heappush(cast(Any, queue), (priority, ii, u))
|
||||
ii += 1
|
||||
for k,v in in_degree.items():
|
||||
if v == 0: push(k)
|
||||
|
||||
schedule: list[ScheduleItem] = []
|
||||
var_vals: dict[Variable, int] = {}
|
||||
while queue:
|
||||
u = queue.popleft()
|
||||
_,_,u = heapq.heappop(queue)
|
||||
# TODO: move this to create_kernels
|
||||
k = fix_kernel_ast(u.src[1], var_vals)
|
||||
schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
|
||||
for x in children.get(u, []):
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
if in_degree[x] == 0: push(x)
|
||||
|
||||
# confirm everything was scheduled correctly
|
||||
if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")
|
||||
|
||||
Reference in New Issue
Block a user