From 3f62c8693b2ce99e0daae9d7d87f6a35d113da5f Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 3 Apr 2025 14:43:02 +0800 Subject: [PATCH] order copies --- tinygrad/engine/schedule.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f0c47f0e41..6c58c025d2 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,4 +1,5 @@ -import sys, atexit, pickle +import sys, atexit, pickle, heapq +from typing import cast, Any from collections import defaultdict, deque from dataclasses import dataclass from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers @@ -462,17 +463,31 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va children.setdefault(s, []).append(u) in_degree[u] += 1 - queue = deque(k for k,v in in_degree.items() if v == 0) + CPU_DEVICE = {"CPU", "LLVM", "PYTHON", "NPY"} + queue:list[tuple[tuple[str, ...], int, UOp]] = [] + ii = 0 + def push(u:UOp): + nonlocal ii + assert u.op is Ops.ASSIGN + devices = tuple(dedup([cast(str, x.device).split(":")[0] for x in u.src[1].src])) + priority = 0 + if devices[0] in CPU_DEVICE: priority = 10 # copy out + if devices[-1] in CPU_DEVICE: priority = -10 # copy in + heapq.heappush(cast(Any, queue), (priority, ii, u)) + ii += 1 + for k,v in in_degree.items(): + if v == 0: push(k) + schedule: list[ScheduleItem] = [] var_vals: dict[Variable, int] = {} while queue: - u = queue.popleft() + _,_,u = heapq.heappop(queue) # TODO: move this to create_kernels k = fix_kernel_ast(u.src[1], var_vals) schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata)) for x in children.get(u, []): in_degree[x] -= 1 - if in_degree[x] == 0: queue.append(x) + if in_degree[x] == 0: push(x) # confirm everything was scheduled correctly if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")