mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 16:37:04 +08:00
* get_runner -> get_runtime * do not use get_runner * fix * remove get_tunner * remove * fix * x
2272 lines
135 KiB
Python
2272 lines
135 KiB
Python
# RDNA3 emulator v2 - compiles pcode to UOps executed via tinygrad CPU backend
|
||
# Each instruction is compiled to a kernel that operates on buffers:
|
||
# arg=0: sgpr - sgpr[0-127], inline constants[128-255], PC_LO=256, PC_HI=257, SCC=258, SCRATCH_STRIDE=259
|
||
# arg=1: vgpr - vgpr[reg * 32 + lane]
|
||
# arg=2: vmem - base address 0, INDEX offsets directly to host memory
|
||
# arg=3: lds - local data share
|
||
# arg=4: scratch - per-lane scratch memory
|
||
from __future__ import annotations
|
||
import ctypes, functools, re, platform, subprocess, tempfile
|
||
from typing import Callable
|
||
|
||
# Set/restore DAZ+FTZ (denormals-are-zero + flush-to-zero) to match RDNA3 default float mode
|
||
# x86: MXCSR bits DAZ(6)+FTZ(15), ARM64: FPCR bit FZ(24)
|
||
# Only applied during emulator execution, restored afterward to avoid breaking hypothesis tests
|
||
@functools.cache
|
||
def _get_ftz_lib():
|
||
machine = platform.machine()
|
||
if machine in ('x86_64', 'AMD64'):
|
||
src = b'''
|
||
unsigned int get_fpcr(void){unsigned int m;__asm__ __volatile__("stmxcsr %0":"=m"(m));return m;}
|
||
void set_fpcr(unsigned int m){__asm__ __volatile__("ldmxcsr %0"::"m"(m));}
|
||
'''
|
||
ftz_bits = 0x8040 # DAZ (bit 6) + FTZ (bit 15)
|
||
elif machine in ('arm64', 'aarch64'):
|
||
src = b'''
|
||
unsigned int get_fpcr(void){unsigned long long v;__asm__ __volatile__("mrs %0,fpcr":"=r"(v));return(unsigned int)v;}
|
||
void set_fpcr(unsigned int m){unsigned long long v=m;__asm__ __volatile__("msr fpcr,%0"::"r"(v));}
|
||
'''
|
||
ftz_bits = 1 << 24 # FZ (bit 24)
|
||
else: return None, 0
|
||
try:
|
||
with tempfile.NamedTemporaryFile(suffix='.so', delete=False) as f:
|
||
subprocess.check_output(['clang', '-shared', '-O2', '-x', 'c', '-', '-o', f.name], input=src)
|
||
lib = ctypes.CDLL(f.name)
|
||
lib.get_fpcr.restype = ctypes.c_uint32
|
||
lib.set_fpcr.argtypes = [ctypes.c_uint32]
|
||
return lib, ftz_bits
|
||
except Exception: return None, 0
|
||
|
||
class _MXCSRContext:
|
||
"""Context manager to set DAZ+FTZ during emulator execution and restore afterward."""
|
||
__slots__ = ('_saved',)
|
||
def __enter__(self):
|
||
lib, ftz_bits = _get_ftz_lib()
|
||
if lib is None: return self
|
||
self._saved = lib.get_fpcr()
|
||
lib.set_fpcr(self._saved | ftz_bits)
|
||
return self
|
||
def __exit__(self, *args):
|
||
lib, _ = _get_ftz_lib()
|
||
if lib is None or not hasattr(self, '_saved'): return
|
||
lib.set_fpcr(self._saved)
|
||
|
||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
||
from tinygrad.dtype import dtypes, AddrSpace
|
||
from tinygrad.device import Buffer, BufferSpec, Device
|
||
from tinygrad.runtime.autogen import hsa
|
||
from tinygrad.helpers import Context, DEBUG, PROFILE, colored
|
||
from tinygrad.engine.realize import get_runtime
|
||
from tinygrad.codegen import to_program
|
||
|
||
from tinygrad.renderer.amd import decode_inst
|
||
from tinygrad.runtime.autogen.amd.rdna3.str_pcode import PCODE as PCODE_RDNA3
|
||
from tinygrad.runtime.autogen.amd.rdna4.str_pcode import PCODE as PCODE_RDNA4
|
||
from tinygrad.runtime.autogen.amd.cdna.str_pcode import PCODE as PCODE_CDNA
|
||
from tinygrad.runtime.autogen.amd.rdna3 import ins as ir3
|
||
from tinygrad.runtime.autogen.amd.rdna4 import ins as ir4
|
||
from tinygrad.runtime.autogen.amd.cdna import ins as irc
|
||
from tinygrad.renderer.amd.dsl import VCC_LO, EXEC_LO, SCC, ttmp
|
||
from tinygrad.runtime.autogen.amd.common import Fmt, OpType
|
||
from test.amd.helpers import decode_dpp16
|
||
from test.mockgpu.amd.pcode import parse_block, _FUNCS, _set_bits, _val_to_bits
|
||
|
||
MASK32 = 0xFFFFFFFF
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# SQTT TRACE COLLECTION
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
# Global trace storage: populated by run_asm as raw SQTT blobs, consumed by amdgpu.py
|
||
sqtt_traces: list[bytes] = []
|
||
|
||
# Encoder primitives
|
||
from tinygrad.renderer.amd.sqtt import _build_decode_tables, PACKET_TYPES_RDNA3, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, IMMEDIATE, VALUINST, InstOp
|
||
|
||
_NIB_COUNTS: dict = {cls: nc for _, (cls, nc, *_) in _build_decode_tables(PACKET_TYPES_RDNA3)[0].items()}
|
||
|
||
def _encode_raw(pkt_cls, **kwargs) -> tuple[int, int]:
|
||
raw = pkt_cls.encoding.default
|
||
for k, v in kwargs.items(): raw = pkt_cls.__dict__[k].set(raw, v)
|
||
return raw, _NIB_COUNTS[pkt_cls]
|
||
|
||
def _emit_nibbles(nibbles: list[int], pkt_cls, **kwargs):
|
||
raw, nc = _encode_raw(pkt_cls, **kwargs)
|
||
for i in range(nc): nibbles.append((raw >> (i * 4)) & 0xF)
|
||
|
||
def _nibbles_to_bytes(nibbles: list[int]) -> bytes:
|
||
result = bytearray()
|
||
for i in range(0, len(nibbles), 2): result.append(nibbles[i] | ((nibbles[i + 1] if i + 1 < len(nibbles) else 0) << 4))
|
||
return bytes(result)
|
||
|
||
def _init_sqtt_encoder():
|
||
"""Initialize and return SQTT encoder state. Called once per dispatch with tracing enabled."""
|
||
from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp as SOPPOp3
|
||
from tinygrad.runtime.autogen.amd.rdna4.enum import SOPPOp as SOPPOp4
|
||
import re
|
||
|
||
_SOPP = (ir3.SOPP, ir4.SOPP, irc.SOPP)
|
||
_SMEM = (ir3.SMEM, ir4.SMEM, irc.SMEM)
|
||
_VALU = (ir3.VOP1, ir3.VOP2, ir3.VOP3, ir3.VOP3P, ir3.VOPC, ir3.VOPD, ir3.VOP3SD, ir3.VOP3_SDST, ir3.VOP1_SDST,
|
||
ir4.VOP1, ir4.VOP2, ir4.VOP3, ir4.VOP3P, ir4.VOPC, ir4.VOPD, ir4.VOP3SD, ir4.VOP3_SDST, ir4.VOP1_SDST,
|
||
irc.VOP1, irc.VOP2, irc.VOP3, irc.VOP3P, irc.VOPC, irc.VOP3SD, irc.VOP3_SDST)
|
||
_DS = (ir3.DS, ir4.DS, irc.DS)
|
||
_GLOBAL = (ir3.GLOBAL, ir4.VGLOBAL, irc.GLOBAL)
|
||
_FLAT = (ir3.FLAT, ir4.VFLAT, irc.FLAT)
|
||
_SCRATCH = (ir3.SCRATCH, ir4.VSCRATCH, irc.SCRATCH)
|
||
|
||
# SOPP classification sets
|
||
_SOPP_SKIP = {SOPPOp3.S_ENDPGM.value, SOPPOp3.S_ENDPGM_SAVED.value, SOPPOp3.S_ENDPGM_ORDERED_PS_DONE.value,
|
||
SOPPOp3.S_DELAY_ALU.value}
|
||
_SOPP_IMMEDIATE = {SOPPOp3.S_NOP.value, SOPPOp3.S_CLAUSE.value, SOPPOp3.S_WAITCNT.value, SOPPOp3.S_WAITCNT_DEPCTR.value,
|
||
SOPPOp3.S_WAIT_IDLE.value, SOPPOp3.S_WAIT_EVENT.value, SOPPOp3.S_SLEEP.value,
|
||
SOPPOp3.S_SET_INST_PREFETCH_DISTANCE.value}
|
||
for _op in (SOPPOp4.S_WAIT_ALU, SOPPOp4.S_WAIT_LOADCNT, SOPPOp4.S_WAIT_STORECNT, SOPPOp4.S_WAIT_SAMPLECNT,
|
||
SOPPOp4.S_WAIT_BVHCNT, SOPPOp4.S_WAIT_EXPCNT, SOPPOp4.S_WAIT_DSCNT, SOPPOp4.S_WAIT_KMCNT,
|
||
SOPPOp4.S_WAIT_LOADCNT_DSCNT, SOPPOp4.S_WAIT_STORECNT_DSCNT):
|
||
_SOPP_IMMEDIATE.add(_op.value)
|
||
_SOPP_BARRIER = {SOPPOp3.S_BARRIER.value}
|
||
if hasattr(SOPPOp4, 'S_BARRIER_WAIT'): _SOPP_BARRIER.add(SOPPOp4.S_BARRIER_WAIT.value)
|
||
if hasattr(SOPPOp4, 'S_BARRIER_LEAVE'): _SOPP_BARRIER.add(SOPPOp4.S_BARRIER_LEAVE.value)
|
||
_SOPP_BRANCH = {SOPPOp3.S_BRANCH.value, SOPPOp3.S_CBRANCH_SCC0.value, SOPPOp3.S_CBRANCH_SCC1.value,
|
||
SOPPOp3.S_CBRANCH_VCCZ.value, SOPPOp3.S_CBRANCH_VCCNZ.value,
|
||
SOPPOp3.S_CBRANCH_EXECZ.value, SOPPOp3.S_CBRANCH_EXECNZ.value}
|
||
|
||
# VALU sub-classification patterns
|
||
_VALUT_4_RE = re.compile(r'V_(EXP|LOG|RCP|RSQ|SQRT|SIN|COS|CEIL|FLOOR|TRUNC|RNDNE|FRACT|FREXP)_')
|
||
_VALUB_2_RE = re.compile(r'V_(LSHLREV|LSHRREV|ASHRREV)_(B|I)64')
|
||
_VALUB_4_RE = re.compile(r'V_MAD_(U|I)64')
|
||
_VALUB_16_RE = re.compile(r'V_\w+_F64')
|
||
|
||
def _valu_op(op_name: str) -> InstOp|None:
|
||
if 'CMPX' in op_name: return InstOp.VALU1_WR_EXEC
|
||
if _VALUB_2_RE.search(op_name): return InstOp.VALUB_2
|
||
if _VALUB_4_RE.search(op_name): return InstOp.VALUB_4
|
||
if _VALUB_16_RE.search(op_name): return InstOp.VALUB_16
|
||
if _VALUT_4_RE.search(op_name): return InstOp.VALUT_4
|
||
return None
|
||
|
||
def _mem_op(t, op_name: str) -> InstOp:
|
||
is_store = "STORE" in op_name
|
||
if issubclass(t, _DS): return InstOp.LDS_WR_2 if is_store else InstOp.LDS_RD
|
||
if issubclass(t, _GLOBAL): return InstOp.SGMEM_WR_2 if is_store else InstOp.SGMEM_RD_1
|
||
if issubclass(t, _FLAT): return InstOp.FLAT_WR_3 if is_store else InstOp.FLAT_RD_2
|
||
if issubclass(t, _SCRATCH): return InstOp.FLAT_WR_3 if is_store else InstOp.FLAT_RD_2
|
||
return InstOp.SALU
|
||
|
||
nibbles: list[int] = []
|
||
started: set[int] = set()
|
||
_emit_nibbles(nibbles, LAYOUT_HEADER, layout=3, sel_a=6)
|
||
|
||
def emit(wave_id: int, inst, branch_taken: bool|None):
|
||
"""Emit an SQTT packet for one executed instruction."""
|
||
w = wave_id & 0x1F
|
||
if wave_id not in started:
|
||
_emit_nibbles(nibbles, WAVESTART, delta=1, simd=0, wgp=0, wave=w, id7=wave_id)
|
||
started.add(wave_id)
|
||
inst_type, inst_op, op_name = type(inst), inst.op.value if hasattr(inst, 'op') else 0, inst.op.name if hasattr(inst, 'op') else ""
|
||
if issubclass(inst_type, _SOPP):
|
||
if inst_op in _SOPP_SKIP: return
|
||
elif inst_op in _SOPP_IMMEDIATE: _emit_nibbles(nibbles, IMMEDIATE, delta=1, wave=w)
|
||
elif inst_op in _SOPP_BARRIER: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.BARRIER)
|
||
elif inst_op in _SOPP_BRANCH:
|
||
_emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.JUMP if branch_taken else InstOp.JUMP_NO)
|
||
else: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.SALU)
|
||
elif issubclass(inst_type, _VALU):
|
||
op = _valu_op(op_name)
|
||
if op is None: _emit_nibbles(nibbles, VALUINST, delta=1, wave=w)
|
||
else: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=op)
|
||
elif issubclass(inst_type, _SMEM): _emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.SMEM_RD)
|
||
else: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=_mem_op(inst_type, op_name))
|
||
|
||
def finish(wave_id: int):
|
||
"""Emit WAVEEND for a completed wave."""
|
||
if wave_id in started: _emit_nibbles(nibbles, WAVEEND, delta=1, simd=0, wgp=0, wave=wave_id & 0x1F)
|
||
|
||
def finalize() -> bytes:
|
||
"""Pad and return the encoded SQTT blob."""
|
||
while len(nibbles) % 2 != 0: nibbles.append(0)
|
||
nibbles.extend([0] * 32)
|
||
while len(nibbles) % 64 != 0: nibbles.append(0)
|
||
return _nibbles_to_bytes(nibbles)
|
||
|
||
return emit, finish, finalize
|
||
|
||
def _c(val, dtype=dtypes.uint32): return UOp.const(dtype, val)
|
||
|
||
def _u64(lo: UOp, hi: UOp) -> UOp:
|
||
"""Combine two 32-bit UOps into a 64-bit UOp."""
|
||
return lo.cast(dtypes.uint64) | (hi.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
|
||
|
||
def _split64(val: UOp) -> tuple[UOp, UOp]:
|
||
"""Split a 64-bit value into (lo, hi) 32-bit values."""
|
||
v64 = val.bitcast(dtypes.uint64) if val.dtype == dtypes.float64 else val.cast(dtypes.uint64) if val.dtype != dtypes.uint64 else val
|
||
return v64.cast(dtypes.uint32), (v64 >> UOp.const(dtypes.uint64, 32)).cast(dtypes.uint32)
|
||
|
||
_SRC_MOD_TYPES = {16: (dtypes.uint16, dtypes.half, 0x7FFF), 32: (dtypes.uint32, dtypes.float32, 0x7FFFFFFF),
|
||
64: (dtypes.uint64, dtypes.float64, 0x7FFFFFFFFFFFFFFF)}
|
||
def _apply_src_mods(val: UOp, mod_bit: int, abs_bits: int, neg_bits: int, bits: int = 32) -> UOp:
|
||
"""Apply abs/neg modifiers to source value based on bit width (16, 32, or 64)."""
|
||
if not (abs_bits & (1 << mod_bit)) and not (neg_bits & (1 << mod_bit)): return val
|
||
ut, ft, mask = _SRC_MOD_TYPES[bits]
|
||
fv = val.cast(ut).bitcast(ft) if bits == 16 else val.bitcast(ft) if val.dtype == ut else val
|
||
if abs_bits & (1 << mod_bit): fv = (fv.bitcast(ut) & UOp.const(ut, mask)).bitcast(ft)
|
||
if neg_bits & (1 << mod_bit): fv = fv.neg()
|
||
return fv.bitcast(ut).cast(dtypes.uint32) if bits == 16 else fv.bitcast(ut)
|
||
|
||
# Map VOPD ops to VOP2 ops for pcode lookup (both RDNA3 and RDNA4)
|
||
VOPD_TO_VOP2 = {
|
||
ir3.VOPDOp.V_DUAL_FMAC_F32: ir3.VOP2Op.V_FMAC_F32_E32, ir3.VOPDOp.V_DUAL_MUL_F32: ir3.VOP2Op.V_MUL_F32_E32,
|
||
ir3.VOPDOp.V_DUAL_ADD_F32: ir3.VOP2Op.V_ADD_F32_E32, ir3.VOPDOp.V_DUAL_SUB_F32: ir3.VOP2Op.V_SUB_F32_E32,
|
||
ir3.VOPDOp.V_DUAL_SUBREV_F32: ir3.VOP2Op.V_SUBREV_F32_E32, ir3.VOPDOp.V_DUAL_MAX_F32: ir3.VOP2Op.V_MAX_F32_E32,
|
||
ir3.VOPDOp.V_DUAL_MIN_F32: ir3.VOP2Op.V_MIN_F32_E32, ir3.VOPDOp.V_DUAL_ADD_NC_U32: ir3.VOP2Op.V_ADD_NC_U32_E32,
|
||
ir3.VOPDOp.V_DUAL_LSHLREV_B32: ir3.VOP2Op.V_LSHLREV_B32_E32, ir3.VOPDOp.V_DUAL_AND_B32: ir3.VOP2Op.V_AND_B32_E32,
|
||
ir3.VOPDOp.V_DUAL_MOV_B32: ir3.VOP1Op.V_MOV_B32_E32, ir3.VOPDOp.V_DUAL_CNDMASK_B32: ir3.VOP2Op.V_CNDMASK_B32_E32,
|
||
ir3.VOPDOp.V_DUAL_FMAAK_F32: ir3.VOP2Op.V_FMAAK_F32_E32, ir3.VOPDOp.V_DUAL_FMAMK_F32: ir3.VOP2Op.V_FMAMK_F32_E32,
|
||
ir3.VOPDOp.V_DUAL_DOT2ACC_F32_F16: ir3.VOP2Op.V_DOT2ACC_F32_F16_E32,
|
||
# RDNA4 mappings (same VOP1/VOP2 targets, RDNA4 uses _NUM_ suffix for min/max)
|
||
ir4.VOPDOp.V_DUAL_FMAC_F32: ir3.VOP2Op.V_FMAC_F32_E32, ir4.VOPDOp.V_DUAL_MUL_F32: ir3.VOP2Op.V_MUL_F32_E32,
|
||
ir4.VOPDOp.V_DUAL_ADD_F32: ir3.VOP2Op.V_ADD_F32_E32, ir4.VOPDOp.V_DUAL_SUB_F32: ir3.VOP2Op.V_SUB_F32_E32,
|
||
ir4.VOPDOp.V_DUAL_SUBREV_F32: ir3.VOP2Op.V_SUBREV_F32_E32, ir4.VOPDOp.V_DUAL_MAX_NUM_F32: ir3.VOP2Op.V_MAX_F32_E32,
|
||
ir4.VOPDOp.V_DUAL_MIN_NUM_F32: ir3.VOP2Op.V_MIN_F32_E32, ir4.VOPDOp.V_DUAL_ADD_NC_U32: ir3.VOP2Op.V_ADD_NC_U32_E32,
|
||
ir4.VOPDOp.V_DUAL_LSHLREV_B32: ir3.VOP2Op.V_LSHLREV_B32_E32, ir4.VOPDOp.V_DUAL_AND_B32: ir3.VOP2Op.V_AND_B32_E32,
|
||
ir4.VOPDOp.V_DUAL_MOV_B32: ir3.VOP1Op.V_MOV_B32_E32, ir4.VOPDOp.V_DUAL_CNDMASK_B32: ir3.VOP2Op.V_CNDMASK_B32_E32,
|
||
ir4.VOPDOp.V_DUAL_FMAAK_F32: ir3.VOP2Op.V_FMAAK_F32_E32, ir4.VOPDOp.V_DUAL_FMAMK_F32: ir3.VOP2Op.V_FMAMK_F32_E32,
|
||
ir4.VOPDOp.V_DUAL_DOT2ACC_F32_F16: ir3.VOP2Op.V_DOT2ACC_F32_F16_E32,
|
||
}
|
||
def _wave_size(arch: str) -> int: return 64 if arch.startswith("cdna") else 32
|
||
# Special registers stored after inline constants (256-259)
|
||
PC_LO_IDX, PC_HI_IDX, SCRATCH_STRIDE_IDX = 256, 257, 259
|
||
# SGPR buffer: 0-127 = SGPRs, 128-255 = inline constants, 256-259 = special registers
|
||
SGPR_COUNT = 260
|
||
# Sentinel PC value for s_endpgm
|
||
ENDPGM_PC = 0xFFFFFFFFFFFFFFFF
|
||
|
||
def _op_name(inst) -> str:
|
||
if hasattr(inst, 'opx'): return f"{inst.opx.name}_{inst.opy.name}" # VOPD has opx/opy not op
|
||
return inst.op.name if hasattr(inst.op, 'name') else str(inst.op)
|
||
|
||
def _to_u32(val: UOp) -> UOp:
|
||
if val.dtype == dtypes.uint32: return val
|
||
if val.dtype.itemsize == 4: return val.bitcast(dtypes.uint32) # same size: bitcast (float32->uint32)
|
||
return val.cast(dtypes.uint32) # different size: cast (bool, int16, etc)
|
||
def _lane_active(exec_mask: UOp, lane: UOp) -> UOp:
|
||
if exec_mask.dtype == dtypes.uint64: return ((exec_mask >> lane.cast(dtypes.uint64)) & UOp.const(dtypes.uint64, 1)).ne(UOp.const(dtypes.uint64, 0))
|
||
return ((exec_mask >> lane.cast(dtypes.uint32)) & _c(1)).ne(_c(0))
|
||
def _hi16(v: UOp) -> UOp: return (v >> _c(16)) & _c(0xFFFF)
|
||
def _cond(cond, if_true, if_false):
|
||
"""Select between values based on condition (works with UOp or bool)."""
|
||
return cond.where(if_true, if_false) if isinstance(cond, UOp) else if_true if cond else if_false
|
||
def _cond_hi16(cond, val: UOp) -> UOp: return _cond(cond, _hi16(val), val)
|
||
def _apply_opsel(val: UOp, sel_bit: int, opsel: int) -> UOp: return _hi16(val) if opsel & (1 << sel_bit) else val
|
||
|
||
def _set_lane_bit(old: UOp, lane: UOp, val: UOp, exec_mask: UOp) -> UOp:
|
||
"""Set/clear a single bit in a mask based on lane index, respecting exec mask."""
|
||
if old.dtype in (dtypes.uint64, dtypes.int64):
|
||
dt = dtypes.uint64
|
||
mask = UOp.const(dt, 1) << lane.cast(dt)
|
||
new_bit = _to_u32(val).cast(dt) << lane.cast(dt)
|
||
cleared = old.cast(dt) & (mask ^ UOp.const(dt, 0xFFFFFFFFFFFFFFFF))
|
||
return _lane_active(exec_mask, lane).where(cleared | new_bit, old.cast(dt))
|
||
mask = _c(1) << lane.cast(dtypes.uint32)
|
||
new_bit = _to_u32(val) << lane.cast(dtypes.uint32)
|
||
cleared = old & (mask ^ _c(MASK32))
|
||
return _lane_active(exec_mask, lane).where(cleared | new_bit, old)
|
||
|
||
def _val_to_u32(val: UOp) -> UOp:
|
||
"""Convert any value to uint32 for storage (bitcast floats, cast ints)."""
|
||
if val.dtype == dtypes.uint32: return val
|
||
if val.dtype == dtypes.float32: return val.bitcast(dtypes.uint32)
|
||
if val.dtype == dtypes.half: return val.bitcast(dtypes.uint16).cast(dtypes.uint32)
|
||
if val.dtype in (dtypes.uint16, dtypes.int16): return val.cast(dtypes.uint32)
|
||
return val.cast(dtypes.uint32)
|
||
|
||
_pcode_fixes = {
|
||
'V_DIV_FMAS_F32': ('D0.f32 = 2.0F ** 32 * fma(S0.f32, S1.f32, S2.f32)',
|
||
'D0.f32 = (exponent(S2.f32) > 127) ? (2.0F ** 64 * fma(S0.f32, S1.f32, S2.f32)) : (2.0F ** -64 * fma(S0.f32, S1.f32, S2.f32))'),
|
||
'V_DIV_FMAS_F64': ('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
|
||
'D0.f64 = (exponent(S2.f64) > 1023) ? (2.0 ** 128 * fma(S0.f64, S1.f64, S2.f64)) : (2.0 ** -128 * fma(S0.f64, S1.f64, S2.f64))'),
|
||
'V_DIV_FIXUP_F32': ('D0.f32 = sign_out ? -abs(S0.f32) : abs(S0.f32)',
|
||
'D0.f32 = isNAN(S0.f32) ? (sign_out ? -INF.f32 : +INF.f32) : (sign_out ? -abs(S0.f32) : abs(S0.f32))'),
|
||
'V_DIV_FIXUP_F64': ('D0.f64 = sign_out ? -abs(S0.f64) : abs(S0.f64)',
|
||
'D0.f64 = isNAN(S0.f64) ? (sign_out ? -INF : +INF) : (sign_out ? -abs(S0.f64) : abs(S0.f64))'),
|
||
'V_TRIG_PREOP_F64': ("result = 64'F((1201'B(2.0 / PI)[1200 : 0] << shift.u32) & 1201'0x1fffffffffffff)", "result = trig_preop_result(shift)"),
|
||
}
|
||
|
||
def _get_pcode_dict(op) -> dict:
|
||
"""Return the PCODE dictionary for the given opcode based on its architecture."""
|
||
return PCODE_CDNA if 'cdna' in type(op).__module__ else PCODE_RDNA4 if 'rdna4' in type(op).__module__ else PCODE_RDNA3
|
||
|
||
# Pcode parser
|
||
@functools.cache
|
||
def get_pcode(op) -> str:
|
||
op_name = op.name
|
||
pcode_dict = _get_pcode_dict(op)
|
||
if op not in pcode_dict and op_name.endswith('_E64'):
|
||
# VOP3 ops ending in _E64 may share pcode with VOP1 _E32 equivalents
|
||
import importlib
|
||
enum_mod = importlib.import_module(type(op).__module__)
|
||
vop1_cls = getattr(enum_mod, 'VOP1Op', None)
|
||
e32_name = op_name.replace('_E64', '_E32')
|
||
if vop1_cls and hasattr(vop1_cls, e32_name): op = vop1_cls[e32_name]
|
||
pcode = pcode_dict[op]
|
||
fix_name = op_name.replace('_E64', '').replace('_E32', '')
|
||
if fix_name in _pcode_fixes: pcode = pcode.replace(*_pcode_fixes[fix_name])
|
||
if 'V_DIV_SCALE' in op_name:
|
||
dt, exp_lim, ldexp_val = ('f32', '23', '64') if 'F32' in op_name else ('f64', '52', '128')
|
||
for old, new in [(f'S2.{dt} / S1.{dt} == DENORM.{dt}', f'divWouldBeDenorm(S2.{dt}, S1.{dt})'), (f"1.0 / 64'F(S1.{dt}) == DENORM.f64", '0'),
|
||
(f'1.0 / S1.{dt} == DENORM.{dt}', '0'), (f'S1.{dt} == DENORM.{dt}', f'isDENORM(S1.{dt})'),
|
||
(f'D0.{dt} = NAN.{dt}', f'VCC = 0x1LL;\nD0.{dt} = NAN.{dt}'),
|
||
(f'elsif isDENORM(S1.{dt}) then\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})', f'elsif 1 == 0 then\nD0.{dt} = S0.{dt}'),
|
||
(f'elsif exponent(S2.{dt}) <= {exp_lim} then\n// Numerator is tiny\n'
|
||
f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})',
|
||
f'elsif exponent(S2.{dt}) <= {exp_lim} then\nVCC = 0x1LL;\n'
|
||
f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})'),
|
||
(f'elsif divWouldBeDenorm(S2.{dt}, S1.{dt}) then\nVCC = 0x1LL;\n'
|
||
f'if S0.{dt} == S2.{dt} then\n// Only scale the numerator\n'
|
||
f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif',
|
||
f'elsif divWouldBeDenorm(S2.{dt}, S1.{dt}) then\n'
|
||
f'VCC = 0x1LL;\nD0.{dt} = S0.{dt}'),
|
||
(f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif\nelsif',
|
||
f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nelse\n'
|
||
f'D0.{dt} = S0.{dt}\nendif\nelsif')]:
|
||
pcode = pcode.replace(old, new)
|
||
lines = pcode.rstrip().split('\n')
|
||
for i in range(len(lines) - 1, -1, -1):
|
||
if lines[i].strip() == 'endif':
|
||
lines.insert(i, f'else\nD0.{dt} = S0.{dt}')
|
||
break
|
||
pcode = '\n'.join(lines) + f';\nif isDENORM(S1.{dt}) then\nD0.{dt} = NAN.{dt}\nendif'
|
||
pcode = pcode.replace('VCC = 0x0LL', 'VCC.u64[laneId] = 0').replace('VCC = 0x1LL', 'VCC.u64[laneId] = 1')
|
||
return pcode
|
||
|
||
def parse_pcode(pcode: str, srcs: dict[str, UOp | int] | None = None) -> tuple[dict, list[tuple[str, UOp]]]:
|
||
env: dict = srcs.copy() if srcs else {}
|
||
assigns: list[tuple[str, UOp]] = []
|
||
raw_lines = [l.strip().rstrip(';') for l in pcode.split('\n') if l.strip() and not l.strip().startswith('//')]
|
||
# TODO: pcode.py should tokenize full pcode string instead of line-by-line, then this hack can be removed
|
||
lines: list[str] = []
|
||
for l in raw_lines:
|
||
if lines and re.search(r'(&&|\|\||[&|+\-*/^])\s*$', lines[-1]): lines[-1] = lines[-1] + ' ' + l
|
||
else: lines.append(l)
|
||
_, final, _ = parse_block(lines, 0, env, assigns=assigns)
|
||
sliced = set(d.split('[')[0] for d, _ in assigns if '[' in d)
|
||
for var, val in final.items():
|
||
if var in ['D0', 'S0', 'SCC', 'VCC', 'EXEC', 'PC', 'RETURN_DATA', 'VDATA'] and isinstance(val, UOp):
|
||
if var in sliced and not any(re.match(rf'{var}\.\w+\s*=', l) for l in lines): continue
|
||
for l in lines:
|
||
if (m := re.match(rf'{var}\.(\w+(?:\[\w+\])?)', l)):
|
||
assigns.append((f'{var}.{m.group(1)}', val))
|
||
break
|
||
else: assigns.append((var, val))
|
||
return env, assigns
|
||
|
||
def _write_64bit(val: UOp, wfn, reg_or_addr, is_mem: bool, *args) -> list[UOp]:
|
||
"""Write a 64-bit value as two 32-bit writes. args passed to wfn after reg/addr and lo/hi value."""
|
||
lo, hi = _split64(val)
|
||
incr = 4 if is_mem else 1 # 4 bytes for memory addresses, 1 for register indices
|
||
return [wfn(reg_or_addr, lo, *args), wfn(reg_or_addr + (UOp.const(reg_or_addr.dtype, incr) if isinstance(reg_or_addr, UOp) else incr), hi, *args)]
|
||
|
||
def _write_val(bits: int, val: UOp, wfn, reg_or_addr, *args, is_mem: bool = False) -> list[UOp]:
|
||
"""Write value, splitting 64-bit if needed. bits=64 for 64-bit writes, otherwise 32-bit."""
|
||
return _write_64bit(val, wfn, reg_or_addr, is_mem, *args) if bits == 64 else [wfn(reg_or_addr, _to_u32(val), *args)]
|
||
|
||
def _mem_store(mem: UOp, addr: UOp, val: UOp, active: UOp, addr_bits: int = 32, data_bits: int = 32) -> list[UOp]:
|
||
"""Conditional memory store with sub-word support. Returns list of store UOps."""
|
||
adt = dtypes.uint64 if addr_bits == 64 else dtypes.uint32
|
||
word_addr = addr >> UOp.const(adt, 2)
|
||
idx = mem.index(word_addr.cast(dtypes.int), active)
|
||
if data_bits == 32: return [idx.store(active.where(_to_u32(val), idx))]
|
||
# Sub-word store: read-modify-write with mask
|
||
byte_pos = addr.cast(dtypes.uint32) & _c(3)
|
||
byte_shift = byte_pos * _c(8)
|
||
val_u32, size_mask = val.cast(dtypes.uint32), _c(0xFF if data_bits == 8 else 0xFFFF)
|
||
mask = size_mask << byte_shift
|
||
new_word = (idx & (mask ^ _c(0xFFFFFFFF))) | ((val_u32 & size_mask) << byte_shift)
|
||
if data_bits == 8: return [idx.store(active.where(new_word, idx))]
|
||
# 16-bit cross-word case: byte_pos == 3 means value spans two words
|
||
is_cross = byte_pos.eq(_c(3))
|
||
cross_word0 = (idx & _c(0x00FFFFFF)) | ((val_u32 & _c(0xFF)) << _c(24))
|
||
store0 = idx.store(active.where(is_cross.where(cross_word0, new_word), idx))
|
||
next_idx = mem.index((word_addr + UOp.const(adt, 1)).cast(dtypes.int), active & is_cross)
|
||
cross_word1 = (next_idx & _c(0xFFFFFF00)) | ((val_u32 >> _c(8)) & _c(0xFF))
|
||
return [store0, next_idx.store((active & is_cross).where(cross_word1, next_idx))]
|
||
|
||
def _mem_store_bytes(mem: UOp, addr: UOp, val: UOp, active: UOp, data_bits: int = 32) -> list[UOp]:
|
||
"""Store to byte-addressable memory (scratch). addr is byte offset, mem is uint8 buffer."""
|
||
stores = []
|
||
val_u32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val
|
||
for i in range(data_bits // 8):
|
||
byte_val = (val_u32 >> UOp.const(dtypes.uint32, i * 8)) & UOp.const(dtypes.uint32, 0xFF)
|
||
stores.append(mem.index((addr + UOp.const(dtypes.uint64, i)).cast(dtypes.int), active).store(byte_val.cast(dtypes.uint8)))
|
||
return stores
|
||
|
||
def _collect_data_slices(assigns: list[tuple[str, UOp]], data_prefix: str, pcode_vars: dict | None = None, op_name: str = "") -> dict[int, UOp]:
|
||
"""Collect bit slices from assigns into {dword_idx: value} dict."""
|
||
slices = {}
|
||
for dest, val in assigns:
|
||
if dest.startswith(f'{data_prefix}['):
|
||
if (m := re.match(rf'{data_prefix}\[(\d+)\s*:\s*(\d+)\]', dest)):
|
||
hi_bit, low_bit = int(m.group(1)), int(m.group(2))
|
||
dword_idx = low_bit // 32
|
||
# D16 loads preserve bits - use final value from pcode_vars which has hi bits preserved
|
||
if pcode_vars and 'D16' in op_name and dword_idx == 0 and hi_bit < 32:
|
||
slices[0] = _to_u32(pcode_vars.get(data_prefix, val))
|
||
else: slices[dword_idx] = _to_u32(val)
|
||
elif dest.startswith(data_prefix): slices[0] = _to_u32(val)
|
||
return slices
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# INSTRUCTION COMPILER - converts decoded instruction to UOp SINK
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
class _Ctx:
|
||
"""Context for instruction compilation - holds buffers and helpers."""
|
||
__slots__ = ('inst_size', 'dyn_fields', '_axis_id', 'wave_size', 'vgpr', 'accvgpr')
|
||
sgpr = UOp(Ops.PARAM, dtypes.uint32.ptr(SGPR_COUNT), arg=0)
|
||
vmem = UOp(Ops.PARAM, dtypes.uint32.ptr(1 << 46), arg=2)
|
||
lds = UOp(Ops.PARAM, dtypes.uint32.ptr(16384), arg=3)
|
||
scratch = UOp(Ops.PARAM, dtypes.uint8.ptr(1 << 30), arg=4)
|
||
# Cache PARAM UOps by wave_size so all _Ctx instances with same wave_size share identical UOp references
|
||
_vgpr_cache: dict[int, UOp] = {}
|
||
_accvgpr_cache: dict[int, UOp] = {}
|
||
|
||
def __init__(self, inst_size: int, wave_size: int = 32):
|
||
self.inst_size, self._axis_id, self.wave_size = inst_size, 0, wave_size
|
||
self.dyn_fields: list[tuple[int, int]] = [] # (lo, hi) of fields read dynamically
|
||
if wave_size not in _Ctx._vgpr_cache: _Ctx._vgpr_cache[wave_size] = UOp(Ops.PARAM, dtypes.uint32.ptr(256 * wave_size), arg=1)
|
||
self.vgpr = _Ctx._vgpr_cache[wave_size]
|
||
if wave_size == 64:
|
||
if wave_size not in _Ctx._accvgpr_cache: _Ctx._accvgpr_cache[wave_size] = UOp(Ops.PARAM, dtypes.uint32.ptr(256 * wave_size), arg=5)
|
||
self.accvgpr = _Ctx._accvgpr_cache[wave_size]
|
||
else:
|
||
self.accvgpr = self.vgpr
|
||
|
||
def range(self, n: int | None = None) -> UOp:
|
||
"""Create a lane range UOp with unique axis ID."""
|
||
if n is None: n = self.wave_size
|
||
self._axis_id += 1
|
||
return UOp.range(n, self._axis_id, AxisType.LOOP, dtype=dtypes.int)
|
||
|
||
def unroll_lanes(self, get_lane_bit, exec_mask: UOp, apply_exec: bool = True) -> UOp:
|
||
"""Combine lane bits into a mask using RANGE+REDUCE (32-bit for RDNA, 64-bit for CDNA)."""
|
||
lane = self.range()
|
||
if self.wave_size <= 32:
|
||
bit = get_lane_bit(lane).cast(dtypes.uint32) << lane.cast(dtypes.uint32)
|
||
result = bit.reduce(lane, arg=Ops.ADD)
|
||
else:
|
||
bit = get_lane_bit(lane).cast(dtypes.uint64) << lane.cast(dtypes.uint64)
|
||
result = bit.reduce(lane, arg=Ops.ADD)
|
||
return result & exec_mask if apply_exec else result
|
||
|
||
def inst_word(self, dword_idx: int) -> UOp:
|
||
"""Read instruction dword from vmem at PC + dword_idx*4."""
|
||
pc = self.rpc()
|
||
addr = pc if dword_idx == 0 else pc + UOp.const(dtypes.uint64, dword_idx * 4)
|
||
return self.vmem.index((addr >> UOp.const(dtypes.uint64, 2)).cast(dtypes.int), ptr=True).load()
|
||
|
||
def inst_field(self, field) -> UOp:
|
||
"""Extract field bits from instruction encoding. Tracks field for canonical key computation."""
|
||
lo, hi = field.lo, field.hi
|
||
self.dyn_fields.append((lo, hi))
|
||
dword_idx = lo // 32
|
||
lo_in_dword = lo % 32
|
||
hi_in_dword = hi % 32
|
||
word = self.inst_word(dword_idx)
|
||
if lo // 32 == hi // 32: # Same dword
|
||
mask = (1 << (hi - lo + 1)) - 1
|
||
shifted = word if lo_in_dword == 0 else word >> UOp.const(dtypes.uint32, lo_in_dword)
|
||
return shifted & UOp.const(dtypes.uint32, mask)
|
||
else: # Spans two dwords
|
||
lo_bits = 32 - lo_in_dword
|
||
lo_mask = (1 << lo_bits) - 1
|
||
hi_mask = (1 << (hi_in_dword + 1)) - 1
|
||
lo_part = (word >> UOp.const(dtypes.uint32, lo_in_dword)) & UOp.const(dtypes.uint32, lo_mask)
|
||
hi_part = self.inst_word(dword_idx + 1) & UOp.const(dtypes.uint32, hi_mask)
|
||
return lo_part | (hi_part << UOp.const(dtypes.uint32, lo_bits))
|
||
|
||
def inst_field_signed(self, field) -> UOp:
|
||
"""Extract field and sign-extend based on field width."""
|
||
val = self.inst_field(field)
|
||
width = field.hi - field.lo + 1
|
||
sign_bit = 1 << (width - 1)
|
||
return (val.cast(dtypes.int) ^ _c(sign_bit, dtypes.int)) - _c(sign_bit, dtypes.int)
|
||
|
||
def canonical_mask(self, inst_bytes: bytes) -> tuple[int, int, int]:
|
||
"""Compute canonical (base, mask, size) for cache lookup.
|
||
base = instruction bits with dynamic fields zeroed
|
||
mask = bitmask with 1s for static bits, 0s for dynamic bits
|
||
size = instruction size in bytes"""
|
||
size = self.inst_size
|
||
base = int.from_bytes(inst_bytes[:size], 'little')
|
||
mask = (1 << (size * 8)) - 1 # all 1s initially
|
||
for lo, hi in self.dyn_fields:
|
||
field_mask = ((1 << (hi - lo + 1)) - 1) << lo
|
||
base &= ~field_mask # zero dynamic bits in base
|
||
mask &= ~field_mask # zero dynamic bits in mask
|
||
return base, mask, size
|
||
|
||
def rexec(self) -> UOp:
|
||
"""Read full EXEC mask (32-bit for RDNA, 64-bit for CDNA)."""
|
||
lo = self.rsgpr_dyn(_c(EXEC_LO.offset))
|
||
if self.wave_size <= 32: return lo
|
||
hi = self.rsgpr_dyn(_c(EXEC_LO.offset + 1))
|
||
return _u64(lo, hi)
|
||
|
||
# Dynamic register access (takes UOp index instead of int)
|
||
def rsgpr_dyn(self, reg: UOp, valid: UOp | None = None) -> UOp:
|
||
"""Read SGPR with dynamic register index."""
|
||
if valid is not None: return self.sgpr.index(reg.cast(dtypes.int), valid, ptr=True).load()
|
||
return self.sgpr.index(reg.cast(dtypes.int), ptr=True).load()
|
||
|
||
def wsgpr_dyn(self, reg: UOp, val: UOp) -> UOp:
|
||
"""Write SGPR with dynamic register index. On RDNA, index 124 = NULL (writes discarded). On CDNA, index 124 = M0 (read/write)."""
|
||
# RDNA: NULL (124) discards writes. CDNA: M0 (124) is writable.
|
||
valid = None if self.wave_size == 64 else reg.ne(_c(124))
|
||
return self.sgpr.index(reg.cast(dtypes.int), valid).store(val.cast(dtypes.uint32))
|
||
|
||
def wmask(self, reg: UOp, val: UOp) -> list[UOp]:
|
||
"""Write a lane mask (VCC/EXEC). Splits into lo/hi for wave64."""
|
||
if self.wave_size > 32:
|
||
lo, hi = _split64(val)
|
||
return [self.wsgpr_dyn(reg, lo), self.wsgpr_dyn(reg + _c(1), hi)]
|
||
return [self.wsgpr_dyn(reg, val)]
|
||
|
||
def rmask(self, reg: UOp) -> UOp:
|
||
"""Read a lane mask (VCC/EXEC). Combines lo/hi for wave64."""
|
||
if self.wave_size > 32: return _u64(self.rsgpr_dyn(reg), self.rsgpr_dyn(reg + _c(1)))
|
||
return self.rsgpr_dyn(reg)
|
||
|
||
def rvgpr_dyn(self, reg: UOp, lane: UOp, valid: UOp | None = None) -> UOp:
|
||
"""Read VGPR with dynamic register index."""
|
||
idx = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
|
||
return self.vgpr.index(idx, valid, ptr=True).load() if valid is not None else self.vgpr.index(idx, ptr=True).load()
|
||
|
||
def wvgpr_dyn(self, reg: UOp, lane: UOp, val: UOp, exec_mask: UOp, after: UOp | None = None) -> UOp:
|
||
"""Write VGPR with dynamic register index."""
|
||
buf = self.vgpr.after(after) if after is not None else self.vgpr
|
||
offset = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
|
||
return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32))
|
||
|
||
def raccvgpr_dyn(self, reg: UOp, lane: UOp, valid: UOp | None = None) -> UOp:
|
||
"""Read ACCVGPR with dynamic register index (CDNA only)."""
|
||
idx = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
|
||
return self.accvgpr.index(idx, valid, ptr=True).load() if valid is not None else self.accvgpr.index(idx, ptr=True).load()
|
||
|
||
def waccvgpr_dyn(self, reg: UOp, lane: UOp, val: UOp, exec_mask: UOp, after: UOp | None = None) -> UOp:
|
||
"""Write ACCVGPR with dynamic register index (CDNA only)."""
|
||
buf = self.accvgpr.after(after) if after is not None else self.accvgpr
|
||
offset = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
|
||
return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32))
|
||
|
||
def rsrc_dyn(self, off: UOp, lane: UOp | None, bits: int = 32, literal: UOp | None = None, is_f64: bool = False, do_cast: bool = True) -> UOp:
|
||
"""Read source operand with dynamic offset. Handles SGPR/inline constants (<256), VGPR (>=256).
|
||
If lane is None, only scalar access is supported (off must be < 256).
|
||
is_f64: True for F64 operations where 64-bit literals go in high 32 bits."""
|
||
is_float_const = (off >= _c(240)) & (off <= _c(248))
|
||
is_vgpr = off >= _c(256)
|
||
is_sgpr = is_vgpr.ne(True)
|
||
sgpr_lo = self.rsgpr_dyn(off, is_sgpr)
|
||
|
||
if lane is not None:
|
||
vgpr_reg = off - _c(256)
|
||
vgpr_lo = self.rvgpr_dyn(vgpr_reg, lane, is_vgpr)
|
||
vgpr_val = _u64(vgpr_lo, self.rvgpr_dyn(vgpr_reg + _c(1), lane, is_vgpr)) if bits == 64 else vgpr_lo
|
||
|
||
if bits == 64:
|
||
sgpr_hi = self.rsgpr_dyn(off + _c(1), is_sgpr)
|
||
sgpr_val = _u64(sgpr_lo, sgpr_hi)
|
||
# Integer inline constants: sign-extend 32-bit value from buffer to 64-bit
|
||
# Float constants: cast F32 to F64
|
||
int_inline = sgpr_lo.cast(dtypes.int32).cast(dtypes.int64)
|
||
float_inline = sgpr_lo.bitcast(dtypes.float32).cast(dtypes.float64)
|
||
# compute inline
|
||
inline = is_float_const.where(float_inline.bitcast(dtypes.uint64), int_inline.bitcast(dtypes.uint64))
|
||
# Literal handling: F64 VOP puts literal in high 32 bits; B64/I64/U64 VOP and SOP zero-extend
|
||
if literal is not None:
|
||
lit_val = literal.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32) if is_f64 else literal.cast(dtypes.uint64)
|
||
inline = off.eq(_c(255)).where(lit_val, inline)
|
||
scalar_val = (off < _c(128)).where(sgpr_val, inline)
|
||
else:
|
||
scalar_val = sgpr_lo
|
||
if literal is not None: scalar_val = off.eq(_c(255)).where(literal, scalar_val)
|
||
if bits == 16 and do_cast: # Float constants: cast F32 to F16
|
||
scalar_val = is_float_const.where(scalar_val.bitcast(dtypes.float32).cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32), scalar_val)
|
||
|
||
return is_vgpr.where(vgpr_val, scalar_val) if lane is not None else scalar_val
|
||
|
||
def rpc(self) -> UOp:
|
||
"""Read PC as 64-bit byte address."""
|
||
# Index at PC_LO, then cast to uint64 ptr and load
|
||
return self.sgpr.index(_c(PC_LO_IDX, dtypes.int), ptr=True).cast(dtypes.uint64.ptr(SGPR_COUNT // 2)).load()
|
||
|
||
def inc_pc(self) -> list[UOp]:
|
||
"""Increment PC by instruction size in bytes. Returns [store]."""
|
||
new_pc = self.rpc() + UOp.const(dtypes.uint64, self.inst_size)
|
||
return [self.sgpr.index(_c(PC_LO_IDX, dtypes.int), ptr=True).cast(dtypes.uint64.ptr(SGPR_COUNT // 2)).store(new_pc)]
|
||
|
||
def scalar_stores(self, assigns: list[tuple[str, UOp]], sdst_reg: UOp, sdst_size: int = 1) -> list[UOp]:
|
||
"""Generate stores for scalar assigns with dynamic destination register (D0, SCC, EXEC, VCC)."""
|
||
stores: list[UOp] = []
|
||
for dest, val in assigns:
|
||
if dest.startswith('D0'):
|
||
if sdst_size == 2:
|
||
lo, hi = _split64(val)
|
||
stores.extend([self.wsgpr_dyn(sdst_reg, lo), self.wsgpr_dyn(sdst_reg + _c(1), hi)])
|
||
else: stores.append(self.wsgpr_dyn(sdst_reg, _val_to_u32(val)))
|
||
elif dest.startswith('SCC'): stores.append(self.wsgpr_dyn(_c(SCC.offset), _to_u32(val)))
|
||
elif dest.startswith('EXEC'):
|
||
if self.wave_size > 32 and val.dtype in (dtypes.uint64, dtypes.int64):
|
||
lo, hi = _split64(val)
|
||
stores.extend([self.wsgpr_dyn(_c(EXEC_LO.offset), lo), self.wsgpr_dyn(_c(EXEC_LO.offset + 1), hi)])
|
||
else: stores.append(self.wsgpr_dyn(_c(EXEC_LO.offset), _to_u32(val)))
|
||
elif dest.startswith('VCC'): stores.extend(self.wmask(_c(VCC_LO.offset), val))
|
||
return stores
|
||
|
||
def compile_sop_pcode(self, op, srcs: dict[str, UOp | int], sdst_reg: UOp, sdst_size: int) -> UOp:
|
||
"""Compile a scalar instruction with dynamic destination register."""
|
||
pcode = get_pcode(op)
|
||
srcs.update({'VCC': self.rmask(_c(VCC_LO.offset)), 'EXEC': self.rexec(), 'SCC': self.rsgpr_dyn(_c(SCC.offset)),
|
||
'_wave_size': self.wave_size})
|
||
if 'D0' not in srcs: srcs['D0'] = self.rsgpr_dyn(sdst_reg) # D0 is current dest value for read-modify-write ops
|
||
_, assigns = parse_pcode(pcode, srcs)
|
||
return UOp.sink(*self.scalar_stores(assigns, sdst_reg, sdst_size), *self.inc_pc())
|
||
|
||
def compile_lane_pcode(self, op, inst) -> UOp:
|
||
"""Compile cross-lane ops (READLANE/WRITELANE/PERMLANE) using pcode parser."""
|
||
pcode = get_pcode(op)
|
||
op_name = op.name if hasattr(op, 'name') else str(op)
|
||
src0_off, vdst_off = self.inst_field(type(inst).src0), self.inst_field(type(inst).vdst)
|
||
src0_reg = (src0_off >= _c(256)).where(src0_off - _c(256), _c(0)) # VGPR index or 0
|
||
src1_off = self.inst_field(type(inst).src1) if hasattr(type(inst), 'src1') else None
|
||
src2_off = self.inst_field(type(inst).src2) if hasattr(type(inst), 'src2') else None
|
||
src1_reg = (src1_off >= _c(256)).where(src1_off - _c(256), src1_off) if src1_off is not None else _c(0)
|
||
src2_reg = (src2_off >= _c(256)).where(src2_off - _c(256), src2_off) if src2_off is not None else _c(0)
|
||
exec_val = self.rexec()
|
||
exec_lo = exec_val.cast(dtypes.uint32) if exec_val.dtype == dtypes.uint64 else exec_val
|
||
srcs = {
|
||
'SRC0': src0_reg, 'VDST': vdst_off, 'EXEC_LO': exec_lo, 'EXEC': exec_val if exec_val.dtype == dtypes.uint64 else exec_val.cast(dtypes.uint64),
|
||
'_vgpr': self.vgpr, '_wave_size': self.wave_size, 'SRC1': src1_reg, 'SRC2': src2_reg,
|
||
'S0': self.rsrc_dyn(src0_off, _c(0, dtypes.int)) if 'WRITELANE' in op_name else src0_reg,
|
||
'S1': self.rsrc_dyn(src1_off, _c(0, dtypes.int)) if src1_off is not None else _c(0),
|
||
'S2': self.rsrc_dyn(src2_off, _c(0, dtypes.int)) if src2_off is not None else _c(0),
|
||
}
|
||
_, assigns = parse_pcode(pcode, srcs)
|
||
stores = []
|
||
for dest, val in assigns:
|
||
if dest.startswith('D0'): stores.append(self.wsgpr_dyn(vdst_off, val.cast(dtypes.uint32)))
|
||
elif dest.startswith('VGPR['): stores.append(self.vgpr.index(val[0].cast(dtypes.int)).store(val[1].cast(dtypes.uint32)))
|
||
return UOp.sink(*stores, *self.inc_pc())
|
||
|
||
def compile_vop_pcode(self, op, srcs: dict[str, UOp | int], lane: UOp, vdst_reg: UOp, exec_mask: UOp,
|
||
opsel_dst_hi: bool | UOp = False, sdst_reg: int | None = None, clmp: int = 0,
|
||
src0_off: UOp | None = None) -> UOp:
|
||
"""Compile VOP instruction. Returns sink with stores and inc_pc."""
|
||
pcode = get_pcode(op)
|
||
vcc_reg = sdst_reg if sdst_reg is not None else VCC_LO.offset
|
||
if 'VCC' not in srcs: srcs['VCC'] = self.rmask(_c(vcc_reg))
|
||
srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane, 'VDST': vdst_reg,
|
||
'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0), 'ROUND_NEAREST_EVEN': _c(0), '_vgpr': self.vgpr, '_wave_size': self.wave_size,
|
||
'MAX_FLOAT_F32': UOp.const(dtypes.float32, 3.4028234663852886e38),
|
||
# CDNA SDWA byte/word select constants (E32 always uses BYTE0/WORD0 defaults)
|
||
'SDWA_SRC0_SEL': _c(0), 'BYTE0': _c(0), 'BYTE1': _c(1), 'BYTE2': _c(2), 'BYTE3': _c(3),
|
||
'WORD0': _c(0), 'WORD1': _c(1)}) # rounding mode and SDWA constants
|
||
_, assigns = parse_pcode(pcode, srcs)
|
||
|
||
# For integer ops with clamp, compute overflow using wide arithmetic
|
||
# NOTE: MUL_LO ops don't saturate - they always return the low bits
|
||
int_saturate = None
|
||
if clmp and any(p in op.name for p in ('_NC_U', '_MAD_U', '_NC_I', '_MAD_I')):
|
||
is_signed, is_16bit = '_I' in op.name and '_U' not in op.name, '16' in op.name
|
||
if not (is_16bit and is_signed): # Skip 16-bit signed ops due to codegen issues
|
||
s0, s1, s2 = srcs.get('S0'), srcs.get('S1'), srcs.get('S2')
|
||
if s0 is not None and s1 is not None:
|
||
narrow_dt = dtypes.uint16 if is_16bit else (dtypes.int32 if is_signed else dtypes.uint32)
|
||
wide_dt = dtypes.int32 if is_16bit else dtypes.int64
|
||
narrow_max, narrow_min = (0xFFFF, 0) if is_16bit else ((0x7FFFFFFF, -0x80000000) if is_signed else (0xFFFFFFFF, 0))
|
||
def to_wide(x): return (x.bitcast(narrow_dt) if x.dtype.itemsize == narrow_dt.itemsize else x.cast(narrow_dt)).cast(wide_dt)
|
||
is_sub, is_mad = 'SUB' in op.name, 'MAD' in op.name
|
||
full = (to_wide(s0) * to_wide(s1) + to_wide(s2)) if is_mad and s2 is not None else \
|
||
(to_wide(s1) - to_wide(s0)) if is_sub and 'SUBREV' in op.name else \
|
||
(to_wide(s0) - to_wide(s1)) if is_sub else (to_wide(s0) + to_wide(s1))
|
||
int_saturate = full.clamp(narrow_min, narrow_max).cast(narrow_dt)
|
||
# V_SUB_U32 / V_ADD_U32 with clamp: unsigned saturate (SUB underflow->0, ADD overflow->0xFFFFFFFF)
|
||
if clmp and int_saturate is None and any(p in op.name for p in ('_SUB_U32', '_ADD_U32', '_SUB_U16', '_ADD_U16')):
|
||
s0, s1 = srcs.get('S0'), srcs.get('S1')
|
||
if s0 is not None and s1 is not None:
|
||
assert isinstance(s0, UOp) and isinstance(s1, UOp)
|
||
a, b = (s1.cast(dtypes.uint32), s0.cast(dtypes.uint32)) if 'SUBREV' in op.name else (s0.cast(dtypes.uint32), s1.cast(dtypes.uint32))
|
||
if 'SUB' in op.name:
|
||
int_saturate = (a < b).where(_c(0), a - b) # underflow -> 0
|
||
else:
|
||
raw_sum = a + b
|
||
int_saturate = (raw_sum < a).where(_c(0xFFFFFFFF), raw_sum) # overflow -> MAX
|
||
|
||
raw_stores: list = []
|
||
vcc_val, exec_val = None, None
|
||
for dest, val in assigns:
|
||
# VGPR bit-slice assignment: VGPR[lane][reg][hi:lo] = (vgpr_idx, rhs_val, hi, lo[, cond]) -> read-modify-write
|
||
if dest.startswith('VGPR[') and re.search(r'\[\d+:\d+\]', dest):
|
||
# VGPR bit-slice: (vgpr_idx, rhs_val, hi_bit, lo_bit) - hi/lo are UOp constants
|
||
hi_bit, lo_bit = int(val[2].arg), int(val[3].arg)
|
||
width = hi_bit - lo_bit + 1
|
||
old = self.vgpr.index(val[0].cast(dtypes.int), ptr=True).load()
|
||
new_val = _set_bits(old, _val_to_bits(val[1]), width, lo_bit).cast(dtypes.uint32)
|
||
active = _lane_active(exec_mask, lane)
|
||
raw_stores.append(('vgpr_direct', self.vgpr.index(val[0].cast(dtypes.int), active).store(new_val)))
|
||
continue
|
||
if 'D0' in dest and '[laneId]' in dest:
|
||
old_vcc = self.rmask(_c(VCC_LO.offset))
|
||
new_vcc = _set_lane_bit(old_vcc, lane, val, exec_mask)
|
||
raw_stores.extend([('vcc', s) for s in self.wmask(_c(VCC_LO.offset), new_vcc)])
|
||
elif dest.startswith('D0'):
|
||
dest_suffix = re.match(r'D0\.(\w+)', dest)
|
||
if dest_suffix is not None:
|
||
target_dt = {'u16': dtypes.uint16, 'i16': dtypes.int16, 'f16': dtypes.half}.get(dest_suffix.group(1))
|
||
if target_dt is not None and val.dtype != target_dt: val = val.cast(target_dt)
|
||
if (slice_match := re.match(r'D0\[(\d+)\s*:\s*(\d+)\]', dest)):
|
||
d0_hi_bit, d0_lo_bit = int(slice_match.group(1)), int(slice_match.group(2))
|
||
if d0_hi_bit != 31 or d0_lo_bit != 0:
|
||
d0_width, slice_mask = d0_hi_bit - d0_lo_bit + 1, (1 << (d0_hi_bit - d0_lo_bit + 1)) - 1
|
||
val_bits = val.bitcast(dtypes.uint16).cast(dtypes.uint32) if val.dtype == dtypes.half else \
|
||
val.cast(dtypes.uint32) if val.dtype in (dtypes.uint16, dtypes.int16) else \
|
||
val.cast(dtypes.uint32) & UOp.const(dtypes.uint32, slice_mask)
|
||
raw_stores.append(('vgpr_slice', (d0_lo_bit, d0_width, val_bits)))
|
||
continue
|
||
# For integer ops with clamp, use pre-computed saturated value; for floats, clamp to [0,1]
|
||
if int_saturate is not None: val = int_saturate
|
||
elif clmp and val.dtype in (dtypes.float32, dtypes.half, dtypes.float64):
|
||
clamped = val.maximum(UOp.const(val.dtype, 0.0)).minimum(UOp.const(val.dtype, 1.0))
|
||
val = _FUNCS['isNAN'](val).where(UOp.const(val.dtype, 0.0), clamped)
|
||
if val.dtype in (dtypes.uint64, dtypes.int64, dtypes.float64):
|
||
lo, hi = _split64(val)
|
||
raw_stores.extend([('vgpr', self.wvgpr_dyn(vdst_reg, lane, lo, exec_mask)),
|
||
('vgpr', self.wvgpr_dyn(vdst_reg + _c(1), lane, hi, exec_mask))])
|
||
elif val.dtype in (dtypes.half, dtypes.uint16, dtypes.int16):
|
||
result, old_val = _val_to_u32(val), self.rvgpr_dyn(vdst_reg, lane)
|
||
hi_result = (old_val & UOp.const(dtypes.uint32, 0xFFFF)) | (result << UOp.const(dtypes.uint32, 16))
|
||
# GFX9/CDNA zeroes upper 16 bits on lo-half write; RDNA preserves them
|
||
lo_result = (result & UOp.const(dtypes.uint32, 0xFFFF)) if self.wave_size == 64 else \
|
||
(old_val & UOp.const(dtypes.uint32, 0xFFFF0000)) | (result & UOp.const(dtypes.uint32, 0xFFFF))
|
||
result = opsel_dst_hi.where(hi_result, lo_result) if isinstance(opsel_dst_hi, UOp) else hi_result if opsel_dst_hi else lo_result
|
||
raw_stores.append(('vgpr', self.wvgpr_dyn(vdst_reg, lane, result, exec_mask)))
|
||
else: raw_stores.append(('vgpr', self.wvgpr_dyn(vdst_reg, lane, _val_to_u32(val), exec_mask)))
|
||
elif dest.startswith('S0') and src0_off is not None:
|
||
# Write back to src0 VGPR (e.g. v_swap_b32). src0_off is raw encoding (256+ = VGPR)
|
||
src0_vgpr = src0_off - _c(256)
|
||
raw_stores.append(('vgpr_s0', self.wvgpr_dyn(src0_vgpr, lane, _val_to_u32(val), exec_mask)))
|
||
elif dest.startswith('VCC'): vcc_val = val
|
||
elif dest.startswith('EXEC'): exec_val = val
|
||
elif dest.startswith('SCC'): raw_stores.append(('scc', self.wsgpr_dyn(_c(SCC.offset), _to_u32(val))))
|
||
|
||
lane_stores = [s for t, s in raw_stores if t in ('vgpr', 'vgpr_s0', 'vgpr_direct')]
|
||
stores, scalar_stores = [], [s for t, s in raw_stores if t == 'scc']
|
||
slice_stores = [s for t, s in raw_stores if t == 'vgpr_slice']
|
||
if slice_stores:
|
||
result = self.rvgpr_dyn(vdst_reg, lane)
|
||
for lo_bit, width, val_bits in slice_stores:
|
||
mask = UOp.const(dtypes.uint32, ((1 << width) - 1) << lo_bit)
|
||
result = (result & (mask ^ UOp.const(dtypes.uint32, 0xFFFFFFFF))) | (val_bits << UOp.const(dtypes.uint32, lo_bit))
|
||
lane_stores.append(self.wvgpr_dyn(vdst_reg, lane, result, exec_mask))
|
||
# VCC/EXEC mask writes must be computed BEFORE VGPR stores to avoid reading modified VGPRs.
|
||
# When vdst overlaps with src operands (e.g. v_add_co_u32 v[0], vcc, s[8], v[0]), the carry
|
||
# computation reads the original source values only if its range loop runs before the VGPR write loop.
|
||
mask_stores: list[UOp] = []
|
||
for mask_val, reg in [(vcc_val, vcc_reg), (exec_val, EXEC_LO.offset)]:
|
||
if mask_val is None: continue
|
||
def get_bit(l, v=mask_val): return (_to_u32(v.substitute({lane: l})) & _c(1)).cast(dtypes.uint32)
|
||
mask_stores.extend(self.wmask(_c(reg), self.unroll_lanes(get_bit, exec_mask, apply_exec=False)))
|
||
stores.extend(mask_stores)
|
||
if lane_stores: stores.append(UOp.sink(*lane_stores).end(lane))
|
||
stores.extend(scalar_stores)
|
||
return UOp.sink(*stores, *self.inc_pc())
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# INSTRUCTION HANDLERS
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
def _compile_sopp(inst: ir3.SOPP | ir4.SOPP, ctx: _Ctx) -> UOp:
|
||
simm16 = ctx.inst_field_signed(type(inst).simm16).cast(dtypes.int16)
|
||
if inst.op in (ir3.SOPPOp.S_ENDPGM, ir4.SOPPOp.S_ENDPGM, irc.SOPPOp.S_ENDPGM):
|
||
return UOp.sink(ctx.wsgpr_dyn(_c(PC_LO_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)),
|
||
ctx.wsgpr_dyn(_c(PC_HI_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)))
|
||
# S_BARRIER: advance PC past the barrier instruction. The execution loop detects barriers before executing and handles synchronization.
|
||
barrier_ops = {ir3.SOPPOp.S_BARRIER, irc.SOPPOp.S_BARRIER}
|
||
if hasattr(ir4.SOPPOp, 'S_BARRIER_WAIT'): barrier_ops.add(ir4.SOPPOp.S_BARRIER_WAIT)
|
||
if inst.op in barrier_ops: return UOp.sink(*ctx.inc_pc())
|
||
# S_NOP and S_WAITCNT are no-ops in emulator (no pipeline/cache to wait on)
|
||
if inst.op in (ir3.SOPPOp.S_NOP, ir4.SOPPOp.S_NOP, irc.SOPPOp.S_NOP, irc.SOPPOp.S_WAITCNT): return UOp.sink(*ctx.inc_pc())
|
||
# NOTE: we ignore SOPPs without PCODE
|
||
if inst.op in _get_pcode_dict(inst.op):
|
||
pcode = get_pcode(inst.op)
|
||
pc_bytes = ctx.rpc() # PC is already 64-bit byte address
|
||
vcc, exec_val = ctx.rmask(_c(VCC_LO.offset)), ctx.rexec()
|
||
srcs = {'PC': pc_bytes.cast(dtypes.int64), 'SIMM16': simm16, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'VCC': vcc,
|
||
'VCCZ': vcc.eq(UOp.const(vcc.dtype, 0)).cast(dtypes.uint32),
|
||
'EXECZ': exec_val.eq(UOp.const(exec_val.dtype, 0)).cast(dtypes.uint32)}
|
||
for dest, val in parse_pcode(pcode, srcs)[1]:
|
||
if dest == 'PC' or dest.startswith('PC.'):
|
||
lo, hi = _split64(val.cast(dtypes.uint64))
|
||
return UOp.sink(ctx.wsgpr_dyn(_c(PC_LO_IDX), lo), ctx.wsgpr_dyn(_c(PC_HI_IDX), hi))
|
||
return UOp.sink(*ctx.inc_pc())
|
||
|
||
def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp:
|
||
# Cache invalidation instructions are no-ops in the emulator (we don't model caches)
|
||
if '_INV' in inst.op.name: return UOp.sink(*ctx.inc_pc())
|
||
# Dynamic sbase field (bits 5:0) - SGPR pair, field value * 2 = register offset
|
||
sbase = ctx.inst_field(type(inst).sbase) * _c(2)
|
||
# Dynamic sdata field (bits 12:6) - destination SGPR
|
||
sdata_reg = ctx.inst_field(type(inst).sdata)
|
||
# RDNA4 uses 'ioffset', RDNA3 uses 'offset' - use type(inst) to get correct field
|
||
offset_field = type(inst).ioffset if hasattr(type(inst), 'ioffset') else type(inst).offset # type: ignore[union-attr]
|
||
offset = ctx.inst_field_signed(offset_field) # signed immediate
|
||
# Dynamic soffset field - SGPR for additional offset (NULL=124 reads as 0, CDNA soffset_en=0 means no soffset)
|
||
soffset_val = _c(0).cast(dtypes.uint64)
|
||
if not (isinstance(inst, irc.SMEM) and not inst.soffset_en):
|
||
soffset_val = ctx.rsgpr_dyn(ctx.inst_field(type(inst).soffset)).cast(dtypes.uint64)
|
||
addr = _u64(ctx.rsgpr_dyn(sbase), ctx.rsgpr_dyn(sbase + _c(1))) + offset.cast(dtypes.uint64) + soffset_val
|
||
# S_LOAD_(DTYPE) series: B32/DWORD=1, B64/DWORDX2=2, U8=0.25, I8=-0.25, etc.
|
||
op_name = _op_name(inst)
|
||
assert (op_name).startswith('S_LOAD_'), f"unexpected SMEM op: {op_name}"
|
||
part = op_name.rsplit('_', 1)[1] # B32, DWORD, DWORDX2, U8, I8, etc.
|
||
nval = int(part.removeprefix('DWORD').removeprefix('X') or '1') if 'DWORD' in part else int(part[1:]) / 32 * (-1 if part[0] == 'I' else 1)
|
||
ndwords = max(1, int(abs(nval)))
|
||
dword_base = addr >> UOp.const(dtypes.uint64, 2)
|
||
vals = [ctx.vmem.index((dword_base + UOp.const(dtypes.uint64, i)).cast(dtypes.int)) for i in range(ndwords)]
|
||
if abs(nval) < 1:
|
||
nbits = int(abs(nval) * 32)
|
||
byte_off = (addr & UOp.const(dtypes.uint64, 3)).cast(dtypes.uint32) * UOp.const(dtypes.uint32, 8)
|
||
extracted = (vals[0] >> byte_off) & UOp.const(dtypes.uint32, (1 << nbits) - 1)
|
||
vals[0] = extracted.cast({8: dtypes.int8, 16: dtypes.int16}[nbits]).cast(dtypes.int32).bitcast(dtypes.uint32) if nval < 0 else extracted
|
||
stores = [ctx.wsgpr_dyn(sdata_reg + _c(i), vals[i]) for i in range(ndwords)]
|
||
return UOp.sink(*stores, *ctx.inc_pc())
|
||
|
||
def _compile_sop(inst: ir3.SOP1|ir3.SOP2|ir3.SOPC|ir3.SOPK|ir4.SOP1|ir4.SOP2|ir4.SOPC|ir4.SOPK|irc.SOP1|irc.SOP2|irc.SOPC|irc.SOPK, ctx: _Ctx) -> UOp:
|
||
bits = inst.canonical_op_bits
|
||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
|
||
|
||
if isinstance(inst, (ir3.SOPK, ir4.SOPK, irc.SOPK)):
|
||
sdst_off = ctx.inst_field(type(inst).sdst)
|
||
simm16 = ctx.inst_field(type(inst).simm16)
|
||
# Sign-extend simm16
|
||
simm16_sext = simm16.cast(dtypes.int16).cast(dtypes.int32)
|
||
# RDNA4 pcodes use S0.i16 for the immediate (e.g., S_MULK_I32), RDNA3 uses S0 for the register (e.g., S_CMPK_*)
|
||
# CDNA pcode uses S0 for the immediate in MOVK/MULK/ADDK/CMOVK, but S0 = register for CMPK/SETREG
|
||
op_name = _op_name(inst)
|
||
if isinstance(inst, ir4.SOPK): s0 = simm16
|
||
elif isinstance(inst, irc.SOPK) and 'CMPK' not in op_name and 'SETREG' not in op_name: s0 = simm16_sext
|
||
else: s0 = ctx.rsgpr_dyn(sdst_off)
|
||
srcs = {'S0': s0, 'S1': simm16_sext, 'SIMM16': simm16_sext, 'D0': ctx.rsgpr_dyn(sdst_off)}
|
||
dst_off, dst_size = sdst_off, 1
|
||
# S_GETREG_B32: extract bits from HW register. Handle as special case since HW_REGISTERS is not a normal variable.
|
||
# HW register values are stored at SGPR[SGPR_COUNT-16 + hwRegId] by _init_wave.
|
||
if 'GETREG' in op_name:
|
||
hw_reg_id = simm16.cast(dtypes.uint32) & _c(0x3F)
|
||
offset = (simm16.cast(dtypes.uint32) >> _c(6)) & _c(0x1F)
|
||
size = ((simm16.cast(dtypes.uint32) >> _c(11)) & _c(0x1F)) + _c(1)
|
||
hw_val = ctx.rsgpr_dyn(_c(SGPR_COUNT - 16) + hw_reg_id)
|
||
mask = (_c(1) << size) - _c(1)
|
||
result = (hw_val >> offset) & mask
|
||
return UOp.sink(ctx.wsgpr_dyn(sdst_off, result), *ctx.inc_pc())
|
||
elif isinstance(inst, (ir3.SOP1, ir4.SOP1, irc.SOP1)):
|
||
# S_BARRIER_SIGNAL: no-op in emulator, barrier sync handled by execution loop
|
||
if isinstance(inst, ir4.SOP1) and inst.op in _BARRIER_SOP1_OPS: return UOp.sink(*ctx.inc_pc())
|
||
sdst_off = ctx.inst_field(type(inst).sdst)
|
||
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
|
||
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal)}
|
||
dst_off, dst_size = sdst_off, bits['d'] // 32
|
||
elif isinstance(inst, (ir3.SOP2, ir4.SOP2, irc.SOP2)):
|
||
sdst_off = ctx.inst_field(type(inst).sdst)
|
||
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
|
||
ssrc1_off = ctx.inst_field(type(inst).ssrc1)
|
||
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal),
|
||
'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)}
|
||
if literal is not None: srcs['SIMM32'] = literal
|
||
dst_off, dst_size = sdst_off, bits['d'] // 32
|
||
elif isinstance(inst, (ir3.SOPC, ir4.SOPC, irc.SOPC)):
|
||
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
|
||
ssrc1_off = ctx.inst_field(type(inst).ssrc1)
|
||
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal),
|
||
'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)}
|
||
dst_off, dst_size = _c(0), 0 # SOPC writes to SCC, not sdst
|
||
else:
|
||
raise RuntimeError(f"unknown SOP type: {type(inst).__name__}")
|
||
|
||
return ctx.compile_sop_pcode(inst.op, srcs, dst_off, dst_size)
|
||
|
||
def _sdwa_select(val: UOp, sel: UOp, sext: UOp) -> UOp:
|
||
"""Apply SDWA byte/word selection and optional sign extension to a 32-bit value."""
|
||
# sel: 0-3=BYTE_0..3, 4=WORD_0, 5=WORD_1, 6=DWORD
|
||
b0 = val & _c(0xFF)
|
||
b1 = (val >> _c(8)) & _c(0xFF)
|
||
b2 = (val >> _c(16)) & _c(0xFF)
|
||
b3 = (val >> _c(24)) & _c(0xFF)
|
||
w0 = val & _c(0xFFFF)
|
||
w1 = (val >> _c(16)) & _c(0xFFFF)
|
||
selected = sel.eq(_c(1)).where(b1, sel.eq(_c(2)).where(b2, sel.eq(_c(3)).where(b3,
|
||
sel.eq(_c(4)).where(w0, sel.eq(_c(5)).where(w1, sel.eq(_c(6)).where(val, b0))))))
|
||
# Sign extend when sext=1
|
||
is_byte = sel < _c(4)
|
||
byte_sext = (selected & _c(0x80)).ne(_c(0)).where(selected | _c(0xFFFFFF00), selected)
|
||
word_sext = (selected & _c(0x8000)).ne(_c(0)).where(selected | _c(0xFFFF0000), selected)
|
||
return sext.ne(_c(0)).where(is_byte.where(byte_sext, word_sext), selected)
|
||
|
||
def _sdwa_write(old: UOp, val: UOp, dst_sel: UOp, dst_unused: UOp) -> UOp:
|
||
"""Apply SDWA destination selection: write selected byte/word, handle unused bits."""
|
||
# dst_unused: 0=PAD(zero), 1=SEXT, 2=PRESERVE
|
||
# dst_sel: 0-3=BYTE, 4=WORD_0, 5=WORD_1, 6=DWORD
|
||
is_byte = dst_sel < _c(4)
|
||
is_word = (dst_sel >= _c(4)) & (dst_sel < _c(6))
|
||
shift = is_byte.where(dst_sel * _c(8), (dst_sel - _c(4)) * _c(16))
|
||
mask = is_byte.where(_c(0xFF), is_word.where(_c(0xFFFF), _c(0xFFFFFFFF)))
|
||
placed = (val & mask) << shift
|
||
preserve_mask = (mask << shift) ^ _c(0xFFFFFFFF)
|
||
preserved = (old & preserve_mask) | placed
|
||
# For PAD and SEXT, unused bits are zero (PAD) or sign-extended (SEXT). For DWORD, just return val.
|
||
return dst_sel.eq(_c(6)).where(val, dst_unused.eq(_c(2)).where(preserved, placed))
|
||
|
||
def _dpp_quad_sel(quad_lane: UOp, sels: tuple[int, int, int, int]) -> UOp:
|
||
sel = _c(sels[0], dtypes.int)
|
||
for i, src in enumerate(sels[1:], start=1): sel = quad_lane.eq(_c(i, dtypes.int)).where(_c(src, dtypes.int), sel)
|
||
return sel
|
||
|
||
def _dpp16_ctrl(lane: UOp, dpp: int, row_mask: int, bank_mask: int, wave_size: int) -> tuple[UOp, UOp, UOp]:
|
||
"""Return (src_lane, row/bank enabled, in-bounds) for a DPP16 swizzle."""
|
||
lane_i = lane.cast(dtypes.int)
|
||
row_base, lane_in_row = lane_i & _c(~15, dtypes.int), lane_i & _c(15, dtypes.int)
|
||
row = lane_i // _c(16, dtypes.int)
|
||
bank = lane_in_row >> _c(2, dtypes.int)
|
||
enabled = (((_c(row_mask) >> row.cast(dtypes.uint32)) & _c(1)).ne(_c(0)) &
|
||
(((_c(bank_mask) >> bank.cast(dtypes.uint32)) & _c(1)).ne(_c(0))))
|
||
op, arg = decode_dpp16(dpp)
|
||
src_lane, valid = lane_i, UOp.const(dtypes.bool, True)
|
||
|
||
if op == 'quad_perm':
|
||
assert isinstance(arg, tuple)
|
||
src_lane = (lane_i & _c(~3, dtypes.int)) + _dpp_quad_sel(lane_i & _c(3, dtypes.int), arg)
|
||
else:
|
||
assert isinstance(arg, int)
|
||
if op == 'row_shl': src_lane, valid = row_base + lane_in_row + _c(arg, dtypes.int), lane_in_row <= _c(15 - arg, dtypes.int)
|
||
elif op == 'row_shr': src_lane, valid = row_base + lane_in_row - _c(arg, dtypes.int), lane_in_row >= _c(arg, dtypes.int)
|
||
elif op == 'row_ror': src_lane = row_base + ((lane_in_row - _c(arg, dtypes.int)) & _c(15, dtypes.int))
|
||
elif op == 'row_mirror': src_lane = row_base + (_c(15, dtypes.int) - lane_in_row)
|
||
elif op == 'row_half_mirror': src_lane = row_base + ((lane_in_row & _c(8, dtypes.int)) | (_c(7, dtypes.int) - (lane_in_row & _c(7, dtypes.int))))
|
||
elif op == 'row_bcast': src_lane = row_base
|
||
elif op == 'wave_shl': src_lane, valid = lane_i + _c(arg, dtypes.int), lane_i < _c(wave_size - arg, dtypes.int)
|
||
elif op == 'wave_rol': src_lane = (lane_i + _c(arg, dtypes.int)) % _c(wave_size, dtypes.int)
|
||
elif op == 'wave_shr': src_lane, valid = lane_i - _c(arg, dtypes.int), lane_i >= _c(arg, dtypes.int)
|
||
elif op == 'wave_ror': src_lane = (lane_i - _c(arg, dtypes.int)) % _c(wave_size, dtypes.int)
|
||
else: raise NotImplementedError(f"DPP16 control {dpp:#x} ({op}:{arg}) not implemented in emulator")
|
||
return src_lane, enabled, valid
|
||
|
||
def _load_dpp16_src0(ctx: _Ctx, inst, lane: UOp, fallback: UOp) -> UOp:
|
||
"""Load a DPP16-swizzled src0 value from vsrc0."""
|
||
src_lane, enabled, valid = _dpp16_ctrl(lane, getattr(inst, 'dpp', 0) or 0, getattr(inst, 'row_mask', 0xf) or 0xf,
|
||
getattr(inst, 'bank_mask', 0xf) or 0xf, ctx.wave_size)
|
||
safe_src_lane = (enabled & valid).where(src_lane, _c(0, dtypes.int))
|
||
swizzled = ctx.rvgpr_dyn(ctx.inst_field(type(inst).vsrc0), safe_src_lane)
|
||
invalid = UOp.const(fallback.dtype, 0) if getattr(inst, 'bc', 0) else fallback
|
||
return enabled.where(valid.where(swizzled, invalid), fallback)
|
||
|
||
def _compile_sdwa(inst: irc.VOP1_SDWA | irc.VOP2_SDWA | irc.VOP2_SDWA_SDST | irc.VOPC_SDWA_SDST, ctx: _Ctx) -> UOp:
|
||
"""Compile CDNA SDWA (Sub-Dword Access) VOP1/VOP2/VOPC instructions."""
|
||
is_vopc = isinstance(inst, irc.VOPC_SDWA_SDST)
|
||
exec_mask = ctx.rexec()
|
||
# sd=1 means use sdst register, sd=0 means use VCC (for VOPC_SDWA_SDST and VOP2_SDWA_SDST)
|
||
if isinstance(inst, (irc.VOP2_SDWA_SDST, irc.VOPC_SDWA_SDST)):
|
||
sdst_off = _c(inst.sdst.offset) if getattr(inst, 'sd', False) else _c(VCC_LO.offset)
|
||
else:
|
||
sdst_off = _c(VCC_LO.offset)
|
||
# Read SDWA fields (these are dynamic but shared across lanes)
|
||
src0_sel = ctx.inst_field(type(inst).src0_sel)
|
||
src0_sext = ctx.inst_field(type(inst).src0_sext)
|
||
vsrc0_reg = ctx.inst_field(type(inst).vsrc0)
|
||
pcode = get_pcode(inst.op)
|
||
if isinstance(inst, (irc.VOP2_SDWA, irc.VOP2_SDWA_SDST, irc.VOPC_SDWA_SDST)):
|
||
src1_sel = ctx.inst_field(type(inst).src1_sel)
|
||
src1_sext = ctx.inst_field(type(inst).src1_sext)
|
||
vsrc1_reg = ctx.inst_field(type(inst).vsrc1)
|
||
|
||
# For VOPC: use unroll_lanes to build the bitmask from scratch (no read-modify-write on stale data)
|
||
if is_vopc:
|
||
def get_cmp_bit(lane) -> UOp:
|
||
lc = lane.cast(dtypes.int) if isinstance(lane, UOp) else _c(lane, dtypes.int)
|
||
s0_raw = ctx.rsgpr_dyn(vsrc0_reg) if inst.s0 else ctx.rvgpr_dyn(vsrc0_reg, lc)
|
||
s0 = _sdwa_select(s0_raw, src0_sel, src0_sext)
|
||
s1_raw = ctx.rsgpr_dyn(vsrc1_reg) if inst.s1 else ctx.rvgpr_dyn(vsrc1_reg, lc)
|
||
s1 = _sdwa_select(s1_raw, src1_sel, src1_sext)
|
||
srcs = {'S0': s0, 'S1': s1, 'laneId': lc}
|
||
for dest, val in parse_pcode(pcode, srcs)[1]:
|
||
if '[laneId]' in dest and ('D0' in dest or 'EXEC' in dest): return val.cast(dtypes.uint32)
|
||
return _c(0)
|
||
new_result = ctx.unroll_lanes(get_cmp_bit, exec_mask, apply_exec=False) & exec_mask
|
||
stores = ctx.wmask(sdst_off, new_result)
|
||
return UOp.sink(*stores, *ctx.inc_pc())
|
||
|
||
# Non-VOPC path: VOP1_SDWA, VOP2_SDWA, VOP2_SDWA_SDST — uses lane loop
|
||
lane = ctx.range()
|
||
vdst_reg = ctx.inst_field(type(inst).vdst) # type: ignore[union-attr]
|
||
s0_raw = ctx.rsgpr_dyn(vsrc0_reg) if inst.s0 else ctx.rvgpr_dyn(vsrc0_reg, lane)
|
||
s0 = _sdwa_select(s0_raw, src0_sel, src0_sext)
|
||
if isinstance(inst, (irc.VOP2_SDWA, irc.VOP2_SDWA_SDST)):
|
||
s1_raw = ctx.rsgpr_dyn(vsrc1_reg) if inst.s1 else ctx.rvgpr_dyn(vsrc1_reg, lane)
|
||
s1 = _sdwa_select(s1_raw, src1_sel, src1_sext)
|
||
srcs:dict[str, UOp | int] = {'S0': s0, 'S1': s1, 'D0': ctx.rvgpr_dyn(vdst_reg, lane)}
|
||
else:
|
||
srcs = {'S0': s0}
|
||
# dst_sel and dst_unused
|
||
has_dst_sel = hasattr(type(inst), 'dst_sel')
|
||
if has_dst_sel:
|
||
dst_sel = ctx.inst_field(type(inst).dst_sel) # type: ignore[union-attr]
|
||
dst_unused = ctx.inst_field(type(inst).dst_unused) # type: ignore[union-attr]
|
||
srcs.update({'VCC': ctx.rmask(_c(VCC_LO.offset)), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)),
|
||
'laneId': lane, 'VDST': vdst_reg, 'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0),
|
||
'ROUND_NEAREST_EVEN': _c(0), '_vgpr': ctx.vgpr, '_wave_size': ctx.wave_size,
|
||
'SDWA_SRC0_SEL': _c(0), 'BYTE0': _c(0), 'BYTE1': _c(1), 'BYTE2': _c(2), 'BYTE3': _c(3),
|
||
'WORD0': _c(0), 'WORD1': _c(1)})
|
||
_, assigns = parse_pcode(pcode, srcs)
|
||
stores = []
|
||
vcc_val = None
|
||
for dest, val in assigns:
|
||
if 'D0' in dest and '[laneId]' in dest:
|
||
vcc_val = val
|
||
elif dest.startswith('D0'):
|
||
result = _val_to_u32(val)
|
||
if has_dst_sel:
|
||
old = ctx.rvgpr_dyn(vdst_reg, lane)
|
||
result = _sdwa_write(old, result, dst_sel, dst_unused)
|
||
stores.append(ctx.wvgpr_dyn(vdst_reg, lane, result, exec_mask))
|
||
elif dest.startswith('VCC'):
|
||
old_vcc = ctx.rmask(_c(VCC_LO.offset))
|
||
stores.extend(ctx.wmask(_c(VCC_LO.offset), _set_lane_bit(old_vcc, lane, val, exec_mask)))
|
||
if vcc_val is not None:
|
||
# Initialize sdst to 0 before lane loop (old value may be unrelated data), then set lane bits in loop
|
||
init_stores = [ctx.wsgpr_dyn(sdst_off, _c(0)), ctx.wsgpr_dyn(sdst_off + _c(1), _c(0))]
|
||
old_sdst = ctx.rmask(sdst_off)
|
||
stores.extend(ctx.wmask(sdst_off, _set_lane_bit(old_sdst, lane, vcc_val, exec_mask)))
|
||
if stores:
|
||
return UOp.sink(*init_stores, UOp.sink(*stores).end(lane), *ctx.inc_pc())
|
||
return UOp.sink(*init_stores, *ctx.inc_pc())
|
||
if stores:
|
||
return UOp.sink(UOp.sink(*stores).end(lane), *ctx.inc_pc())
|
||
return UOp.sink(*ctx.inc_pc())
|
||
|
||
def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP1_DPP16 | ir3.VOP2 | ir3.VOP2_DPP16 |
|
||
ir4.VOP1 | ir4.VOP1_SDST | ir4.VOP1_DPP16 | ir4.VOP2 | ir4.VOP2_DPP16 |
|
||
irc.VOP1 | irc.VOP1_DPP16 | irc.VOP2 | irc.VOP2_DPP16, ctx: _Ctx) -> UOp:
|
||
op_name = _op_name(inst)
|
||
if op_name in ('V_READFIRSTLANE_B32_E32', 'V_PERMLANE64_B32_E32'): return ctx.compile_lane_pcode(inst.op, inst)
|
||
# v_accvgpr_mov_b32: ACCVGPR[vdst] = ACCVGPR[src0] (VOP1 encoding, no pcode)
|
||
if 'ACCVGPR_MOV' in op_name:
|
||
lane, exec_mask = ctx.range(), ctx.rexec()
|
||
vdst_reg = ctx.inst_field(type(inst).vdst) # VGPRField: raw ACCVGPR index (0-255)
|
||
acc_src0_off = ctx.inst_field(type(inst).src0) # SrcField: raw 256 + ACCVGPR index
|
||
val = ctx.raccvgpr_dyn(acc_src0_off - _c(256), lane)
|
||
return UOp.sink(ctx.waccvgpr_dyn(vdst_reg, lane, val, exec_mask).end(lane), *ctx.inc_pc())
|
||
lane, exec_mask, bits = ctx.range(), ctx.rexec(), inst.canonical_op_bits
|
||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
|
||
is_f64 = 'F64' in op_name and 'B64' not in op_name
|
||
is_float = any(x in op_name for x in ('F16', 'F32', 'F64'))
|
||
is_dpp16 = hasattr(type(inst), 'dpp') and hasattr(type(inst), 'vsrc0')
|
||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||
write_hi_half = bits['d'] == 16 and (vdst_reg >= _c(128))
|
||
if isinstance(write_hi_half, UOp): vdst_reg = write_hi_half.where(vdst_reg - _c(128), vdst_reg)
|
||
elif write_hi_half: vdst_reg -= 128
|
||
src0_off: UOp | None = None
|
||
if isinstance(inst, (ir3.VOP1, ir4.VOP1, irc.VOP1)):
|
||
# Handle VOP1 hi-half source operand (src0 >= v[128] for 16-bit ops)
|
||
d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane))
|
||
if is_dpp16:
|
||
s0 = _load_dpp16_src0(ctx, inst, lane, d0)
|
||
else:
|
||
src0_off = ctx.inst_field(type(inst).src0)
|
||
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal, is_f64)
|
||
if bits['s0'] == 16 and not is_dpp16:
|
||
src0_hi = src0_off >= _c(384)
|
||
# Only compute hi-half when src0_off >= 384, use guarded index to prevent OOB access
|
||
src0_reg = src0_hi.where(src0_off - _c(384), _c(0))
|
||
s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0)
|
||
if is_dpp16 and is_float:
|
||
s0 = _apply_src_mods(s0, 0, 1 if getattr(inst, 'src0_abs', 0) else 0, 1 if getattr(inst, 'src0_neg', 0) else 0, bits['s0'])
|
||
srcs:dict[str, UOp | int] = {'S0': s0, 'D0': d0}
|
||
else:
|
||
vsrc1_reg = ctx.inst_field(type(inst).vsrc1)
|
||
vsrc1_hi = bits['s0'] == 16 and (vsrc1_reg >= _c(128))
|
||
vsrc1_actual = _cond(vsrc1_hi, vsrc1_reg - _c(128), vsrc1_reg)
|
||
if bits['s1'] == 64:
|
||
s1 = _u64(ctx.rvgpr_dyn(vsrc1_reg, lane), ctx.rvgpr_dyn(vsrc1_reg + _c(1), lane))
|
||
d0 = _u64(ctx.rvgpr_dyn(vdst_reg, lane), ctx.rvgpr_dyn(vdst_reg + _c(1), lane))
|
||
else:
|
||
s1 = _cond_hi16(vsrc1_hi, ctx.rvgpr_dyn(vsrc1_actual, lane))
|
||
d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane)) # FMAC/FMAMK hi-half dest needs hi-half accumulator
|
||
# Handle VOP2 hi-half src0 operand (src0 >= v[128] for 16-bit ops)
|
||
if is_dpp16:
|
||
s0 = _load_dpp16_src0(ctx, inst, lane, d0)
|
||
else:
|
||
src0_off = ctx.inst_field(type(inst).src0)
|
||
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal, is_f64)
|
||
if bits['s0'] == 16 and not is_dpp16:
|
||
src0_hi = src0_off >= _c(384)
|
||
# Only compute hi-half when src0_off >= 384, use guarded index to prevent OOB access
|
||
src0_reg = src0_hi.where(src0_off - _c(384), _c(0))
|
||
s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0)
|
||
if is_dpp16 and is_float:
|
||
s0 = _apply_src_mods(s0, 0, 1 if getattr(inst, 'src0_abs', 0) else 0, 1 if getattr(inst, 'src0_neg', 0) else 0, bits['s0'])
|
||
s1 = _apply_src_mods(s1, 0, 1 if getattr(inst, 'src1_abs', 0) else 0, 1 if getattr(inst, 'src1_neg', 0) else 0, bits['s1'])
|
||
srcs = {'S0': s0, 'S1': s1, 'D0': d0}
|
||
# FMAAK_(DTYPE)_E32 series
|
||
if 'V_FMAA' in _op_name(inst) or 'V_FMAM' in _op_name(inst):
|
||
assert literal is not None
|
||
srcs['SIMM32'] = literal
|
||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=write_hi_half, src0_off=src0_off)
|
||
|
||
def _compile_vopc(inst: ir3.VOPC|ir3.VOPC_DPP16|ir3.VOP3|ir4.VOPC|ir4.VOPC_DPP16|ir4.VOP3|irc.VOPC|irc.VOP3, ctx: _Ctx,
|
||
opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp:
|
||
exec_mask, op_name, bits = ctx.rexec(), _op_name(inst), inst.canonical_op_bits
|
||
is_cmpx, is_vopc = 'CMPX' in op_name, hasattr(inst, 'vsrc1') # is_vopc: e32 vs e64
|
||
is_dpp16 = hasattr(type(inst), 'dpp') and hasattr(type(inst), 'vsrc0')
|
||
|
||
# Handle both VOPC (vsrc1) and VOP3 (src1) instruction formats - read operands dynamically
|
||
if is_vopc:
|
||
src0_off = ctx.inst_field(type(inst).src0)
|
||
vsrc1_off = ctx.inst_field(type(inst).vsrc1) # type: ignore[union-attr]
|
||
# For 16-bit ops, vsrc1 >= 128 means hi-half of v[vsrc1-128]
|
||
if bits['s0'] == 16:
|
||
vsrc1_hi = vsrc1_off >= _c(128)
|
||
src1_off = _c(256) + vsrc1_hi.where(vsrc1_off - _c(128), vsrc1_off)
|
||
else:
|
||
vsrc1_hi = False
|
||
src1_off = _c(256) + vsrc1_off
|
||
else:
|
||
src0_off = ctx.inst_field(type(inst).src0)
|
||
src1_off = ctx.inst_field(type(inst).src1) # type: ignore[union-attr]
|
||
dst_off = ctx.inst_field(type(inst).vdst) # type: ignore[union-attr]
|
||
vsrc1_hi = False
|
||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
|
||
|
||
is_float, is_f64, pcode = any(x in op_name for x in ('_F32', '_F64', '_F16')), '_F64' in op_name, get_pcode(inst.op)
|
||
def get_cmp_bit(lane) -> UOp:
|
||
lc = lane.cast(dtypes.int) if isinstance(lane, UOp) else _c(lane, dtypes.int)
|
||
s0 = _load_dpp16_src0(ctx, inst, lc, _c(0)) if is_dpp16 else ctx.rsrc_dyn(src0_off, lc, bits['s0'], literal, is_f64)
|
||
s1 = _cond_hi16(vsrc1_hi, ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)) if bits['s0'] == 16 \
|
||
else ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)
|
||
if bits['s0'] == 16 and opsel: s0, s1 = _apply_opsel(s0, 0, opsel), _apply_opsel(s1, 1, opsel)
|
||
if is_float:
|
||
if is_dpp16:
|
||
s0 = _apply_src_mods(s0, 0, 1 if getattr(inst, 'src0_abs', 0) else 0, 1 if getattr(inst, 'src0_neg', 0) else 0, bits['s0'])
|
||
s1 = _apply_src_mods(s1, 0, 1 if getattr(inst, 'src1_abs', 0) else 0, 1 if getattr(inst, 'src1_neg', 0) else 0, bits['s1'])
|
||
s0 = _apply_src_mods(s0, 0, abs_bits, neg_bits, bits['s0'])
|
||
s1 = _apply_src_mods(s1, 1, abs_bits, neg_bits, bits['s1'])
|
||
for dest, val in parse_pcode(pcode, {'S0': s0, 'S1': s1, 'laneId': lc, 'D0': UOp.const(dtypes.uint64, 0)})[1]:
|
||
if '[laneId]' in dest and ('D0' in dest or 'EXEC' in dest): return val.cast(dtypes.uint32)
|
||
return _c(0)
|
||
|
||
new_bits = ctx.unroll_lanes(get_cmp_bit, exec_mask, apply_exec=False)
|
||
# Both VOPC and VOP3 clear inactive lane bits (hardware verified)
|
||
new_result = new_bits & exec_mask
|
||
|
||
# CMPX e32: writes EXEC only; CMPX e64: writes both EXEC and SDST; non-CMPX: writes dst only
|
||
if is_cmpx:
|
||
stores = ctx.wmask(_c(EXEC_LO.offset), new_result)
|
||
if not is_vopc: stores.extend(ctx.wmask(dst_off, new_result))
|
||
else:
|
||
stores = ctx.wmask(dst_off, new_result) if not is_vopc else ctx.wmask(_c(VCC_LO.offset), new_result)
|
||
return UOp.sink(*stores, *ctx.inc_pc())
|
||
|
||
|
||
def _compile_bitop3(inst, ctx: _Ctx, exec_mask: UOp, bits: dict, op_name: str) -> UOp:
|
||
"""BITOP3: 3-input truth table. abs/neg/omod encode the truth table, not source modifiers."""
|
||
lane = ctx.range()
|
||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||
ops = inst.canonical_operands
|
||
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, bits['s0'], None, 's0' in ops and ops['s0'][0] == Fmt.FMT_NUM_F64)
|
||
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, bits['s1'], None, 's1' in ops and ops['s1'][0] == Fmt.FMT_NUM_F64)
|
||
src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, bits['s2'], None, 's2' in ops and ops['s2'][0] == Fmt.FMT_NUM_F64)
|
||
# Truth table: TTBL = { omod[1:0], abs[2:0], neg[2:0] } = 8-bit LUT
|
||
ttbl = ((getattr(inst, 'omod', 0) or 0) << 6) | ((getattr(inst, 'abs', 0) or 0) << 3) | (getattr(inst, 'neg', 0) or 0)
|
||
is_16 = 'B16' in op_name
|
||
dt, mask = (dtypes.uint16, 0xFFFF) if is_16 else (dtypes.uint32, 0xFFFFFFFF)
|
||
s0, s1, s2 = src0.cast(dt), src1.cast(dt), src2.cast(dt)
|
||
def bnot(v): return v ^ UOp.const(dt, mask)
|
||
result = UOp.const(dt, 0)
|
||
for i in range(8):
|
||
if not (ttbl & (1 << i)): continue
|
||
result = result | ((s0 if i & 4 else bnot(s0)) & (s1 if i & 2 else bnot(s1)) & (s2 if i & 1 else bnot(s2)))
|
||
return UOp.sink(ctx.wvgpr_dyn(vdst_reg, lane, result.cast(dtypes.uint32), exec_mask).end(lane), *ctx.inc_pc())
|
||
|
||
def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3 | irc.VOP3, ctx: _Ctx) -> UOp:
|
||
exec_mask = ctx.rexec()
|
||
bits = inst.canonical_op_bits
|
||
opsel, op_name = getattr(inst, 'opsel', 0) or 0, _op_name(inst)
|
||
|
||
# Lane operations
|
||
if op_name in ('V_READLANE_B32', 'V_READFIRSTLANE_B32', 'V_READFIRSTLANE_B32_E64', 'V_WRITELANE_B32'):
|
||
return ctx.compile_lane_pcode(inst.op, inst)
|
||
|
||
# V_PERMLANE16_B32 / V_PERMLANEX16_B32: cross-lane swizzle via pcode
|
||
if 'PERMLANE16' in op_name or 'PERMLANEX16' in op_name:
|
||
return ctx.compile_lane_pcode(inst.op, inst)
|
||
|
||
# VOP3 VOPC (v_cmp_*_e64) - delegate to unified VOPC handler
|
||
if 'V_CMP' in op_name or 'V_CMPX' in op_name:
|
||
return _compile_vopc(inst, ctx, opsel=opsel, abs_bits=getattr(inst, 'abs', 0) or 0, neg_bits=getattr(inst, 'neg', 0) or 0)
|
||
|
||
# BITOP3: abs/neg/omod encode truth table, not source modifiers
|
||
if 'BITOP3' in op_name:
|
||
return _compile_bitop3(inst, ctx, exec_mask, bits, op_name)
|
||
|
||
# VOP3 specific fields
|
||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
|
||
abs_bits, neg_bits = getattr(inst, 'abs', 0) or 0, getattr(inst, 'neg', 0) or 0
|
||
|
||
# VOP3_SDST: v_s_* instructions goes to SGPR
|
||
if 'V_S_' in op_name:
|
||
src0 = _apply_src_mods(ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), _c(0, dtypes.int), bits['s0'], literal), 0, abs_bits, neg_bits, bits['s0'])
|
||
srcs = {'S0': src0, 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': _c(0, dtypes.int),
|
||
'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0)}
|
||
_, assigns = parse_pcode(get_pcode(inst.op), srcs)
|
||
stores = [ctx.wsgpr_dyn(vdst_reg, _val_to_u32(val)) for dest, val in assigns if dest.startswith('D0')]
|
||
return UOp.sink(*stores, *ctx.inc_pc())
|
||
|
||
# Regular VOP3 - read operands dynamically
|
||
lane = ctx.range()
|
||
ops = inst.canonical_operands
|
||
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, bits['s0'], literal, 's0' in ops and ops['s0'][0] == Fmt.FMT_NUM_F64)
|
||
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, bits['s1'], literal, 's1' in ops and ops['s1'][0] == Fmt.FMT_NUM_F64)
|
||
src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, bits['s2'], literal, 's2' in ops and ops['s2'][0] == Fmt.FMT_NUM_F64)
|
||
if bits['s0'] == 16:
|
||
src0 = _apply_opsel(src0, 0, opsel)
|
||
src1 = _apply_opsel(src1, 1, opsel)
|
||
src2 = _apply_opsel(src2, 2, opsel)
|
||
src0 = _apply_src_mods(src0, 0, abs_bits, neg_bits, bits['s0'])
|
||
src1 = _apply_src_mods(src1, 1, abs_bits, neg_bits, bits['s1'])
|
||
src2 = _apply_src_mods(src2, 2, abs_bits, neg_bits, bits['s2'])
|
||
srcs = {'S0': src0, 'S1': src1, 'S2': src2, 'OPSEL': UOp.const(dtypes.uint32, opsel)}
|
||
if 'CNDMASK' in op_name and src2 is not None: srcs['VCC'] = src2
|
||
# FMAC instructions need D0 (accumulator) from destination register
|
||
if 'FMAC' in op_name: srcs['D0'] = ctx.rvgpr_dyn(vdst_reg, lane)
|
||
opsel_dst_hi = bool(opsel & 0b1000) and bits['d'] == 16
|
||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=opsel_dst_hi, clmp=getattr(inst, 'clmp', 0))
|
||
|
||
def _compile_vinterp(inst: ir3.VINTERP | ir4.VINTERP, ctx: _Ctx) -> UOp:
|
||
lane, exec_mask = ctx.range(), ctx.rexec()
|
||
inst_type = type(inst)
|
||
vdst_reg = ctx.inst_field(inst_type.vdst)
|
||
src0_off, src1_off, src2_off = ctx.inst_field(inst_type.src0), ctx.inst_field(inst_type.src1), ctx.inst_field(inst_type.src2)
|
||
src0_reg = (src0_off >= _c(256)).where(src0_off - _c(256), src0_off)
|
||
src2_reg = (src2_off >= _c(256)).where(src2_off - _c(256), src2_off)
|
||
srcs = {
|
||
'SRC0': src0_reg, 'SRC2': src2_reg,
|
||
'S0': ctx.rsrc_dyn(src0_off, lane), 'S1': ctx.rsrc_dyn(src1_off, lane), 'S2': ctx.rsrc_dyn(src2_off, lane),
|
||
}
|
||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask)
|
||
|
||
def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD | irc.VOP3SD, ctx: _Ctx) -> UOp:
|
||
exec_mask = ctx.rexec()
|
||
bits, pcode, ops = inst.canonical_op_bits, get_pcode(inst.op), inst.canonical_operands
|
||
|
||
# Read operands dynamically from instruction encoding
|
||
vdst_reg, sdst_off = ctx.inst_field(type(inst).vdst), ctx.inst_field(type(inst).sdst)
|
||
src0_off, src1_off, src2_off = ctx.inst_field(type(inst).src0), ctx.inst_field(type(inst).src1), ctx.inst_field(type(inst).src2)
|
||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
|
||
|
||
has_carry_in = 's2' in ops and ops['s2'][2] == OpType.OPR_SREG
|
||
vcc_in_off = src2_off if has_carry_in else sdst_off
|
||
|
||
def load_srcs(lane_uop):
|
||
ret = {'VCC': ctx.rmask(vcc_in_off), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane_uop}
|
||
ret['S0'] = ctx.rsrc_dyn(src0_off, lane_uop, bits['s0'], literal, ops['s0'][0] == Fmt.FMT_NUM_F64)
|
||
ret['S1'] = ctx.rsrc_dyn(src1_off, lane_uop, bits['s1'], literal, ops['s1'][0] == Fmt.FMT_NUM_F64)
|
||
if 's2' in ops: ret['S2'] = ctx.rsrc_dyn(src2_off, lane_uop, bits['s2'], literal, ops['s2'][0] == Fmt.FMT_NUM_F64)
|
||
return ret
|
||
|
||
lane = ctx.range()
|
||
srcs = load_srcs(lane)
|
||
_, assigns = parse_pcode(pcode, srcs)
|
||
|
||
has_per_lane_vcc = any('[laneId]' in dest for dest, _ in assigns if dest.startswith('VCC') or dest.startswith('D0.u64'))
|
||
clmp = getattr(inst, 'clmp', 0)
|
||
if has_per_lane_vcc:
|
||
# VCC computation: RANGE+REDUCE gets axis ID first (lower ID = runs first)
|
||
# This ensures VCC reads source values BEFORE VGPR stores modify them
|
||
def get_vcc_bit(lane_uop) -> UOp:
|
||
vcc_bit = _c(0)
|
||
for dest, val in parse_pcode(pcode, load_srcs(lane_uop))[1]:
|
||
if dest.startswith('VCC') or (dest.startswith('D0.u64') and '[laneId]' in dest): vcc_bit = val.cast(dtypes.uint32)
|
||
return vcc_bit
|
||
final_vcc = ctx.unroll_lanes(get_vcc_bit, exec_mask)
|
||
# VGPR stores: RANGE gets axis ID second (higher ID = runs after VCC loop)
|
||
lane3 = ctx.range()
|
||
d0_val, vcc_per_lane = None, None
|
||
for dest, val in parse_pcode(pcode, load_srcs(lane3))[1]:
|
||
if dest.startswith('D0') and '[laneId]' not in dest: d0_val = val
|
||
if dest.startswith('VCC') or (dest.startswith('D0.u64') and '[laneId]' in dest): vcc_per_lane = val
|
||
vgpr_stores = []
|
||
if d0_val is not None:
|
||
# Apply clamp using carry/borrow bit: ADD overflow->0xFFFFFFFF, SUB underflow->0
|
||
if clmp and vcc_per_lane is not None:
|
||
is_sub = 'SUB' in inst.op.name
|
||
sat_val = _c(0) if is_sub else _c(0xFFFFFFFF)
|
||
d0_val = vcc_per_lane.cast(dtypes.bool).where(sat_val, d0_val.cast(dtypes.uint32))
|
||
if d0_val.dtype in (dtypes.uint64, dtypes.int64, dtypes.float64):
|
||
lo, hi = _split64(d0_val)
|
||
vgpr_stores.extend([ctx.wvgpr_dyn(vdst_reg, lane3, lo, exec_mask), ctx.wvgpr_dyn(vdst_reg + _c(1), lane3, hi, exec_mask)])
|
||
else:
|
||
d0_u32 = d0_val.bitcast(dtypes.uint32) if d0_val.dtype in (dtypes.float32, dtypes.half) else d0_val.cast(dtypes.uint32)
|
||
vgpr_stores.append(ctx.wvgpr_dyn(vdst_reg, lane3, d0_u32, exec_mask))
|
||
# Write carry output (wmask handles lo/hi split for wave64)
|
||
vcc_writes = ctx.wmask(sdst_off, final_vcc)
|
||
return UOp.sink(*vcc_writes, UOp.group(*vgpr_stores).end(lane3), *ctx.inc_pc())
|
||
else:
|
||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, sdst_reg=inst.sdst.offset)
|
||
|
||
def _compile_mfma(inst: irc.VOP3P, ctx: _Ctx) -> UOp:
|
||
"""CDNA MFMA matrix multiply-accumulate emulation.
|
||
|
||
Uses local temp arrays to cache inputs, avoiding aliasing issues when vdst overlaps src0/src1.
|
||
Phase 1: Read all input f32 values from VGPRs into temp arrays (range loop over 64 lanes).
|
||
Phase 2: Compute 256 output values using temp arrays and write to VGPRs (range loop over 64 lanes)
|
||
|
||
Register layout (wave64):
|
||
- 16x16: 4 groups of 16 lanes. Each lane in group holds k_per_grp elements. 4 output ACCVGPRs per lane.
|
||
- 32x32: 2 groups of 32 lanes. lanes%16 = M/N index within block, lanes//16 selects block. 16 output ACCVGPRs per lane.
|
||
- 4x4: 16 groups of 4 lanes. 4 output ACCVGPRs per lane.
|
||
"""
|
||
import re as _re
|
||
op_name = _op_name(inst)
|
||
exec_mask = ctx.rexec()
|
||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||
src0_off = ctx.inst_field(type(inst).src0)
|
||
src1_off = ctx.inst_field(type(inst).src1)
|
||
src0_r = src0_off - _c(256) # VGPR-relative index (only valid when src is VGPR)
|
||
src1_r = src1_off - _c(256)
|
||
src2_off = ctx.inst_field(type(inst).src2)
|
||
# Check if sources are VGPRs (offset >= 256) vs inline constants/SGPRs
|
||
src0_is_vgpr = src0_off >= _c(256)
|
||
src1_is_vgpr = src1_off >= _c(256)
|
||
|
||
m = _re.search(r'(\d+)X(\d+)X(\d+)', op_name)
|
||
if m is None: raise ValueError(f"could not parse MFMA dimensions from {op_name}")
|
||
M, N, K = int(m.group(1)), int(m.group(2)), int(m.group(3))
|
||
|
||
is_bf16 = 'BF16' in op_name
|
||
is_fp8 = 'FP8' in op_name or 'F8' in op_name
|
||
is_i8 = 'I8' in op_name
|
||
# Source type is the LAST type in the name: V_MFMA_F32_16X16X32_**F16** -> source is F16, not F32
|
||
src_type = op_name.rsplit('_', 1)[-1] # e.g. "F16", "BF16", "F32", "I8"
|
||
is_f32_src = src_type == 'F32'
|
||
is_int_out = 'I32' in op_name.split('_')[2] # V_MFMA_I32_...
|
||
|
||
# Determine elements per VGPR and conversion function
|
||
if is_i8: vpg = 4
|
||
elif is_f32_src: vpg = 1
|
||
elif is_fp8: vpg = 4
|
||
else: vpg = 2
|
||
|
||
# For 16x16: grp_size=16, n_grps=4, out_per_lane=4
|
||
# For 32x32: grp_size=32, n_grps=2, out_per_lane=16
|
||
# For 4x4: grp_size=4, n_grps=16, out_per_lane=4
|
||
if M == 16 and N == 16:
|
||
grp_size, n_grps, out_per_lane = 16, 4, 4
|
||
elif M == 32 and N == 32:
|
||
grp_size, n_grps, out_per_lane = 32, 2, 16
|
||
elif M == 4 and N == 4:
|
||
grp_size, n_grps, out_per_lane = 4, 16, 4
|
||
else:
|
||
raise RuntimeError(f"unsupported MFMA shape {M}x{N}x{K}")
|
||
|
||
# For 4x4: each group independently computes a 4x4 block. K is NOT split across groups.
|
||
# For 16x16/32x32: K IS split across groups (each group has K/n_grps elements).
|
||
k_per_grp = K if M == 4 else K // n_grps
|
||
# Temp array size: for 4x4, store all 16 independent blocks; for others, store shared MxK/NxK
|
||
n_a_elems = n_grps * M * K if M == 4 else M * K
|
||
n_b_elems = n_grps * N * K if M == 4 else N * K
|
||
|
||
# src2 can be VGPR (>=256) or inline constant/SGPR (<256)
|
||
src2_is_vgpr = src2_off >= _c(256)
|
||
src2_r = src2_off - _c(256)
|
||
if is_int_out:
|
||
acc_scalar = ctx.rsgpr_dyn(src2_off, src2_is_vgpr.ne(True)).cast(dtypes.int32)
|
||
else:
|
||
acc_scalar = ctx.rsgpr_dyn(src2_off, src2_is_vgpr.ne(True)).bitcast(dtypes.float32)
|
||
|
||
# Phase 1: Read all A and B values from VGPRs into temp arrays.
|
||
# Layout: tmp[0..n_a_elems-1] = A[m][k], tmp[n_a_elems..n_a_elems+n_b_elems-1] = B[n][k]
|
||
# Within each group of lanes, lane%grp_sub gives M/N index, lane//grp_sub gives sub-block
|
||
grp_sub = min(M, 16) # lanes within group mapped to M/N dimension
|
||
b_off = UOp.const(dtypes.int, n_a_elems)
|
||
acc_dt = dtypes.int32 if is_int_out else dtypes.float32
|
||
# Use uint32 temp array to prevent optimizer from eliminating f16→f32 bitcast chains.
|
||
# The optimizer folds bitcast(uint32→float32) stores to float32 arrays, losing the conversion.
|
||
tmp = UOp(Ops.DEFINE_LOCAL, dtypes.uint32.ptr(n_a_elems + n_b_elems, addrspace=AddrSpace.LOCAL), arg=(n_a_elems + n_b_elems,))
|
||
|
||
def cvt_elem(raw: UOp, sub_idx: int) -> UOp:
|
||
if is_i8:
|
||
# Extract i8, sign-extend to i32
|
||
byte_val = (raw >> UOp.const(dtypes.uint32, sub_idx * 8)) & UOp.const(dtypes.uint32, 0xFF)
|
||
return (byte_val.cast(dtypes.int32) ^ UOp.const(dtypes.int32, 0x80)) - UOp.const(dtypes.int32, 0x80)
|
||
elif is_f32_src:
|
||
return raw # already uint32 (f32 bit pattern)
|
||
elif is_fp8:
|
||
return ((raw >> UOp.const(dtypes.uint32, sub_idx * 8)) & UOp.const(dtypes.uint32, 0xFF)).cast(dtypes.uint32)
|
||
elif is_bf16:
|
||
# bf16→f32 bits: just shift left by 16 (bf16 is upper 16 bits of f32)
|
||
return ((raw >> UOp.const(dtypes.uint32, sub_idx * 16)) & UOp.const(dtypes.uint32, 0xFFFF)) << UOp.const(dtypes.uint32, 16)
|
||
else:
|
||
# f16→f32 conversion using float arithmetic to avoid UOp optimizer eliminating the conversion.
|
||
# The optimizer folds bitcast(uint32→float32) chains, so we compute the float value directly.
|
||
h = (raw >> UOp.const(dtypes.uint32, sub_idx * 16)) & UOp.const(dtypes.uint32, 0xFFFF)
|
||
sign = (h >> UOp.const(dtypes.uint32, 15)) & UOp.const(dtypes.uint32, 1)
|
||
exp = (h >> UOp.const(dtypes.uint32, 10)) & UOp.const(dtypes.uint32, 0x1F)
|
||
mant = h & UOp.const(dtypes.uint32, 0x3FF)
|
||
# Use bf16 path: shift left by 16 to create bf16 bits, then shift mantissa and adjust exponent in float domain
|
||
# bf16 bits = (sign << 15) | (exp_bf16 << 7) | mant_bf16 -- but f16 and bf16 have different formats
|
||
# Instead: construct f32 bits properly, use a DEFINE_LOCAL uint32 array to force materialization
|
||
f32_bits = (sign << UOp.const(dtypes.uint32, 31)) | \
|
||
((exp + UOp.const(dtypes.uint32, 112)) << UOp.const(dtypes.uint32, 23)) | \
|
||
(mant << UOp.const(dtypes.uint32, 13))
|
||
is_zero = exp.eq(UOp.const(dtypes.uint32, 0))
|
||
# Return uint32 (f32 bit pattern) — stored directly to uint32 temp array, bitcast to float on read
|
||
return is_zero.where(UOp.const(dtypes.uint32, 0), f32_bits)
|
||
|
||
read_lane = ctx.range()
|
||
# For 32x32: lane%16 = M/N index within 16-wide block, lane//16 = which of 4 quarter-waves
|
||
# Groups: lanes 0-31 = group 0, lanes 32-63 = group 1
|
||
# Within group: (lane%32)%16 = M/N[0-15], (lane%32)//16 selects M/N[0-15] or [16-31]
|
||
lane_in_grp = read_lane % UOp.const(dtypes.int, grp_size)
|
||
grp_idx = read_lane // UOp.const(dtypes.int, grp_size)
|
||
|
||
if M == 32:
|
||
# 32x32: lane_in_grp%16 = sub-row/col (0-15), lane_in_grp//16 = block (0=rows 0-15, 1=rows 16-31)
|
||
sub_mn = lane_in_grp % UOp.const(dtypes.int, 16)
|
||
block_mn = lane_in_grp // UOp.const(dtypes.int, 16)
|
||
mn_idx = block_mn * UOp.const(dtypes.int, 16) + sub_mn # actual M/N index (0-31)
|
||
else:
|
||
mn_idx = lane_in_grp # for 16x16 and 4x4
|
||
|
||
read_stores = []
|
||
for kl in range(k_per_grp):
|
||
reg_idx, sub_idx = kl // vpg, kl % vpg
|
||
# Read A/B sources. Use rsrc_dyn for inline constants/SGPRs (src_off < 256), rvgpr_dyn for VGPRs (src_off >= 256).
|
||
a_raw = src0_is_vgpr.where(ctx.rvgpr_dyn(src0_r + _c(reg_idx), read_lane),
|
||
ctx.rsrc_dyn(src0_off, _c(0, dtypes.int), 32))
|
||
a_val = cvt_elem(a_raw, sub_idx)
|
||
if M == 4:
|
||
a_idx = grp_idx * UOp.const(dtypes.int, M * K) + mn_idx * UOp.const(dtypes.int, K) + UOp.const(dtypes.int, kl)
|
||
else:
|
||
a_idx = mn_idx * UOp.const(dtypes.int, K) + grp_idx * UOp.const(dtypes.int, k_per_grp) + UOp.const(dtypes.int, kl)
|
||
read_stores.append(tmp.index(a_idx).store(a_val))
|
||
|
||
b_raw = src1_is_vgpr.where(ctx.rvgpr_dyn(src1_r + _c(reg_idx), read_lane),
|
||
ctx.rsrc_dyn(src1_off, _c(0, dtypes.int), 32))
|
||
b_val = cvt_elem(b_raw, sub_idx)
|
||
if M == 4:
|
||
b_idx = b_off + grp_idx * UOp.const(dtypes.int, N * K) + mn_idx * UOp.const(dtypes.int, K) + UOp.const(dtypes.int, kl)
|
||
else:
|
||
b_idx = b_off + mn_idx * UOp.const(dtypes.int, K) + grp_idx * UOp.const(dtypes.int, k_per_grp) + UOp.const(dtypes.int, kl)
|
||
read_stores.append(tmp.index(b_idx).store(b_val))
|
||
|
||
read_phase = UOp.group(*read_stores).end(read_lane)
|
||
|
||
# Phase 2: Compute dot products and write outputs.
|
||
# For 16x16: each lane computes 4 outputs. n_idx = lane%16, grp selects which 4 rows.
|
||
# For 32x32: each lane computes 16 outputs. Layout: lane%16 selects n within block, lane//16 selects column block.
|
||
# Output mapping: out_reg r at lane l -> D[m][n] where
|
||
# n = (l%32)%16 + ((l%32)//16)*16, m = (l//32)*4 + r (for r in 0..3), with 4 groups of 4 rows -> 16 outputs total
|
||
# Actually: 16 ACCVGPRs per lane, organized as 4 groups (l//32 gives half, each half has 2 sub-groups) of 4 rows
|
||
tmp2 = tmp.after(read_phase)
|
||
|
||
compute_lane = ctx.range()
|
||
compute_stores = []
|
||
|
||
if M == 32 and N == 32:
|
||
# 32x32: each lane has 16 output ACCVGPRs
|
||
# Lane mapping: n = (lane%32)%16 + ((lane%32)//16)*16, gives column 0-31
|
||
# Row groups: 4 groups of 4, covering rows 0-31. Group g (0-3): rows g*4 .. g*4+3
|
||
# group assignment: lane//16 gives quarter (0-3), each quarter maps to 4 rows
|
||
c_lane_in_32 = compute_lane % UOp.const(dtypes.int, 32)
|
||
c_sub = c_lane_in_32 % UOp.const(dtypes.int, 16)
|
||
c_block = c_lane_in_32 // UOp.const(dtypes.int, 16)
|
||
n_idx = c_block * UOp.const(dtypes.int, 16) + c_sub
|
||
c_half = compute_lane // UOp.const(dtypes.int, 32) # 0 or 1
|
||
|
||
for out_reg in range(16):
|
||
# Each half covers 8 rows. out_reg 0-3: rows 0-3 (half0) or 16-19 (half1)
|
||
# out_reg 4-7: rows 4-7 (half0) or 20-23 (half1), etc.
|
||
# Actually: for 32x32, the output layout per lane is:
|
||
# acc[0:3] -> rows 0-3 (half 0) or rows 0-3 (half 1)?
|
||
# Let me use the ISA doc: for 32x32, D has 16 dwords per lane. The mapping is:
|
||
# acc[r] at lane l -> D[m][n] where n = (l%32)%16 + ((l%32)//16)*16
|
||
# m = (l//32)*16 + (r//4)*4 + (r%4) ... giving rows in blocks of 4
|
||
# So: m_base = half * 16 + (out_reg // 4) * 4 + (out_reg % 4)
|
||
m_base = c_half * UOp.const(dtypes.int, 16) + UOp.const(dtypes.int, (out_reg // 4) * 4 + (out_reg % 4))
|
||
|
||
acc_v = ctx.raccvgpr_dyn(src2_r + _c(out_reg), compute_lane, src2_is_vgpr)
|
||
if is_int_out: acc_v = acc_v.cast(dtypes.int32)
|
||
else: acc_v = acc_v.bitcast(dtypes.float32)
|
||
acc = src2_is_vgpr.where(acc_v, acc_scalar)
|
||
|
||
for k in range(K):
|
||
a_val = tmp2.index(m_base * UOp.const(dtypes.int, K) + UOp.const(dtypes.int, k)).bitcast(acc_dt)
|
||
b_val = tmp2.index(b_off + n_idx * UOp.const(dtypes.int, K) + UOp.const(dtypes.int, k)).bitcast(acc_dt)
|
||
acc = acc + a_val * b_val
|
||
|
||
if is_int_out:
|
||
compute_stores.append(ctx.waccvgpr_dyn(vdst_reg + _c(out_reg), compute_lane, acc.cast(dtypes.uint32), exec_mask))
|
||
else:
|
||
compute_stores.append(ctx.waccvgpr_dyn(vdst_reg + _c(out_reg), compute_lane, acc.bitcast(dtypes.uint32), exec_mask))
|
||
else:
|
||
# 16x16 and 4x4: each lane computes out_per_lane outputs
|
||
n_idx = compute_lane % UOp.const(dtypes.int, grp_sub)
|
||
c_grp = compute_lane // UOp.const(dtypes.int, grp_sub)
|
||
|
||
for out_reg in range(out_per_lane):
|
||
acc_v = ctx.raccvgpr_dyn(src2_r + _c(out_reg), compute_lane, src2_is_vgpr)
|
||
if is_int_out: acc_v = acc_v.cast(dtypes.int32)
|
||
else: acc_v = acc_v.bitcast(dtypes.float32)
|
||
acc = src2_is_vgpr.where(acc_v, acc_scalar)
|
||
|
||
if M == 4:
|
||
# 4x4: each group is independent. A/B indexed per-group.
|
||
m_base = c_grp * UOp.const(dtypes.int, M * K) + UOp.const(dtypes.int, out_reg * K)
|
||
for k in range(K):
|
||
a_val = tmp2.index(m_base + UOp.const(dtypes.int, k)).bitcast(acc_dt)
|
||
b_val = tmp2.index(b_off + c_grp * UOp.const(dtypes.int, N*K) + n_idx * UOp.const(dtypes.int, K)+UOp.const(dtypes.int, k)).bitcast(acc_dt)
|
||
acc = acc + a_val * b_val
|
||
else:
|
||
# 16x16: K is split across groups. Shared MxK/NxK arrays.
|
||
m_base = c_grp * UOp.const(dtypes.int, out_per_lane) + UOp.const(dtypes.int, out_reg)
|
||
for k in range(K):
|
||
a_val = tmp2.index(m_base * UOp.const(dtypes.int, K) + UOp.const(dtypes.int, k)).bitcast(acc_dt)
|
||
b_val = tmp2.index(b_off + n_idx * UOp.const(dtypes.int, K) + UOp.const(dtypes.int, k)).bitcast(acc_dt)
|
||
acc = acc + a_val * b_val
|
||
|
||
if is_int_out:
|
||
compute_stores.append(ctx.waccvgpr_dyn(vdst_reg + _c(out_reg), compute_lane, acc.cast(dtypes.uint32), exec_mask))
|
||
else:
|
||
compute_stores.append(ctx.waccvgpr_dyn(vdst_reg + _c(out_reg), compute_lane, acc.bitcast(dtypes.uint32), exec_mask))
|
||
|
||
compute_phase = UOp.group(*compute_stores).end(compute_lane)
|
||
return UOp.sink(read_phase, compute_phase, *ctx.inc_pc())
|
||
|
||
def _compile_wmma(inst: ir3.VOP3P | ir4.VOP3P | irc.VOP3P, ctx: _Ctx) -> UOp:
|
||
op_name = _op_name(inst)
|
||
exec_mask = ctx.rexec()
|
||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||
src0_r = ctx.inst_field(type(inst).src0) - _c(256)
|
||
src1_r = ctx.inst_field(type(inst).src1) - _c(256)
|
||
src2_r = ctx.inst_field(type(inst).src2) - _c(256)
|
||
is_f16_output = 'F16_16X16X16_F16' in op_name or 'BF16_16X16X16_BF16' in op_name # F16/BF16 output vs F32 output
|
||
is_bf16 = 'BF16' in op_name
|
||
cvt = _FUNCS['bf16_to_f32'] if is_bf16 else _FUNCS['f16_to_f32']
|
||
is_rdna4 = isinstance(inst, ir4.VOP3P)
|
||
# read 16x16 F16/BF16 matrix from VGPRs → flat f32 array[row*16+k]
|
||
def read_f16_val(src, lane, vgpr, half):
|
||
v = ctx.rvgpr_dyn(src + _c(vgpr), UOp.const(dtypes.int, lane))
|
||
return cvt((v >> UOp.const(dtypes.uint32, 16)) if half else (v & UOp.const(dtypes.uint32, 0xFFFF)))
|
||
|
||
# RDNA3: 16 lanes × 8 VGPRs × 2 halves, k maps linearly
|
||
# RDNA4: 32 lanes × 4 VGPRs × 2 halves, k bits are scrambled (k[2] goes to lane bit 4)
|
||
def read_f16_mat(src):
|
||
# (row, k) → (lane, vgpr, half)
|
||
def ab_map(i, k):
|
||
elem, lane = ((k & 3) | ((k >> 1) & 4), i + ((k >> 2) & 1) * 16) if is_rdna4 else (k, i)
|
||
return lane, elem // 2, elem % 2
|
||
return [read_f16_val(src, *ab_map(row, k)) for row in range(16) for k in range(16)]
|
||
mat_a, mat_b = read_f16_mat(src0_r), read_f16_mat(src1_r)
|
||
# (row, col) -> (lane, vgpr)
|
||
def d_map(m, n):
|
||
lane_bit, vgpr = (m >> 3, m & 7) if is_rdna4 else (m & 1, m >> 1)
|
||
return n + lane_bit * 16, vgpr
|
||
if is_f16_output:
|
||
# read accumulator C with f16 layout: for RDNA4, pairs of f32 vgprs pack into one f16 vgpr
|
||
# for RDNA3, same layout as f32 but only lo 16 bits used
|
||
mat_c = [read_f16_val(src2_r, *((lane, vgpr // 2, vgpr % 2) if is_rdna4 else (lane, vgpr, 0)))
|
||
for m in range(16) for n in range(16) for lane, vgpr in [d_map(m, n)]]
|
||
mat_d = [sum(mat_a[r*16+k] * mat_b[c*16+k] for k in range(16)) + mat_c[r*16+c] for r in range(16) for c in range(16)]
|
||
def f32_to_f16_bits(v: UOp) -> UOp: return v.cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32)
|
||
def f32_to_bf16_bits(v: UOp) -> UOp: return (v.bitcast(dtypes.uint32) >> UOp.const(dtypes.uint32, 16)) & UOp.const(dtypes.uint32, 0xFFFF)
|
||
out_cvt = f32_to_bf16_bits if is_bf16 else f32_to_f16_bits
|
||
if is_rdna4: # pack 2 f16 per VGPR: adjacent m values share (lane, vgpr) since vgpr=m&7, half=m&1
|
||
stores = [ctx.wvgpr_dyn(vdst_reg + _c(d_map(m, n)[1] // 2), UOp.const(dtypes.int, d_map(m, n)[0]),
|
||
out_cvt(mat_d[m*16+n]) | (out_cvt(mat_d[(m+1)*16+n]) << UOp.const(dtypes.uint32, 16)), exec_mask)
|
||
for n in range(16) for m in range(0, 16, 2)]
|
||
else: # (rdna3) 1 f16 per VGPR (lo half only)
|
||
stores = [ctx.wvgpr_dyn(vdst_reg + _c(d_map(m, n)[1]), UOp.const(dtypes.int, d_map(m, n)[0]), out_cvt(mat_d[m*16+n]), exec_mask)
|
||
for m in range(16) for n in range(16)]
|
||
else: # f32
|
||
mat_c = [ctx.rvgpr_dyn(src2_r + _c(d_map(m, n)[1]), UOp.const(dtypes.int, d_map(m, n)[0])).bitcast(dtypes.float32)
|
||
for m in range(16) for n in range(16)]
|
||
mat_d = [sum(mat_a[r*16+k] * mat_b[c*16+k] for k in range(16)) + mat_c[r*16+c] for r in range(16) for c in range(16)]
|
||
stores = [ctx.wvgpr_dyn(vdst_reg + _c(d_map(m, n)[1]), UOp.const(dtypes.int, d_map(m, n)[0]), mat_d[m*16+n].bitcast(dtypes.uint32), exec_mask)
|
||
for m in range(16) for n in range(16)]
|
||
return UOp.sink(*stores, *ctx.inc_pc())
|
||
|
||
def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P | irc.VOP3P, ctx: _Ctx) -> UOp:
|
||
op_name = _op_name(inst)
|
||
if 'WMMA' in op_name and ('16X16X16_F16' in op_name or '16X16X16_BF16' in op_name): return _compile_wmma(inst, ctx)
|
||
if 'MFMA' in op_name and any(f'{s}X{s}X' in op_name for s in ('4', '16', '32')) and isinstance(inst, irc.VOP3P): return _compile_mfma(inst, ctx)
|
||
|
||
# ACCVGPR_WRITE/READ/MOV: copies between VGPR and ACCVGPR register files
|
||
# Detect by checking operand types for ACCVGPR involvement
|
||
ops = inst.operands
|
||
src0_is_acc = ops.get('src0', (None, None, None))[2] in (OpType.OPR_SRC_ACCVGPR, OpType.OPR_ACCVGPR)
|
||
vdst_is_acc = ops.get('vdst', (None, None, None))[2] in (OpType.OPR_ACCVGPR,)
|
||
if src0_is_acc or vdst_is_acc:
|
||
lane = ctx.range()
|
||
exec_mask = ctx.rexec()
|
||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||
src0_off = ctx.inst_field(type(inst).src0)
|
||
if src0_is_acc and not vdst_is_acc:
|
||
# v_accvgpr_read: VGPR[vdst] = ACCVGPR[src0]
|
||
val = ctx.raccvgpr_dyn(src0_off - _c(256), lane)
|
||
return UOp.sink(ctx.wvgpr_dyn(vdst_reg, lane, val, exec_mask).end(lane), *ctx.inc_pc())
|
||
elif vdst_is_acc and not src0_is_acc:
|
||
# v_accvgpr_write: ACCVGPR[vdst] = src0 (src0 can be VGPR or SGPR/const)
|
||
src0 = ctx.rsrc_dyn(src0_off, lane, 32)
|
||
return UOp.sink(ctx.waccvgpr_dyn(vdst_reg, lane, src0, exec_mask).end(lane), *ctx.inc_pc())
|
||
else:
|
||
# v_accvgpr_mov: ACCVGPR[vdst] = ACCVGPR[src0]
|
||
val = ctx.raccvgpr_dyn(src0_off - _c(256), lane)
|
||
return UOp.sink(ctx.waccvgpr_dyn(vdst_reg, lane, val, exec_mask).end(lane), *ctx.inc_pc())
|
||
|
||
lane = ctx.range()
|
||
exec_mask = ctx.rexec()
|
||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||
is_pk_f32 = 'PK' in op_name and 'F32' in op_name and 'MOV' not in op_name # CDNA packed F32 ops
|
||
is_pk_mov_b32 = 'PK_MOV_B32' in op_name # CDNA packed MOV needs special handling
|
||
do_cast = any(x in op_name for x in ('F16', 'F32', 'BF16')) and 'IU' not in op_name and not is_pk_f32
|
||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
|
||
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, 16, literal=literal, do_cast=do_cast)
|
||
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, 16, literal=literal, do_cast=do_cast)
|
||
src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, 16, literal=literal, do_cast=do_cast)
|
||
opsel, opsel_hi = getattr(inst, 'opsel', 0) or 0, getattr(inst, 'opsel_hi', 3) if getattr(inst, 'opsel_hi', 3) is not None else 3
|
||
opsel_hi2 = getattr(inst, 'opsel_hi2', 1) if getattr(inst, 'opsel_hi2', 1) is not None else 1
|
||
neg, neg_hi = getattr(inst, 'neg', 0) or 0, getattr(inst, 'neg_hi', 0) or 0
|
||
|
||
if is_pk_mov_b32:
|
||
# v_pk_mov_b32: D[lo] = src0[opsel_bit0 ? hi : lo], D[hi] = src1[opsel_bit1 ? hi : lo]
|
||
src_offs = [ctx.inst_field(type(inst).src0), ctx.inst_field(type(inst).src1)]
|
||
def _pk_mov_sel(src_lo: UOp, src_off: UOp, sel_bit: int) -> UOp:
|
||
is_vgpr = src_off >= _c(256)
|
||
vgpr_lo = ctx.rvgpr_dyn(src_off - _c(256), lane) if lane is not None else _c(0)
|
||
vgpr_hi = ctx.rvgpr_dyn(src_off - _c(256) + _c(1), lane) if lane is not None else _c(0)
|
||
is_sgpr_pair = src_off < _c(128)
|
||
sgpr_hi = ctx.rsgpr_dyn(src_off + _c(1), is_sgpr_pair)
|
||
scalar_sel = is_sgpr_pair.where(sgpr_hi, src_lo) if sel_bit else src_lo
|
||
return is_vgpr.where(vgpr_hi if sel_bit else vgpr_lo, scalar_sel)
|
||
lo_val = _pk_mov_sel(src0, src_offs[0], opsel & 1)
|
||
hi_val = _pk_mov_sel(src1, src_offs[1], opsel & 2)
|
||
result = _u64(lo_val, hi_val)
|
||
lo_out, hi_out = _split64(result)
|
||
stores = [ctx.wvgpr_dyn(vdst_reg, lane, lo_out, exec_mask), ctx.wvgpr_dyn(vdst_reg + _c(1), lane, hi_out, exec_mask)]
|
||
return UOp.sink(UOp.group(*stores).end(lane), *ctx.inc_pc())
|
||
|
||
srcs: dict[str, UOp | int] = {}
|
||
if is_pk_f32:
|
||
# CDNA packed F32: read 32-bit sources, build 64-bit packed values using opsel.
|
||
# For VGPRs: opsel selects between v[reg] (0) and v[reg+1] (1) for each half.
|
||
# For SGPR pairs (off < 128): s[N] = lo float32, s[N+1] = hi float32.
|
||
# For inline constants (128 <= off < 256): broadcast same value to both halves.
|
||
src_offs = [ctx.inst_field(type(inst).src0), ctx.inst_field(type(inst).src1), ctx.inst_field(type(inst).src2)]
|
||
def build_pk_f32(src_lo: UOp, src_off: UOp, opsel_lo: int, opsel_hi_bit: int, neg_lo: int, neg_hi_bit: int) -> UOp:
|
||
is_vgpr = src_off >= _c(256)
|
||
vgpr_lo = ctx.rvgpr_dyn(src_off - _c(256), lane) if lane is not None else _c(0)
|
||
vgpr_hi = ctx.rvgpr_dyn(src_off - _c(256) + _c(1), lane) if lane is not None else _c(0)
|
||
# For SGPR pairs, opsel selects between s[N] (0) and s[N+1] (1); inline constants always broadcast.
|
||
is_sgpr_pair = src_off < _c(128)
|
||
sgpr_hi = ctx.rsgpr_dyn(src_off + _c(1), is_sgpr_pair)
|
||
scalar_lo_sel = src_lo if not opsel_lo else is_sgpr_pair.where(sgpr_hi, src_lo)
|
||
scalar_hi_sel = src_lo if not opsel_hi_bit else is_sgpr_pair.where(sgpr_hi, src_lo)
|
||
lo = is_vgpr.where(vgpr_hi if opsel_lo else vgpr_lo, scalar_lo_sel)
|
||
hi = is_vgpr.where(vgpr_hi if opsel_hi_bit else vgpr_lo, scalar_hi_sel)
|
||
if neg_lo: lo = lo ^ UOp.const(dtypes.uint32, 0x80000000)
|
||
if neg_hi_bit: hi = hi ^ UOp.const(dtypes.uint32, 0x80000000)
|
||
return _u64(lo, hi)
|
||
srcs = {'S0': build_pk_f32(src0, src_offs[0], opsel & 1, opsel_hi & 1, neg & 1, neg_hi & 1),
|
||
'S1': build_pk_f32(src1, src_offs[1], opsel & 2, opsel_hi & 2, neg & 2, neg_hi & 2),
|
||
'S2': build_pk_f32(src2, src_offs[2], opsel & 4, 1 if opsel_hi2 else 0, neg & 4, neg_hi & 4)}
|
||
elif 'FMA_MIX' in op_name or 'MAD_MIX' in op_name:
|
||
combined_opsel_hi = (opsel_hi & 0x3) | ((opsel_hi2 & 0x1) << 2)
|
||
# For FMA_MIX: neg_hi is ABS (not neg!), neg is actual negation
|
||
def apply_abs(v, bit, opsel_hi_bit, opsel_bit):
|
||
if not (neg_hi & bit): return v
|
||
# Apply abs based on whether source is f32 or f16
|
||
if not (combined_opsel_hi & opsel_hi_bit): return v & UOp.const(dtypes.uint32, 0x7FFFFFFF) # f32 abs
|
||
if opsel & opsel_bit: return v & UOp.const(dtypes.uint32, 0x7FFF0000) # f16 hi abs (preserve lo)
|
||
return v & UOp.const(dtypes.uint32, 0xFFFF7FFF) # f16 lo abs (preserve hi)
|
||
def apply_neg_mix(v, bit, opsel_hi_bit, opsel_bit):
|
||
if not (neg & bit): return v
|
||
if not (combined_opsel_hi & opsel_hi_bit): return v ^ UOp.const(dtypes.uint32, 0x80000000) # f32 neg
|
||
if opsel & opsel_bit: return v ^ UOp.const(dtypes.uint32, 0x80000000) # f16 hi neg
|
||
return v ^ UOp.const(dtypes.uint32, 0x00008000) # f16 lo neg
|
||
s0_mod = apply_neg_mix(apply_abs(src0, 1, 1, 1), 1, 1, 1)
|
||
s1_mod = apply_neg_mix(apply_abs(src1, 2, 2, 2), 2, 2, 2)
|
||
s2_mod = apply_neg_mix(apply_abs(src2, 4, 4, 4), 4, 4, 4)
|
||
srcs = {'S@0': s0_mod, 'S@1': s1_mod, 'S@2': s2_mod,
|
||
'OPSEL_HI': UOp.const(dtypes.uint32, combined_opsel_hi), 'OPSEL': UOp.const(dtypes.uint32, opsel)}
|
||
else:
|
||
def get_half_bits(val: UOp, use_hi: bool, apply_neg: bool = False) -> UOp:
|
||
bits = ((val >> UOp.const(dtypes.uint32, 16)) if use_hi else val) & UOp.const(dtypes.uint32, 0xFFFF)
|
||
if apply_neg: bits = bits.cast(dtypes.uint16).bitcast(dtypes.half).neg().bitcast(dtypes.uint16).cast(dtypes.uint32)
|
||
return bits
|
||
def build_remapped_src(src: UOp, opsel_lo_bit: int, opsel_hi_bit: int, neg_lo_bit: int, neg_hi_bit: int) -> UOp:
|
||
lo = get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit))
|
||
hi = get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit))
|
||
return lo | (hi << UOp.const(dtypes.uint32, 16))
|
||
# DOT IU instructions use NEG bits for signed/unsigned selection, not fp16 negation
|
||
is_dot_iu = 'DOT' in op_name and 'IU' in op_name
|
||
n0, n1, n2, nh0, nh1, nh2 = (0, 0, 0, 0, 0, 0) if is_dot_iu else (neg & 1, neg & 2, neg & 4, neg_hi & 1, neg_hi & 2, neg_hi & 4)
|
||
srcs = {'S0': build_remapped_src(src0, opsel & 1, opsel_hi & 1, n0, nh0),
|
||
'S1': build_remapped_src(src1, opsel & 2, opsel_hi & 2, n1, nh1),
|
||
'S2': build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, n2, nh2)}
|
||
if is_dot_iu: srcs['NEG'] = UOp.const(dtypes.uint32, neg)
|
||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask)
|
||
|
||
def _compile_vopd(inst: ir3.VOPD | ir4.VOPD, ctx: _Ctx) -> UOp:
|
||
exec_mask = ctx.rexec()
|
||
# Read operands dynamically - use type(inst) to get correct field descriptors
|
||
inst_type = type(inst)
|
||
vdstx_reg = ctx.inst_field(inst_type.vdstx)
|
||
# vdsty has complex encoding: actual = (raw << 1) | ((vdstx & 1) ^ 1)
|
||
vdsty_raw = ctx.inst_field(inst_type.vdsty)
|
||
vdsty_reg = (vdsty_raw << _c(1)) | ((vdstx_reg & _c(1)) ^ _c(1))
|
||
srcx0_off = ctx.inst_field(inst_type.srcx0)
|
||
srcy0_off = ctx.inst_field(inst_type.srcy0)
|
||
vsrcx1_reg = ctx.inst_field(inst_type.vsrcx1)
|
||
vsrcy1_reg = ctx.inst_field(inst_type.vsrcy1)
|
||
literal = ctx.inst_field(inst_type.literal) if hasattr(inst_type, 'literal') else None
|
||
|
||
lane = ctx.range()
|
||
srcy0, srcy1 = ctx.rsrc_dyn(srcy0_off, lane, literal=literal), ctx.rvgpr_dyn(vsrcy1_reg, lane)
|
||
all_stores = []
|
||
srcs:dict[str, UOp | int] = {}
|
||
for op, src0_off, vsrc1_reg, vdst_reg, label in [(inst.opx, srcx0_off, vsrcx1_reg, vdstx_reg, 'X'),
|
||
(inst.opy, srcy0_off, vsrcy1_reg, vdsty_reg, 'Y')]:
|
||
vop = VOPD_TO_VOP2.get(op)
|
||
assert vop is not None, f"no VOP mapping for VOPD {label}: {op}"
|
||
if label == 'Y': srcs = {'S0': srcy0, 'S1': srcy1, 'D0': ctx.rvgpr_dyn(vdst_reg, lane)}
|
||
else: srcs = {'S0': ctx.rsrc_dyn(src0_off, lane, literal=literal), 'S1': ctx.rvgpr_dyn(vsrc1_reg, lane), 'D0': ctx.rvgpr_dyn(vdst_reg, lane)}
|
||
# VOP2_FMAAK/FMAMK_(DTYPE)_E32
|
||
if vop in (ir3.VOP2Op.V_FMAAK_F32_E32, ir3.VOP2Op.V_FMAMK_F32_E32, ir3.VOP2Op.V_FMAAK_F32_E32, ir3.VOP2Op.V_FMAMK_F32_E32):
|
||
assert literal is not None
|
||
srcs['SIMM32'] = literal
|
||
if op in (ir3.VOPDOp.V_DUAL_CNDMASK_B32, ir4.VOPDOp.V_DUAL_CNDMASK_B32): srcs['VCC'] = ctx.rmask(_c(VCC_LO.offset))
|
||
pcode = get_pcode(vop)
|
||
srcs.update({'VCC': ctx.rmask(_c(VCC_LO.offset)), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane})
|
||
for dest, val in parse_pcode(pcode, srcs)[1]:
|
||
if dest.startswith('D0'): all_stores.append(ctx.wvgpr_dyn(vdst_reg, lane, _val_to_u32(val), exec_mask, after=srcy1))
|
||
return UOp.sink(UOp.group(*all_stores).end(lane), *ctx.inc_pc())
|
||
|
||
def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLAT|ir4.VGLOBAL|ir4.VSCRATCH
|
||
|irc.DS|irc.FLAT|irc.GLOBAL|irc.SCRATCH, ctx: _Ctx) -> UOp:
|
||
"""Unified memory operation compiler for DS, FLAT, GLOBAL, SCRATCH."""
|
||
exec_mask, op_name = ctx.rexec(), _op_name(inst)
|
||
pcode = get_pcode(inst.op)
|
||
# CDNA pcode uses CalcGlobalAddr/CalcDsAddr to compute address from raw components, but make_addr already handles this.
|
||
# Strip the addr computation line and use pre-computed ADDR directly (rename 'addr' -> 'ADDR' in remaining pcode).
|
||
if isinstance(inst, (irc.GLOBAL, irc.FLAT, irc.SCRATCH, irc.DS, ir4.VSCRATCH)) and 'Calc' in pcode and 'Addr' in pcode:
|
||
pcode = re.sub(r'addr\s*=\s*Calc\w+Addr\([^)]*\)\s*;?\n?', '', pcode).replace('MEM[addr', 'MEM[ADDR')
|
||
|
||
is_lds = isinstance(inst, (ir3.DS, ir4.DS, irc.DS))
|
||
is_scratch = isinstance(inst, (ir3.SCRATCH, ir4.VSCRATCH, irc.SCRATCH))
|
||
# CDNA acc bit: when set, VGPR operands (vdst/vdata) target ACCVGPR file instead of VGPR
|
||
use_acc = bool(getattr(inst, 'acc', 0))
|
||
mem = ctx.lds if is_lds else ctx.scratch if is_scratch else ctx.vmem
|
||
addr_shift = UOp.const(dtypes.uint32 if is_lds else dtypes.uint64, 2)
|
||
|
||
# Extract register info - all dynamic for deduplication
|
||
if is_lds:
|
||
addr_reg = ctx.inst_field(type(inst).addr) # type: ignore[union-attr]
|
||
vdata_reg = ctx.inst_field(type(inst).data0) # type: ignore[union-attr]
|
||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||
offset0 = ctx.inst_field(type(inst).offset0) # type: ignore[union-attr]
|
||
offset1 = ctx.inst_field(type(inst).offset1) # type: ignore[union-attr]
|
||
offset = (offset1 << _c(8)) | offset0 # DS offset is 16-bit: (offset1 << 8) | offset0
|
||
saddr_reg = None
|
||
elif isinstance(inst, (ir4.VGLOBAL, ir4.VSCRATCH, ir4.VFLAT)): # RDNA4: vaddr, vsrc, ioffset
|
||
addr_reg = ctx.inst_field(type(inst).vaddr)
|
||
vdata_reg = ctx.inst_field(type(inst).vsrc)
|
||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||
offset = ctx.inst_field_signed(type(inst).ioffset)
|
||
offset0, offset1 = _c(0), _c(0)
|
||
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(type(inst), 'saddr') else None
|
||
else: # RDNA3: addr, data, offset
|
||
addr_reg = ctx.inst_field(type(inst).addr) # type: ignore[union-attr]
|
||
vdata_reg = ctx.inst_field(type(inst).data) # type: ignore[union-attr]
|
||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||
offset = ctx.inst_field_signed(type(inst).offset) # type: ignore[union-attr]
|
||
offset0, offset1 = _c(0), _c(0)
|
||
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(type(inst), 'saddr') else None # type: ignore[union-attr]
|
||
|
||
# Data width from canonical_op_bits (32/64/96/128), default to 32 for untyped ops
|
||
data_bits_mem = inst.canonical_op_bits.get('data', 32)
|
||
is_atomic, glc = 'ATOMIC' in op_name, getattr(inst, 'glc', 0)
|
||
has_data1 = is_lds and hasattr(inst, 'data1') and inst.data1 is not None
|
||
data1_reg = ctx.inst_field(type(inst).data1) if is_lds else _c(0) # type: ignore[union-attr]
|
||
|
||
# DS_PERMUTE/DS_BPERMUTE: cross-lane VGPR access via pcode
|
||
if is_lds and 'PERMUTE' in op_name:
|
||
pcode = get_pcode(inst.op)
|
||
srcs = {'ADDR': addr_reg, 'DATA0': vdata_reg, 'VDST': vdst_reg, 'OFFSET': offset,
|
||
'EXEC': exec_mask.cast(dtypes.uint64), '_vgpr': ctx.vgpr, '_wave_size': ctx.wave_size}
|
||
_, assigns = parse_pcode(pcode, srcs)
|
||
stores = [ctx.vgpr.index(val[0].cast(dtypes.int)).store(val[1].cast(dtypes.uint32)) for dest, val in assigns if dest.startswith('VGPR[')]
|
||
return UOp.sink(*stores, *ctx.inc_pc())
|
||
|
||
def make_addr(lane: UOp) -> UOp:
|
||
if is_lds:
|
||
addr = ctx.rvgpr_dyn(addr_reg, lane)
|
||
# Some DS pcode (e.g. DS_STORE_B16) uses MEM[ADDR] without adding OFFSET explicitly.
|
||
# In those cases, add the instruction offset to ADDR here.
|
||
if 'OFFSET' not in pcode: addr = addr + offset
|
||
return addr
|
||
offset64 = offset.cast(dtypes.uint64)
|
||
# Dynamic saddr check: saddr < 124 means valid SGPR, otherwise use VGPR pair for address
|
||
use_saddr = (saddr_reg < _c(124)) if saddr_reg is not None else UOp.const(dtypes.bool, False)
|
||
if is_scratch:
|
||
scratch_stride = ctx.rsgpr_dyn(_c(SCRATCH_STRIDE_IDX)).cast(dtypes.uint64)
|
||
base = lane.cast(dtypes.uint64) * scratch_stride
|
||
# SVE (Scratch VGPR Enable): when SVE=1, VADDR is used as offset; when SVE=0, VADDR is ignored
|
||
sve = getattr(inst, 'sve', 0)
|
||
vaddr = ctx.rvgpr_dyn(addr_reg, lane).cast(dtypes.uint64)
|
||
addr_offset = vaddr if sve == 1 else UOp.const(dtypes.uint64, 0)
|
||
# Add saddr value only if use_saddr is true (saddr < 124)
|
||
saddr_contrib = use_saddr.where(ctx.rsgpr_dyn(saddr_reg).cast(dtypes.uint64), UOp.const(dtypes.uint64, 0)) \
|
||
if saddr_reg is not None else UOp.const(dtypes.uint64, 0)
|
||
return base + addr_offset + saddr_contrib + offset64
|
||
# FLAT/GLOBAL: choose between SGPR base (saddr) or VGPR pair (addr) based on saddr validity
|
||
saddr_base = _u64(ctx.rsgpr_dyn(saddr_reg), ctx.rsgpr_dyn(saddr_reg + _c(1))) if saddr_reg is not None else UOp.const(dtypes.uint64, 0)
|
||
vaddr_base = _u64(ctx.rvgpr_dyn(addr_reg, lane), ctx.rvgpr_dyn(addr_reg + _c(1), lane))
|
||
# When saddr is valid: base = saddr pair, vaddr is 32-bit offset; otherwise: base = 0, vaddr is 64-bit address
|
||
base_addr = use_saddr.where(saddr_base + ctx.rvgpr_dyn(addr_reg, lane).cast(dtypes.uint64), vaddr_base)
|
||
return base_addr + offset64
|
||
|
||
def wmem(addr: UOp, val: UOp, active: UOp, data_bits: int = 32) -> UOp:
|
||
if data_bits < 32:
|
||
# Sub-dword LDS write: read-modify-write within the uint32 slot
|
||
word_addr = (addr >> addr_shift).cast(dtypes.int)
|
||
idx = mem.index(word_addr, active)
|
||
byte_pos = addr.cast(dtypes.uint32) & _c(3)
|
||
byte_shift = byte_pos * _c(8)
|
||
size_mask = _c(0xFF if data_bits == 8 else 0xFFFF)
|
||
mask = size_mask << byte_shift
|
||
new_word = (idx & (mask ^ _c(0xFFFFFFFF))) | ((val.cast(dtypes.uint32) & size_mask) << byte_shift)
|
||
return idx.store(active.where(new_word, idx))
|
||
idx = mem.index((addr >> addr_shift).cast(dtypes.int))
|
||
return idx.store(active.where(val, idx.load()))
|
||
|
||
def make_srcs(lane: UOp) -> dict:
|
||
addr = make_addr(lane)
|
||
if is_lds:
|
||
if data_bits_mem == 128:
|
||
data = {'DATA': ctx.rvgpr_dyn(vdata_reg, lane), 'DATA1': ctx.rvgpr_dyn(vdata_reg + _c(1), lane),
|
||
'DATA2': ctx.rvgpr_dyn(vdata_reg + _c(2), lane), 'DATA3': ctx.rvgpr_dyn(vdata_reg + _c(3), lane)}
|
||
elif data_bits_mem == 96:
|
||
data = {'DATA': ctx.rvgpr_dyn(vdata_reg, lane), 'DATA1': ctx.rvgpr_dyn(vdata_reg + _c(1), lane),
|
||
'DATA2': ctx.rvgpr_dyn(vdata_reg + _c(2), lane)}
|
||
elif data_bits_mem <= 32:
|
||
data = {'DATA': ctx.rvgpr_dyn(vdata_reg, lane), 'DATA2': ctx.rvgpr_dyn(data1_reg, lane) if has_data1 else UOp.const(dtypes.uint32, 0)}
|
||
else:
|
||
data = {'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)),
|
||
'DATA2': _u64(ctx.rvgpr_dyn(data1_reg, lane), ctx.rvgpr_dyn(data1_reg + _c(1), lane)) if has_data1 else UOp.const(dtypes.uint64, 0)}
|
||
# RDNA3 uses ADDR/OFFSET, RDNA4 uses vgpr_a/offset (lowercase) + CalcDsAddr function
|
||
return {'ADDR': addr, 'ADDR_BASE': addr, 'OFFSET': offset, 'OFFSET0': offset0, 'OFFSET1': offset1, '_lds': mem, 'laneId': lane,
|
||
'vgpr_a': ctx.rvgpr_dyn(addr_reg, lane), 'offset': offset, 'offset0': offset0, 'offset1': offset1, **data}
|
||
active = _lane_active(exec_mask, lane)
|
||
# saddr < 124 means valid SGPR pair, otherwise use 0 (NULL means no saddr contribution)
|
||
use_saddr = (saddr_reg < _c(124)) if saddr_reg is not None else UOp.const(dtypes.bool, False)
|
||
saddr_raw = _u64(ctx.rsgpr_dyn(saddr_reg), ctx.rsgpr_dyn(saddr_reg + _c(1))) if saddr_reg is not None else UOp.const(dtypes.uint64, 0)
|
||
saddr_base = use_saddr.where(saddr_raw, UOp.const(dtypes.uint64, 0))
|
||
# Sign-extend offset to 64-bit for the final address calculation
|
||
ioffset64 = offset.cast(dtypes.int64).cast(dtypes.uint64)
|
||
# v_addr for CalcGlobalAddr: when saddr valid, use low 32 bits as offset; otherwise full 64-bit address. Include ioffset.
|
||
vaddr_full = _u64(ctx.rvgpr_dyn(addr_reg, lane), ctx.rvgpr_dyn(addr_reg + _c(1), lane))
|
||
vaddr_lo = ctx.rvgpr_dyn(addr_reg, lane).cast(dtypes.uint64)
|
||
vaddr_base = use_saddr.where(vaddr_lo + ioffset64, vaddr_full + ioffset64)
|
||
if is_atomic:
|
||
atomic_data = _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)) \
|
||
if data_bits_mem == 64 else ctx.rvgpr_dyn(vdata_reg, lane)
|
||
return {'ADDR': addr, 'DATA': atomic_data, '_vmem': mem, '_active': active,
|
||
'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base}
|
||
# acc bit: read/write ACCVGPR instead of VGPR for data operands
|
||
_rvdata = (lambda r, l, *a: ctx.raccvgpr_dyn(r, l)) if use_acc else ctx.rvgpr_dyn
|
||
vdata = _rvdata(vdata_reg, lane).cast(dtypes.uint64) if 'STORE' in op_name \
|
||
else _rvdata(vdst_reg, lane) if 'D16' in op_name else UOp.const(dtypes.uint32, 0)
|
||
if 'STORE' in op_name and data_bits_mem >= 64:
|
||
vdata = vdata | (_rvdata(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
|
||
srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active,
|
||
'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base, 'SADDR': saddr_base, 'OFFSET': offset}
|
||
for i in range(data_bits_mem // 32):
|
||
srcs[f'VDATA{i}'] = _rvdata(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0)
|
||
return srcs
|
||
|
||
def make_stores(dest: str, val: UOp, lane: UOp, active: UOp, writes_return_data: bool) -> list[UOp]:
|
||
# Parse bit width from dest format: MEM[...].b32 or RETURN_DATA[63:32].b64
|
||
parts = dest.rsplit('.', 1)
|
||
data_bits = int(parts[1][1:]) if len(parts) == 2 else 32
|
||
if dest.startswith('MEM['):
|
||
if is_lds or is_atomic:
|
||
if data_bits < 32 and is_lds: return [wmem(val[0], val[1], active, data_bits)]
|
||
return _write_val(data_bits, val[1], wmem, val[0], active, is_mem=True)
|
||
if is_scratch: return _mem_store_bytes(mem, val[0], val[1], active, data_bits)
|
||
return _mem_store(mem, val[0], val[1], active, 64, data_bits)
|
||
if dest.startswith('RETURN_DATA') and writes_return_data:
|
||
_wdata = (lambda r, v, l, e: ctx.waccvgpr_dyn(r, l, v, e)) if use_acc else (lambda r, v, l, e: ctx.wvgpr_dyn(r, l, v, e))
|
||
if (m := re.match(r'RETURN_DATA\[(\d+)\s*:\s*(\d+)\]', dest)):
|
||
bit_width, dword_idx = int(m.group(1)) - int(m.group(2)) + 1, int(m.group(2)) // 32
|
||
return _write_val(bit_width, val, _wdata, vdst_reg + _c(dword_idx), lane, exec_mask)
|
||
return _write_val(data_bits, val, _wdata, vdst_reg, lane, exec_mask)
|
||
return []
|
||
|
||
# DS-specific: check for 2ADDR pattern needing separate ranges
|
||
if is_lds:
|
||
dummy_lane = ctx.range()
|
||
_, assigns = parse_pcode(pcode, make_srcs(dummy_lane))
|
||
mem_assigns = [d for d, _ in assigns if d.startswith('MEM[')]
|
||
mem_addrs = set(m.group(1) if (m := re.match(r'MEM\[([^\]]+)\]', d)) else d for d in mem_assigns)
|
||
use_separate_ranges = (len(mem_addrs) > 1 or '2ADDR' in op_name) and 'STOREXCHG' not in op_name
|
||
if use_separate_ranges:
|
||
# Split assigns into MEM writes (stores) and RETURN_DATA writes (loads).
|
||
# Stores to different addresses need separate lane ranges. Loads must share a single lane range so the
|
||
# addr vgpr is read before any vdst write (hardware reads addr once, then writes all results).
|
||
store_assigns = [(i, d) for i, (d, _) in enumerate(assigns) if d.startswith('MEM[')]
|
||
load_assigns = [(i, d) for i, (d, _) in enumerate(assigns) if d.startswith('RETURN_DATA')]
|
||
ended: list[UOp] = []
|
||
for i, dest in store_assigns:
|
||
lane = ctx.range()
|
||
active = _lane_active(exec_mask, lane)
|
||
_, lane_assigns = parse_pcode(pcode, make_srcs(lane))
|
||
ended.extend(s.end(lane) for s in make_stores(dest, lane_assigns[i][1], lane, active, True))
|
||
if load_assigns:
|
||
lane = ctx.range()
|
||
active = _lane_active(exec_mask, lane)
|
||
_, lane_assigns = parse_pcode(pcode, make_srcs(lane))
|
||
load_stores: list[UOp] = []
|
||
for i, dest in load_assigns:
|
||
load_stores.extend(make_stores(dest, lane_assigns[i][1], lane, active, True))
|
||
if load_stores: ended.append(UOp.group(*load_stores).end(lane))
|
||
return UOp.sink(*ended, *ctx.inc_pc())
|
||
|
||
# Standard path: single lane range
|
||
writes_return_data = '_RTN' in op_name or (is_lds and (op_name.startswith('DS_LOAD') or op_name.startswith('DS_READ'))) or bool(is_atomic and glc)
|
||
lane = ctx.range()
|
||
active = _lane_active(exec_mask, lane)
|
||
pcode_vars, assigns = parse_pcode(pcode, make_srcs(lane))
|
||
stores = [s for dest, val in assigns for s in make_stores(dest, val, lane, active, writes_return_data)]
|
||
|
||
# FLAT/GLOBAL/SCRATCH: collect VDATA slices for loads
|
||
if not is_lds and not is_atomic:
|
||
_wdst = ctx.waccvgpr_dyn if use_acc else ctx.wvgpr_dyn
|
||
for dword_idx, val in sorted(_collect_data_slices(assigns, 'VDATA', pcode_vars, op_name).items()):
|
||
stores.append(_wdst(vdst_reg + _c(dword_idx), lane, val, exec_mask))
|
||
|
||
return UOp.sink(UOp.group(*stores).end(lane), *ctx.inc_pc())
|
||
|
||
def _compile_mubuf(inst: irc.MUBUF, ctx: _Ctx) -> UOp:
|
||
"""CDNA MUBUF: linear buffer address = base + soffset + (stride * index) + vgpr_offset + inst_offset"""
|
||
exec_mask, op_name = ctx.rexec(), _op_name(inst)
|
||
use_acc, is_store, is_lds = bool(getattr(inst, 'acc', 0)), 'STORE' in op_name, bool(getattr(inst, 'lds', 0))
|
||
n_dwords = 4 if 'X4' in op_name else 2 if 'X2' in op_name else 1
|
||
|
||
# instruction fields
|
||
vdata, vaddr = ctx.inst_field(type(inst).vdata), ctx.inst_field(type(inst).vaddr)
|
||
srsrc, soffset = ctx.inst_field(type(inst).srsrc) * _c(4), ctx.inst_field(type(inst).soffset)
|
||
offset, offen, idxen = ctx.inst_field(type(inst).offset), ctx.inst_field(type(inst).offen), ctx.inst_field(type(inst).idxen)
|
||
|
||
# V# descriptor: base[0:1], num_records[2], stride=word3[13:0]
|
||
base = _u64(ctx.rsgpr_dyn(srsrc), ctx.rsgpr_dyn(srsrc + _c(1))) & UOp.const(dtypes.uint64, 0xFFFFFFFFFFFF)
|
||
num_records = ctx.rsgpr_dyn(srsrc + _c(2))
|
||
stride = (ctx.rsgpr_dyn(srsrc + _c(3)) & _c(0x3FFF)).cast(dtypes.uint64)
|
||
|
||
lane = ctx.range()
|
||
active = _lane_active(exec_mask, lane)
|
||
|
||
# soffset: sgpr if < 128, else inline constant
|
||
soff = (soffset < _c(128)).where(ctx.rsgpr_dyn(soffset), soffset - _c(128)).cast(dtypes.uint64)
|
||
# vaddr: index (if idxen) in vaddr, offset (if offen) in vaddr or vaddr+1
|
||
index = idxen.ne(_c(0)).where(ctx.rvgpr_dyn(vaddr, lane), _c(0)).cast(dtypes.uint64)
|
||
voff = offen.ne(_c(0)).where(ctx.rvgpr_dyn(idxen.ne(_c(0)).where(vaddr + _c(1), vaddr), lane), _c(0)).cast(dtypes.uint64)
|
||
|
||
# buffer_offset for bounds check, final address
|
||
buffer_offset = (stride * index + voff + offset.cast(dtypes.uint64)).cast(dtypes.uint32)
|
||
in_bounds = active & buffer_offset.__lt__(num_records)
|
||
addr = base + soff + buffer_offset.cast(dtypes.uint64)
|
||
addr = in_bounds.where(addr, UOp.const(dtypes.uint64, 0)) # safe address when OOB
|
||
mem = ctx.vmem
|
||
|
||
stores: list[UOp] = []
|
||
if is_lds and not is_store:
|
||
# LDS load: buffer -> LDS (bypass VGPRs), LDS addr = M0[17:0] + lane * elem_size
|
||
lds_base = ctx.rsgpr_dyn(_c(124)) & _c(0x3FFFF)
|
||
lds_addr = lds_base + lane.cast(dtypes.uint32) * _c(n_dwords * 4)
|
||
for i in range(n_dwords):
|
||
word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2)
|
||
val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64), ptr=True).load(), _c(0))
|
||
lds_idx = ((lds_addr + _c(i * 4)) >> _c(2)).cast(dtypes.int)
|
||
stores.append(ctx.lds.index(lds_idx, active).store(active.where(val, ctx.lds.index(lds_idx, active))))
|
||
elif is_store:
|
||
for i in range(n_dwords):
|
||
word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2)
|
||
idx = mem.index(word_addr.cast(dtypes.int64), in_bounds)
|
||
val = (ctx.raccvgpr_dyn if use_acc else ctx.rvgpr_dyn)(vdata + _c(i), lane)
|
||
stores.append(idx.store(in_bounds.where(_to_u32(val), idx)))
|
||
else:
|
||
for i in range(n_dwords):
|
||
word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2)
|
||
val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64), in_bounds, ptr=True).load(), _c(0))
|
||
stores.append((ctx.waccvgpr_dyn if use_acc else ctx.wvgpr_dyn)(vdata + _c(i), lane, val, exec_mask))
|
||
return UOp.sink(UOp.group(*stores).end(lane), *ctx.inc_pc())
|
||
|
||
# Dispatch table: instruction type -> handler function
|
||
_INST_HANDLERS: dict[type, Callable[..., UOp]] = {
|
||
ir3.SOPP: _compile_sopp, ir3.SMEM: _compile_smem, ir3.SOP1: _compile_sop, ir3.SOP2: _compile_sop, ir3.SOPC: _compile_sop, ir3.SOPK: _compile_sop,
|
||
ir3.VOP1: _compile_vop12, ir3.VOP1_SDST: _compile_vop12, ir3.VOP1_DPP16: _compile_vop12, ir3.VOP2: _compile_vop12, ir3.VOP2_DPP16: _compile_vop12,
|
||
ir3.VOPC: _compile_vopc, ir3.VOPC_DPP16: _compile_vopc, ir3.VOP3: _compile_vop3, ir3.VINTERP: _compile_vinterp,
|
||
ir3.VOP3_SDST: _compile_vop3, ir3.VOP3SD: _compile_vop3sd, ir3.VOP3P: _compile_vop3p, ir3.VOPD: _compile_vopd,
|
||
ir3.DS: _compile_mem_op, ir3.FLAT: _compile_mem_op, ir3.GLOBAL: _compile_mem_op, ir3.SCRATCH: _compile_mem_op,
|
||
# RDNA4 instruction classes
|
||
ir4.SOPP: _compile_sopp, ir4.SMEM: _compile_smem, ir4.SOP1: _compile_sop, ir4.SOP2: _compile_sop, ir4.SOPC: _compile_sop, ir4.SOPK: _compile_sop,
|
||
ir4.VOP1: _compile_vop12, ir4.VOP1_SDST: _compile_vop12, ir4.VOP1_DPP16: _compile_vop12, ir4.VOP2: _compile_vop12, ir4.VOP2_DPP16: _compile_vop12,
|
||
ir4.VOPC: _compile_vopc, ir4.VOPC_DPP16: _compile_vopc, ir4.VOP3: _compile_vop3, ir4.VINTERP: _compile_vinterp,
|
||
ir4.VOP3_SDST: _compile_vop3, ir4.VOP3SD: _compile_vop3sd, ir4.VOP3P: _compile_vop3p, ir4.VOPD: _compile_vopd,
|
||
ir4.DS: _compile_mem_op, ir4.VFLAT: _compile_mem_op, ir4.VGLOBAL: _compile_mem_op, ir4.VSCRATCH: _compile_mem_op,
|
||
# CDNA instruction classes
|
||
irc.SOPP: _compile_sopp, irc.SMEM: _compile_smem, irc.SOP1: _compile_sop, irc.SOP2: _compile_sop, irc.SOPC: _compile_sop, irc.SOPK: _compile_sop,
|
||
irc.VOP1: _compile_vop12, irc.VOP1_DPP16: _compile_vop12, irc.VOP2: _compile_vop12, irc.VOP2_DPP16: _compile_vop12,
|
||
irc.VOPC: _compile_vopc, irc.VOP3: _compile_vop3,
|
||
irc.VOP3_SDST: _compile_vop3, irc.VOP3SD: _compile_vop3sd, irc.VOP3P: _compile_vop3p,
|
||
irc.VOP1_SDWA: _compile_sdwa, irc.VOP2_SDWA: _compile_sdwa, irc.VOP2_SDWA_SDST: _compile_sdwa, irc.VOPC_SDWA_SDST: _compile_sdwa,
|
||
irc.DS: _compile_mem_op, irc.FLAT: _compile_mem_op, irc.GLOBAL: _compile_mem_op, irc.SCRATCH: _compile_mem_op,
|
||
irc.MUBUF: _compile_mubuf,
|
||
}
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# PROGRAM DECODE AND COMPILATION
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
_canonical_runner_cache: list[tuple[type, int, int, int, tuple[UOp, object]]] = [] # [(inst_type, base, mask, size, (prg, runtime)), ...]
|
||
|
||
@functools.cache
|
||
def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
|
||
"""Build and compile instruction to (prg, runtime). Cached by instruction bytes, with canonical dedup."""
|
||
inst = decode_inst(inst_bytes, arch)
|
||
inst_size = inst.size()
|
||
inst_int = int.from_bytes(inst_bytes[:inst_size], 'little')
|
||
|
||
# Check if instruction matches any cached canonical pattern (must also match instruction type to avoid variant conflicts)
|
||
for inst_type, base, mask, size, entry in _canonical_runner_cache:
|
||
if type(inst) is inst_type and inst_size == size and (inst_int & mask) == base: return entry
|
||
|
||
# Look up handler by type, falling back to base classes for _LIT variants
|
||
handler = _INST_HANDLERS.get(type(inst))
|
||
if handler is None:
|
||
for cls in type(inst).__mro__:
|
||
if cls in _INST_HANDLERS:
|
||
handler = _INST_HANDLERS[cls]
|
||
break
|
||
if handler is None: raise RuntimeError(f"[emu] unimplemented instruction type: {type(inst).__name__} {_op_name(inst)}")
|
||
|
||
ctx = _Ctx(inst_size, _wave_size(arch))
|
||
sink = handler(inst, ctx)
|
||
base, mask, size = ctx.canonical_mask(inst_bytes)
|
||
canonical_name = f"{_op_name(inst).lower()}_{base.to_bytes(size, 'little').hex()}"
|
||
sink = sink.replace(arg=KernelInfo(name=canonical_name)).rtag(1)
|
||
|
||
# NOTE: renderer output is not reproducible because of _MXCSRContext. PROFILE=0 prevents emulator instruction runners from polluting profiling.
|
||
with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES="", CAPTURE_PROCESS_REPLAY=0, PROFILE=0):
|
||
prg = to_program(sink, Device['CPU'].renderer)
|
||
runtime = get_runtime('CPU', prg)
|
||
_canonical_runner_cache.append((type(inst), base, mask, size, (prg, runtime)))
|
||
return prg, runtime
|
||
|
||
_BARRIER_OPS = {ir3.SOPPOp.S_BARRIER, irc.SOPPOp.S_BARRIER}
|
||
if hasattr(ir4.SOPPOp, 'S_BARRIER_WAIT'): _BARRIER_OPS.add(ir4.SOPPOp.S_BARRIER_WAIT)
|
||
_BARRIER_SOP1_OPS: set = set()
|
||
if hasattr(ir4.SOP1Op, 'S_BARRIER_SIGNAL'): _BARRIER_SOP1_OPS.add(ir4.SOP1Op.S_BARRIER_SIGNAL)
|
||
_BRANCH_OPS: set[int] = {op.value for op in (ir3.SOPPOp.S_BRANCH, ir3.SOPPOp.S_CBRANCH_SCC0, ir3.SOPPOp.S_CBRANCH_SCC1,
|
||
ir3.SOPPOp.S_CBRANCH_VCCZ, ir3.SOPPOp.S_CBRANCH_VCCNZ, ir3.SOPPOp.S_CBRANCH_EXECZ, ir3.SOPPOp.S_CBRANCH_EXECNZ)}
|
||
|
||
def _decode_at(pc: int, arch: str):
|
||
"""Decode and compile instruction at absolute address pc. Returns (runner, decoded_inst)."""
|
||
inst_bytes = bytes((ctypes.c_char * 16).from_address(pc).raw)
|
||
inst = decode_inst(inst_bytes, arch)
|
||
try: return _get_runner(bytes(inst_bytes[:inst.size() + 4]), arch), inst
|
||
except Exception as e:
|
||
try: inst_str = repr(inst)
|
||
except Exception: inst_str = f"<{type(inst).__name__}>"
|
||
raise RuntimeError(f"[emu] Failed to compile {inst_str}: {type(e).__name__}: {e}") from e
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# WAVE STATE
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
# Inline float constants (as bit patterns) for GPU instructions
|
||
F32_INLINE = {240: 0x3f000000, 241: 0xbf000000, 242: 0x3f800000, 243: 0xbf800000, # 0.5, -0.5, 1.0, -1.0
|
||
244: 0x40000000, 245: 0xc0000000, 246: 0x40800000, 247: 0xc0800000, 248: 0x3e22f983} # 2.0, -2.0, 4.0, -4.0, 1/(2*pi)
|
||
|
||
class WaveState:
|
||
__slots__ = ('vgpr_buf', 'sgpr_buf', 'accvgpr_buf', '_vgpr_mv', '_sgpr_mv', 'n_lanes', 'wave_size')
|
||
|
||
def __init__(self, n_lanes: int, wave_size: int = 32):
|
||
self.n_lanes, self.wave_size = n_lanes, wave_size
|
||
vgpr_size = 256 * wave_size
|
||
self.vgpr_buf = Buffer('CPU', vgpr_size, dtypes.uint32).ensure_allocated()
|
||
self.sgpr_buf = Buffer('CPU', SGPR_COUNT, dtypes.uint32).ensure_allocated()
|
||
# CDNA (wave64) has separate ACCVGPR file; RDNA shares with VGPR
|
||
if wave_size == 64:
|
||
self.accvgpr_buf = Buffer('CPU', vgpr_size, dtypes.uint32).ensure_allocated()
|
||
ctypes.memset(self.accvgpr_buf._buf.va_addr, 0, vgpr_size * 4)
|
||
else:
|
||
self.accvgpr_buf = self.vgpr_buf
|
||
self._vgpr_mv = self.vgpr_buf.as_memoryview(force_zero_copy=True).cast('I')
|
||
self._sgpr_mv = self.sgpr_buf.as_memoryview(force_zero_copy=True).cast('I')
|
||
# Zero memory using ctypes memset (much faster than Python loops)
|
||
ctypes.memset(self.vgpr_buf._buf.va_addr, 0, vgpr_size * 4)
|
||
ctypes.memset(self.sgpr_buf._buf.va_addr, 0, SGPR_COUNT * 4)
|
||
# Pre-populate inline constants at indices 128-255
|
||
for i in range(65): self._write_sgpr(128 + i, i) # 128-192: integers 0-64
|
||
for i in range(16): self._write_sgpr(193 + i, (-(i + 1)) & MASK32) # 193-208: -1 to -16
|
||
for off, val in F32_INLINE.items(): self._write_sgpr(off, val) # 240-248: float constants
|
||
# EXEC mask: for 64-lane waves, set both EXEC_LO and EXEC_HI
|
||
if wave_size == 64:
|
||
self._write_sgpr(EXEC_LO.offset, (1 << min(n_lanes, 32)) - 1)
|
||
self._write_sgpr(EXEC_LO.offset + 1, (1 << max(n_lanes - 32, 0)) - 1 if n_lanes > 32 else 0)
|
||
else:
|
||
self._write_sgpr(EXEC_LO.offset, (1 << n_lanes) - 1)
|
||
self._write_sgpr(PC_LO_IDX, 0)
|
||
self._write_sgpr(PC_HI_IDX, 0)
|
||
|
||
def _write_sgpr(self, idx: int, val: int): self._sgpr_mv[idx] = val & MASK32
|
||
def _read_sgpr(self, idx: int) -> int: return self._sgpr_mv[idx]
|
||
def _write_vgpr(self, reg: int, lane: int, val: int): self._vgpr_mv[reg * self.wave_size + lane] = val & MASK32
|
||
def _read_vgpr(self, reg: int, lane: int) -> int: return self._vgpr_mv[reg * self.wave_size + lane]
|
||
|
||
@property
|
||
def pc(self) -> int: return self._read_sgpr(PC_LO_IDX) | (self._read_sgpr(PC_HI_IDX) << 32)
|
||
@pc.setter
|
||
def pc(self, val: int):
|
||
self._write_sgpr(PC_LO_IDX, val & MASK32)
|
||
self._write_sgpr(PC_HI_IDX, (val >> 32) & MASK32)
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# EXECUTION
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
def _init_wave(lib: int, wave_start: int, total_threads: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int,
|
||
scratch_size: int, arch: str, gidx: int, gidy: int, gidz: int, user_data: list[int]|None,
|
||
wave_size: int = 32) -> WaveState:
|
||
"""Initialize a single wavefront and return WaveState."""
|
||
n_lanes = min(wave_size, total_threads - wave_start)
|
||
st = WaveState(n_lanes, wave_size)
|
||
st.pc = lib
|
||
if user_data:
|
||
for i, val in enumerate(user_data): st._write_sgpr(i, val)
|
||
else:
|
||
st._write_sgpr(0, args_ptr & MASK32)
|
||
st._write_sgpr(1, (args_ptr >> 32) & MASK32)
|
||
if arch == "rdna4":
|
||
# workgroup IDs only exist in ttmp registers, not normal SGPRs
|
||
st._write_sgpr(ttmp[7].offset, (gidy & 0xFFFF) | ((gidz & 0xFFFF) << 16))
|
||
st._write_sgpr(ttmp[9].offset, gidx)
|
||
else:
|
||
sgpr_idx = (rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT_SHIFT
|
||
for enabled, gid in [(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_X, gidx),
|
||
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Y, gidy),
|
||
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Z, gidz)]:
|
||
if rsrc2 & enabled:
|
||
st._write_sgpr(sgpr_idx, gid)
|
||
sgpr_idx += 1
|
||
for lane in range(n_lanes):
|
||
tid = wave_start + lane
|
||
st._write_vgpr(0, lane, ((tid // (lx * ly)) << 20) | (((tid // lx) % ly) << 10) | (tid % lx))
|
||
st._write_sgpr(SCRATCH_STRIDE_IDX, scratch_size)
|
||
# Store HW register values at SGPR[SGPR_COUNT-16 .. SGPR_COUNT-1] for s_getreg_b32 emulation.
|
||
# HW_ID (hwRegId=4): WAVE_ID[3:0], SIMD_ID[5:4], PIPE_ID[7:6], CU_ID[11:8], ...
|
||
wave_idx = wave_start // wave_size # wave index within this workgroup (0, 1, 2, 3 for 256 threads / 64 wave_size)
|
||
hw_id = (wave_idx & 0xF) | ((wave_idx & 0x3) << 4) # WAVE_ID = wave_idx, SIMD_ID = wave_idx % 4
|
||
st._write_sgpr(SGPR_COUNT - 16 + 4, hw_id) # HW_REGISTERS[4] = HW_ID
|
||
return st
|
||
|
||
def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c,
|
||
scratch_size: int = 0, arch: str = "rdna3", user_data: list[int]|None = None) -> int:
|
||
"""Execute AMD assembly program. scratch_size is private_segment_fixed_size from kernel descriptor (per-lane)."""
|
||
from tinygrad.renderer.amd.dsl import Inst
|
||
program: dict[int, tuple[Callable, list[int], bool, Inst]] = {} # pc -> (fxn, globals, is_barrier, inst)
|
||
lds_size = ((rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE_SHIFT) * 512
|
||
total_threads = lx * ly * lz
|
||
wave_size = _wave_size(arch)
|
||
|
||
# Use Buffer objects with external_ptr=0 for vmem
|
||
vmem_buf = Buffer('CPU', 1 << 40, dtypes.uint32, options=BufferSpec(external_ptr=0)).ensure_allocated()
|
||
lds_buf = Buffer('CPU', max(lds_size // 4, 1), dtypes.uint32).ensure_allocated()
|
||
scratch_buf = Buffer('CPU', scratch_size * wave_size, dtypes.uint8).ensure_allocated() if scratch_size else None
|
||
|
||
# Initialize SQTT encoder — emits packets inline as instructions execute (only when profiling)
|
||
if PROFILE:
|
||
sqtt_emit, sqtt_finish, sqtt_finalize = _init_sqtt_encoder()
|
||
|
||
def _ensure_compiled(pc: int) -> tuple[Callable, list[int], bool, Inst]:
|
||
if pc not in program:
|
||
prev_len = len(_canonical_runner_cache)
|
||
(prg, runtime), inst = _decode_at(pc, arch)
|
||
is_barrier = (isinstance(inst, (ir3.SOPP, ir4.SOPP, irc.SOPP)) and inst.op in _BARRIER_OPS) or \
|
||
(isinstance(inst, (ir4.SOP1,)) and inst.op in _BARRIER_SOP1_OPS)
|
||
program[pc] = (runtime.fxn, prg.arg.globals, is_barrier, inst)
|
||
if DEBUG >= 3:
|
||
msg = f"[emu] PC={pc - lib}: {inst!r}"
|
||
print(colored(msg, 'green') if len(_canonical_runner_cache) > prev_len else msg)
|
||
return program[pc]
|
||
|
||
# Set DAZ+FTZ during emulator execution, restore afterward to avoid breaking hypothesis tests
|
||
# Only trace the first workgroup (like real HW traces one CU/SIMD), subsequent workgroups run but don't add to trace
|
||
tracing = bool(PROFILE)
|
||
|
||
with _MXCSRContext():
|
||
for gidz in range(gz):
|
||
for gidy in range(gy):
|
||
for gidx in range(gx):
|
||
# Initialize all wavefronts for this workgroup
|
||
waves: list[tuple[WaveState, list]] = []
|
||
for wave_start in range(0, total_threads, wave_size):
|
||
st = _init_wave(lib, wave_start, total_threads, lx, ly, lz, args_ptr, rsrc2, scratch_size, arch, gidx, gidy, gidz, user_data,
|
||
wave_size)
|
||
c_bufs = [ctypes.c_uint64(st.sgpr_buf._buf.va_addr), ctypes.c_uint64(st.vgpr_buf._buf.va_addr),
|
||
ctypes.c_uint64(vmem_buf._buf.va_addr), ctypes.c_uint64(lds_buf._buf.va_addr),
|
||
ctypes.c_uint64(scratch_buf._buf.va_addr if scratch_buf else 0),
|
||
ctypes.c_uint64(st.accvgpr_buf._buf.va_addr)]
|
||
waves.append((st, c_bufs))
|
||
|
||
# Execute wavefronts with barrier synchronization
|
||
# Each wave runs until it hits s_barrier or s_endpgm. When all waves have stopped, release barrier waves.
|
||
done = [False] * len(waves)
|
||
for total_inst in range(10_000_000):
|
||
if all(done): break
|
||
for wi, (st, c_bufs) in enumerate(waves):
|
||
if done[wi]: continue
|
||
# Run this wave until barrier or endpgm
|
||
for _ in range(1_000_000):
|
||
pc = st.pc
|
||
if pc == ENDPGM_PC:
|
||
done[wi] = True
|
||
if tracing: sqtt_finish(wi)
|
||
break
|
||
fxn, globals_list, is_barrier, inst = _ensure_compiled(pc)
|
||
if DEBUG >= 5: print(f" exec gid=({gidx},{gidy},{gidz}) w={wi} PC={pc - lib}: {inst!r}", flush=True)
|
||
fxn(*[c_bufs[g] for g in globals_list])
|
||
if tracing:
|
||
inst_op = inst.op.value if hasattr(inst, 'op') else 0
|
||
sqtt_emit(wi, inst, (st.pc != ENDPGM_PC and st.pc != pc + inst.size()) if inst_op in _BRANCH_OPS else None)
|
||
if is_barrier: break # s_barrier hit: PC already advanced past it, pause this wave
|
||
else: raise RuntimeError("exceeded 1M instructions in single wave, likely infinite loop")
|
||
# All waves have either hit barrier or endpgm — release barrier waves for next round
|
||
else: raise RuntimeError("exceeded 10M total scheduling rounds")
|
||
tracing = False # only trace the first workgroup
|
||
|
||
# Reset LDS for next workgroup
|
||
if lds_size > 0: ctypes.memset(lds_buf._buf.va_addr, 0, max(lds_size, 4))
|
||
|
||
if PROFILE: sqtt_traces.append(sqtt_finalize())
|
||
return 0
|