Files
tinygrad/extra/hcq2/ops_amd2.py
2026-05-29 22:09:52 +03:00

552 lines
33 KiB
Python

from __future__ import annotations
from typing import cast
import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, contextlib, sys, weakref, itertools, collections, atexit
assert sys.platform != 'win32'
from dataclasses import dataclass
from extra.hcq2.hcq2 import HCQ2Compiled, HCQAllocator, HCQ2Buffer, HCQEncoder
from tinygrad.uop.ops import sint, UOp
from tinygrad.device import Compiled, BufferSpec, Buffer, Device
from tinygrad.dtype import dtypes
from tinygrad.helpers import getenv, round_up, data64_le, DEBUG, PROFILE, ProfileEvent, lo32, hi32, colored, prod, ContextVar, TracingKey
from tinygrad.helpers import VIZ, ceildiv, unwrap, pluralize
from tinygrad.renderer.cstyle import HIPRenderer, HIPCCRenderer
from tinygrad.renderer.llvmir import AMDLLVMRenderer
from tinygrad.runtime.autogen import kfd, hsa, sqtt, amdgpu_kd, amdgpu_drm
from tinygrad.runtime.autogen.am import am
from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.support.am.amdev import AMDev, AMMemoryManager
from tinygrad.runtime.support.amd import AMDReg, AMDIP, import_module, import_soc, import_pmc
from tinygrad.runtime.support.system import PCIIfaceBase, PCIAllocationMeta, USBPCIDevice, MAP_FIXED, MAP_NORESERVE
from tinygrad.runtime.support.usb import USB3
from tinygrad.runtime.support.memory import AddrSpace, BumpAllocator
from tinygrad.runtime.ops_amd import SQTT, SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE, SQTT_SIMD_SEL, SQTT_TOKEN_EXCLUDE, PMC
from tinygrad.runtime.ops_amd import EVENT_INDEX_PARTIAL_FLUSH, WAIT_REG_MEM_FUNCTION_EQ, WAIT_REG_MEM_FUNCTION_NEQ, WAIT_REG_MEM_FUNCTION_GEQ
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
from tinygrad.engine.realize import get_runtime
from tinygrad.uop.ops import Ops, UPat, PatternMatcher, graph_rewrite
class AMDComputeQueue(HCQEncoder):
def __init__(self, dev:AMDDevice, devs:tuple[str, ...]|None=None):
super().__init__()
self.dev, self.devs = dev, devs or (dev.device,)
self.pm4, self.gc, self.nbio, self.soc = dev.pm4, dev.gc, dev.nbio, dev.soc
def pkt3(self, cmd, *vals): self.q(self.pm4.PACKET3(cmd, len(vals) - 1), *vals)
def wreg(self, reg:AMDReg, *args:sint, **kwargs:int):
if bool(args) == bool(kwargs): raise RuntimeError('One (and only one) of *args or **kwargs must be specified')
if self.pm4.PACKET3_SET_SH_REG_START <= reg.addr[0] < self.pm4.PACKET3_SET_SH_REG_END:
set_packet, set_packet_start = self.pm4.PACKET3_SET_SH_REG, self.pm4.PACKET3_SET_SH_REG_START
elif self.pm4.PACKET3_SET_UCONFIG_REG_START <= reg.addr[0] < self.pm4.PACKET3_SET_UCONFIG_REG_START + 2**16-1:
set_packet, set_packet_start = self.pm4.PACKET3_SET_UCONFIG_REG, self.pm4.PACKET3_SET_UCONFIG_REG_START
else: raise RuntimeError(f'Cannot set {reg.name} ({reg.addr[0]}) via pm4 packet')
self.pkt3(set_packet, reg.addr[0] - set_packet_start, *(args or (reg.encode(**kwargs),)))
def wait_reg_mem(self, value, mask=0xffffffff, mem=None, reg=None, reg_done=0, op=WAIT_REG_MEM_FUNCTION_GEQ):
wrm_info_dw = self.pm4.WAIT_REG_MEM_MEM_SPACE(int(mem is not None)) | self.pm4.WAIT_REG_MEM_OPERATION(int(mem is None and reg_done > 0)) \
| self.pm4.WAIT_REG_MEM_FUNCTION(op) | self.pm4.WAIT_REG_MEM_ENGINE(0)
self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, wrm_info_dw, *(data64_le(mem) if mem is not None else (reg, reg_done)), value, mask, 4)
def acquire_mem(self, addr=0x0, sz=(1 << 64)-1, gli=1, glm=1, glk=1, glv=1, gl1=1, gl2=1):
if self.dev.target[0] != 9:
cache_flags_dw = self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLI_INV(gli) \
| self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLM_INV(glm) | self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLM_WB(glm) \
| self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLK_INV(glk) | self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLK_WB(glk) \
| self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLV_INV(glv) | self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GL1_INV(gl1) \
| self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GL2_INV(gl2) | self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GL2_WB(gl2)
self.pkt3(self.pm4.PACKET3_ACQUIRE_MEM, 0, *data64_le(sz), *data64_le(addr), 0, cache_flags_dw)
else:
cp_coher_cntl = self.pm4.PACKET3_ACQUIRE_MEM_CP_COHER_CNTL_SH_ICACHE_ACTION_ENA(gli) | \
self.pm4.PACKET3_ACQUIRE_MEM_CP_COHER_CNTL_SH_KCACHE_ACTION_ENA(glk) | \
self.pm4.PACKET3_ACQUIRE_MEM_CP_COHER_CNTL_TC_ACTION_ENA(gl2) | \
self.pm4.PACKET3_ACQUIRE_MEM_CP_COHER_CNTL_TCL1_ACTION_ENA(gl1) | \
self.pm4.PACKET3_ACQUIRE_MEM_CP_COHER_CNTL_TC_WB_ACTION_ENA(gl2)
self.pkt3(self.pm4.PACKET3_ACQUIRE_MEM, cp_coher_cntl, *data64_le(sz), *data64_le(addr), 0x0000000A)
def release_mem(self, address=0x0, value=0, data_sel=0, int_sel=2, ctxid=0, cache_flush=False):
if self.dev.target[0] != 9:
cache_flags_dw = 0 if not cache_flush else (self.pm4.PACKET3_RELEASE_MEM_GCR_GLV_INV | self.pm4.PACKET3_RELEASE_MEM_GCR_GL1_INV \
| self.pm4.PACKET3_RELEASE_MEM_GCR_GL2_INV | self.pm4.PACKET3_RELEASE_MEM_GCR_GLM_WB \
| self.pm4.PACKET3_RELEASE_MEM_GCR_GLM_INV | self.pm4.PACKET3_RELEASE_MEM_GCR_GL2_WB | self.pm4.PACKET3_RELEASE_MEM_GCR_SEQ)
event_dw = self.pm4.PACKET3_RELEASE_MEM_EVENT_TYPE(self.pm4.CACHE_FLUSH_AND_INV_TS_EVENT) \
| self.pm4.PACKET3_RELEASE_MEM_EVENT_INDEX(self.pm4.event_index__mec_release_mem__end_of_pipe)
memsel_dw = self.pm4.PACKET3_RELEASE_MEM_DATA_SEL(data_sel) | self.pm4.PACKET3_RELEASE_MEM_INT_SEL(int_sel) \
| self.pm4.PACKET3_RELEASE_MEM_DST_SEL(0)
else:
cache_flags_dw = 0 if not cache_flush else (self.pm4.EOP_TC_WB_ACTION_EN | self.pm4.EOP_TC_NC_ACTION_EN)
event_dw = self.pm4.EVENT_TYPE(self.pm4.CACHE_FLUSH_AND_INV_TS_EVENT) | self.pm4.EVENT_INDEX(self.pm4.event_index__mec_release_mem__end_of_pipe)
memsel_dw = self.pm4.DATA_SEL(data_sel) | self.pm4.INT_SEL(int_sel)
ctxid = 0
self.pkt3(self.pm4.PACKET3_RELEASE_MEM, event_dw | cache_flags_dw, memsel_dw, *data64_le(address), *data64_le(value), ctxid)
def memory_barrier(self):
pf = '' if self.nbio.version[0] == 2 else '0' if self.nbio.version[:2] != (7, 11) else '1'
self.wait_reg_mem(reg=getattr(self.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_REQ').addr[0],
reg_done=getattr(self.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_DONE').addr[0], value=0xffffffff)
self.acquire_mem()
def wait(self, x): self.wait_reg_mem(x.src[1], mem=self.get_dev_addr(x.src[0]))
def barrier(self, x): self.memory_barrier()
def store(self, x):
self.release_mem(self.get_dev_addr(x.src[0]), x.src[1], self.pm4.data_sel__mec_release_mem__send_32_bit_low,
self.pm4.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True)
def timestamp(self, x):
self.release_mem(self.get_dev_addr(x.src[0]), 0, self.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
self.pm4.int_sel__mec_release_mem__none)
def program(self, x):
data, info = x.arg
lib_gpu, args = x.src
prog_addr = self.get_dev_addr(lib_gpu) + data.entry_point_offset
self.acquire_mem(gli=0, gl2=0)
scratch_addr = self.get_dev_addr(UOp.new_buffer(self.devs, data.private_segment_size, dtypes.uint8).rtag("scratch"))
args_addr = self.get_dev_addr(args)
user_regs = []
if data.enable_private_segment_sgpr:
scratch_hilo = data64_le(scratch_addr)
user_regs = [scratch_hilo[0], scratch_hilo[1] | 1 << 31, 0xffffffff, 0x20c14000]
if data.enable_dispatch_ptr: user_regs += [*data64_le(args_addr + data.kernargs_segment_size)]
user_regs += [*data64_le(args_addr)]
self.wreg(self.gc.regCOMPUTE_PGM_LO, *data64_le(prog_addr >> 8))
self.wreg(self.gc.regCOMPUTE_PGM_RSRC1, data.rsrc1, data.rsrc2)
self.wreg(self.gc.regCOMPUTE_PGM_RSRC3, data.rsrc3)
self.wreg(self.gc.regCOMPUTE_TMPRING_SIZE, self.dev.tmpring_size(data.private_segment_size))
for xcc_id in range(self.dev.xccs):
scratch_base = scratch_addr + (data.private_segment_size // self.dev.xccs * xcc_id)
self.wreg(self.gc.regCOMPUTE_DISPATCH_SCRATCH_BASE_LO, *data64_le(scratch_base >> 8))
self.wreg(self.gc.regCOMPUTE_RESTART_X, 0, 0, 0)
self.wreg(self.gc.regCOMPUTE_USER_DATA_0, *user_regs)
self.wreg(self.gc.regCOMPUTE_RESOURCE_LIMITS, self.gc.regCOMPUTE_RESOURCE_LIMITS.encode(waves_per_sh=getenv("WAVES_PER_SH")))
self.wreg(self.gc.regCOMPUTE_START_X, 0, 0, 0, *(info.local_size or (1, 1, 1)), 0, 0)
dispatch_init = self.gc.regCOMPUTE_DISPATCH_INITIATOR.encode(
**({'cs_w32_en': int(data.wave32)} if self.dev.target[0] != 9 else {}), force_start_at_000=1, compute_shader_en=1)
self.pkt3(self.pm4.PACKET3_DISPATCH_DIRECT, *info.global_size, dispatch_init)
self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.CS_PARTIAL_FLUSH) | self.pm4.EVENT_INDEX(EVENT_INDEX_PARTIAL_FLUSH))
amd_inner_pm = PatternMatcher([
(UPat(Ops.LINEAR, src=(UPat(Ops.WAIT, name="x"),)), lambda ctx, x: ctx.wait(x)),
(UPat(Ops.LINEAR, src=(UPat(Ops.BARRIER, name="x"),)), lambda ctx, x: ctx.barrier(x)),
(UPat(Ops.LINEAR, src=(UPat(Ops.PROGRAM, name="x"),)), lambda ctx, x: ctx.program(x)),
(UPat(Ops.LINEAR, src=(UPat(Ops.CUSTOM_FUNCTION, arg="timestamp", name="x"),)), lambda ctx, x: ctx.timestamp(x)),
(UPat(Ops.LINEAR, src=(UPat(Ops.STORE, src=(UPat((Ops.BUFFER, Ops.PARAM)), UPat()), name="x"),)), lambda ctx, x: ctx.store(x)),
])
def amd_lower_pm4(linear, devs):
enc = AMDComputeQueue(Device[devs[0]], devs)
graph_rewrite(linear.replace(src=tuple(UOp(Ops.LINEAR, dtypes.void, (cmd,)) for cmd in linear.src)), amd_inner_pm, ctx=enc, name="amd: encode")
return enc.uop(dev=devs if len(devs) > 1 else devs[0], tag="compute")
def amd_submit_pm4(cmdbuf, devs):
size, zero = UOp.const(dtypes.uint32, cmdbuf.src[0].arg // dtypes.uint32.itemsize), UOp.const(dtypes.int, 0)
# the compute queue's ring and its host-side ring/write/put pointers (placeholders, resolved in pm_bufferize)
q = Device['AMD'].compute_queue
ring, wptr, doorbell, put_ptr = (UOp.new_buffer(devs, b.size, b.dtype).rtag(("COMPUTE:0", name))
for name, b in (("ring", q.ring), ("write_ptr", q.write_ptr), ("doorbell", q.doorbell), ("put_value", q.put_value)))
# place the cmdbuf at the ring's write offset, wrapping the ring
put = put_ptr.index(zero)
next_put = put + size.cast(put.dtype)
i = UOp.range(size, 0, dtype=dtypes.int, src=(cmdbuf,))
ring_idx = ((put + i.cast(put.dtype)) % q.ring.size).cast(dtypes.int)
# copy the cmdbuf into the ring and advance the put/write pointers
copy_to_ring = ring.index(ring_idx, dtype=ring.dtype.ptr()).store(
cmdbuf.index(i*4, dtype=cmdbuf.dtype.ptr()).cast(dtypes.uint32.ptr()).load()).end(i)
bump_put_ptr = put_ptr.index(zero, dtype=put_ptr.dtype.ptr()).store(next_put)
bump_wptr = wptr.index(zero, dtype=wptr.dtype.ptr()).store(next_put)
# ring the doorbell once the copy and pointer bumps have landed
flush = UOp.barrier(copy_to_ring, bump_put_ptr, bump_wptr)
return doorbell.after(flush).index(zero, dtype=doorbell.dtype.ptr()).store(next_put)
class AMDCopyQueue(HCQEncoder):
def __init__(self, dev:AMDDevice, queue_idx=0):
super().__init__()
self.dev = dev
self.sdma, self.queue_idx, self.max_copy_size = dev.sdma, queue_idx, dev.max_copy_size
def copy(self, x):
dest, src, copy_size = self.get_dev_addr(x.src[0]), self.get_dev_addr(x.src[1]), x.arg
copied = 0
while copied < copy_size:
step = min(copy_size - copied, self.max_copy_size)
self.q(self.sdma.SDMA_OP_COPY | self.sdma.SDMA_PKT_COPY_LINEAR_HEADER_SUB_OP(self.sdma.SDMA_SUBOP_COPY_LINEAR),
self.sdma.SDMA_PKT_COPY_LINEAR_COUNT_COUNT(step - 1), 0, *data64_le(src + copied), *data64_le(dest + copied))
copied += step
def wait(self, x):
self.q(self.sdma.SDMA_OP_POLL_REGMEM | self.sdma.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) | \
self.sdma.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1), *data64_le(self.get_dev_addr(x.src[0])), x.src[1], 0xffffffff,
self.sdma.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | self.sdma.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff))
def store(self, x):
fence_flags = self.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if self.dev.target[0] != 9 else 0
self.q(self.sdma.SDMA_OP_FENCE | fence_flags, *data64_le(self.get_dev_addr(x.src[0])), x.src[1])
self.q(self.sdma.SDMA_OP_TRAP, 0)
def timestamp(self, x):
self.q(self.sdma.SDMA_OP_TIMESTAMP | self.sdma.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(self.sdma.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL),
*data64_le(self.get_dev_addr(x.src[0])))
def amd_lower_sdma(linear, devs):
enc = AMDCopyQueue(Device[devs[0]])
graph_rewrite(linear.replace(src=tuple(UOp(Ops.LINEAR, dtypes.void, (cmd,)) for cmd in linear.src)), amd_inner_sdma_pm, ctx=enc, name="amd: encode sdma")
return enc.uop(dev=devs if len(devs) > 1 else devs[0], tag="copy")
amd_inner_sdma_pm = PatternMatcher([
(UPat(Ops.LINEAR, src=(UPat(Ops.WAIT, name="x"),)), lambda ctx, x: ctx.wait(x)),
(UPat(Ops.LINEAR, src=(UPat(Ops.BARRIER, name="x"),)), lambda ctx, x: None),
(UPat(Ops.LINEAR, src=(UPat(Ops.COPY, name="x"),)), lambda ctx, x: ctx.copy(x)),
(UPat(Ops.LINEAR, src=(UPat(Ops.CUSTOM_FUNCTION, arg="timestamp", name="x"),)), lambda ctx, x: ctx.timestamp(x)),
(UPat(Ops.LINEAR, src=(UPat(Ops.STORE, src=(UPat((Ops.BUFFER, Ops.PARAM)), UPat()), name="x"),)), lambda ctx, x: ctx.store(x)),
])
def amd_submit_sdma(cmdbuf, devs):
# the cmdbuf to submit + the patch writes that fill it
size_dw, zero = cmdbuf.src[0].arg // dtypes.uint32.itemsize, UOp.const(dtypes.int, 0)
# the sdma queue's ring and its host-side ring/write/put pointers
q = Device['AMD'].sdma_queue(0)
ring, wptr, doorbell, put_ptr = (UOp.new_buffer(devs, b.size, b.dtype).rtag(("SDMA:0", name))
for name, b in (("ring", q.ring), ("write_ptr", q.write_ptr), ("doorbell", q.doorbell), ("put_value", q.put_value)))
# sdma needs the cmdbuf contiguous: if it won't fit before the ring end, restart at 0 and zero the tail
put_b = put_ptr.index(zero)
tail_off_dw = ((put_b % (q.ring.size * 4)) // 4).cast(dtypes.int)
fits = (size_dw <= q.ring.size - tail_off_dw).cast(dtypes.int)
start_dw = fits * tail_off_dw
zero_amt_dw = (1 - fits) * (q.ring.size - tail_off_dw)
# zero the wrapped tail, then copy the cmdbuf into the ring
zi = UOp.range(zero_amt_dw, 0, dtype=dtypes.int, src=(cmdbuf,))
zero_tail = ring.index(tail_off_dw + zi, dtype=ring.dtype.ptr()).store(UOp.const(dtypes.uint32, 0)).end(zi)
i = UOp.range(UOp.const(dtypes.int, size_dw), 0, dtype=dtypes.int, src=(cmdbuf,))
copy_to_ring = ring.index(start_dw + i, dtype=ring.dtype.ptr()).store(
cmdbuf.index(i*4, dtype=cmdbuf.dtype.ptr()).cast(dtypes.uint32.ptr()).load()).end(i)
# advance the put/write pointers past the zeroed tail and the cmdbuf
next_put_b = put_b + ((zero_amt_dw + size_dw) * 4).cast(put_b.dtype)
bump_put_ptr = put_ptr.index(zero, dtype=put_ptr.dtype.ptr()).store(next_put_b)
bump_wptr = wptr.index(zero, dtype=wptr.dtype.ptr()).store(next_put_b)
# ring the doorbell once the writes have landed
flush = UOp.barrier(zero_tail, copy_to_ring, bump_put_ptr, bump_wptr)
return doorbell.after(flush).index(zero, dtype=doorbell.dtype.ptr()).store(next_put_b)
@dataclass(frozen=True)
class AMDProgramData:
entry_point_offset:int; rsrc1:int; rsrc2:int; rsrc3:int; wave32:bool
private_segment_size:int; kernargs_segment_size:int; kernargs_alloc_size:int
enable_dispatch_ptr:int; enable_private_segment_sgpr:int
_amd_program_cache:dict[tuple[bytes,str], tuple[AMDProgramData,bytes]] = {}
def amd_build_program(prg:UOp) -> UOp:
dev = Device[prg.src[1].arg] # TODO: rm this
if (cached:=_amd_program_cache.get(key:=(lib:=prg.src[4].arg, dev.device))) is None:
image, sections, relocs = elf_loader(lib)
rodata = next(sh.header.sh_addr for sh in sections if sh.name == ".rodata")
for off, sym, typ, addent in relocs:
assert typ == 5, f"unknown AMD reloc {typ}" # R_AMDGPU_REL64
image[off:off+8] = struct.pack('<q', sym - off + addent)
desc = amdgpu_kd.llvm_amdhsa_kernel_descriptor_t.from_buffer_copy(bytes(image[rodata:rodata+ctypes.sizeof(amdgpu_kd.llvm_amdhsa_kernel_descriptor_t)]))
if (lds:=((desc.group_segment_fixed_size+511)//512)&0x1FF) > (dev.iface.props['lds_size_in_kb']*1024)//512:
raise RuntimeError("Too many resources requested: group_segment_size")
edp = desc.kernel_code_properties & hsa.AMD_KERNEL_CODE_PROPERTIES_ENABLE_SGPR_DISPATCH_PTR
cached = _amd_program_cache[key] = (AMDProgramData(
entry_point_offset=rodata + desc.kernel_code_entry_byte_offset,
rsrc1=desc.compute_pgm_rsrc1 | ((1<<20) if dev.target[0]==11 else 0), # priv=1 on gfx11 for cwsr
rsrc2=desc.compute_pgm_rsrc2 | (lds<<15), rsrc3=desc.compute_pgm_rsrc3,
wave32=bool(desc.kernel_code_properties & 0x400),
private_segment_size=desc.private_segment_fixed_size,
kernargs_segment_size=desc.kernarg_size,
kernargs_alloc_size=desc.kernarg_size + (ctypes.sizeof(hsa.hsa_kernel_dispatch_packet_t) if edp else 0),
enable_dispatch_ptr=edp,
enable_private_segment_sgpr=desc.kernel_code_properties & hsa.AMD_KERNEL_CODE_PROPERTIES_ENABLE_SGPR_PRIVATE_SEGMENT_BUFFER), bytes(image))
return cached
pm_prep_program = PatternMatcher([
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE, arg="AMD"), UPat(), UPat(), UPat(Ops.BINARY)), name="prg"), amd_build_program),
])
class AMDAllocator(HCQAllocator['AMDDevice']):
def __init__(self, dev:AMDDevice):
super().__init__(dev, supports_copy_from_disk=dev.has_sdma_queue, supports_transfer=dev.has_sdma_queue and not dev.is_usb())
def _alloc(self, size:int, options:BufferSpec) -> HCQ2Buffer:
return self.dev.iface.alloc(size, host=options.host, uncached=options.uncached, cpu_access=options.cpu_access or not self.dev.has_sdma_queue)
def _do_free(self, opaque, options:BufferSpec): self.dev.iface.free(opaque)
def _do_map(self, buf:HCQ2Buffer): return self.dev.iface.map(buf._base if buf._base is not None else buf)
@dataclass
class AMDQueueDesc:
ring: Buffer # uint32[ring_size//4]
read_ptr: Buffer # uint64[1]
write_ptr: Buffer # uint64[1]
doorbell: Buffer # uint64[1]
put_value: Buffer # uint64[1]
params: tuple|None = None # setup_ring params for recovery
class PCIIface(PCIIfaceBase):
def __init__(self, dev, dev_id):
super().__init__(dev, dev_id, vendor=0x1002, devices=((0xffff, (0x74a1,0x744c,0x7480,0x7550,0x7551,0x7590,0x75a0)),), vram_bar=0,
va_start=AMMemoryManager.va_allocator.base, va_size=AMMemoryManager.va_allocator.size, dev_impl_t=AMDev)
self._compute_props()
def p2p_paddrs(self, paddrs:list[tuple[int,int]]) -> tuple[list[tuple[int,int]], AddrSpace]:
return ([(self.dev_impl.paddr2xgmi(p), sz) for p, sz in paddrs], AddrSpace.PEER) if self.dev_impl.is_hive() else super().p2p_paddrs(paddrs)
def require_profile_mode(self): return True
def is_wgp_active(self, xcc, se, sa, wgp) -> bool: return True # TODO: account for WGP disablement on some asics.
def _compute_props(self):
self.ip_versions = self.dev_impl.ip_ver
gfxver = int(f"{self.dev_impl.ip_ver[am.GC_HWIP][0]:02d}{self.dev_impl.ip_ver[am.GC_HWIP][1]:02d}{self.dev_impl.ip_ver[am.GC_HWIP][2]:02d}")
if self.dev_impl.gc_info.header.version_major == 2:
cu_per_sa = self.dev_impl.gc_info.gc_num_cu_per_sh
max_sh_per_se = self.dev_impl.gc_info.gc_num_sh_per_se
else:
cu_per_sa = 2 * (self.dev_impl.gc_info.gc_num_wgp0_per_sa + self.dev_impl.gc_info.gc_num_wgp1_per_sa)
max_sh_per_se = self.dev_impl.gc_info.gc_num_sa_per_se
array_count = max_sh_per_se * self.dev_impl.gc_info.gc_num_se * self.dev_impl.gfx.xccs
self.props = {'cu_per_simd_array': cu_per_sa, 'simd_count': 2 * cu_per_sa * array_count, 'simd_per_cu': 2, 'array_count': array_count,
'max_slots_scratch_cu': self.dev_impl.gc_info.gc_max_scratch_slots_per_cu, 'max_waves_per_simd': self.dev_impl.gc_info.gc_max_waves_per_simd,
'simd_arrays_per_engine': max_sh_per_se, 'lds_size_in_kb': self.dev_impl.gc_info.gc_lds_size, 'num_xcc': self.dev_impl.gfx.xccs,
'gfx_target_version': {90403: 90402}.get(gfxver, gfxver)}
def create_queue(self, queue_type, ring, gart, rptr, wptr, eop_buffer=None, cwsr_buffer=None, ctl_stack_size=0, ctx_save_restore_size=0,
xcc_id=0, idx=0):
assert cwsr_buffer is None, "no cwsr buffer for am"
rcvr_params: tuple
if queue_type == kfd.KFD_IOC_QUEUE_TYPE_SDMA:
doorbell_index = self.dev_impl.sdma.setup_ring(*(rcvr_params:=(ring.va_addr, ring.size, gart.va_addr+rptr, gart.va_addr+wptr, idx)))
else:
doorbell_index = self.dev_impl.gfx.setup_ring(*(rcvr_params:=(ring.va_addr, ring.size, gart.va_addr+rptr, gart.va_addr+wptr,
eop_buffer.va_addr, eop_buffer.size, is_aql:=(queue_type==kfd.KFD_IOC_QUEUE_TYPE_COMPUTE_AQL), is_aql)))
ext = lambda addr,n,dt: Buffer("CPU", n, dt, options=BufferSpec(external_ptr=addr), preallocate=True)
(put_value := Buffer("CPU", 1, dtypes.uint64, preallocate=True))._buf.view.view(fmt='Q')[0] = 0
return AMDQueueDesc(ring=ext(ring.va_addr, ring.size//4, dtypes.uint32),
doorbell=ext(self.dev_impl.doorbell64.addr + doorbell_index*8, 1, dtypes.uint64),
read_ptr=ext(gart.va_addr+rptr, 1, dtypes.uint64), write_ptr=ext(gart.va_addr+wptr, 1, dtypes.uint64),
put_value=put_value, params=rcvr_params)
def _collect_interrupts(self, reset=False, drain_only=False):
d = self.dev
if drain_only: d.iface.dev_impl.ih.drain()
else: d.iface.dev_impl.ih.interrupt_handler()
if reset and d.iface.dev_impl.recover():
cq = d.compute_queue
for b in (cq.put_value, cq.read_ptr, cq.write_ptr): b._buf.view.view(fmt='Q')[0] = 0
d.iface.dev_impl.gfx.setup_ring(*cq.params)
d.timeline_signal._buf.cpu_view().mv.cast('Q')[0] = d.timeline_value.as_memoryview(force_zero_copy=True).cast('Q')[0] - 1
def sleep(self, timeout):
if hasattr(self.pci_dev, 'irq_poller') and self.pci_dev.irq_poller is not None and (events_cnt:=len(self.pci_dev.irq_poller.poll(timeout))):
self.pci_dev.irq_fd.read(8 * events_cnt)
self._collect_interrupts()
if self.dev_impl.is_err_state: raise RuntimeError("Device is in error state")
def on_device_hang(self):
self._collect_interrupts(reset=True)
raise RuntimeError("Device hang detected")
def device_fini(self): self.dev_impl.fini()
def _mock(iface, name=None): return type(name or f"MOCK{iface.__name__}", (iface,), {})
def encode_queue(q:UOp) -> UOp|None:
if not (isinstance(q.arg, tuple) and len(q.arg) == 2 and q.arg[1] in ("COMPUTE", "COPY")): return None
devs = (q.arg[0],) if isinstance(q.arg[0], str) else q.arg[0] # TODO: make this prettier
return amd_submit_pm4(amd_lower_pm4(q, devs), devs) if q.arg[1] == "COMPUTE" else amd_submit_sdma(amd_lower_sdma(q, devs), devs)
pm_lower = PatternMatcher([
(UPat(Ops.LINEAR, name="q"), encode_queue),
])
class AMDDevice(HCQ2Compiled):
timestamp_divider = 100.0 # AMD GPU clock: ticks/us
ifaces = [PCIIface]
def is_am(self) -> bool: return isinstance(self.iface, (PCIIface,))
def is_usb(self) -> bool: return False
def __init__(self, device:str=""):
self.device_id = int(device.split(":")[1]) if ":" in device else 0
self.iface = self._select_iface()
self.target:tuple[int, ...] = ((trgt:=self.iface.props['gfx_target_version']) // 10000, (trgt // 100) % 100, trgt % 100)
self.arch = "gfx%d%x%x" % self.target
assert (self.target in ((9,4,2),(9,5,0))) or self.target[0] in (11, 12), f"Unsupported arch: {self.arch}"
if DEBUG >= 1: print(f"AMDDevice: opening {self.device_id} with target {self.target} arch {self.arch}")
self.xccs = self.iface.props.get('num_xcc', 1)
self.se_cnt = self.iface.props['array_count'] // self.iface.props['simd_arrays_per_engine'] // self.xccs
self.cu_cnt = self.iface.props['simd_count'] // self.iface.props['simd_per_cu'] // self.xccs
self.waves_per_cu = self.iface.props['max_waves_per_simd'] * self.iface.props['simd_per_cu']
self.wave_cnt = (self.cu_cnt * self.waves_per_cu) if self.target[0] != 9 else min(self.cu_cnt * 40, self.se_cnt * self.xccs * 512)
self.ip_off = importlib.import_module(f"tinygrad.runtime.autogen.am.{'vega' if self.target[0] == 9 else 'navi'}_offsets")
self.soc = import_soc(self.target)
self.pm4 = importlib.import_module(f"tinygrad.runtime.autogen.am.pm4_{'soc15' if self.target[0] == 9 else 'nv'}")
self.sdma = import_module('sdma', min(self.iface.ip_versions[am.SDMA0_HWIP], (6, 0, 0)))
self.gc = AMDIP('gc', self.iface.ip_versions[am.GC_HWIP],
bases={i: tuple(getattr(self.ip_off, f'GC_BASE__INST{i}_SEG{s}', 0) for s in range(6)) for i in range(6)})
self.nbio = AMDIP('nbio' if self.target[0] < 12 else 'nbif', self.iface.ip_versions[am.NBIF_HWIP],
bases={i: tuple(getattr(self.ip_off, f'NBIO_BASE__INST{i}_SEG{s}', 0) for s in range(9)) for i in range(6)})
self.is_aql = getenv("AMD_AQL", int(self.xccs > 1))
if self.is_aql:
self.pm4_ibs = self.iface.alloc(0x2000 if self.is_usb() else (16 << 20), uncached=True, cpu_access=True)
self.pm4_ib_alloc = BumpAllocator(self.pm4_ibs.size, wrap=True)
self.max_copy_size = 0x40000000 if self.iface.ip_versions[am.SDMA0_HWIP][0] >= 5 else 0x400000
self.sdma_queues:dict = {}
self.has_sdma_queue = True # self.sdma_queue(0) is not None, TODO: think of this
super().__init__(device, AMDAllocator(self), [HIPRenderer, AMDLLVMRenderer, HIPCCRenderer], None, can_recover=self.is_am(), arch=self.arch)
# Scratch setup
self.max_private_segment_size = 0
self.pm_bufferize = PatternMatcher([(UPat(Ops.BUFFER, tag="scratch", name="b"), lambda ctx, b: ctx.scratch_buffer(b.arg))]) + self.pm_bufferize
self.pmc_enabled:bool = PROFILE > 0 and PMC > 0
if self.pmc_enabled:
self.iface.require_profile_mode()
self.pmc_sched:list[PMCSample] = []
self.pmc_counters = import_pmc(self.target)
# validate counters: SQ for SIMD busy/instruction counts, LDS stats, GRBM for GPU cycles, L2 cache hits/misses
l2, lds = ("TCC", "SQ") if self.target[0] == 9 else ("GL2C", "SQC")
pmc_default = f"SQ_BUSY_CYCLES,SQ_INSTS_VALU,SQ_INSTS_SALU,{lds}_LDS_IDX_ACTIVE,{lds}_LDS_BANK_CONFLICT,GRBM_GUI_ACTIVE,{l2}_HIT,{l2}_MISS"
for k in (PMC_COUNTERS:=getenv("PMC_COUNTERS", pmc_default).split(",")):
if k not in self.pmc_counters: raise RuntimeError(f"PMC counter {k} is not supported. Available: {','.join(self.pmc_counters.keys())}")
raise NotImplementedError("PMC start not migrated to hcq2 yet")
# SQTT is disabled by default because of runtime overhead and big file sizes (~200mb to Tensor.full() two 4096x4096 tensors and matmul them)
self.sqtt_enabled:bool = PROFILE > 0 and SQTT > 0
if self.sqtt_enabled:
self.iface.require_profile_mode()
SQTT_BUFFER_SIZE = getenv("SQTT_BUFFER_SIZE", 256) # in mb, per shader engine
self.sqtt_buffers = [self.allocator.alloc(SQTT_BUFFER_SIZE<<20, BufferSpec(nolru=True, uncached=True)) for _ in range(self.se_cnt * self.xccs)]
self.sqtt_wptrs = self.allocator.alloc(round_up(self.se_cnt * self.xccs * 4, 0x1000), BufferSpec(cpu_access=True, nolru=True))
self.sqtt_next_cmd_id = itertools.count(0)
def create_queue(self, queue_type, ring_size, ctx_save_restore_size=0, eop_buffer_size=0, ctl_stack_size=0, debug_memory_size=0, idx=0):
ring = self.iface.alloc(ring_size, uncached=True, cpu_access=True)
gart = self.iface.alloc(0x100, uncached=True, cpu_access=True)
if queue_type == kfd.KFD_IOC_QUEUE_TYPE_COMPUTE_AQL:
self.aql_gart = gart
self.aql_desc = hsa.amd_queue_t(queue_properties=hsa.AMD_QUEUE_PROPERTIES_IS_PTR64 | hsa.AMD_QUEUE_PROPERTIES_ENABLE_PROFILING,
read_dispatch_id_field_base_byte_offset=getattr(hsa.amd_queue_t, 'read_dispatch_id').offset,
max_cu_id=(self.cu_cnt * self.xccs) - 1, max_wave_id=self.waves_per_cu - 1)
self.aql_gart.cpu_view().view(fmt='B')[:ctypes.sizeof(self.aql_desc)] = bytes(self.aql_desc)
cwsr_buffer_size = round_up((ctx_save_restore_size + debug_memory_size) * self.xccs, mmap.PAGESIZE)
cwsr_buffer = self.iface.alloc(cwsr_buffer_size) if ctx_save_restore_size else None
eop_buffer = self.iface.alloc(eop_buffer_size) if eop_buffer_size else None
queue = (self.iface.create_queue(queue_type, ring, gart, rptr=getattr(hsa.amd_queue_t, 'read_dispatch_id').offset,
wptr=getattr(hsa.amd_queue_t, 'write_dispatch_id').offset, eop_buffer=eop_buffer, cwsr_buffer=cwsr_buffer,
ctx_save_restore_size=ctx_save_restore_size, ctl_stack_size=ctl_stack_size, idx=idx))
qname = f"{'SDMA' if queue_type == kfd.KFD_IOC_QUEUE_TYPE_SDMA else 'COMPUTE'}:{idx}"
self.pm_bufferize = PatternMatcher([
(UPat(Ops.BUFFER, tag={(qname, name)}), lambda ctx, b=getattr(queue, name): b) for name in ["ring", "write_ptr", "doorbell", "put_value"]
]) + self.pm_bufferize
return queue
@functools.cached_property
def compute_queue(self) -> AMDQueueDesc:
# https://gitlab.freedesktop.org/agd5f/linux/-/blob/a1fc9f584c4aaf8bc1ebfa459fc57a3f26a290d8/drivers/gpu/drm/amd/amdkfd/kfd_queue.c#L391
sgrp_size_per_cu, hwreg_size_per_cu = 0x4000, 0x1000
lds_size_per_cu = self.iface.props["lds_size_in_kb"] << 10 if self.target[:2] == (9,5) else 0x10000
vgpr_size_per_cu = 0x60000 if self.target in {(11,0,0), (11,0,1), (11,5,1), (12,0,0), (12,0,1)} else 0x80000 if self.target[0] == 9 else 0x40000
wg_data_size = round_up((vgpr_size_per_cu + sgrp_size_per_cu + lds_size_per_cu + hwreg_size_per_cu) * self.cu_cnt, mmap.PAGESIZE)
ctl_stack_size = round_up((12 if self.target[0] != 9 else 8) * self.wave_cnt + 8 + 40, mmap.PAGESIZE)
return self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_COMPUTE_AQL if self.is_aql else kfd.KFD_IOC_QUEUE_TYPE_COMPUTE,
0x2000 if self.is_usb() else (16 << 20), eop_buffer_size=0x1000,
ctx_save_restore_size=0 if self.is_am() else wg_data_size + ctl_stack_size, ctl_stack_size=ctl_stack_size,
debug_memory_size=round_up(self.wave_cnt * 32, 64))
def sdma_queue(self, idx:int):
if getenv("AMD_DISABLE_SDMA"): return None
if idx in self.sdma_queues: return self.sdma_queues[idx]
with contextlib.suppress(OSError):
self.sdma_queues[idx] = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_SDMA, 0x200 if self.is_usb() else (16 << 20), idx=idx)
return self.sdma_queues.get(idx, None)
def tmpring_size(self, private_segment_size):
private_segment_size = max(private_segment_size, 128)
lanes_per_wave = 64 # wave64
mem_alignment_size = 256 if self.target[0] != 9 else 1024
size_per_thread = round_up(private_segment_size, mem_alignment_size // lanes_per_wave)
size_per_xcc = size_per_thread * lanes_per_wave * self.iface.props['max_slots_scratch_cu'] * self.cu_cnt
# NOTE: xcc logic is correct only for GFX9.
max_scratch_waves = self.cu_cnt * self.iface.props['max_slots_scratch_cu'] * self.xccs
wave_scratch = ceildiv(lanes_per_wave * size_per_thread, mem_alignment_size)
num_waves = (size_per_xcc // (wave_scratch * mem_alignment_size)) // (self.se_cnt if self.target[0] != 9 else 1)
tmpring_t = getattr(hsa, f'union_COMPUTE_TMPRING_SIZE{"_GFX"+str(self.target[0]) if self.target[0] != 9 else ""}_bitfields')
tmpring = int.from_bytes(tmpring_t(WAVES=min(num_waves, max_scratch_waves), WAVESIZE=wave_scratch), 'little')
if hasattr(self, 'aql_desc'):
gfx9_rsrc = {'NUM_FORMAT':hsa.BUF_NUM_FORMAT_UINT, 'DATA_FORMAT':hsa.BUF_DATA_FORMAT_32, 'ELEMENT_SIZE':1, 'INDEX_STRIDE':3}
rsrc = {'DST_SEL_X':hsa.SQ_SEL_X, 'DST_SEL_Y':hsa.SQ_SEL_Y, 'DST_SEL_Z':hsa.SQ_SEL_Z, 'DST_SEL_W':hsa.SQ_SEL_W, 'ADD_TID_ENABLE':1,
'TYPE':hsa.SQ_RSRC_BUF, **(gfx9_rsrc if self.target[0] == 9 else {'FORMAT':hsa.BUF_FORMAT_32_UINT, 'OOB_SELECT':2})}
rsrc1_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD1{"_GFX11" if self.target[0] != 9 else ""}_bitfields')
rsrc3_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD3{"_GFX"+str(self.target[0]) if self.target[0] != 9 else ""}_bitfields')
self.aql_desc.scratch_backing_memory_location = int(self.scratch.get_buf().va_addr)
self.aql_desc.scratch_wave64_lane_byte_size = self.max_private_segment_size * lanes_per_wave // 64
self.aql_desc.scratch_resource_descriptor[:] = [lo32(self.scratch.get_buf().va_addr),
int.from_bytes(rsrc1_t(BASE_ADDRESS_HI=hi32(self.scratch.get_buf().va_addr), SWIZZLE_ENABLE=1), 'little'),
lo32(size_per_xcc), int.from_bytes(bytes(rsrc3_t(**rsrc)), 'little')]
self.aql_desc.compute_tmpring_size = tmpring
self.aql_gart.cpu_view()[:ctypes.sizeof(self.aql_desc)] = bytes(self.aql_desc)
return tmpring
def scratch_buffer(self, private_segment_size):
private_segment_size = max(private_segment_size, 128)
if self.max_private_segment_size < private_segment_size:
lanes_per_wave = 64 # wave64
mem_alignment_size = 256 if self.target[0] != 9 else 1024
size_per_thread = round_up(private_segment_size, mem_alignment_size // lanes_per_wave)
size_per_xcc = size_per_thread * lanes_per_wave * self.iface.props['max_slots_scratch_cu'] * self.cu_cnt
self.scratch = Buffer(self.device, size_per_xcc * self.xccs, dtypes.uint8, options=BufferSpec(nolru=True), preallocate=True)
self.max_private_segment_size = private_segment_size
return self.scratch
def on_device_hang(self): self.iface.on_device_hang()
def device_props(self): return self.iface.props