Files
tinygrad/extra/hcq2/hcq2.py
2026-05-29 22:09:52 +03:00

396 lines
20 KiB
Python

from __future__ import annotations
from typing import cast, Callable, TypeVar, Generic, Any, TYPE_CHECKING
import struct, functools, time, collections, importlib, itertools
from dataclasses import replace
if TYPE_CHECKING: from tinygrad.engine.realize import ExecContext
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, mv_address, round_up, DEBUG, dedup, pluralize
from tinygrad.device import Device, Buffer, BufferSpec, Compiled, LRUAllocator, MultiBuffer
from tinygrad.uop.ops import Ops, sint, UOp, UPat, PatternMatcher, KernelInfo, graph_rewrite, track_rewrites, GroupOp
from tinygrad.uop.symbolic import symbolic_simple, symbolic
from tinygrad.dtype import dtypes, DType
from dataclasses import dataclass, field
from tinygrad.runtime.support.memory import BumpAllocator
from tinygrad.runtime.support.hcq import MMIOInterface
from tinygrad.renderer import Renderer, Estimates
from tinygrad.engine.realize import to_program, track_stats, get_call_arg_uops, resolve_params, pm_flatten_linear
HCQDeviceType = TypeVar('HCQDeviceType', bound='HCQ2Compiled')
class HCQ2Compiled(Compiled):
timestamp_divider: float = 1000.0 # GPU timestamp counter ticks per microsecond; override per device
def __init__(self, device:str, allocator:'HCQAllocator', compilers:list[type[Renderer]], runtime, can_recover:bool=False, arch=None):
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
# default pm bufferize
self.pm_bufferize = PatternMatcher([
(UPat(Ops.BUFFER, tag="timeline_signal"), lambda ctx: ctx.timeline_signal),
(UPat(Ops.BUFFER, tag="timeline_value"), lambda ctx: ctx.timeline_value),
(UPat(Ops.BUFFER, name="b"), lambda ctx, b: Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=True, uncached=True, cpu_access=True))),
])
super().__init__(device, allocator, compilers, lambda *a, **kw: None, None, arch=arch)
@functools.cached_property
def timeline_signal(self) -> Buffer:
return Buffer(self.device, 0x100, dtypes.uint8, options=BufferSpec(host=True, uncached=True, cpu_access=True), preallocate=True)
@functools.cached_property
def timestamps_buf(self) -> Buffer:
return Buffer(self.device, 0x100, dtypes.uint8, options=BufferSpec(cpu_access=True), preallocate=True)
@functools.cached_property
def timeline_value(self) -> Buffer:
buf = Buffer("CPU", 1, dtypes.uint64, preallocate=True)
buf.as_memoryview(force_zero_copy=True).cast('Q')[0] = 1
return buf
def synchronize(self, timeout:int|None=None):
if not hasattr(self, 'iface'): return
sig = self.timeline_signal._buf.cpu_view().mv.cast('Q')
tl = self.timeline_value.as_memoryview(force_zero_copy=True).cast('Q')
st = time.perf_counter()
while sig[0] < tl[0] - 1:
if time.perf_counter() - st > (timeout or 3000) / 1000: self.on_device_hang()
def device_props(self) -> dict[str,Any]: return {} # to be overridden if needed. dict keys are backend dependent.
def count(self) -> int: return self.iface.count if hasattr(self, 'iface') else 1
def _select_iface(self):
assert (v:=getenv(k:=f'{type(self).__name__[:-6].upper()}_IFACE', "")) == "", \
f"{k}={v} is deprecated, use DEV={replace(DEV.target(type(self).__name__[:-6]), interface=v)} instead"
assert hasattr(self, "ifaces"), "must have ifaces to select an iface"
t = DEV.target(dev:=type(self).__name__[:-6])
filtered = select_by_name(self.ifaces, lambda i: i.__name__[:-5], t.interface, f"{dev} has no interface {t.interface!r}")
filtered = [i for i in filtered if t.interface.startswith("MOCK") or not i.__name__[:-5].startswith("MOCK")] # never fall back to mock ifaces
return select_first_inited([functools.partial(cast(Callable, iface), self, self.device_id) for iface in filtered],
f"No interface for {dev}:{self.device_id} is available")
def _is_cpu(self) -> bool: return hasattr(self, 'device') and self.device.split(":")[0] == "CPU"
def finalize(self):
try: self.synchronize() # try to finalize the device in any case
except RuntimeError as e: print(f"{self.device} synchronization failed before finalizing: {e}")
# if the device has an interface, call device_fini to clean up resources
if hasattr(self, 'iface') and hasattr(self.iface, 'device_fini'): self.iface.device_fini()
class HCQ2Buffer:
def __init__(self, va_addr:sint, size:int, meta:Any=None, _base:HCQ2Buffer|None=None, view:MMIOInterface|None=None, owner:HCQ2Compiled|None=None):
self.va_addr, self.size, self.meta, self._base, self.view, self.owner = va_addr, size, meta, _base, view, owner
def offset(self, offset:int=0, size:int|None=None) -> HCQ2Buffer:
return HCQ2Buffer(self.va_addr+offset, size or (self.size - offset), owner=self.owner, meta=self.meta,
_base=self._base or self, view=(self.view.view(offset=offset, size=size) if self.view is not None else None))
def cpu_view(self) -> MMIOInterface:
assert self.view is not None, "buffer has no cpu_view"
return self.view
@property
def base(self) -> HCQ2Buffer: return self._base or self
class HCQAllocator(LRUAllocator[HCQDeviceType], Generic[HCQDeviceType]):
def _map(self, buf:HCQ2Buffer) -> HCQ2Buffer:
if not hasattr(self, '_do_map'): raise NotImplementedError("map failed: no method implemented")
return self._do_map(buf)
@suppress_finalizing
def _free(self, buf:HCQ2Buffer, options:BufferSpec|None=None):
if options is not None and options.external_ptr is not None: return
if hasattr(self, '_do_free'): self._do_free(buf, options)
def _unmap(self, mb):
self.dev.synchronize()
self.dev.iface.dev_impl.mm.unmap_range(int(mb.va_addr), round_up(mb.size, 0x1000))
def _offset(self, buf, size:int, offset:int) -> HCQ2Buffer: return buf.offset(offset=offset, size=size)
def _wrap(self, dev:str, sz:int, opaque:HCQ2Buffer) -> Buffer:
return Buffer(dev, sz, dtypes.uint8, opaque=opaque, options=BufferSpec(external_ptr=1))
def _copy(self, dst:Buffer, src:Buffer):
from tinygrad.engine.realize import run_linear
su = UOp.from_buffer(src)
run_linear(UOp(Ops.LINEAR, dtypes.void, (su.copy_to_device(dst.device).call(UOp.from_buffer(dst), su),)), update_stats=False)
def _copyin(self, dest:HCQ2Buffer, src:memoryview):
s = Buffer(self.dev.device, len(src), dtypes.uint8, options=BufferSpec(host=True), preallocate=True)
s._buf.cpu_view()[:len(src)] = src
self._copy(self._wrap(self.dev.device, len(src), dest), s)
def _copyout(self, dest:memoryview, src:HCQ2Buffer):
d = Buffer(self.dev.device, len(dest), dtypes.uint8, options=BufferSpec(host=True), preallocate=True)
self._copy(d, self._wrap(self.dev.device, len(dest), src))
self.dev.synchronize()
dest[:] = d._buf.cpu_view()[:len(dest)]
# def _as_buffer(self, buf): return buf.cpu_view().mv
def unwrap_after(uop):
while uop.op is Ops.AFTER: uop = uop.src[0]
return uop
class HCQEncoder:
def __init__(self): self.blob, self.patches = b'', []
def get_dev_addr(self, uop:UOp) -> UOp:
if unwrap_after(uop).op not in (Ops.BUFFER, Ops.SLICE, Ops.BINARY, Ops.MSTACK, Ops.MSELECT): return uop
return UOp(Ops.GETADDR, dtypes.uint64, src=(uop, UOp(Ops.DEVICE, arg=self.dev.device)))
def append(self, *data, dtype=dtypes.uint32):
for d in data:
if isinstance(d, int): self.blob += struct.pack(f'<{dtype.fmt}', d)
else:
self.patches.append((len(self.blob), self.get_dev_addr(d), dtype))
self.blob += struct.pack(f'<{dtype.fmt}', 0)
def q(self, *values): self.append(*values)
def uop(self, dev:str|tuple[str, ...], tag:str|None=None) -> UOp:
buf = UOp.new_buffer(dev, len(self.blob), dtypes.uint8)
if tag: buf = buf.rtag(tag)
blob_uop = UOp(Ops.BINARY, dtypes.void, src=(), arg=self.blob)
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for off, val, dt in self.patches]
return buf.after(buf.store(blob_uop), *stores)
# *****************
# 0. helpers
HCQ_DEVS = frozenset(("AMD",))
HCQ_P2P_DEVS = HCQ_DEVS | frozenset(("CPU",))
def to_tuple(d): return d if isinstance(d, tuple) else (d,)
def all_devices_in(d:Any, c:frozenset[str]) -> bool: return {x.split(":")[0] for x in to_tuple(d)} <= c
# *****************
# 1.1. prep runtimes: staging copies
def _need_staging(a, b): return all_devices_in(a.device, HCQ_DEVS) and not all_devices_in(b.device, HCQ_P2P_DEVS)
def stage_copy(dst:UOp, src:UOp) -> UOp|None:
if not (_need_staging(src, dst) or _need_staging(dst, src)): return None
stage = UOp.new_buffer("CPU", src.buffer.nbytes, dtypes.uint8)
return UOp(Ops.LINEAR, dtypes.void, (src.copy_to_device("CPU").call(stage, src), stage.copy_to_device(dst.device).call(dst, stage)))
pm_insert_copy_staging = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.COPY), UPat(name="dst"), UPat(name="src"))), stage_copy)])
# *****************
# 1.2. prep runtimes: programs/kernargs
@functools.cache
def get_pm_prep_program(name:str) -> PatternMatcher|None:
try:
importlib.import_module(f'tinygrad.runtime.ops_{name.lower()}') # TODO: remove that
return importlib.import_module(f'extra.hcq2.ops_{name.lower()}2').pm_prep_program
except ImportError: return None
def prep_program(call:UOp, prg:UOp) -> UOp|None:
dev = call.src[1].device
if (pm:=get_pm_prep_program(to_tuple(dev)[0].split(":")[0])) is None or (lowered:=pm.rewrite(prg)) is None: return None
data, image_bytes = lowered
buf = UOp.new_buffer(dev, len(image_bytes), dtypes.uint8).rtag("program")
blob = UOp(Ops.BINARY, dtypes.void, src=(), arg=image_bytes)
return call.replace(src=(prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)),) + call.src[1:])
def prep_kernargs(call:UOp, prg:UOp) -> UOp:
data, info = prg.arg
patches = [(i*dtypes.uint64.itemsize, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], UOp(Ops.DEVICE, arg=call.src[1+gi].device))),
dtypes.uint64) for i,gi in enumerate(info.globals)] \
+ [(len(info.globals)*dtypes.uint64.itemsize + i*dtypes.uint32.itemsize, v, dtypes.uint32) for i,v in enumerate(info.vars)]
buf = UOp.new_buffer(call.src[1].device, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
kernargs = buf.after(*tuple(buf.index(UOp.const(dtypes.int, o), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for o, val, dt in patches))
return call.replace(src=(prg.replace(src=prg.src + (kernargs,), arg=(data, info)),) + call.src[1:])
pm_prep_runtime = PatternMatcher([
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, src=(UPat(), UPat(), UPat(), UPat(), UPat(Ops.BINARY)), name="prg"),),
name="call", allow_any_len=True), prep_program),
# lower kernargs (PROGRAM.src[0] is now AFTER(BUFFER, COPY) — the lowered program image)
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, src=(UPat(Ops.AFTER),), name="prg"),), name="call", allow_any_len=True), prep_kernargs),
])
# *****************
# 2.1. lowering to hcq ir
def lower_program(call:UOp, prg:UOp) -> UOp:
q = UOp(Ops.LINEAR, dtypes.void, (prg,), arg=(call.src[1].device, "COMPUTE"))
return call.replace(src=(q,) + call.src[1:]).rtag('hcq')
def lower_copy(call:UOp, copy:UOp) -> UOp|None:
dst, src = call.src[1], call.src[2]
if (hcq_dev:=next((b.device for b in (dst, src) if b.device.split(":")[0] in HCQ_DEVS), None)) is None: return None
q = UOp(Ops.LINEAR, dtypes.void, (UOp(Ops.COPY, dtypes.void, src=(dst, src), arg=src.buffer.nbytes),), arg=(hcq_dev, "COPY"))
return call.replace(src=(q,) + call.src[1:]).rtag('hcq')
pm_lower_ops = PatternMatcher([
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, src=(UPat(Ops.AFTER), UPat(Ops.AFTER)), name="prg"),), name="call", allow_any_len=True), lower_program),
(UPat(Ops.CALL, src=(UPat(Ops.COPY, name="copy"),), name="call", allow_any_len=True), lower_copy),
])
# *****************
# 2.2. queue split
# def split_into_queues(linear:UOp) -> UOp:
# out = []
# for k, grp in itertools.groupby(linear.src, lambda c: c.src[0].arg if c.op is Ops.CALL and c.src[0].op is Ops.LINEAR else None):
# if k is None: out.extend(grp)
# else:
# calls = list(grp)
# items = tuple(x for c in calls for x in c.src[0].src)
# args = tuple(a for c in calls for a in c.src[1:])
# out.append(calls[0].replace(src=(UOp(Ops.LINEAR, dtypes.void, items, arg=k),) + args))
# return linear.replace(src=tuple(out))
# pm_split_into_queues = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), split_into_queues)])
# *****************
# 2.3. barriers / signals / timeline inc
def add_barriers(call:UOp, q:UOp) -> UOp:
return call.replace(src=(q.replace(src=(UOp(Ops.BARRIER, dtypes.void), *q.src)),) + call.src[1:])
pm_add_barriers = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.LINEAR, name="q"),), name="call", allow_any_len=True), add_barriers)])
def add_signals(call:UOp, q:UOp) -> UOp:
sig = UOp.new_buffer(q.arg[0], 0x100, dtypes.uint8).rtag("timeline_signal")
tl = UOp.new_buffer(q.arg[0], 1, dtypes.uint64).rtag("timeline_value").index(UOp.const(dtypes.int, 0))
return call.replace(src=(q.replace(src=(sig.wait(tl-1), *q.src, sig.store(tl)), arg=q.arg),) + call.src[1:])
pm_add_signals = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.LINEAR, name="q"),), name="call", allow_any_len=True), add_signals)])
# *****************
# 3.1. encode cmdbufs
@functools.cache
def get_pm_lower(name:str) -> PatternMatcher|None:
try:
importlib.import_module(f'tinygrad.runtime.ops_{name.lower()}') # TODO: remove that
return importlib.import_module(f'extra.hcq2.ops_{name.lower()}2').pm_lower
except ImportError: return None
def encode_cmdbuf(call:UOp, q:UOp) -> UOp|None:
if (pm:=get_pm_lower(to_tuple(q.arg[0])[0].split(":")[0])) is None or (encoded:=pm.rewrite(q)) is None: return None
return call.replace(src=(encoded,) + call.src[1:])
pm_encode_cmdbufs = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.LINEAR, name="q"),), name="call", allow_any_len=True), encode_cmdbuf)])
# *****************
# 3.2. add timeline inc
def add_timeline_inc(call:UOp, s:UOp) -> UOp:
tl = UOp.new_buffer(s.device, 1, dtypes.uint64).rtag("timeline_value")
return call.replace(src=(tl.after(s).index(UOp.const(dtypes.int, 0), dtype=tl.dtype.ptr()).store(tl.index(UOp.const(dtypes.int, 0)) + 1),) + call.src[1:])
pm_add_timeline_inc = PatternMatcher([(UPat(Ops.CALL, tag="hcq", src=(UPat(name="s"),), name="call", allow_any_len=True), add_timeline_inc)])
# *****************
# 3.3. lift patches to the command buffer (root)
def lift_patches_to_cmdbuf(cmdbuf:UOp) -> UOp|None:
if not (patches:=dedup(u for store in cmdbuf.src[1:] for u in store.toposort() if u.op is Ops.AFTER)): return None
deps = tuple(d for p in patches for d in p.src[1:])
return cmdbuf.replace(src=cmdbuf.src + deps).substitute({p: p.src[0] for p in patches})
pm_lift_patches_to_cmdbuf = PatternMatcher([
(UPat(Ops.AFTER, src=(UPat(Ops.BUFFER, tag={"compute", "copy"}),), allow_any_len=True, name="cmdbuf"), lift_patches_to_cmdbuf),
])
# *****************
# 4. bufferize placeholders: replace placeholders with real buffers.
def bufferize_buf(buf:UOp) -> UOp|None:
if buf.tag is None: return None
uops = tuple(UOp.from_buffer((dv:=Device[dev]).pm_bufferize.rewrite(buf, ctx=dv), dev) for dev in to_tuple(buf.src[1].arg))
return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, uops)
pm_bufferize = PatternMatcher([(UPat(Ops.BUFFER, name="buf"), bufferize_buf)])
# *****************
# 5.1. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
def hold_call_buffers(call:UOp) -> UOp|None:
if not (bufs:=tuple(dedup(u for u in call.src[0].toposort() if u.op is Ops.BUFFER and u not in call.src))): return None
return call.replace(src=call.src + (UOp(Ops.BIND, dtypes.void, src=bufs),))
pm_hold_call_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), hold_call_buffers)])
# *****************
# 5.2. resolve patches
def push_stack(op, s): return UOp(Ops.STACK, op.dtype.scalar().vec(len(s.src)),
tuple(op.replace(dtype=op.dtype.scalar(), src=tuple(x if y is s else y for y in op.src)) for x in s.src))
def fold_blob_store(buf:UOp, blob:UOp) -> UOp:
for b in (buf.src if buf.op is Ops.MSTACK else (buf,)): b.buffer.ensure_allocated()._buf.cpu_view().mv.cast('B')[:len(blob.arg)] = blob.arg
return UOp(Ops.NOOP)
def fold_const_store(buf:UOp, off:UOp, val:UOp) -> UOp:
for b, v in zip((buf.src if buf.op is Ops.MSTACK else (buf,)), (val.src if val.op is Ops.STACK else (val,))):
struct.pack_into(f'<{v.dtype.fmt}', b.buffer.ensure_allocated()._buf.cpu_view().mv.cast('B'), off.arg * b.dtype.base.itemsize, v.arg)
return UOp(Ops.NOOP)
def resolve_getaddr(buf:UOp, g:UOp) -> UOp:
if isinstance(b:=buf.buffer, Buffer): return UOp.const(dtypes.uint64, b.get_buf(g.src[1].arg).va_addr)
return UOp(Ops.STACK, dtypes.uint64.vec(len(b.bufs)), tuple(UOp.const(dtypes.uint64, x.ensure_allocated()._buf.va_addr) for x in b.bufs))
pm_resolve_patches = PatternMatcher([
# multi
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
# getaddr
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), # getaddr(slice(x)) -> offset+getaddr(x)
lambda bv, dev: UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0], dev)) + UOp.const(dtypes.uint64, bv.src[1].arg * bv.src[0].dtype.itemsize)),
(UPat(Ops.GETADDR, src=(UPat({Ops.BUFFER, Ops.MSTACK, Ops.MSELECT}, name="buf"), UPat(Ops.DEVICE)), name="g"), resolve_getaddr),
# folders
(UPat({Ops.BUFFER, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
(UPat({Ops.BUFFER, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
fold_const_store),
]) + symbolic_simple
# *****************
# 6. callify hcq programs
pm_fixup = PatternMatcher([ # TODO: this should gone?
(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
])
def to_param(bufs:list[UOp], ref:UOp) -> UOp:
bufs.append(ref)
return UOp.placeholder((ref.buffer.size,), ref.dtype, len(bufs)-1)
pm_to_param = PatternMatcher([(UPat({Ops.MSELECT, Ops.MSTACK, Ops.BUFFER}, name="r"), lambda ctx, r: to_param(ctx, r))])
def parametrize_host_buffers(call:UOp) -> UOp:
body = graph_rewrite(call.src[0], pm_to_param, ctx=(bufs:=[]), bottom_up=True, name="parametrize host buffers")
return call.replace(src=(body, *bufs) + call.src[1:], tag="hcq_param")
pm_parametrize_host_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), parametrize_host_buffers)])
def callify_hcq(call:UOp) -> UOp:
sink = UOp.sink(call.src[0], arg=KernelInfo(name="hcq_submit", estimates=Estimates()), tag=1)
return to_program(sink, Device["CPU"].renderer).call(*call.src[1:])
pm_callify_hcq = PatternMatcher([(UPat(Ops.CALL, tag="hcq_param", name="call"), callify_hcq)])
@track_rewrites(lambda _,ret: f"HCQ Schedule {pluralize('Kernel', len(ret.src))}")
def hcq_schedule(linear:UOp) -> UOp:
linear = graph_rewrite(linear, pm_insert_copy_staging + pm_flatten_linear, name="insert copy staging")
linear = graph_rewrite(linear, pm_prep_runtime, name="prepare runtime")
linear = graph_rewrite(linear, pm_lower_ops, name="lower ops into hcq ir")
# linear = graph_rewrite(linear, pm_split_into_queues, name="split into queues")
linear = graph_rewrite(linear, pm_add_barriers, walk=True, name="add barriers")
linear = graph_rewrite(linear, pm_add_signals, walk=True, name="add signals")
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs")
linear = graph_rewrite(linear, pm_add_timeline_inc, walk=True, name="add timeline inc")
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)
# realize starts from here
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, name="bufferize placeholders", enter_calls=True)
linear = graph_rewrite(linear, pm_hold_call_buffers, walk=True, name="hold call buffers")
linear = graph_rewrite(linear, pm_resolve_patches, bottom_up=False, name="simplify patches", enter_calls=True)
linear = graph_rewrite(linear, pm_fixup, bottom_up=False, name="fixup", enter_calls=True)
linear = graph_rewrite(linear, pm_parametrize_host_buffers, name="parametrize host buffers")
linear = graph_rewrite(linear, pm_callify_hcq, name="callify hcq")
return linear