mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-14 00:45:16 +08:00
396 lines
20 KiB
Python
396 lines
20 KiB
Python
from __future__ import annotations
|
|
from typing import cast, Callable, TypeVar, Generic, Any, TYPE_CHECKING
|
|
import struct, functools, time, collections, importlib, itertools
|
|
from dataclasses import replace
|
|
if TYPE_CHECKING: from tinygrad.engine.realize import ExecContext
|
|
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, mv_address, round_up, DEBUG, dedup, pluralize
|
|
from tinygrad.device import Device, Buffer, BufferSpec, Compiled, LRUAllocator, MultiBuffer
|
|
from tinygrad.uop.ops import Ops, sint, UOp, UPat, PatternMatcher, KernelInfo, graph_rewrite, track_rewrites, GroupOp
|
|
from tinygrad.uop.symbolic import symbolic_simple, symbolic
|
|
from tinygrad.dtype import dtypes, DType
|
|
from dataclasses import dataclass, field
|
|
from tinygrad.runtime.support.memory import BumpAllocator
|
|
from tinygrad.runtime.support.hcq import MMIOInterface
|
|
from tinygrad.renderer import Renderer, Estimates
|
|
from tinygrad.engine.realize import to_program, track_stats, get_call_arg_uops, resolve_params, pm_flatten_linear
|
|
|
|
HCQDeviceType = TypeVar('HCQDeviceType', bound='HCQ2Compiled')
|
|
|
|
class HCQ2Compiled(Compiled):
|
|
timestamp_divider: float = 1000.0 # GPU timestamp counter ticks per microsecond; override per device
|
|
|
|
def __init__(self, device:str, allocator:'HCQAllocator', compilers:list[type[Renderer]], runtime, can_recover:bool=False, arch=None):
|
|
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
|
|
|
|
# default pm bufferize
|
|
self.pm_bufferize = PatternMatcher([
|
|
(UPat(Ops.BUFFER, tag="timeline_signal"), lambda ctx: ctx.timeline_signal),
|
|
(UPat(Ops.BUFFER, tag="timeline_value"), lambda ctx: ctx.timeline_value),
|
|
(UPat(Ops.BUFFER, name="b"), lambda ctx, b: Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=True, uncached=True, cpu_access=True))),
|
|
])
|
|
|
|
super().__init__(device, allocator, compilers, lambda *a, **kw: None, None, arch=arch)
|
|
|
|
@functools.cached_property
|
|
def timeline_signal(self) -> Buffer:
|
|
return Buffer(self.device, 0x100, dtypes.uint8, options=BufferSpec(host=True, uncached=True, cpu_access=True), preallocate=True)
|
|
|
|
@functools.cached_property
|
|
def timestamps_buf(self) -> Buffer:
|
|
return Buffer(self.device, 0x100, dtypes.uint8, options=BufferSpec(cpu_access=True), preallocate=True)
|
|
|
|
@functools.cached_property
|
|
def timeline_value(self) -> Buffer:
|
|
buf = Buffer("CPU", 1, dtypes.uint64, preallocate=True)
|
|
buf.as_memoryview(force_zero_copy=True).cast('Q')[0] = 1
|
|
return buf
|
|
|
|
def synchronize(self, timeout:int|None=None):
|
|
if not hasattr(self, 'iface'): return
|
|
sig = self.timeline_signal._buf.cpu_view().mv.cast('Q')
|
|
tl = self.timeline_value.as_memoryview(force_zero_copy=True).cast('Q')
|
|
st = time.perf_counter()
|
|
while sig[0] < tl[0] - 1:
|
|
if time.perf_counter() - st > (timeout or 3000) / 1000: self.on_device_hang()
|
|
|
|
def device_props(self) -> dict[str,Any]: return {} # to be overridden if needed. dict keys are backend dependent.
|
|
|
|
def count(self) -> int: return self.iface.count if hasattr(self, 'iface') else 1
|
|
|
|
def _select_iface(self):
|
|
assert (v:=getenv(k:=f'{type(self).__name__[:-6].upper()}_IFACE', "")) == "", \
|
|
f"{k}={v} is deprecated, use DEV={replace(DEV.target(type(self).__name__[:-6]), interface=v)} instead"
|
|
assert hasattr(self, "ifaces"), "must have ifaces to select an iface"
|
|
t = DEV.target(dev:=type(self).__name__[:-6])
|
|
filtered = select_by_name(self.ifaces, lambda i: i.__name__[:-5], t.interface, f"{dev} has no interface {t.interface!r}")
|
|
filtered = [i for i in filtered if t.interface.startswith("MOCK") or not i.__name__[:-5].startswith("MOCK")] # never fall back to mock ifaces
|
|
return select_first_inited([functools.partial(cast(Callable, iface), self, self.device_id) for iface in filtered],
|
|
f"No interface for {dev}:{self.device_id} is available")
|
|
|
|
def _is_cpu(self) -> bool: return hasattr(self, 'device') and self.device.split(":")[0] == "CPU"
|
|
|
|
def finalize(self):
|
|
try: self.synchronize() # try to finalize the device in any case
|
|
except RuntimeError as e: print(f"{self.device} synchronization failed before finalizing: {e}")
|
|
|
|
# if the device has an interface, call device_fini to clean up resources
|
|
if hasattr(self, 'iface') and hasattr(self.iface, 'device_fini'): self.iface.device_fini()
|
|
|
|
class HCQ2Buffer:
|
|
def __init__(self, va_addr:sint, size:int, meta:Any=None, _base:HCQ2Buffer|None=None, view:MMIOInterface|None=None, owner:HCQ2Compiled|None=None):
|
|
self.va_addr, self.size, self.meta, self._base, self.view, self.owner = va_addr, size, meta, _base, view, owner
|
|
|
|
def offset(self, offset:int=0, size:int|None=None) -> HCQ2Buffer:
|
|
return HCQ2Buffer(self.va_addr+offset, size or (self.size - offset), owner=self.owner, meta=self.meta,
|
|
_base=self._base or self, view=(self.view.view(offset=offset, size=size) if self.view is not None else None))
|
|
|
|
def cpu_view(self) -> MMIOInterface:
|
|
assert self.view is not None, "buffer has no cpu_view"
|
|
return self.view
|
|
|
|
@property
|
|
def base(self) -> HCQ2Buffer: return self._base or self
|
|
|
|
class HCQAllocator(LRUAllocator[HCQDeviceType], Generic[HCQDeviceType]):
|
|
def _map(self, buf:HCQ2Buffer) -> HCQ2Buffer:
|
|
if not hasattr(self, '_do_map'): raise NotImplementedError("map failed: no method implemented")
|
|
return self._do_map(buf)
|
|
|
|
@suppress_finalizing
|
|
def _free(self, buf:HCQ2Buffer, options:BufferSpec|None=None):
|
|
if options is not None and options.external_ptr is not None: return
|
|
if hasattr(self, '_do_free'): self._do_free(buf, options)
|
|
|
|
def _unmap(self, mb):
|
|
self.dev.synchronize()
|
|
self.dev.iface.dev_impl.mm.unmap_range(int(mb.va_addr), round_up(mb.size, 0x1000))
|
|
|
|
def _offset(self, buf, size:int, offset:int) -> HCQ2Buffer: return buf.offset(offset=offset, size=size)
|
|
|
|
def _wrap(self, dev:str, sz:int, opaque:HCQ2Buffer) -> Buffer:
|
|
return Buffer(dev, sz, dtypes.uint8, opaque=opaque, options=BufferSpec(external_ptr=1))
|
|
|
|
def _copy(self, dst:Buffer, src:Buffer):
|
|
from tinygrad.engine.realize import run_linear
|
|
su = UOp.from_buffer(src)
|
|
run_linear(UOp(Ops.LINEAR, dtypes.void, (su.copy_to_device(dst.device).call(UOp.from_buffer(dst), su),)), update_stats=False)
|
|
|
|
def _copyin(self, dest:HCQ2Buffer, src:memoryview):
|
|
s = Buffer(self.dev.device, len(src), dtypes.uint8, options=BufferSpec(host=True), preallocate=True)
|
|
s._buf.cpu_view()[:len(src)] = src
|
|
self._copy(self._wrap(self.dev.device, len(src), dest), s)
|
|
|
|
def _copyout(self, dest:memoryview, src:HCQ2Buffer):
|
|
d = Buffer(self.dev.device, len(dest), dtypes.uint8, options=BufferSpec(host=True), preallocate=True)
|
|
self._copy(d, self._wrap(self.dev.device, len(dest), src))
|
|
self.dev.synchronize()
|
|
dest[:] = d._buf.cpu_view()[:len(dest)]
|
|
|
|
# def _as_buffer(self, buf): return buf.cpu_view().mv
|
|
|
|
def unwrap_after(uop):
|
|
while uop.op is Ops.AFTER: uop = uop.src[0]
|
|
return uop
|
|
|
|
class HCQEncoder:
|
|
def __init__(self): self.blob, self.patches = b'', []
|
|
|
|
def get_dev_addr(self, uop:UOp) -> UOp:
|
|
if unwrap_after(uop).op not in (Ops.BUFFER, Ops.SLICE, Ops.BINARY, Ops.MSTACK, Ops.MSELECT): return uop
|
|
return UOp(Ops.GETADDR, dtypes.uint64, src=(uop, UOp(Ops.DEVICE, arg=self.dev.device)))
|
|
|
|
def append(self, *data, dtype=dtypes.uint32):
|
|
for d in data:
|
|
if isinstance(d, int): self.blob += struct.pack(f'<{dtype.fmt}', d)
|
|
else:
|
|
self.patches.append((len(self.blob), self.get_dev_addr(d), dtype))
|
|
self.blob += struct.pack(f'<{dtype.fmt}', 0)
|
|
|
|
def q(self, *values): self.append(*values)
|
|
|
|
def uop(self, dev:str|tuple[str, ...], tag:str|None=None) -> UOp:
|
|
buf = UOp.new_buffer(dev, len(self.blob), dtypes.uint8)
|
|
if tag: buf = buf.rtag(tag)
|
|
blob_uop = UOp(Ops.BINARY, dtypes.void, src=(), arg=self.blob)
|
|
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for off, val, dt in self.patches]
|
|
return buf.after(buf.store(blob_uop), *stores)
|
|
|
|
# *****************
|
|
# 0. helpers
|
|
|
|
HCQ_DEVS = frozenset(("AMD",))
|
|
HCQ_P2P_DEVS = HCQ_DEVS | frozenset(("CPU",))
|
|
|
|
def to_tuple(d): return d if isinstance(d, tuple) else (d,)
|
|
|
|
def all_devices_in(d:Any, c:frozenset[str]) -> bool: return {x.split(":")[0] for x in to_tuple(d)} <= c
|
|
|
|
# *****************
|
|
# 1.1. prep runtimes: staging copies
|
|
|
|
def _need_staging(a, b): return all_devices_in(a.device, HCQ_DEVS) and not all_devices_in(b.device, HCQ_P2P_DEVS)
|
|
|
|
def stage_copy(dst:UOp, src:UOp) -> UOp|None:
|
|
if not (_need_staging(src, dst) or _need_staging(dst, src)): return None
|
|
|
|
stage = UOp.new_buffer("CPU", src.buffer.nbytes, dtypes.uint8)
|
|
return UOp(Ops.LINEAR, dtypes.void, (src.copy_to_device("CPU").call(stage, src), stage.copy_to_device(dst.device).call(dst, stage)))
|
|
pm_insert_copy_staging = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.COPY), UPat(name="dst"), UPat(name="src"))), stage_copy)])
|
|
|
|
# *****************
|
|
# 1.2. prep runtimes: programs/kernargs
|
|
|
|
@functools.cache
|
|
def get_pm_prep_program(name:str) -> PatternMatcher|None:
|
|
try:
|
|
importlib.import_module(f'tinygrad.runtime.ops_{name.lower()}') # TODO: remove that
|
|
return importlib.import_module(f'extra.hcq2.ops_{name.lower()}2').pm_prep_program
|
|
except ImportError: return None
|
|
|
|
def prep_program(call:UOp, prg:UOp) -> UOp|None:
|
|
dev = call.src[1].device
|
|
if (pm:=get_pm_prep_program(to_tuple(dev)[0].split(":")[0])) is None or (lowered:=pm.rewrite(prg)) is None: return None
|
|
|
|
data, image_bytes = lowered
|
|
buf = UOp.new_buffer(dev, len(image_bytes), dtypes.uint8).rtag("program")
|
|
blob = UOp(Ops.BINARY, dtypes.void, src=(), arg=image_bytes)
|
|
return call.replace(src=(prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)),) + call.src[1:])
|
|
|
|
def prep_kernargs(call:UOp, prg:UOp) -> UOp:
|
|
data, info = prg.arg
|
|
patches = [(i*dtypes.uint64.itemsize, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], UOp(Ops.DEVICE, arg=call.src[1+gi].device))),
|
|
dtypes.uint64) for i,gi in enumerate(info.globals)] \
|
|
+ [(len(info.globals)*dtypes.uint64.itemsize + i*dtypes.uint32.itemsize, v, dtypes.uint32) for i,v in enumerate(info.vars)]
|
|
|
|
buf = UOp.new_buffer(call.src[1].device, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
|
|
kernargs = buf.after(*tuple(buf.index(UOp.const(dtypes.int, o), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for o, val, dt in patches))
|
|
|
|
return call.replace(src=(prg.replace(src=prg.src + (kernargs,), arg=(data, info)),) + call.src[1:])
|
|
|
|
pm_prep_runtime = PatternMatcher([
|
|
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
|
|
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, src=(UPat(), UPat(), UPat(), UPat(), UPat(Ops.BINARY)), name="prg"),),
|
|
name="call", allow_any_len=True), prep_program),
|
|
|
|
# lower kernargs (PROGRAM.src[0] is now AFTER(BUFFER, COPY) — the lowered program image)
|
|
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, src=(UPat(Ops.AFTER),), name="prg"),), name="call", allow_any_len=True), prep_kernargs),
|
|
])
|
|
|
|
# *****************
|
|
# 2.1. lowering to hcq ir
|
|
|
|
def lower_program(call:UOp, prg:UOp) -> UOp:
|
|
q = UOp(Ops.LINEAR, dtypes.void, (prg,), arg=(call.src[1].device, "COMPUTE"))
|
|
return call.replace(src=(q,) + call.src[1:]).rtag('hcq')
|
|
|
|
def lower_copy(call:UOp, copy:UOp) -> UOp|None:
|
|
dst, src = call.src[1], call.src[2]
|
|
if (hcq_dev:=next((b.device for b in (dst, src) if b.device.split(":")[0] in HCQ_DEVS), None)) is None: return None
|
|
|
|
q = UOp(Ops.LINEAR, dtypes.void, (UOp(Ops.COPY, dtypes.void, src=(dst, src), arg=src.buffer.nbytes),), arg=(hcq_dev, "COPY"))
|
|
return call.replace(src=(q,) + call.src[1:]).rtag('hcq')
|
|
|
|
pm_lower_ops = PatternMatcher([
|
|
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, src=(UPat(Ops.AFTER), UPat(Ops.AFTER)), name="prg"),), name="call", allow_any_len=True), lower_program),
|
|
(UPat(Ops.CALL, src=(UPat(Ops.COPY, name="copy"),), name="call", allow_any_len=True), lower_copy),
|
|
])
|
|
|
|
# *****************
|
|
# 2.2. queue split
|
|
|
|
# def split_into_queues(linear:UOp) -> UOp:
|
|
# out = []
|
|
# for k, grp in itertools.groupby(linear.src, lambda c: c.src[0].arg if c.op is Ops.CALL and c.src[0].op is Ops.LINEAR else None):
|
|
# if k is None: out.extend(grp)
|
|
# else:
|
|
# calls = list(grp)
|
|
# items = tuple(x for c in calls for x in c.src[0].src)
|
|
# args = tuple(a for c in calls for a in c.src[1:])
|
|
# out.append(calls[0].replace(src=(UOp(Ops.LINEAR, dtypes.void, items, arg=k),) + args))
|
|
# return linear.replace(src=tuple(out))
|
|
# pm_split_into_queues = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), split_into_queues)])
|
|
|
|
# *****************
|
|
# 2.3. barriers / signals / timeline inc
|
|
|
|
def add_barriers(call:UOp, q:UOp) -> UOp:
|
|
return call.replace(src=(q.replace(src=(UOp(Ops.BARRIER, dtypes.void), *q.src)),) + call.src[1:])
|
|
pm_add_barriers = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.LINEAR, name="q"),), name="call", allow_any_len=True), add_barriers)])
|
|
|
|
def add_signals(call:UOp, q:UOp) -> UOp:
|
|
sig = UOp.new_buffer(q.arg[0], 0x100, dtypes.uint8).rtag("timeline_signal")
|
|
tl = UOp.new_buffer(q.arg[0], 1, dtypes.uint64).rtag("timeline_value").index(UOp.const(dtypes.int, 0))
|
|
return call.replace(src=(q.replace(src=(sig.wait(tl-1), *q.src, sig.store(tl)), arg=q.arg),) + call.src[1:])
|
|
pm_add_signals = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.LINEAR, name="q"),), name="call", allow_any_len=True), add_signals)])
|
|
|
|
# *****************
|
|
# 3.1. encode cmdbufs
|
|
|
|
@functools.cache
|
|
def get_pm_lower(name:str) -> PatternMatcher|None:
|
|
try:
|
|
importlib.import_module(f'tinygrad.runtime.ops_{name.lower()}') # TODO: remove that
|
|
return importlib.import_module(f'extra.hcq2.ops_{name.lower()}2').pm_lower
|
|
except ImportError: return None
|
|
|
|
def encode_cmdbuf(call:UOp, q:UOp) -> UOp|None:
|
|
if (pm:=get_pm_lower(to_tuple(q.arg[0])[0].split(":")[0])) is None or (encoded:=pm.rewrite(q)) is None: return None
|
|
return call.replace(src=(encoded,) + call.src[1:])
|
|
pm_encode_cmdbufs = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.LINEAR, name="q"),), name="call", allow_any_len=True), encode_cmdbuf)])
|
|
|
|
# *****************
|
|
# 3.2. add timeline inc
|
|
|
|
def add_timeline_inc(call:UOp, s:UOp) -> UOp:
|
|
tl = UOp.new_buffer(s.device, 1, dtypes.uint64).rtag("timeline_value")
|
|
return call.replace(src=(tl.after(s).index(UOp.const(dtypes.int, 0), dtype=tl.dtype.ptr()).store(tl.index(UOp.const(dtypes.int, 0)) + 1),) + call.src[1:])
|
|
pm_add_timeline_inc = PatternMatcher([(UPat(Ops.CALL, tag="hcq", src=(UPat(name="s"),), name="call", allow_any_len=True), add_timeline_inc)])
|
|
|
|
# *****************
|
|
# 3.3. lift patches to the command buffer (root)
|
|
|
|
def lift_patches_to_cmdbuf(cmdbuf:UOp) -> UOp|None:
|
|
if not (patches:=dedup(u for store in cmdbuf.src[1:] for u in store.toposort() if u.op is Ops.AFTER)): return None
|
|
deps = tuple(d for p in patches for d in p.src[1:])
|
|
return cmdbuf.replace(src=cmdbuf.src + deps).substitute({p: p.src[0] for p in patches})
|
|
pm_lift_patches_to_cmdbuf = PatternMatcher([
|
|
(UPat(Ops.AFTER, src=(UPat(Ops.BUFFER, tag={"compute", "copy"}),), allow_any_len=True, name="cmdbuf"), lift_patches_to_cmdbuf),
|
|
])
|
|
|
|
# *****************
|
|
# 4. bufferize placeholders: replace placeholders with real buffers.
|
|
|
|
def bufferize_buf(buf:UOp) -> UOp|None:
|
|
if buf.tag is None: return None
|
|
uops = tuple(UOp.from_buffer((dv:=Device[dev]).pm_bufferize.rewrite(buf, ctx=dv), dev) for dev in to_tuple(buf.src[1].arg))
|
|
return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, uops)
|
|
pm_bufferize = PatternMatcher([(UPat(Ops.BUFFER, name="buf"), bufferize_buf)])
|
|
|
|
# *****************
|
|
# 5.1. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
|
|
|
|
def hold_call_buffers(call:UOp) -> UOp|None:
|
|
if not (bufs:=tuple(dedup(u for u in call.src[0].toposort() if u.op is Ops.BUFFER and u not in call.src))): return None
|
|
return call.replace(src=call.src + (UOp(Ops.BIND, dtypes.void, src=bufs),))
|
|
pm_hold_call_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), hold_call_buffers)])
|
|
|
|
# *****************
|
|
# 5.2. resolve patches
|
|
|
|
def push_stack(op, s): return UOp(Ops.STACK, op.dtype.scalar().vec(len(s.src)),
|
|
tuple(op.replace(dtype=op.dtype.scalar(), src=tuple(x if y is s else y for y in op.src)) for x in s.src))
|
|
|
|
def fold_blob_store(buf:UOp, blob:UOp) -> UOp:
|
|
for b in (buf.src if buf.op is Ops.MSTACK else (buf,)): b.buffer.ensure_allocated()._buf.cpu_view().mv.cast('B')[:len(blob.arg)] = blob.arg
|
|
return UOp(Ops.NOOP)
|
|
|
|
def fold_const_store(buf:UOp, off:UOp, val:UOp) -> UOp:
|
|
for b, v in zip((buf.src if buf.op is Ops.MSTACK else (buf,)), (val.src if val.op is Ops.STACK else (val,))):
|
|
struct.pack_into(f'<{v.dtype.fmt}', b.buffer.ensure_allocated()._buf.cpu_view().mv.cast('B'), off.arg * b.dtype.base.itemsize, v.arg)
|
|
return UOp(Ops.NOOP)
|
|
|
|
def resolve_getaddr(buf:UOp, g:UOp) -> UOp:
|
|
if isinstance(b:=buf.buffer, Buffer): return UOp.const(dtypes.uint64, b.get_buf(g.src[1].arg).va_addr)
|
|
return UOp(Ops.STACK, dtypes.uint64.vec(len(b.bufs)), tuple(UOp.const(dtypes.uint64, x.ensure_allocated()._buf.va_addr) for x in b.bufs))
|
|
|
|
pm_resolve_patches = PatternMatcher([
|
|
# multi
|
|
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
|
|
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
|
|
|
|
# getaddr
|
|
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), # getaddr(slice(x)) -> offset+getaddr(x)
|
|
lambda bv, dev: UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0], dev)) + UOp.const(dtypes.uint64, bv.src[1].arg * bv.src[0].dtype.itemsize)),
|
|
(UPat(Ops.GETADDR, src=(UPat({Ops.BUFFER, Ops.MSTACK, Ops.MSELECT}, name="buf"), UPat(Ops.DEVICE)), name="g"), resolve_getaddr),
|
|
|
|
# folders
|
|
(UPat({Ops.BUFFER, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
|
(UPat({Ops.BUFFER, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
|
|
fold_const_store),
|
|
]) + symbolic_simple
|
|
|
|
# *****************
|
|
# 6. callify hcq programs
|
|
|
|
pm_fixup = PatternMatcher([ # TODO: this should gone?
|
|
(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
|
|
])
|
|
|
|
def to_param(bufs:list[UOp], ref:UOp) -> UOp:
|
|
bufs.append(ref)
|
|
return UOp.placeholder((ref.buffer.size,), ref.dtype, len(bufs)-1)
|
|
pm_to_param = PatternMatcher([(UPat({Ops.MSELECT, Ops.MSTACK, Ops.BUFFER}, name="r"), lambda ctx, r: to_param(ctx, r))])
|
|
|
|
def parametrize_host_buffers(call:UOp) -> UOp:
|
|
body = graph_rewrite(call.src[0], pm_to_param, ctx=(bufs:=[]), bottom_up=True, name="parametrize host buffers")
|
|
return call.replace(src=(body, *bufs) + call.src[1:], tag="hcq_param")
|
|
pm_parametrize_host_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), parametrize_host_buffers)])
|
|
|
|
def callify_hcq(call:UOp) -> UOp:
|
|
sink = UOp.sink(call.src[0], arg=KernelInfo(name="hcq_submit", estimates=Estimates()), tag=1)
|
|
return to_program(sink, Device["CPU"].renderer).call(*call.src[1:])
|
|
pm_callify_hcq = PatternMatcher([(UPat(Ops.CALL, tag="hcq_param", name="call"), callify_hcq)])
|
|
|
|
@track_rewrites(lambda _,ret: f"HCQ Schedule {pluralize('Kernel', len(ret.src))}")
|
|
def hcq_schedule(linear:UOp) -> UOp:
|
|
linear = graph_rewrite(linear, pm_insert_copy_staging + pm_flatten_linear, name="insert copy staging")
|
|
linear = graph_rewrite(linear, pm_prep_runtime, name="prepare runtime")
|
|
|
|
linear = graph_rewrite(linear, pm_lower_ops, name="lower ops into hcq ir")
|
|
# linear = graph_rewrite(linear, pm_split_into_queues, name="split into queues")
|
|
linear = graph_rewrite(linear, pm_add_barriers, walk=True, name="add barriers")
|
|
linear = graph_rewrite(linear, pm_add_signals, walk=True, name="add signals")
|
|
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs")
|
|
linear = graph_rewrite(linear, pm_add_timeline_inc, walk=True, name="add timeline inc")
|
|
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)
|
|
|
|
# realize starts from here
|
|
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, name="bufferize placeholders", enter_calls=True)
|
|
linear = graph_rewrite(linear, pm_hold_call_buffers, walk=True, name="hold call buffers")
|
|
linear = graph_rewrite(linear, pm_resolve_patches, bottom_up=False, name="simplify patches", enter_calls=True)
|
|
linear = graph_rewrite(linear, pm_fixup, bottom_up=False, name="fixup", enter_calls=True)
|
|
linear = graph_rewrite(linear, pm_parametrize_host_buffers, name="parametrize host buffers")
|
|
linear = graph_rewrite(linear, pm_callify_hcq, name="callify hcq")
|
|
|
|
return linear
|