mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-14 00:45:16 +08:00
146 lines
8.8 KiB
Python
146 lines
8.8 KiB
Python
from __future__ import annotations
|
|
import time
|
|
from typing import cast
|
|
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, MultiBuffer
|
|
from tinygrad.dtype import dtypes
|
|
from tinygrad.engine.jit import GraphRunner
|
|
from tinygrad.engine.realize import get_call_outs_ins, get_runtime
|
|
from tinygrad.helpers import round_up, ceildiv
|
|
from tinygrad.runtime.support.memory import BumpAllocator
|
|
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, graph_rewrite
|
|
from extra.hcq2.hcq2 import HCQ2Compiled, HCQ2DeviceCtx, HCQ2LowerCtx, pm_prep_runtime, pm_lower_ops
|
|
from extra.hcq2.hcq2 import pm_split_into_queues, pm_add_barriers, pm_add_signals
|
|
from extra.hcq2.hcq2 import pm_bufferize, pm_lift_patches_to_cmdbuf, pm_resolve_patches, pm_parametrize_host_buffers
|
|
from extra.hcq2.hcq2 import pm_add_timeline_inc, pm_callify, pm_calc_kernargs_sizes
|
|
|
|
# **************** insert deps ****************
|
|
|
|
def insert_deps(ctx:HCQ2Graph, linear:UOp) -> UOp:
|
|
src = []
|
|
for j, call in enumerate(linear.src):
|
|
call = call.replace(tag=j)
|
|
_, _, bufs, _ = ctx.calls[j]
|
|
outs, ins = get_call_outs_ins(call)
|
|
deps = ctx._access_resources([bufs[i] for i in outs + ins], list(range(len(outs))), call)
|
|
src.append(UOp(Ops.AFTER, call.dtype, (call, *deps), tag=call.tag))
|
|
return linear.replace(src=tuple(src))
|
|
pm_insert_deps = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), insert_deps)])
|
|
|
|
pm_replace_params = PatternMatcher([
|
|
(UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.input_addrs_uop.index(UOp.const(dtypes.int, p.arg))),
|
|
(UPat(Ops.SLICE, src=(UPat(Ops.INDEX, name="addr"), UPat(Ops.CONST, dtype=dtypes.weakint, name="off")), name="bv"),
|
|
lambda ctx, bv, addr, off: addr.cast(dtypes.uint64) + UOp.const(dtypes.uint64, off.arg * ctx.input_uops[addr.src[1].arg].dtype.itemsize)),
|
|
])
|
|
|
|
# **************** graph-only passes ****************
|
|
|
|
def alloc_queue_sig(ctx:HCQ2Graph, q:UOp) -> None:
|
|
if q.arg in ctx.queue_sigs: return None
|
|
dev = q.arg[0][0] # TODO: multi device
|
|
buf = Buffer(dev, 0x100, dtypes.uint8, options=BufferSpec(host=True, uncached=True, cpu_access=True), preallocate=True)
|
|
ctx.queue_sig_bufs.append(buf)
|
|
ctx.queue_sigs[q.arg] = UOp.from_buffer(buf, dev)
|
|
return None
|
|
pm_alloc_queue_sigs = PatternMatcher([(UPat(Ops.LINEAR, src=UPat({Ops.PROGRAM, Ops.COPY}), name="q"), alloc_queue_sig)])
|
|
|
|
def lower_queue_deps(ctx:HCQ2Graph, after:UOp) -> UOp:
|
|
wrapper, deps, call_idx = after.src[0], after.src[1:], after.tag
|
|
def store(q_arg, v): return ctx.queue_sigs[q_arg].store(UOp.const(dtypes.uint32, v))
|
|
waits = tuple(UOp(Ops.WAIT, dtypes.void, (ctx.queue_sigs[dep.src[0].arg], UOp.const(dtypes.uint32, dep.tag),
|
|
store(dep.src[0].arg, dep.tag))) for dep in deps)
|
|
return wrapper.replace(src=tuple(q.replace(src=(*waits, *q.src, store(q.arg, call_idx))) for q in wrapper.src))
|
|
pm_lower_queue_deps = PatternMatcher([(UPat(Ops.AFTER, src=UPat(Ops.LINEAR), name="after"), lower_queue_deps)])
|
|
|
|
def optimize_queue_deps(ctx:HCQ2Graph, queue:UOp) -> UOp|None:
|
|
src, seen, pending, queue_sig = [], {}, {}, ctx.queue_sigs[queue.arg]
|
|
for x in queue.src:
|
|
if x.op is Ops.WAIT:
|
|
sig, val = x.src[0], x.src[1]
|
|
if sig is queue_sig or seen.get(sig, -1) >= val.arg: continue
|
|
if (old:=pending.get(sig)) is None or old.src[1].arg < val.arg: pending[sig] = x
|
|
continue
|
|
for wait in pending.values():
|
|
src.append(wait)
|
|
seen[wait.src[0]] = wait.src[1].arg
|
|
pending.clear()
|
|
src.append(x)
|
|
src += pending.values()
|
|
return queue.replace(src=tuple(src)) if tuple(src) != queue.src else None
|
|
pm_optimize_queue_deps = PatternMatcher([
|
|
(UPat(Ops.LINEAR, src=UPat({Ops.BARRIER, Ops.WAIT, Ops.STORE, Ops.PROGRAM, Ops.COPY}), name="queue"), optimize_queue_deps),
|
|
])
|
|
|
|
def drop_dead_stores(ctx:HCQ2Graph, outer:UOp) -> UOp:
|
|
live = {u.src[2] for u in outer.toposort() if u.op is Ops.WAIT}
|
|
return outer.replace(src=tuple(q.replace(src=tuple(x for x in q.src if x.op is not Ops.STORE or x in live)) for q in outer.src))
|
|
pm_drop_dead_stores = PatternMatcher([(UPat(Ops.LINEAR, src=UPat(Ops.LINEAR), name="outer"), drop_dead_stores)])
|
|
|
|
def add_queue_sig_resets(ctx:HCQ2Graph, x:UOp, cmdbuf:UOp) -> UOp|None:
|
|
if not ctx.queue_sig_bufs or cmdbuf.tag not in ("compute", "copy"): return None
|
|
resets = tuple((b:=UOp.from_buffer(sig)).index(UOp.const(dtypes.int, 0), dtype=b.dtype.ptr())
|
|
.cast(dtypes.uint64.ptr()).store(UOp.const(dtypes.uint64, 0)) for sig in ctx.queue_sig_bufs)
|
|
return x.replace(src=x.src + resets)
|
|
pm_add_queue_sig_resets = PatternMatcher([(UPat(Ops.AFTER, src=(UPat(Ops.BUFFER, name="cmdbuf"),), allow_any_len=True, name="x"),
|
|
add_queue_sig_resets)])
|
|
|
|
# **************** Graph ****************
|
|
|
|
class HCQ2Graph(GraphRunner):
|
|
def __init__(self, linear:UOp, input_uops:tuple[UOp, ...]=()):
|
|
super().__init__(linear, input_uops)
|
|
self.dev = cast(HCQ2Compiled, Device[self.device])
|
|
self.hcq_ctx = HCQ2LowerCtx(name="hcq_graph")
|
|
|
|
self.input_addrs = Buffer("CPU", max(len(input_uops), 1), dtypes.uint64, preallocate=True)
|
|
self.input_addrs_uop = UOp.from_buffer(self.input_addrs, "CPU")
|
|
|
|
self.linear = graph_rewrite(self.linear, pm_insert_deps, ctx=self, name="hcq: insert deps", walk=True)
|
|
self.linear = graph_rewrite(self.linear, pm_replace_params, ctx=self, name="hcq: replace params", walk=True)
|
|
self.linear = graph_rewrite(self.linear, pm_prep_runtime, ctx=self.hcq_ctx, name="hcq: prepare runtime")
|
|
self.linear = graph_rewrite(self.linear, pm_lower_ops, ctx=self.hcq_ctx, name="hcq: lower ops")
|
|
|
|
# per-queue signal state — populated as a side-effect by pm_alloc_queue_sigs walking the lowered linear.
|
|
self.queue_sig_bufs:list[Buffer] = []
|
|
self.queue_sigs:dict[tuple[str, str], UOp] = {}
|
|
graph_rewrite(self.linear, pm_alloc_queue_sigs, ctx=self, name="hcq: alloc queue sigs", walk=True)
|
|
|
|
self.linear = graph_rewrite(self.linear, pm_lower_queue_deps, ctx=self, name="hcq: lower queue deps")
|
|
self.linear = graph_rewrite(self.linear, pm_split_into_queues, ctx=self.hcq_ctx, name="hcq: split into queues")
|
|
self.linear = graph_rewrite(self.linear, pm_add_barriers, ctx=self.hcq_ctx, name="hcq: add barriers", walk=True)
|
|
self.linear = graph_rewrite(self.linear, pm_optimize_queue_deps, ctx=self, name="hcq: optimize queue deps", walk=True)
|
|
self.linear = graph_rewrite(self.linear, pm_drop_dead_stores, ctx=self, name="hcq: drop dead stores")
|
|
self.linear = graph_rewrite(self.linear, pm_add_signals, ctx=self.hcq_ctx, name="hcq: add signals", walk=True)
|
|
self.linear = graph_rewrite(self.linear, pm_add_timeline_inc, ctx=self.hcq_ctx, name="hcq: add submit", walk=True)
|
|
self.linear = graph_rewrite(self.linear, self.dev.pm_lower, ctx=self.hcq_ctx, name=f"hcq: encode cmdbuf {self.dev.device}", walk=True)
|
|
|
|
graph_rewrite(self.linear, pm_calc_kernargs_sizes, ctx=(sizes:={}), name=None)
|
|
for dev_name, sz in sizes.items():
|
|
buf = Buffer(dev_name, sz, dtypes.uint8, options=BufferSpec(cpu_access=True), preallocate=True)
|
|
self.hcq_ctx.dev_ctx[dev_name] = HCQ2DeviceCtx(dev_name, UOp.from_buffer(buf, dev_name), UOp.const(dtypes.uint64, buf._buf.va_addr))
|
|
|
|
self.linear = graph_rewrite(self.linear, pm_bufferize, ctx=self.hcq_ctx, bottom_up=True, name="realize binaries")
|
|
self.linear = graph_rewrite(self.linear, pm_lift_patches_to_cmdbuf, ctx=self.hcq_ctx, bottom_up=False, name="lift patches to cmdbuf")
|
|
self.linear = graph_rewrite(self.linear, pm_resolve_patches, ctx=self.hcq_ctx, bottom_up=False, name="simplify patches")
|
|
self.linear = graph_rewrite(self.linear, pm_add_queue_sig_resets, ctx=self, name="hcq: add queue sig resets", walk=True)
|
|
self.linear = graph_rewrite(self.linear, pm_parametrize_host_buffers, ctx=self.hcq_ctx, bottom_up=True, name="parametrize host buffers")
|
|
self.host_call = graph_rewrite(self.linear, pm_callify, ctx=self.hcq_ctx, name="hcq: callify")
|
|
|
|
self.host_rt, self.host_globals = get_runtime("CPU", self.host_call.src[0]), self.host_call.src[0].arg.globals
|
|
|
|
def __call__(self, input_uops:tuple[UOp, ...], var_vals:dict[str, int], wait=False) -> float|None:
|
|
addrs = self.input_addrs.as_memoryview(force_zero_copy=True).cast('Q')
|
|
for i, u in enumerate(input_uops):
|
|
buf = next(b for b in u.buffer.bufs if b.device == self.dev.device) if isinstance(u.buffer, MultiBuffer) else u.buffer
|
|
addrs[i] = buf._buf.va_addr
|
|
self.host_rt(*[self.hcq_ctx.inputs[i].get_buf("CPU") for i in self.host_globals], vals=self.host_call.src[0].arg.vals(var_vals), wait=True)
|
|
if wait:
|
|
st = time.perf_counter()
|
|
self.dev.synchronize()
|
|
return time.perf_counter() - st
|
|
return None
|
|
|
|
@staticmethod
|
|
def supports_uop(batch_devs:list[Compiled], new_call:UOp) -> bool:
|
|
all_devs = GraphRunner._all_devs(batch_devs, new_call)
|
|
return new_call.src[0].op in (Ops.PROGRAM, Ops.COPY) and len(all_devs) == 1 and isinstance(all_devs[0], HCQ2Compiled)
|