am: only support CDNA3/4 and RDNA3/4 (#16017)

This commit is contained in:
Christopher Milan
2026-05-01 21:02:14 -07:00
committed by GitHub
parent 4a2e1f1076
commit 637bdd5530
3 changed files with 33 additions and 43 deletions

View File

@@ -5,7 +5,7 @@ tinygrad supports various runtimes, enabling your code to scale across a wide ra
| Runtime | Description | Compiler Options | Requirements |
|---------|-------------|------------------|--------------|
| [NV](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_nv.py) | Provides acceleration for NVIDIA GPUs | nvrtc (default)<br>PTX (`DEV=NV:PTX`) | Ampere/Ada/Blackwell series GPUs.<br>You can select an interface via [the `DEV` variable](env_vars.md#dev-variable). See [NV interfaces](#nv-interfaces) for details. |
| [AMD](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_amd.py) | Provides acceleration for AMD GPUs | LLVM (`DEV=AMD:LLVM`)<br>HIP/COMGR (`DEV=AMD:HIP`) | RDNA2 or newer GPUs.<br>You can select an interface via [the `DEV` variable](env_vars.md#dev-variable). See [AMD interfaces](#amd-interfaces) for details. |
| [AMD](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_amd.py) | Provides acceleration for AMD GPUs | LLVM (`DEV=AMD:LLVM`)<br>HIP/COMGR (`DEV=AMD:HIP`) | CDNA3, CDNA4, RDNA3 or RDNA4 GPUs.<br>You can select an interface via [the `DEV` variable](env_vars.md#dev-variable). See [AMD interfaces](#amd-interfaces) for details. |
| [QCOM](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_qcom.py) | Provides acceleration for QCOM GPUs | - | 6xx series GPUs |
| [METAL](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_metal.py) | Utilizes Metal for acceleration on Apple devices | - | M1+ Macs; Metal 3.0+ for `bfloat` support |
| [CUDA](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cuda.py) | Utilizes CUDA for acceleration on NVIDIA GPUs | nvrtc (default)<br> PTX (`DEV=CUDA:PTX`) | NVIDIA GPU with CUDA support |

View File

@@ -15,7 +15,7 @@ from tinygrad.runtime.autogen import kfd, hsa, sqtt, amdgpu_kd, amdgpu_drm
from tinygrad.runtime.autogen.am import am
from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.support.am.amdev import AMDev, AMMemoryManager
from tinygrad.runtime.support.amd import AMDReg, AMDIP, import_module, import_soc, import_ip_offsets, import_pmc
from tinygrad.runtime.support.amd import AMDReg, AMDIP, import_module, import_soc, import_pmc
from tinygrad.runtime.support.system import System, PCIIfaceBase, PCIAllocationMeta, USBPCIDevice, MAP_FIXED, MAP_NORESERVE
from tinygrad.runtime.support.usb import USB3
from tinygrad.runtime.support.memory import AddrSpace
@@ -90,7 +90,7 @@ class AMDComputeQueue(HWQueue):
return self
def acquire_mem(self, addr=0x0, sz=(1 << 64)-1, gli=1, glm=1, glk=1, glv=1, gl1=1, gl2=1):
if self.dev.target >= (10,0,0):
if self.dev.target[0] != 9:
cache_flags_dw = self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLI_INV(gli) \
| self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLM_INV(glm) | self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLM_WB(glm) \
| self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLK_INV(glk) | self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLK_WB(glk) \
@@ -108,7 +108,7 @@ class AMDComputeQueue(HWQueue):
return self
def release_mem(self, address=0x0, value=0, data_sel=0, int_sel=2, ctxid=0, cache_flush=False):
if self.dev.target >= (10,0,0):
if self.dev.target[0] != 9:
cache_flags_dw = 0 if not cache_flush else (self.pm4.PACKET3_RELEASE_MEM_GCR_GLV_INV | self.pm4.PACKET3_RELEASE_MEM_GCR_GL1_INV \
| self.pm4.PACKET3_RELEASE_MEM_GCR_GL2_INV | self.pm4.PACKET3_RELEASE_MEM_GCR_GLM_WB \
| self.pm4.PACKET3_RELEASE_MEM_GCR_GLM_INV | self.pm4.PACKET3_RELEASE_MEM_GCR_GL2_WB | self.pm4.PACKET3_RELEASE_MEM_GCR_SEQ)
@@ -151,7 +151,7 @@ class AMDComputeQueue(HWQueue):
def pmc_start(self, counters):
self.pmc_reset_counters(en=False)
self.wreg(self.gc.regSQ_PERFCOUNTER_CTRL, cs_en=1, ps_en=1, gs_en=1, hs_en=1, **({'vmid_mask':0xffff} if (gfx9:=self.dev.target[0] == 9) else {}))
if self.dev.target[0] >= 11: self.wreg(self.gc.regSQ_PERFCOUNTER_CTRL2, force_en=1, vmid_en=0xffff)
if not gfx9: self.wreg(self.gc.regSQ_PERFCOUNTER_CTRL2, force_en=1, vmid_en=0xffff)
end_off = 0
block2pid:dict[str, itertools.count] = collections.defaultdict(lambda: itertools.count())
@@ -163,7 +163,7 @@ class AMDComputeQueue(HWQueue):
# gfx11+ and later require even-numbered SQ *_SELECT registers
regsample = f'reg{block}_PERFCOUNTER{(pcid:=next(block2pid[block]))}'
if (regsel:=getattr(self.gc, (f'reg{block}_PERFCOUNTER{(pcid*2) if self.dev.target[0]>=11 and block=="SQ" else pcid}_SELECT'), None)) is None:
if (regsel:=getattr(self.gc, (f'reg{block}_PERFCOUNTER{(pcid*2) if not gfx9 and block=="SQ" else pcid}_SELECT'), None)) is None:
raise RuntimeError(f'{block} is out of perfcounter registers: ({regsample} is not found)')
self.wreg(regsel, perf_sel=idx, **({'simd_mask':0xf, 'sqc_bank_mask':0xf, 'sqc_client_mask':0xf} if gfx9 and block == "SQ" else {}))
@@ -206,7 +206,7 @@ class AMDComputeQueue(HWQueue):
cu_per_se = prod([x if isinstance(x, int) else 1 for x in global_size]) // ((self.dev.cu_cnt // self.dev.se_cnt) * 4)
for xcc in range(self.dev.xccs):
with self.pred_exec(xcc_mask=1 << xcc):
for i in range(8 if prg.dev.target >= (11,0,0) else 4):
for i in range(8 if prg.dev.target[0] != 9 else 4):
if SQTT_LIMIT_SE > 1: mask = 1 if SQTT_ITRACE_SE_MASK.value & (1 << i) else 0 # only run unmasked shader engines
else:
sa_mask = (1 << (self.dev.iface.props['cu_per_simd_array'] // 2)) - 1
@@ -282,7 +282,7 @@ class AMDComputeQueue(HWQueue):
self.sqtt_config(tracing=True)
self.set_grbm()
if self.dev.target[0] > 9: self.wreg(self.gc.regCOMPUTE_THREAD_TRACE_ENABLE, 1)
if self.dev.target[0] != 9: self.wreg(self.gc.regCOMPUTE_THREAD_TRACE_ENABLE, 1)
self.memory_barrier()
return self
@@ -303,7 +303,7 @@ class AMDComputeQueue(HWQueue):
self.set_grbm(se=se % self.dev.se_cnt, sh=0)
regstatus = self.gc.regSQ_THREAD_TRACE_STATUS.addr[0] - (self.pm4.PACKET3_SET_UCONFIG_REG_START if self.dev.target[0] == 9 else 0)
if self.dev.target >= (10,0,0):
if self.dev.target[0] != 9:
self.wait_reg_mem(reg=regstatus, mask=self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('finish_pending'), op=WAIT_REG_MEM_FUNCTION_EQ, value=0)
self.sqtt_config(tracing=False)
self.wait_reg_mem(reg=regstatus, mask=self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('busy'), op=WAIT_REG_MEM_FUNCTION_EQ, value=0)
@@ -313,7 +313,7 @@ class AMDComputeQueue(HWQueue):
self.pkt3(self.pm4.PACKET3_COPY_DATA, 1 << 20 | 2 << 8 | 4, self.gc.regSQ_THREAD_TRACE_WPTR.addr[0], 0, *data64_le(wptrs.va_addr+(se*4)))
self.set_grbm()
if self.dev.target[0] > 9: self.spi_config(tracing=False)
if self.dev.target[0] != 9: self.spi_config(tracing=False)
self.memory_barrier()
return self
@@ -348,22 +348,20 @@ class AMDComputeQueue(HWQueue):
self.wreg(self.gc.regCOMPUTE_PGM_RSRC3, prg.rsrc3)
self.wreg(self.gc.regCOMPUTE_TMPRING_SIZE, prg.dev.tmpring_size)
if prg.dev.has_scratch_base_registers:
for xcc_id in range(self.dev.xccs):
with self.pred_exec(xcc_mask=1<<xcc_id):
scratch_base = prg.dev.scratch.va_addr + (prg.dev.scratch.size // self.dev.xccs * xcc_id)
self.wreg(self.gc.regCOMPUTE_DISPATCH_SCRATCH_BASE_LO, *data64_le(scratch_base >> 8))
if (10,0,0) <= prg.dev.target < (11,0,0): self.wreg(self.gc.mmCP_COHER_START_DELAY, 0x20)
# this is what llvm refers to as "architected flat scratch"
for xcc_id in range(self.dev.xccs):
with self.pred_exec(xcc_mask=1<<xcc_id):
scratch_base = prg.dev.scratch.va_addr + (prg.dev.scratch.size // self.dev.xccs * xcc_id)
self.wreg(self.gc.regCOMPUTE_DISPATCH_SCRATCH_BASE_LO, *data64_le(scratch_base >> 8))
self.wreg(self.gc.regCOMPUTE_RESTART_X, 0, 0, 0)
self.wreg(self.gc.regCOMPUTE_USER_DATA_0, *user_regs)
self.wreg(self.gc.regCOMPUTE_RESOURCE_LIMITS, waves_per_sh=getenv("WAVES_PER_SH"))
self.wreg(self.gc.regCOMPUTE_START_X, 0, 0, 0, *local_size, 0, 0)
gfx10p = {'cs_w32_en': int(prg.wave32)} if prg.dev.target >= (10,0,0) else {}
self.pkt3(self.pm4.PACKET3_DISPATCH_DIRECT, *global_size,
self.gc.regCOMPUTE_DISPATCH_INITIATOR.encode(**gfx10p, force_start_at_000=1, compute_shader_en=1))
self.gc.regCOMPUTE_DISPATCH_INITIATOR.encode(**({'cs_w32_en': int(prg.wave32)} if prg.dev.target[0] != 9 else {}),
force_start_at_000=1, compute_shader_en=1))
if prg.dev.sqtt_enabled: self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.THREAD_TRACE_MARKER) | self.pm4.EVENT_INDEX(0))
self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.CS_PARTIAL_FLUSH) | self.pm4.EVENT_INDEX(EVENT_INDEX_PARTIAL_FLUSH))
@@ -486,7 +484,7 @@ class AMDCopyQueue(HWQueue):
return self
def signal(self, signal:AMDSignal, value:sint=0):
fence_flags = self.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if self.dev.target >= (10,0,0) else 0
fence_flags = self.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if self.dev.target[0] != 9 else 0
self.q(self.sdma.SDMA_OP_FENCE | fence_flags, *data64_le(signal.value_addr), value)
if (dev:=signal.owner) is not None and signal.is_timeline and not dev.is_am():
@@ -593,7 +591,7 @@ class AMDProgram(HCQProgram):
self.wave32: bool = desc.kernel_code_properties & 0x400 == 0x400
# Set rsrc1.priv=1 on gfx11 to workaround cwsr.
self.rsrc1: int = desc.compute_pgm_rsrc1 | ((1 << 20) if (11,0,0) <= self.dev.target < (12,0,0) else 0)
self.rsrc1: int = desc.compute_pgm_rsrc1 | ((1 << 20) if self.dev.target[0] == 11 else 0)
self.rsrc2: int = desc.compute_pgm_rsrc2 | (lds_size << 15)
self.rsrc3: int = desc.compute_pgm_rsrc3
self.aql_prog_addr: int = self.lib_gpu.va_addr + rodata_entry
@@ -961,30 +959,26 @@ class AMDDevice(HCQCompiled):
self.target:tuple[int, ...] = ((trgt:=self.iface.props['gfx_target_version']) // 10000, (trgt // 100) % 100, trgt % 100)
self.arch = "gfx%d%x%x" % self.target
if self.target < (9,4,2) or self.target >= (13,0,0): raise RuntimeError(f"Unsupported arch: {self.arch}")
assert (self.target in ((9,4,2),(9,5,0))) or self.target[0] in (11, 12), f"Unsupported arch: {self.arch}"
if DEBUG >= 1: print(f"AMDDevice: opening {self.device_id} with target {self.target} arch {self.arch}")
self.xccs = self.iface.props.get('num_xcc', 1)
self.se_cnt = self.iface.props['array_count'] // self.iface.props['simd_arrays_per_engine'] // self.xccs
self.cu_cnt = self.iface.props['simd_count'] // self.iface.props['simd_per_cu'] // self.xccs
self.waves_per_cu = self.iface.props['max_waves_per_simd'] * self.iface.props['simd_per_cu']
self.wave_cnt = (self.cu_cnt * self.waves_per_cu) if self.target >= (10,1,0) else min(self.cu_cnt * 40, self.se_cnt * self.xccs * 512)
# this is what llvm refers to as "architected flat scratch"
self.has_scratch_base_registers = self.target >= (11,0,0) or self.target in {(9,4,2), (9,5,0)}
self.wave_cnt = (self.cu_cnt * self.waves_per_cu) if self.target[0] != 9 else min(self.cu_cnt * 40, self.se_cnt * self.xccs * 512)
# https://gitlab.freedesktop.org/agd5f/linux/-/blob/a1fc9f584c4aaf8bc1ebfa459fc57a3f26a290d8/drivers/gpu/drm/amd/amdkfd/kfd_queue.c#L391
sgrp_size_per_cu, lds_size_per_cu, hwreg_size_per_cu = 0x4000, 0x10000, 0x1000
if self.target[:2] == (9,5): lds_size_per_cu = self.iface.props["lds_size_in_kb"] << 10
vgpr_size_per_cu = 0x60000 if self.target in {(11,0,0), (11,0,1), (11,5,1), (12,0,0), (12,0,1)} else \
0x80000 if (self.target[:2]) in {(9,4), (9,5)} or self.target in {(9,0,8), (9,0,10)} else 0x40000
sgrp_size_per_cu, hwreg_size_per_cu = 0x4000, 0x1000
lds_size_per_cu = self.iface.props["lds_size_in_kb"] << 10 if self.target[:2] == (9,5) else 0x10000
vgpr_size_per_cu = 0x60000 if self.target in {(11,0,0), (11,0,1), (11,5,1), (12,0,0), (12,0,1)} else 0x80000 if self.target[0] == 9 else 0x40000
wg_data_size = round_up((vgpr_size_per_cu + sgrp_size_per_cu + lds_size_per_cu + hwreg_size_per_cu) * self.cu_cnt, mmap.PAGESIZE)
ctl_stack_size = round_up((12 if self.target >= (10,1,0) else 8) * self.wave_cnt + 8 + 40, mmap.PAGESIZE)
if self.target[0] == 10: ctl_stack_size = min(ctl_stack_size, 0x7000)
ctl_stack_size = round_up((12 if self.target[0] != 9 else 8) * self.wave_cnt + 8 + 40, mmap.PAGESIZE)
debug_memory_size = round_up(self.wave_cnt * 32, 64)
self.ip_off = import_ip_offsets(self.target)
self.ip_off = importlib.import_module(f"tinygrad.runtime.autogen.am.{'vega' if self.target[0] == 9 else 'navi'}_offsets")
self.soc = import_soc(self.target)
self.pm4 = importlib.import_module(f"tinygrad.runtime.autogen.am.pm4_{'nv' if self.target[0] >= 10 else 'soc15'}")
self.pm4 = importlib.import_module(f"tinygrad.runtime.autogen.am.pm4_{'soc15' if self.target[0] == 9 else 'nv'}")
self.sdma = import_module('sdma', min(self.iface.ip_versions[am.SDMA0_HWIP], (6, 0, 0)))
self.gc = AMDIP('gc', self.iface.ip_versions[am.GC_HWIP],
bases={i: tuple(getattr(self.ip_off, f'GC_BASE__INST{i}_SEG{s}', 0) for s in range(6)) for i in range(6)})
@@ -1017,7 +1011,6 @@ class AMDDevice(HCQCompiled):
self.pmc_enabled:bool = PROFILE > 0 and PMC > 0
if self.pmc_enabled:
if self.target[0] not in {9, 11, 12}: raise RuntimeError(f'PMC are not supported on gc:{self.target}')
self.iface.require_profile_mode()
self.pmc_sched:list[PMCSample] = []
@@ -1036,7 +1029,6 @@ class AMDDevice(HCQCompiled):
# SQTT is disabled by default because of runtime overhead and big file sizes (~200mb to Tensor.full() two 4096x4096 tensors and matmul them)
self.sqtt_enabled:bool = PROFILE > 0 and SQTT > 0
if self.sqtt_enabled:
if self.target[0] not in {9, 11, 12}: raise RuntimeError(f'SQ Thread Tracing is not supported on gc:{self.target}')
self.iface.require_profile_mode()
SQTT_BUFFER_SIZE = getenv("SQTT_BUFFER_SIZE", 256) # in mb, per shader engine
@@ -1074,7 +1066,7 @@ class AMDDevice(HCQCompiled):
if self.max_private_segment_size >= private_segment_size: return
lanes_per_wave = 64 # wave64
mem_alignment_size = 256 if self.target >= (11,0,0) else 1024
mem_alignment_size = 256 if self.target[0] != 9 else 1024
size_per_thread = round_up(private_segment_size, mem_alignment_size // lanes_per_wave)
size_per_xcc = size_per_thread * lanes_per_wave * self.iface.props['max_slots_scratch_cu'] * self.cu_cnt
self.scratch, ok = self._realloc(getattr(self, 'scratch', None), size_per_xcc * self.xccs)
@@ -1082,18 +1074,18 @@ class AMDDevice(HCQCompiled):
# NOTE: xcc logic is correct only for GFX9.
max_scratch_waves = self.cu_cnt * self.iface.props['max_slots_scratch_cu'] * self.xccs
wave_scratch = ceildiv(lanes_per_wave * size_per_thread, mem_alignment_size)
num_waves = (size_per_xcc // (wave_scratch * mem_alignment_size)) // (self.se_cnt if self.target >= (11,0,0) else 1)
num_waves = (size_per_xcc // (wave_scratch * mem_alignment_size)) // (self.se_cnt if self.target[0] != 9 else 1)
tmpring_t = getattr(hsa, f'union_COMPUTE_TMPRING_SIZE{"_GFX"+str(self.target[0]) if self.target[0] >= 11 else ""}_bitfields')
tmpring_t = getattr(hsa, f'union_COMPUTE_TMPRING_SIZE{"_GFX"+str(self.target[0]) if self.target[0] != 9 else ""}_bitfields')
self.tmpring_size = int.from_bytes(tmpring_t(WAVES=min(num_waves, max_scratch_waves), WAVESIZE=wave_scratch), 'little')
self.max_private_segment_size = private_segment_size
if hasattr(self, 'aql_desc'):
gfx9_rsrc = {'NUM_FORMAT':hsa.BUF_NUM_FORMAT_UINT, 'DATA_FORMAT':hsa.BUF_DATA_FORMAT_32, 'ELEMENT_SIZE':1, 'INDEX_STRIDE':3}
rsrc = {'DST_SEL_X':hsa.SQ_SEL_X, 'DST_SEL_Y':hsa.SQ_SEL_Y, 'DST_SEL_Z':hsa.SQ_SEL_Z, 'DST_SEL_W':hsa.SQ_SEL_W, 'ADD_TID_ENABLE':1,
'TYPE':hsa.SQ_RSRC_BUF, **(gfx9_rsrc if self.target[0] < 10 else {'FORMAT':hsa.BUF_FORMAT_32_UINT, 'OOB_SELECT':2})}
rsrc1_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD1{"_GFX11" if self.target[0] >= 11 else ""}_bitfields')
rsrc3_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD3{"_GFX"+str(self.target[0]) if self.target[0] >= 10 else ""}_bitfields')
'TYPE':hsa.SQ_RSRC_BUF, **(gfx9_rsrc if self.target[0] == 9 else {'FORMAT':hsa.BUF_FORMAT_32_UINT, 'OOB_SELECT':2})}
rsrc1_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD1{"_GFX11" if self.target[0] != 9 else ""}_bitfields')
rsrc3_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD3{"_GFX"+str(self.target[0]) if self.target[0] != 9 else ""}_bitfields')
self.aql_desc.scratch_backing_memory_location = int(self.scratch.va_addr)
self.aql_desc.scratch_wave64_lane_byte_size = self.max_private_segment_size * lanes_per_wave // 64

View File

@@ -62,8 +62,6 @@ def import_soc(ip):
# rocm soc headers have more profiling enums than upstream linux
return type("SOC", (object,), import_header(f"aqlprofile/linux/{({9: 'vega10', 10: 'navi10', 11: 'soc21', 12: 'soc24'}[ip[0]])}_enum.h", ROCM_URL))
def import_ip_offsets(ip): return getattr(tinygrad.runtime.autogen.am, f"{'navi' if ip[0] > 9 else 'vega'}_offsets")
def import_pmc(ip) -> dict[str, tuple[str, int]]:
res:dict[str, tuple[str, int]] = {}