mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import cast, Callable, TypeVar, Generic, Any, TYPE_CHECKING
|
||||
import struct, functools, time, collections, importlib, itertools
|
||||
import struct, functools, time, collections, importlib, itertools, weakref
|
||||
from dataclasses import replace
|
||||
if TYPE_CHECKING: from tinygrad.engine.realize import ExecContext
|
||||
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, mv_address, round_up, DEBUG, dedup, pluralize
|
||||
@@ -13,6 +13,7 @@ from tinygrad.runtime.support.memory import BumpAllocator
|
||||
from tinygrad.runtime.support.hcq import MMIOInterface
|
||||
from tinygrad.renderer import Renderer, Estimates
|
||||
from tinygrad.engine.realize import to_program, track_stats, get_call_arg_uops, resolve_params, pm_flatten_linear
|
||||
from tinygrad.engine.jit import DepsTracker
|
||||
|
||||
HCQDeviceType = TypeVar('HCQDeviceType', bound='HCQ2Compiled')
|
||||
|
||||
@@ -26,7 +27,8 @@ class HCQ2Compiled(Compiled):
|
||||
self.pm_bufferize = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, tag="timeline_signal"), lambda ctx: ctx.timeline_signal),
|
||||
(UPat(Ops.BUFFER, tag="timeline_value"), lambda ctx: ctx.timeline_value),
|
||||
(UPat(Ops.BUFFER, name="b"), lambda ctx, b: Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=True, uncached=True, cpu_access=True))),
|
||||
(UPat(Ops.BUFFER, name="b"), lambda ctx, b:
|
||||
Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=True, uncached=True, cpu_access=True, nolru=True))), # TODO: remove nolru
|
||||
])
|
||||
|
||||
super().__init__(device, allocator, compilers, lambda *a, **kw: None, None, arch=arch)
|
||||
@@ -45,6 +47,16 @@ class HCQ2Compiled(Compiled):
|
||||
buf.as_memoryview(force_zero_copy=True).cast('Q')[0] = 1
|
||||
return buf
|
||||
|
||||
@functools.cache
|
||||
def queue_timeline_signal(self, queue:str) -> Buffer:
|
||||
return Buffer(self.device, 0x100, dtypes.uint8, options=BufferSpec(host=True, uncached=True, cpu_access=True), preallocate=True)
|
||||
|
||||
@functools.cache
|
||||
def queue_timeline_value(self, queue:str) -> Buffer:
|
||||
buf = Buffer("CPU", 1, dtypes.uint64, preallocate=True)
|
||||
buf.as_memoryview(force_zero_copy=True).cast('Q')[0] = 1
|
||||
return buf
|
||||
|
||||
def synchronize(self, timeout:int|None=None):
|
||||
if not hasattr(self, 'iface'): return
|
||||
sig = self.timeline_signal._buf.cpu_view().mv.cast('Q')
|
||||
@@ -98,6 +110,7 @@ class HCQAllocator(LRUAllocator[HCQDeviceType], Generic[HCQDeviceType]):
|
||||
|
||||
@suppress_finalizing
|
||||
def _free(self, buf:HCQ2Buffer, options:BufferSpec|None=None):
|
||||
self.dev.synchronize()
|
||||
if options is not None and options.external_ptr is not None: return
|
||||
if hasattr(self, '_do_free'): self._do_free(buf, options)
|
||||
|
||||
@@ -132,6 +145,8 @@ def unwrap_after(uop):
|
||||
while uop.op is Ops.AFTER: uop = uop.src[0]
|
||||
return uop
|
||||
|
||||
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
|
||||
|
||||
class HCQEncoder:
|
||||
def __init__(self): self.blob, self.patches = b'', []
|
||||
|
||||
@@ -220,14 +235,14 @@ pm_prep_runtime = PatternMatcher([
|
||||
# 2.1. lowering to hcq ir
|
||||
|
||||
def lower_program(call:UOp, prg:UOp) -> UOp:
|
||||
q = UOp(Ops.LINEAR, dtypes.void, (prg,), arg=(call.src[1].device, "COMPUTE"))
|
||||
q = UOp(Ops.LINEAR, dtypes.void, (prg,), arg=(call.src[1].device, "COMPUTE:0"))
|
||||
return call.replace(src=(q,) + call.src[1:]).rtag('hcq')
|
||||
|
||||
def lower_copy(call:UOp, copy:UOp) -> UOp|None:
|
||||
dst, src = call.src[1], call.src[2]
|
||||
if (hcq_dev:=next((b.device for b in (dst, src) if b.device.split(":")[0] in HCQ_DEVS), None)) is None: return None
|
||||
|
||||
q = UOp(Ops.LINEAR, dtypes.void, (UOp(Ops.COPY, dtypes.void, src=(dst, src), arg=src.buffer.nbytes),), arg=(hcq_dev, "COPY"))
|
||||
q = UOp(Ops.LINEAR, dtypes.void, (UOp(Ops.COPY, dtypes.void, src=(dst, src), arg=src.buffer.nbytes),), arg=(hcq_dev, "COPY:0"))
|
||||
return call.replace(src=(q,) + call.src[1:]).rtag('hcq')
|
||||
|
||||
pm_lower_ops = PatternMatcher([
|
||||
@@ -236,19 +251,87 @@ pm_lower_ops = PatternMatcher([
|
||||
])
|
||||
|
||||
# *****************
|
||||
# 2.2. queue split
|
||||
# 2.2. deps tracking
|
||||
# device.timeline_signal/value are the per-device schedule epoch. Before a schedule queue accesses memory owned by device N for the first time,
|
||||
# it waits for device[N].timeline_signal >= device[N].timeline_value - 1. This orders the schedule after all prior schedules that touched device N.
|
||||
#
|
||||
# queue.timeline_signal/value are per-queue progress counters used only inside a schedule.
|
||||
# Only the owner queue signals its queue.timeline_signal. Values are monotonic.
|
||||
#
|
||||
# At schedule end, one finalizer queue per touched device[N] waits for every active queue on device[N] to reach its schedule-local
|
||||
# final queue.timeline value, then signals device[N].timeline_signal with the schedule's reserved device epoch. After that, buffers/transients
|
||||
# for device N from this schedule are safe for the next schedule
|
||||
#
|
||||
# C programs reserve and bump timeline values, then patch command buffers with the concrete wait/signal values.
|
||||
|
||||
# def split_into_queues(linear:UOp) -> UOp:
|
||||
# out = []
|
||||
# for k, grp in itertools.groupby(linear.src, lambda c: c.src[0].arg if c.op is Ops.CALL and c.src[0].op is Ops.LINEAR else None):
|
||||
# if k is None: out.extend(grp)
|
||||
# else:
|
||||
# calls = list(grp)
|
||||
# items = tuple(x for c in calls for x in c.src[0].src)
|
||||
# args = tuple(a for c in calls for a in c.src[1:])
|
||||
# out.append(calls[0].replace(src=(UOp(Ops.LINEAR, dtypes.void, items, arg=k),) + args))
|
||||
# return linear.replace(src=tuple(out))
|
||||
# pm_split_into_queues = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), split_into_queues)])
|
||||
@dataclass
|
||||
class DepsCtx:
|
||||
deps:DepsTracker = field(default_factory=DepsTracker)
|
||||
evid:itertools.count = field(default_factory=lambda: itertools.count(0))
|
||||
last_per_queue:weakref.WeakValueDictionary[tuple[Any, str], UOp] = field(default_factory=weakref.WeakValueDictionary)
|
||||
|
||||
def get_writes_ids(call:UOp) -> tuple[int, ...]:
|
||||
ast, writes = call.src[0].src[0], set()
|
||||
for ast in call.src[0].src:
|
||||
if ast.op is Ops.PROGRAM: writes.update(ast.arg[1].outs)
|
||||
elif ast.op in (Ops.COPY, Ops.SLICE, Ops.CUSTOM_FUNCTION): writes.add(0)
|
||||
return tuple(writes)
|
||||
|
||||
def insert_deps(ctx:DepsCtx, call:UOp) -> UOp|None:
|
||||
q, refs, write = call.src[0].rtag(next(ctx.evid)), [b.buffer for b in get_call_arg_uops(call)], get_writes_ids(call)
|
||||
if q.arg not in ctx.last_per_queue:
|
||||
sig = UOp.new_buffer(q.arg[0], 0x100, dtypes.uint8).rtag("timeline_signal")
|
||||
tl = UOp.new_buffer(q.arg[0], 1, dtypes.uint64).rtag("timeline_value").index(UOp.const(dtypes.int, 0))
|
||||
q = q.replace(src=(sig.wait(tl - 1), *q.src))
|
||||
ctx.last_per_queue[q.arg] = q
|
||||
|
||||
deps = []
|
||||
for lane in range(len(refs[0].bufs) if isinstance(refs[0], MultiBuffer) else 1):
|
||||
deps += ctx.deps.access_resources([b.bufs[lane] if isinstance(b, MultiBuffer) else b for b in refs], write, q)
|
||||
return call.replace(src=(q.after(*dps).rtag("deps") if (dps:=dedup(deps)) else q,) + call.src[1:])
|
||||
pm_insert_deps = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call", allow_any_len=True), insert_deps)])
|
||||
|
||||
def make_finalizer(devs:tuple[str, ...], queues:list[UOp], nbump:int) -> UOp:
|
||||
sig = UOp.new_buffer(devs, 0x100, dtypes.uint8).rtag("timeline_signal")
|
||||
tl = UOp.new_buffer(devs, 1, dtypes.uint64).rtag("timeline_value")
|
||||
q = UOp(Ops.LINEAR, dtypes.void, (sig.store(tl.index(UOp.const(dtypes.int, 0))),), arg=(devs, "COMPUTE:0"), tag="finalizer")
|
||||
|
||||
def bump(b, by): return b.index(UOp.const(dtypes.int, 0), dtype=b.dtype.ptr()).store(b.index(UOp.const(dtypes.int, 0)) + by)
|
||||
bumps = (bump(tl, 1),) + tuple(bump(UOp.new_buffer(devs, 1, dtypes.uint64).rtag((ty, "timeline_value")), nbump) for ty in dedup([q.arg[1] for q in queues]))
|
||||
return UOp(Ops.CALL, dtypes.void, (q.after(*bumps).after(*queues).rtag("deps"),), tag="hcq")
|
||||
|
||||
def add_finalizer(ctx:DepsCtx, linear:UOp) -> UOp:
|
||||
fams:dict[str, list[UOp]] = collections.defaultdict(list)
|
||||
for q in ctx.last_per_queue.values(): fams[to_tuple(q.arg[0])[0].split(":")[0]].append(q)
|
||||
|
||||
nbump = next(ctx.evid)
|
||||
finalizers = []
|
||||
for queues in fams.values():
|
||||
devs = tuple(sorted(dedup(d for q in queues for d in to_tuple(q.arg[0]))))
|
||||
finalizers.append(make_finalizer(devs, queues, nbump))
|
||||
return linear.replace(src=linear.src + tuple(finalizers))
|
||||
|
||||
def add_loads(ctx:set[int], call:UOp, after:UOp) -> UOp:
|
||||
q = unwrap_after(after.src[0])
|
||||
cur_devs = to_tuple(q.arg[0])
|
||||
|
||||
waits = []
|
||||
for dq in [unwrap_after(dep) for dep in after.src[1:]]:
|
||||
ctx.add(dq.tag)
|
||||
dq_devs = to_tuple(dq.arg[0])
|
||||
sigs = [UOp.new_buffer(d, 0x100, dtypes.uint8).rtag((dq.arg[1], "timeline_signal") if d in dq_devs else "max_sentinel_signal") for d in cur_devs]
|
||||
orig_val = UOp.new_buffer(cur_devs, 1, dtypes.uint64).rtag((dq.arg[1], "timeline_value")).index(UOp.const(dtypes.int, 0))
|
||||
waits.append(make_mstack(sigs).wait(orig_val + dq.tag))
|
||||
return call.replace(src=(after.src[0].substitute({q: q.replace(src=(*waits, *q.src))}),) + call.src[1:])
|
||||
pm_add_loads = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.AFTER, tag="deps", name="after"),), name="call", allow_any_len=True), add_loads)])
|
||||
|
||||
def add_stores(ctx:set[int], call:UOp) -> UOp|None:
|
||||
if (q:=unwrap_after(call.src[0])).tag not in ctx: return None
|
||||
sig = UOp.new_buffer(q.arg[0], 0x100, dtypes.uint8).rtag((q.arg[1], "timeline_signal"))
|
||||
val = UOp.new_buffer(q.arg[0], 1, dtypes.uint64).rtag((q.arg[1], "timeline_value")).index(UOp.const(dtypes.int, 0))
|
||||
newq = q.replace(src=q.src + (sig.store(val + q.tag),))
|
||||
return call.replace(src=(call.src[0].substitute({q: newq}),) + call.src[1:])
|
||||
pm_add_stores = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call", allow_any_len=True), add_stores)])
|
||||
|
||||
# *****************
|
||||
# 2.3. barriers / signals / timeline inc
|
||||
@@ -257,12 +340,6 @@ def add_barriers(call:UOp, q:UOp) -> UOp:
|
||||
return call.replace(src=(q.replace(src=(UOp(Ops.BARRIER, dtypes.void), *q.src)),) + call.src[1:])
|
||||
pm_add_barriers = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.LINEAR, name="q"),), name="call", allow_any_len=True), add_barriers)])
|
||||
|
||||
def add_signals(call:UOp, q:UOp) -> UOp:
|
||||
sig = UOp.new_buffer(q.arg[0], 0x100, dtypes.uint8).rtag("timeline_signal")
|
||||
tl = UOp.new_buffer(q.arg[0], 1, dtypes.uint64).rtag("timeline_value").index(UOp.const(dtypes.int, 0))
|
||||
return call.replace(src=(q.replace(src=(sig.wait(tl-1), *q.src, sig.store(tl)), arg=q.arg),) + call.src[1:])
|
||||
pm_add_signals = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.LINEAR, name="q"),), name="call", allow_any_len=True), add_signals)])
|
||||
|
||||
# *****************
|
||||
# 3.1. encode cmdbufs
|
||||
|
||||
@@ -273,10 +350,16 @@ def get_pm_lower(name:str) -> PatternMatcher|None:
|
||||
return importlib.import_module(f'extra.hcq2.ops_{name.lower()}2').pm_lower
|
||||
except ImportError: return None
|
||||
|
||||
def encode_cmdbuf(call:UOp, q:UOp) -> UOp|None:
|
||||
if (pm:=get_pm_lower(to_tuple(q.arg[0])[0].split(":")[0])) is None or (encoded:=pm.rewrite(q)) is None: return None
|
||||
def encode_cmdbuf(call:UOp) -> UOp|None:
|
||||
if (q:=unwrap_after(call.src[0])).op is not Ops.LINEAR: return None
|
||||
if (pm:=get_pm_lower(to_tuple(q.arg[0])[0].split(":")[0])) is None or (encoded:=pm.rewrite(call.src[0])) is None: return None
|
||||
return call.replace(src=(encoded,) + call.src[1:])
|
||||
pm_encode_cmdbufs = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.LINEAR, name="q"),), name="call", allow_any_len=True), encode_cmdbuf)])
|
||||
pm_encode_cmdbufs = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call", allow_any_len=True), encode_cmdbuf)])
|
||||
|
||||
pm_compose_submit = PatternMatcher([
|
||||
(UPat(Ops.CALL, tag="hcq", src=(UPat(Ops.CUSTOM_FUNCTION, arg="submit", name="sub"),), allow_any_len=True, name="call"),
|
||||
lambda call, sub: call.replace(src=(UOp.group(*sub.src),) + call.src[1:])),
|
||||
])
|
||||
|
||||
# *****************
|
||||
# 3.2. add timeline inc
|
||||
@@ -303,7 +386,7 @@ pm_lift_patches_to_cmdbuf = PatternMatcher([
|
||||
def bufferize_buf(buf:UOp) -> UOp|None:
|
||||
if buf.tag is None: return None
|
||||
uops = tuple(UOp.from_buffer((dv:=Device[dev]).pm_bufferize.rewrite(buf, ctx=dv), dev) for dev in to_tuple(buf.src[1].arg))
|
||||
return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, uops)
|
||||
return make_mstack(uops)
|
||||
pm_bufferize = PatternMatcher([(UPat(Ops.BUFFER, name="buf"), bufferize_buf)])
|
||||
|
||||
# *****************
|
||||
@@ -377,11 +460,13 @@ def hcq_schedule(linear:UOp) -> UOp:
|
||||
linear = graph_rewrite(linear, pm_prep_runtime, name="prepare runtime")
|
||||
|
||||
linear = graph_rewrite(linear, pm_lower_ops, name="lower ops into hcq ir")
|
||||
# linear = graph_rewrite(linear, pm_split_into_queues, name="split into queues")
|
||||
linear = graph_rewrite(linear, pm_insert_deps, ctx=(deps_ctx:=DepsCtx()), walk=True, name="insert deps")
|
||||
linear = add_finalizer(deps_ctx, linear)
|
||||
linear = graph_rewrite(linear, pm_add_loads, ctx=(waited:=set()), walk=True, name="add loads")
|
||||
linear = graph_rewrite(linear, pm_add_stores, ctx=waited, walk=True, name="add stores")
|
||||
linear = graph_rewrite(linear, pm_add_barriers, walk=True, name="add barriers")
|
||||
linear = graph_rewrite(linear, pm_add_signals, walk=True, name="add signals")
|
||||
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs")
|
||||
linear = graph_rewrite(linear, pm_add_timeline_inc, walk=True, name="add timeline inc")
|
||||
linear = graph_rewrite(linear, pm_compose_submit, walk=True, name="compose submit")
|
||||
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)
|
||||
|
||||
# realize starts from here
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import cast
|
||||
import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, contextlib, sys, weakref, itertools, collections, atexit
|
||||
assert sys.platform != 'win32'
|
||||
from dataclasses import dataclass
|
||||
from extra.hcq2.hcq2 import HCQ2Compiled, HCQAllocator, HCQ2Buffer, HCQEncoder
|
||||
from extra.hcq2.hcq2 import HCQ2Compiled, HCQAllocator, HCQ2Buffer, HCQEncoder, to_tuple
|
||||
from tinygrad.uop.ops import sint, UOp
|
||||
from tinygrad.device import Compiled, BufferSpec, Buffer, Device
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -219,7 +219,7 @@ def amd_submit_sdma(cmdbuf, devs):
|
||||
|
||||
# the sdma queue's ring and its host-side ring/write/put pointers
|
||||
q = Device['AMD'].sdma_queue(0)
|
||||
ring, wptr, doorbell, put_ptr = (UOp.new_buffer(devs, b.size, b.dtype).rtag(("SDMA:0", name))
|
||||
ring, wptr, doorbell, put_ptr = (UOp.new_buffer(devs, b.size, b.dtype).rtag(("COPY:0", name))
|
||||
for name, b in (("ring", q.ring), ("write_ptr", q.write_ptr), ("doorbell", q.doorbell), ("put_value", q.put_value)))
|
||||
|
||||
# sdma needs the cmdbuf contiguous: if it won't fit before the ring end, restart at 0 and zero the tail
|
||||
@@ -374,12 +374,14 @@ class PCIIface(PCIIfaceBase):
|
||||
def _mock(iface, name=None): return type(name or f"MOCK{iface.__name__}", (iface,), {})
|
||||
|
||||
def encode_queue(q:UOp) -> UOp|None:
|
||||
if not (isinstance(q.arg, tuple) and len(q.arg) == 2 and q.arg[1] in ("COMPUTE", "COPY")): return None
|
||||
devs = (q.arg[0],) if isinstance(q.arg[0], str) else q.arg[0] # TODO: make this prettier
|
||||
return amd_submit_pm4(amd_lower_pm4(q, devs), devs) if q.arg[1] == "COMPUTE" else amd_submit_sdma(amd_lower_sdma(q, devs), devs)
|
||||
q, post = (q.src[0], q.src[1:]) if q.op is Ops.AFTER else (q, ())
|
||||
if not (isinstance(q.arg, tuple) and len(q.arg) == 2 and isinstance(q.arg[1], str) and q.arg[1].startswith(("COMPUTE", "COPY"))): return None
|
||||
devs = to_tuple(q.arg[0])
|
||||
ring = amd_submit_pm4(amd_lower_pm4(q, devs), devs) if q.arg[1].startswith("COMPUTE") else amd_submit_sdma(amd_lower_sdma(q, devs), devs)
|
||||
return UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(ring, *post), arg="submit")
|
||||
|
||||
pm_lower = PatternMatcher([
|
||||
(UPat(Ops.LINEAR, name="q"), encode_queue),
|
||||
(UPat({Ops.LINEAR, Ops.AFTER}, name="q"), encode_queue),
|
||||
])
|
||||
|
||||
class AMDDevice(HCQ2Compiled):
|
||||
@@ -475,9 +477,12 @@ class AMDDevice(HCQ2Compiled):
|
||||
wptr=getattr(hsa.amd_queue_t, 'write_dispatch_id').offset, eop_buffer=eop_buffer, cwsr_buffer=cwsr_buffer,
|
||||
ctx_save_restore_size=ctx_save_restore_size, ctl_stack_size=ctl_stack_size, idx=idx))
|
||||
|
||||
qname = f"{'SDMA' if queue_type == kfd.KFD_IOC_QUEUE_TYPE_SDMA else 'COMPUTE'}:{idx}"
|
||||
qname = f"{'COPY' if queue_type == kfd.KFD_IOC_QUEUE_TYPE_SDMA else 'COMPUTE'}:{idx}"
|
||||
self.pm_bufferize = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, tag={(qname, name)}), lambda ctx, b=getattr(queue, name): b) for name in ["ring", "write_ptr", "doorbell", "put_value"]
|
||||
] + [
|
||||
(UPat(Ops.BUFFER, tag={(qname, "timeline_signal")}), lambda ctx, q=qname: ctx.queue_timeline_signal(q)),
|
||||
(UPat(Ops.BUFFER, tag={(qname, "timeline_value")}), lambda ctx, q=qname: ctx.queue_timeline_value(q)),
|
||||
]) + self.pm_bufferize
|
||||
|
||||
return queue
|
||||
|
||||
@@ -87,6 +87,34 @@ def _check_no_non_tensor_return(ret):
|
||||
|
||||
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
|
||||
|
||||
class DepsTracker:
|
||||
def __init__(self):
|
||||
# tracks (offset, end, dep) ranges per base buffer id to handle suballocated buffers correctly.
|
||||
self.w_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
|
||||
self.r_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
|
||||
|
||||
@staticmethod
|
||||
def _buf_key(buf:Buffer) -> int: return id(buf.base)
|
||||
|
||||
def access_resources(self, bufs:list[Buffer], write:list[int], new_dependency:Any):
|
||||
wait_nodes = []
|
||||
for i,buf in enumerate(bufs):
|
||||
key, s, e = self._buf_key(buf), buf.offset, buf.offset + buf.nbytes
|
||||
wait_nodes += [dep for st,en,dep in self.w_dependency_map[key] if st < e and s < en]
|
||||
if i in write: wait_nodes += [dep for st,en,dep in self.r_dependency_map[key] if st < e and s < en]
|
||||
for i,buf in enumerate(bufs):
|
||||
key, s, e = self._buf_key(buf), buf.offset, buf.offset + buf.nbytes
|
||||
if i in write:
|
||||
for dmap in [self.w_dependency_map, self.r_dependency_map]:
|
||||
kept = []
|
||||
for st,en,dep in dmap[key]:
|
||||
if st < min(s, en): kept.append((st, min(s, en), dep))
|
||||
if max(e, st) < en: kept.append((max(e, st), en, dep))
|
||||
dmap[key] = kept
|
||||
self.w_dependency_map[key].append((s, e, new_dependency))
|
||||
else: self.r_dependency_map[key].append((s, e, new_dependency))
|
||||
return list({id(x):x for x in wait_nodes}.values())
|
||||
|
||||
class GraphRunner:
|
||||
def __init__(self, linear:UOp, input_uops:tuple[UOp, ...]=()):
|
||||
self.linear = linear.src[0]
|
||||
@@ -123,9 +151,8 @@ class GraphRunner:
|
||||
|
||||
estimates = sum((estimate_uop(call) for call in self.linear.src), Estimates())
|
||||
|
||||
# used in MultiGraphRunner. tracks (offset, end, dep) ranges per base buffer id to handle suballocated buffers correctly.
|
||||
self.w_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
|
||||
self.r_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
|
||||
# used in MultiGraphRunner
|
||||
self.deps = DepsTracker()
|
||||
|
||||
self.device, self.estimates = self.calls[0][2][0].device.split(":")[0], estimates.simplify()
|
||||
|
||||
@@ -142,23 +169,7 @@ class GraphRunner:
|
||||
yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
|
||||
|
||||
def _access_resources(self, bufs:list[Buffer], write:list[int], new_dependency:Any):
|
||||
wait_nodes = []
|
||||
for i,buf in enumerate(bufs):
|
||||
key, s, e = id(buf.base._buf), buf.offset, buf.offset + buf.nbytes
|
||||
wait_nodes += [dep for st,en,dep in self.w_dependency_map[key] if st < e and s < en]
|
||||
if i in write: wait_nodes += [dep for st,en,dep in self.r_dependency_map[key] if st < e and s < en]
|
||||
for i,buf in enumerate(bufs):
|
||||
key, s, e = id(buf.base._buf), buf.offset, buf.offset + buf.nbytes
|
||||
if i in write:
|
||||
for dmap in [self.w_dependency_map, self.r_dependency_map]:
|
||||
kept = []
|
||||
for st,en,dep in dmap[key]:
|
||||
if st < min(s, en): kept.append((st, min(s, en), dep))
|
||||
if max(e, st) < en: kept.append((max(e, st), en, dep))
|
||||
dmap[key] = kept
|
||||
self.w_dependency_map[key].append((s, e, new_dependency))
|
||||
else: self.r_dependency_map[key].append((s, e, new_dependency))
|
||||
return list({id(x):x for x in wait_nodes}.values())
|
||||
return self.deps.access_resources(bufs, write, new_dependency)
|
||||
|
||||
@staticmethod
|
||||
def _all_devs(batch_devs:list[Compiled], new_call:UOp) -> list[Compiled]:
|
||||
|
||||
Reference in New Issue
Block a user