hcq2: deps (#16459)

* start

* sin

* f
This commit is contained in:
nimlgen
2026-06-02 22:34:25 +03:00
committed by GitHub
parent 82f1c983d4
commit 99e37b1ee3
3 changed files with 157 additions and 56 deletions

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import cast, Callable, TypeVar, Generic, Any, TYPE_CHECKING
import struct, functools, time, collections, importlib, itertools
import struct, functools, time, collections, importlib, itertools, weakref
from dataclasses import replace
if TYPE_CHECKING: from tinygrad.engine.realize import ExecContext
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, mv_address, round_up, DEBUG, dedup, pluralize
@@ -13,6 +13,7 @@ from tinygrad.runtime.support.memory import BumpAllocator
from tinygrad.runtime.support.hcq import MMIOInterface
from tinygrad.renderer import Renderer, Estimates
from tinygrad.engine.realize import to_program, track_stats, get_call_arg_uops, resolve_params, pm_flatten_linear
from tinygrad.engine.jit import DepsTracker
HCQDeviceType = TypeVar('HCQDeviceType', bound='HCQ2Compiled')
@@ -26,7 +27,8 @@ class HCQ2Compiled(Compiled):
self.pm_bufferize = PatternMatcher([
(UPat(Ops.BUFFER, tag="timeline_signal"), lambda ctx: ctx.timeline_signal),
(UPat(Ops.BUFFER, tag="timeline_value"), lambda ctx: ctx.timeline_value),
(UPat(Ops.BUFFER, name="b"), lambda ctx, b: Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=True, uncached=True, cpu_access=True))),
(UPat(Ops.BUFFER, name="b"), lambda ctx, b:
Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=True, uncached=True, cpu_access=True, nolru=True))), # TODO: remove nolru
])
super().__init__(device, allocator, compilers, lambda *a, **kw: None, None, arch=arch)
@@ -45,6 +47,16 @@ class HCQ2Compiled(Compiled):
buf.as_memoryview(force_zero_copy=True).cast('Q')[0] = 1
return buf
@functools.cache
def queue_timeline_signal(self, queue:str) -> Buffer:
return Buffer(self.device, 0x100, dtypes.uint8, options=BufferSpec(host=True, uncached=True, cpu_access=True), preallocate=True)
@functools.cache
def queue_timeline_value(self, queue:str) -> Buffer:
buf = Buffer("CPU", 1, dtypes.uint64, preallocate=True)
buf.as_memoryview(force_zero_copy=True).cast('Q')[0] = 1
return buf
def synchronize(self, timeout:int|None=None):
if not hasattr(self, 'iface'): return
sig = self.timeline_signal._buf.cpu_view().mv.cast('Q')
@@ -98,6 +110,7 @@ class HCQAllocator(LRUAllocator[HCQDeviceType], Generic[HCQDeviceType]):
@suppress_finalizing
def _free(self, buf:HCQ2Buffer, options:BufferSpec|None=None):
self.dev.synchronize()
if options is not None and options.external_ptr is not None: return
if hasattr(self, '_do_free'): self._do_free(buf, options)
@@ -132,6 +145,8 @@ def unwrap_after(uop):
while uop.op is Ops.AFTER: uop = uop.src[0]
return uop
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
class HCQEncoder:
def __init__(self): self.blob, self.patches = b'', []
@@ -220,14 +235,14 @@ pm_prep_runtime = PatternMatcher([
# 2.1. lowering to hcq ir
def lower_program(call:UOp, prg:UOp) -> UOp:
q = UOp(Ops.LINEAR, dtypes.void, (prg,), arg=(call.src[1].device, "COMPUTE"))
q = UOp(Ops.LINEAR, dtypes.void, (prg,), arg=(call.src[1].device, "COMPUTE:0"))
return call.replace(src=(q,) + call.src[1:]).rtag('hcq')
def lower_copy(call:UOp, copy:UOp) -> UOp|None:
dst, src = call.src[1], call.src[2]
if (hcq_dev:=next((b.device for b in (dst, src) if b.device.split(":")[0] in HCQ_DEVS), None)) is None: return None
q = UOp(Ops.LINEAR, dtypes.void, (UOp(Ops.COPY, dtypes.void, src=(dst, src), arg=src.buffer.nbytes),), arg=(hcq_dev, "COPY"))
q = UOp(Ops.LINEAR, dtypes.void, (UOp(Ops.COPY, dtypes.void, src=(dst, src), arg=src.buffer.nbytes),), arg=(hcq_dev, "COPY:0"))
return call.replace(src=(q,) + call.src[1:]).rtag('hcq')
pm_lower_ops = PatternMatcher([
@@ -236,19 +251,87 @@ pm_lower_ops = PatternMatcher([
])
# *****************
# 2.2. queue split
# 2.2. deps tracking
# device.timeline_signal/value are the per-device schedule epoch. Before a schedule queue accesses memory owned by device N for the first time,
# it waits for device[N].timeline_signal >= device[N].timeline_value - 1. This orders the schedule after all prior schedules that touched device N.
#
# queue.timeline_signal/value are per-queue progress counters used only inside a schedule.
# Only the owner queue signals its queue.timeline_signal. Values are monotonic.
#
# At schedule end, one finalizer queue per touched device[N] waits for every active queue on device[N] to reach its schedule-local
# final queue.timeline value, then signals device[N].timeline_signal with the schedule's reserved device epoch. After that, buffers/transients
# for device N from this schedule are safe for the next schedule
#
# C programs reserve and bump timeline values, then patch command buffers with the concrete wait/signal values.
# def split_into_queues(linear:UOp) -> UOp:
# out = []
# for k, grp in itertools.groupby(linear.src, lambda c: c.src[0].arg if c.op is Ops.CALL and c.src[0].op is Ops.LINEAR else None):
# if k is None: out.extend(grp)
# else:
# calls = list(grp)
# items = tuple(x for c in calls for x in c.src[0].src)
# args = tuple(a for c in calls for a in c.src[1:])
# out.append(calls[0].replace(src=(UOp(Ops.LINEAR, dtypes.void, items, arg=k),) + args))
# return linear.replace(src=tuple(out))
# pm_split_into_queues = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), split_into_queues)])
@dataclass
class DepsCtx:
deps:DepsTracker = field(default_factory=DepsTracker)
evid:itertools.count = field(default_factory=lambda: itertools.count(0))
last_per_queue:weakref.WeakValueDictionary[tuple[Any, str], UOp] = field(default_factory=weakref.WeakValueDictionary)
def get_writes_ids(call:UOp) -> tuple[int, ...]:
ast, writes = call.src[0].src[0], set()
for ast in call.src[0].src:
if ast.op is Ops.PROGRAM: writes.update(ast.arg[1].outs)
elif ast.op in (Ops.COPY, Ops.SLICE, Ops.CUSTOM_FUNCTION): writes.add(0)
return tuple(writes)
def insert_deps(ctx:DepsCtx, call:UOp) -> UOp|None:
q, refs, write = call.src[0].rtag(next(ctx.evid)), [b.buffer for b in get_call_arg_uops(call)], get_writes_ids(call)
if q.arg not in ctx.last_per_queue:
sig = UOp.new_buffer(q.arg[0], 0x100, dtypes.uint8).rtag("timeline_signal")
tl = UOp.new_buffer(q.arg[0], 1, dtypes.uint64).rtag("timeline_value").index(UOp.const(dtypes.int, 0))
q = q.replace(src=(sig.wait(tl - 1), *q.src))
ctx.last_per_queue[q.arg] = q
deps = []
for lane in range(len(refs[0].bufs) if isinstance(refs[0], MultiBuffer) else 1):
deps += ctx.deps.access_resources([b.bufs[lane] if isinstance(b, MultiBuffer) else b for b in refs], write, q)
return call.replace(src=(q.after(*dps).rtag("deps") if (dps:=dedup(deps)) else q,) + call.src[1:])
pm_insert_deps = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call", allow_any_len=True), insert_deps)])
def make_finalizer(devs:tuple[str, ...], queues:list[UOp], nbump:int) -> UOp:
sig = UOp.new_buffer(devs, 0x100, dtypes.uint8).rtag("timeline_signal")
tl = UOp.new_buffer(devs, 1, dtypes.uint64).rtag("timeline_value")
q = UOp(Ops.LINEAR, dtypes.void, (sig.store(tl.index(UOp.const(dtypes.int, 0))),), arg=(devs, "COMPUTE:0"), tag="finalizer")
def bump(b, by): return b.index(UOp.const(dtypes.int, 0), dtype=b.dtype.ptr()).store(b.index(UOp.const(dtypes.int, 0)) + by)
bumps = (bump(tl, 1),) + tuple(bump(UOp.new_buffer(devs, 1, dtypes.uint64).rtag((ty, "timeline_value")), nbump) for ty in dedup([q.arg[1] for q in queues]))
return UOp(Ops.CALL, dtypes.void, (q.after(*bumps).after(*queues).rtag("deps"),), tag="hcq")
def add_finalizer(ctx:DepsCtx, linear:UOp) -> UOp:
fams:dict[str, list[UOp]] = collections.defaultdict(list)
for q in ctx.last_per_queue.values(): fams[to_tuple(q.arg[0])[0].split(":")[0]].append(q)
nbump = next(ctx.evid)
finalizers = []
for queues in fams.values():
devs = tuple(sorted(dedup(d for q in queues for d in to_tuple(q.arg[0]))))
finalizers.append(make_finalizer(devs, queues, nbump))
return linear.replace(src=linear.src + tuple(finalizers))
def add_loads(ctx:set[int], call:UOp, after:UOp) -> UOp:
q = unwrap_after(after.src[0])
cur_devs = to_tuple(q.arg[0])
waits = []
for dq in [unwrap_after(dep) for dep in after.src[1:]]:
ctx.add(dq.tag)
dq_devs = to_tuple(dq.arg[0])
sigs = [UOp.new_buffer(d, 0x100, dtypes.uint8).rtag((dq.arg[1], "timeline_signal") if d in dq_devs else "max_sentinel_signal") for d in cur_devs]
orig_val = UOp.new_buffer(cur_devs, 1, dtypes.uint64).rtag((dq.arg[1], "timeline_value")).index(UOp.const(dtypes.int, 0))
waits.append(make_mstack(sigs).wait(orig_val + dq.tag))
return call.replace(src=(after.src[0].substitute({q: q.replace(src=(*waits, *q.src))}),) + call.src[1:])
pm_add_loads = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.AFTER, tag="deps", name="after"),), name="call", allow_any_len=True), add_loads)])
def add_stores(ctx:set[int], call:UOp) -> UOp|None:
if (q:=unwrap_after(call.src[0])).tag not in ctx: return None
sig = UOp.new_buffer(q.arg[0], 0x100, dtypes.uint8).rtag((q.arg[1], "timeline_signal"))
val = UOp.new_buffer(q.arg[0], 1, dtypes.uint64).rtag((q.arg[1], "timeline_value")).index(UOp.const(dtypes.int, 0))
newq = q.replace(src=q.src + (sig.store(val + q.tag),))
return call.replace(src=(call.src[0].substitute({q: newq}),) + call.src[1:])
pm_add_stores = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call", allow_any_len=True), add_stores)])
# *****************
# 2.3. barriers / signals / timeline inc
@@ -257,12 +340,6 @@ def add_barriers(call:UOp, q:UOp) -> UOp:
return call.replace(src=(q.replace(src=(UOp(Ops.BARRIER, dtypes.void), *q.src)),) + call.src[1:])
pm_add_barriers = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.LINEAR, name="q"),), name="call", allow_any_len=True), add_barriers)])
def add_signals(call:UOp, q:UOp) -> UOp:
sig = UOp.new_buffer(q.arg[0], 0x100, dtypes.uint8).rtag("timeline_signal")
tl = UOp.new_buffer(q.arg[0], 1, dtypes.uint64).rtag("timeline_value").index(UOp.const(dtypes.int, 0))
return call.replace(src=(q.replace(src=(sig.wait(tl-1), *q.src, sig.store(tl)), arg=q.arg),) + call.src[1:])
pm_add_signals = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.LINEAR, name="q"),), name="call", allow_any_len=True), add_signals)])
# *****************
# 3.1. encode cmdbufs
@@ -273,10 +350,16 @@ def get_pm_lower(name:str) -> PatternMatcher|None:
return importlib.import_module(f'extra.hcq2.ops_{name.lower()}2').pm_lower
except ImportError: return None
def encode_cmdbuf(call:UOp, q:UOp) -> UOp|None:
if (pm:=get_pm_lower(to_tuple(q.arg[0])[0].split(":")[0])) is None or (encoded:=pm.rewrite(q)) is None: return None
def encode_cmdbuf(call:UOp) -> UOp|None:
if (q:=unwrap_after(call.src[0])).op is not Ops.LINEAR: return None
if (pm:=get_pm_lower(to_tuple(q.arg[0])[0].split(":")[0])) is None or (encoded:=pm.rewrite(call.src[0])) is None: return None
return call.replace(src=(encoded,) + call.src[1:])
pm_encode_cmdbufs = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.LINEAR, name="q"),), name="call", allow_any_len=True), encode_cmdbuf)])
pm_encode_cmdbufs = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call", allow_any_len=True), encode_cmdbuf)])
pm_compose_submit = PatternMatcher([
(UPat(Ops.CALL, tag="hcq", src=(UPat(Ops.CUSTOM_FUNCTION, arg="submit", name="sub"),), allow_any_len=True, name="call"),
lambda call, sub: call.replace(src=(UOp.group(*sub.src),) + call.src[1:])),
])
# *****************
# 3.2. add timeline inc
@@ -303,7 +386,7 @@ pm_lift_patches_to_cmdbuf = PatternMatcher([
def bufferize_buf(buf:UOp) -> UOp|None:
if buf.tag is None: return None
uops = tuple(UOp.from_buffer((dv:=Device[dev]).pm_bufferize.rewrite(buf, ctx=dv), dev) for dev in to_tuple(buf.src[1].arg))
return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, uops)
return make_mstack(uops)
pm_bufferize = PatternMatcher([(UPat(Ops.BUFFER, name="buf"), bufferize_buf)])
# *****************
@@ -377,11 +460,13 @@ def hcq_schedule(linear:UOp) -> UOp:
linear = graph_rewrite(linear, pm_prep_runtime, name="prepare runtime")
linear = graph_rewrite(linear, pm_lower_ops, name="lower ops into hcq ir")
# linear = graph_rewrite(linear, pm_split_into_queues, name="split into queues")
linear = graph_rewrite(linear, pm_insert_deps, ctx=(deps_ctx:=DepsCtx()), walk=True, name="insert deps")
linear = add_finalizer(deps_ctx, linear)
linear = graph_rewrite(linear, pm_add_loads, ctx=(waited:=set()), walk=True, name="add loads")
linear = graph_rewrite(linear, pm_add_stores, ctx=waited, walk=True, name="add stores")
linear = graph_rewrite(linear, pm_add_barriers, walk=True, name="add barriers")
linear = graph_rewrite(linear, pm_add_signals, walk=True, name="add signals")
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs")
linear = graph_rewrite(linear, pm_add_timeline_inc, walk=True, name="add timeline inc")
linear = graph_rewrite(linear, pm_compose_submit, walk=True, name="compose submit")
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)
# realize starts from here

View File

@@ -3,7 +3,7 @@ from typing import cast
import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, contextlib, sys, weakref, itertools, collections, atexit
assert sys.platform != 'win32'
from dataclasses import dataclass
from extra.hcq2.hcq2 import HCQ2Compiled, HCQAllocator, HCQ2Buffer, HCQEncoder
from extra.hcq2.hcq2 import HCQ2Compiled, HCQAllocator, HCQ2Buffer, HCQEncoder, to_tuple
from tinygrad.uop.ops import sint, UOp
from tinygrad.device import Compiled, BufferSpec, Buffer, Device
from tinygrad.dtype import dtypes
@@ -219,7 +219,7 @@ def amd_submit_sdma(cmdbuf, devs):
# the sdma queue's ring and its host-side ring/write/put pointers
q = Device['AMD'].sdma_queue(0)
ring, wptr, doorbell, put_ptr = (UOp.new_buffer(devs, b.size, b.dtype).rtag(("SDMA:0", name))
ring, wptr, doorbell, put_ptr = (UOp.new_buffer(devs, b.size, b.dtype).rtag(("COPY:0", name))
for name, b in (("ring", q.ring), ("write_ptr", q.write_ptr), ("doorbell", q.doorbell), ("put_value", q.put_value)))
# sdma needs the cmdbuf contiguous: if it won't fit before the ring end, restart at 0 and zero the tail
@@ -374,12 +374,14 @@ class PCIIface(PCIIfaceBase):
def _mock(iface, name=None): return type(name or f"MOCK{iface.__name__}", (iface,), {})
def encode_queue(q:UOp) -> UOp|None:
if not (isinstance(q.arg, tuple) and len(q.arg) == 2 and q.arg[1] in ("COMPUTE", "COPY")): return None
devs = (q.arg[0],) if isinstance(q.arg[0], str) else q.arg[0] # TODO: make this prettier
return amd_submit_pm4(amd_lower_pm4(q, devs), devs) if q.arg[1] == "COMPUTE" else amd_submit_sdma(amd_lower_sdma(q, devs), devs)
q, post = (q.src[0], q.src[1:]) if q.op is Ops.AFTER else (q, ())
if not (isinstance(q.arg, tuple) and len(q.arg) == 2 and isinstance(q.arg[1], str) and q.arg[1].startswith(("COMPUTE", "COPY"))): return None
devs = to_tuple(q.arg[0])
ring = amd_submit_pm4(amd_lower_pm4(q, devs), devs) if q.arg[1].startswith("COMPUTE") else amd_submit_sdma(amd_lower_sdma(q, devs), devs)
return UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(ring, *post), arg="submit")
pm_lower = PatternMatcher([
(UPat(Ops.LINEAR, name="q"), encode_queue),
(UPat({Ops.LINEAR, Ops.AFTER}, name="q"), encode_queue),
])
class AMDDevice(HCQ2Compiled):
@@ -475,9 +477,12 @@ class AMDDevice(HCQ2Compiled):
wptr=getattr(hsa.amd_queue_t, 'write_dispatch_id').offset, eop_buffer=eop_buffer, cwsr_buffer=cwsr_buffer,
ctx_save_restore_size=ctx_save_restore_size, ctl_stack_size=ctl_stack_size, idx=idx))
qname = f"{'SDMA' if queue_type == kfd.KFD_IOC_QUEUE_TYPE_SDMA else 'COMPUTE'}:{idx}"
qname = f"{'COPY' if queue_type == kfd.KFD_IOC_QUEUE_TYPE_SDMA else 'COMPUTE'}:{idx}"
self.pm_bufferize = PatternMatcher([
(UPat(Ops.BUFFER, tag={(qname, name)}), lambda ctx, b=getattr(queue, name): b) for name in ["ring", "write_ptr", "doorbell", "put_value"]
] + [
(UPat(Ops.BUFFER, tag={(qname, "timeline_signal")}), lambda ctx, q=qname: ctx.queue_timeline_signal(q)),
(UPat(Ops.BUFFER, tag={(qname, "timeline_value")}), lambda ctx, q=qname: ctx.queue_timeline_value(q)),
]) + self.pm_bufferize
return queue

View File

@@ -87,6 +87,34 @@ def _check_no_non_tensor_return(ret):
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
class DepsTracker:
def __init__(self):
# tracks (offset, end, dep) ranges per base buffer id to handle suballocated buffers correctly.
self.w_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
self.r_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
@staticmethod
def _buf_key(buf:Buffer) -> int: return id(buf.base)
def access_resources(self, bufs:list[Buffer], write:list[int], new_dependency:Any):
wait_nodes = []
for i,buf in enumerate(bufs):
key, s, e = self._buf_key(buf), buf.offset, buf.offset + buf.nbytes
wait_nodes += [dep for st,en,dep in self.w_dependency_map[key] if st < e and s < en]
if i in write: wait_nodes += [dep for st,en,dep in self.r_dependency_map[key] if st < e and s < en]
for i,buf in enumerate(bufs):
key, s, e = self._buf_key(buf), buf.offset, buf.offset + buf.nbytes
if i in write:
for dmap in [self.w_dependency_map, self.r_dependency_map]:
kept = []
for st,en,dep in dmap[key]:
if st < min(s, en): kept.append((st, min(s, en), dep))
if max(e, st) < en: kept.append((max(e, st), en, dep))
dmap[key] = kept
self.w_dependency_map[key].append((s, e, new_dependency))
else: self.r_dependency_map[key].append((s, e, new_dependency))
return list({id(x):x for x in wait_nodes}.values())
class GraphRunner:
def __init__(self, linear:UOp, input_uops:tuple[UOp, ...]=()):
self.linear = linear.src[0]
@@ -123,9 +151,8 @@ class GraphRunner:
estimates = sum((estimate_uop(call) for call in self.linear.src), Estimates())
# used in MultiGraphRunner. tracks (offset, end, dep) ranges per base buffer id to handle suballocated buffers correctly.
self.w_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
self.r_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
# used in MultiGraphRunner
self.deps = DepsTracker()
self.device, self.estimates = self.calls[0][2][0].device.split(":")[0], estimates.simplify()
@@ -142,23 +169,7 @@ class GraphRunner:
yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
def _access_resources(self, bufs:list[Buffer], write:list[int], new_dependency:Any):
wait_nodes = []
for i,buf in enumerate(bufs):
key, s, e = id(buf.base._buf), buf.offset, buf.offset + buf.nbytes
wait_nodes += [dep for st,en,dep in self.w_dependency_map[key] if st < e and s < en]
if i in write: wait_nodes += [dep for st,en,dep in self.r_dependency_map[key] if st < e and s < en]
for i,buf in enumerate(bufs):
key, s, e = id(buf.base._buf), buf.offset, buf.offset + buf.nbytes
if i in write:
for dmap in [self.w_dependency_map, self.r_dependency_map]:
kept = []
for st,en,dep in dmap[key]:
if st < min(s, en): kept.append((st, min(s, en), dep))
if max(e, st) < en: kept.append((max(e, st), en, dep))
dmap[key] = kept
self.w_dependency_map[key].append((s, e, new_dependency))
else: self.r_dependency_map[key].append((s, e, new_dependency))
return list({id(x):x for x in wait_nodes}.values())
return self.deps.access_resources(bufs, write, new_dependency)
@staticmethod
def _all_devs(batch_devs:list[Compiled], new_call:UOp) -> list[Compiled]: