# RDNA3 emulator v2 - compiles pcode to UOps executed via tinygrad CPU backend # Each instruction is compiled to a kernel that operates on buffers: # arg=0: sgpr - sgpr[0-127], inline constants[128-255], PC_LO=256, PC_HI=257, SCC=258, SCRATCH_STRIDE=259 # arg=1: vgpr - vgpr[reg * 32 + lane] # arg=2: vmem - base address 0, INDEX offsets directly to host memory # arg=3: lds - local data share # arg=4: scratch - per-lane scratch memory from __future__ import annotations import ctypes, functools, re, platform, subprocess, tempfile from typing import Callable # Set/restore DAZ+FTZ (denormals-are-zero + flush-to-zero) to match RDNA3 default float mode # x86: MXCSR bits DAZ(6)+FTZ(15), ARM64: FPCR bit FZ(24) # Only applied during emulator execution, restored afterward to avoid breaking hypothesis tests @functools.cache def _get_ftz_lib(): machine = platform.machine() if machine in ('x86_64', 'AMD64'): src = b''' unsigned int get_fpcr(void){unsigned int m;__asm__ __volatile__("stmxcsr %0":"=m"(m));return m;} void set_fpcr(unsigned int m){__asm__ __volatile__("ldmxcsr %0"::"m"(m));} ''' ftz_bits = 0x8040 # DAZ (bit 6) + FTZ (bit 15) elif machine in ('arm64', 'aarch64'): src = b''' unsigned int get_fpcr(void){unsigned long long v;__asm__ __volatile__("mrs %0,fpcr":"=r"(v));return(unsigned int)v;} void set_fpcr(unsigned int m){unsigned long long v=m;__asm__ __volatile__("msr fpcr,%0"::"r"(v));} ''' ftz_bits = 1 << 24 # FZ (bit 24) else: return None, 0 try: with tempfile.NamedTemporaryFile(suffix='.so', delete=False) as f: subprocess.check_output(['clang', '-shared', '-O2', '-x', 'c', '-', '-o', f.name], input=src) lib = ctypes.CDLL(f.name) lib.get_fpcr.restype = ctypes.c_uint32 lib.set_fpcr.argtypes = [ctypes.c_uint32] return lib, ftz_bits except Exception: return None, 0 class _MXCSRContext: """Context manager to set DAZ+FTZ during emulator execution and restore afterward.""" __slots__ = ('_saved',) def __enter__(self): lib, ftz_bits = _get_ftz_lib() if lib is None: return self self._saved = lib.get_fpcr() lib.set_fpcr(self._saved | ftz_bits) return self def __exit__(self, *args): lib, _ = _get_ftz_lib() if lib is None or not hasattr(self, '_saved'): return lib.set_fpcr(self._saved) from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType from tinygrad.dtype import dtypes, AddrSpace from tinygrad.device import Buffer, BufferSpec, Device from tinygrad.runtime.autogen import hsa from tinygrad.helpers import Context, DEBUG, PROFILE, colored from tinygrad.engine.realize import get_runtime from tinygrad.codegen import to_program from tinygrad.renderer.amd import decode_inst from tinygrad.runtime.autogen.amd.rdna3.str_pcode import PCODE as PCODE_RDNA3 from tinygrad.runtime.autogen.amd.rdna4.str_pcode import PCODE as PCODE_RDNA4 from tinygrad.runtime.autogen.amd.cdna.str_pcode import PCODE as PCODE_CDNA from tinygrad.runtime.autogen.amd.rdna3 import ins as ir3 from tinygrad.runtime.autogen.amd.rdna4 import ins as ir4 from tinygrad.runtime.autogen.amd.cdna import ins as irc from tinygrad.renderer.amd.dsl import VCC_LO, EXEC_LO, SCC, ttmp from tinygrad.runtime.autogen.amd.common import Fmt, OpType from test.amd.helpers import decode_dpp16 from test.mockgpu.amd.pcode import parse_block, _FUNCS, _set_bits, _to_bool, _val_to_bits MASK32 = 0xFFFFFFFF # ═══════════════════════════════════════════════════════════════════════════════ # SQTT TRACE COLLECTION # ═══════════════════════════════════════════════════════════════════════════════ # Global trace storage: populated by run_asm as raw SQTT blobs, consumed by amdgpu.py sqtt_traces: list[bytes] = [] # Encoder primitives from tinygrad.renderer.amd.sqtt import _build_decode_tables, PACKET_TYPES_RDNA3, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, IMMEDIATE, VALUINST, InstOp _NIB_COUNTS: dict = {cls: nc for _, (cls, nc, *_) in _build_decode_tables(PACKET_TYPES_RDNA3)[0].items()} def _encode_raw(pkt_cls, **kwargs) -> tuple[int, int]: raw = pkt_cls.encoding.default for k, v in kwargs.items(): raw = pkt_cls.__dict__[k].set(raw, v) return raw, _NIB_COUNTS[pkt_cls] def _emit_nibbles(nibbles: list[int], pkt_cls, **kwargs): raw, nc = _encode_raw(pkt_cls, **kwargs) for i in range(nc): nibbles.append((raw >> (i * 4)) & 0xF) def _nibbles_to_bytes(nibbles: list[int]) -> bytes: result = bytearray() for i in range(0, len(nibbles), 2): result.append(nibbles[i] | ((nibbles[i + 1] if i + 1 < len(nibbles) else 0) << 4)) return bytes(result) def _init_sqtt_encoder(): """Initialize and return SQTT encoder state. Called once per dispatch with tracing enabled.""" from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp as SOPPOp3 from tinygrad.runtime.autogen.amd.rdna4.enum import SOPPOp as SOPPOp4 import re _SOPP = (ir3.SOPP, ir4.SOPP, irc.SOPP) _SMEM = (ir3.SMEM, ir4.SMEM, irc.SMEM) _VALU = (ir3.VOP1, ir3.VOP2, ir3.VOP3, ir3.VOP3P, ir3.VOPC, ir3.VOPD, ir3.VOP3SD, ir3.VOP3_SDST, ir3.VOP1_SDST, ir4.VOP1, ir4.VOP2, ir4.VOP3, ir4.VOP3P, ir4.VOPC, ir4.VOPD, ir4.VOP3SD, ir4.VOP3_SDST, ir4.VOP1_SDST, irc.VOP1, irc.VOP2, irc.VOP3, irc.VOP3P, irc.VOPC, irc.VOP3SD, irc.VOP3_SDST) _DS = (ir3.DS, ir4.DS, irc.DS) _GLOBAL = (ir3.GLOBAL, ir4.VGLOBAL, irc.GLOBAL) _FLAT = (ir3.FLAT, ir4.VFLAT, irc.FLAT) _SCRATCH = (ir3.SCRATCH, ir4.VSCRATCH, irc.SCRATCH) # SOPP classification sets _SOPP_SKIP = {SOPPOp3.S_ENDPGM.value, SOPPOp3.S_ENDPGM_SAVED.value, SOPPOp3.S_ENDPGM_ORDERED_PS_DONE.value, SOPPOp3.S_DELAY_ALU.value} _SOPP_IMMEDIATE = {SOPPOp3.S_NOP.value, SOPPOp3.S_CLAUSE.value, SOPPOp3.S_WAITCNT.value, SOPPOp3.S_WAITCNT_DEPCTR.value, SOPPOp3.S_WAIT_IDLE.value, SOPPOp3.S_WAIT_EVENT.value, SOPPOp3.S_SLEEP.value, SOPPOp3.S_SET_INST_PREFETCH_DISTANCE.value} for _op in (SOPPOp4.S_WAIT_ALU, SOPPOp4.S_WAIT_LOADCNT, SOPPOp4.S_WAIT_STORECNT, SOPPOp4.S_WAIT_SAMPLECNT, SOPPOp4.S_WAIT_BVHCNT, SOPPOp4.S_WAIT_EXPCNT, SOPPOp4.S_WAIT_DSCNT, SOPPOp4.S_WAIT_KMCNT, SOPPOp4.S_WAIT_LOADCNT_DSCNT, SOPPOp4.S_WAIT_STORECNT_DSCNT): _SOPP_IMMEDIATE.add(_op.value) _SOPP_BARRIER = {SOPPOp3.S_BARRIER.value} if hasattr(SOPPOp4, 'S_BARRIER_WAIT'): _SOPP_BARRIER.add(SOPPOp4.S_BARRIER_WAIT.value) if hasattr(SOPPOp4, 'S_BARRIER_LEAVE'): _SOPP_BARRIER.add(SOPPOp4.S_BARRIER_LEAVE.value) _SOPP_BRANCH = {SOPPOp3.S_BRANCH.value, SOPPOp3.S_CBRANCH_SCC0.value, SOPPOp3.S_CBRANCH_SCC1.value, SOPPOp3.S_CBRANCH_VCCZ.value, SOPPOp3.S_CBRANCH_VCCNZ.value, SOPPOp3.S_CBRANCH_EXECZ.value, SOPPOp3.S_CBRANCH_EXECNZ.value} # VALU sub-classification patterns _VALUT_4_RE = re.compile(r'V_(EXP|LOG|RCP|RSQ|SQRT|SIN|COS|CEIL|FLOOR|TRUNC|RNDNE|FRACT|FREXP)_') _VALUB_2_RE = re.compile(r'V_(LSHLREV|LSHRREV|ASHRREV)_(B|I)64') _VALUB_4_RE = re.compile(r'V_MAD_(U|I)64') _VALUB_16_RE = re.compile(r'V_\w+_F64') def _valu_op(op_name: str) -> InstOp|None: if 'CMPX' in op_name: return InstOp.VALU1_WR_EXEC if _VALUB_2_RE.search(op_name): return InstOp.VALUB_2 if _VALUB_4_RE.search(op_name): return InstOp.VALUB_4 if _VALUB_16_RE.search(op_name): return InstOp.VALUB_16 if _VALUT_4_RE.search(op_name): return InstOp.VALUT_4 return None def _mem_op(t, op_name: str) -> InstOp: is_store = "STORE" in op_name if issubclass(t, _DS): return InstOp.LDS_WR_2 if is_store else InstOp.LDS_RD if issubclass(t, _GLOBAL): return InstOp.SGMEM_WR_2 if is_store else InstOp.SGMEM_RD_1 if issubclass(t, _FLAT): return InstOp.FLAT_WR_3 if is_store else InstOp.FLAT_RD_2 if issubclass(t, _SCRATCH): return InstOp.FLAT_WR_3 if is_store else InstOp.FLAT_RD_2 return InstOp.SALU nibbles: list[int] = [] started: set[int] = set() _emit_nibbles(nibbles, LAYOUT_HEADER, layout=3, sel_a=6) def emit(wave_id: int, inst, branch_taken: bool|None): """Emit an SQTT packet for one executed instruction.""" w = wave_id & 0x1F if wave_id not in started: _emit_nibbles(nibbles, WAVESTART, delta=1, simd=0, wgp=0, wave=w, id7=wave_id) started.add(wave_id) inst_type, inst_op, op_name = type(inst), inst.op.value if hasattr(inst, 'op') else 0, inst.op.name if hasattr(inst, 'op') else "" if issubclass(inst_type, _SOPP): if inst_op in _SOPP_SKIP: return elif inst_op in _SOPP_IMMEDIATE: _emit_nibbles(nibbles, IMMEDIATE, delta=1, wave=w) elif inst_op in _SOPP_BARRIER: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.BARRIER) elif inst_op in _SOPP_BRANCH: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.JUMP if branch_taken else InstOp.JUMP_NO) else: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.SALU) elif issubclass(inst_type, _VALU): op = _valu_op(op_name) if op is None: _emit_nibbles(nibbles, VALUINST, delta=1, wave=w) else: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=op) elif issubclass(inst_type, _SMEM): _emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.SMEM_RD) else: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=_mem_op(inst_type, op_name)) def finish(wave_id: int): """Emit WAVEEND for a completed wave.""" if wave_id in started: _emit_nibbles(nibbles, WAVEEND, delta=1, simd=0, wgp=0, wave=wave_id & 0x1F) def finalize() -> bytes: """Pad and return the encoded SQTT blob.""" while len(nibbles) % 2 != 0: nibbles.append(0) nibbles.extend([0] * 32) while len(nibbles) % 64 != 0: nibbles.append(0) return _nibbles_to_bytes(nibbles) return emit, finish, finalize def _c(val, dtype=dtypes.uint32): return UOp.const(dtype, val) def _u64(lo: UOp, hi: UOp) -> UOp: """Combine two 32-bit UOps into a 64-bit UOp.""" return lo.cast(dtypes.uint64) | (hi.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32)) def _split64(val: UOp) -> tuple[UOp, UOp]: """Split a 64-bit value into (lo, hi) 32-bit values.""" v64 = val.bitcast(dtypes.uint64) if val.dtype == dtypes.float64 else val.cast(dtypes.uint64) if val.dtype != dtypes.uint64 else val return v64.cast(dtypes.uint32), (v64 >> UOp.const(dtypes.uint64, 32)).cast(dtypes.uint32) _SRC_MOD_TYPES = {16: (dtypes.uint16, dtypes.half, 0x7FFF), 32: (dtypes.uint32, dtypes.float32, 0x7FFFFFFF), 64: (dtypes.uint64, dtypes.float64, 0x7FFFFFFFFFFFFFFF)} def _apply_src_mods(val: UOp, mod_bit: int, abs_bits: int, neg_bits: int, bits: int = 32) -> UOp: """Apply abs/neg modifiers to source value based on bit width (16, 32, or 64).""" if not (abs_bits & (1 << mod_bit)) and not (neg_bits & (1 << mod_bit)): return val ut, ft, mask = _SRC_MOD_TYPES[bits] fv = val.cast(ut).bitcast(ft) if bits == 16 else val.bitcast(ft) if val.dtype == ut else val if abs_bits & (1 << mod_bit): fv = (fv.bitcast(ut) & UOp.const(ut, mask)).bitcast(ft) if neg_bits & (1 << mod_bit): fv = fv.neg() return fv.bitcast(ut).cast(dtypes.uint32) if bits == 16 else fv.bitcast(ut) # Map VOPD ops to VOP2 ops for pcode lookup (both RDNA3 and RDNA4) VOPD_TO_VOP2 = { ir3.VOPDOp.V_DUAL_FMAC_F32: ir3.VOP2Op.V_FMAC_F32_E32, ir3.VOPDOp.V_DUAL_MUL_F32: ir3.VOP2Op.V_MUL_F32_E32, ir3.VOPDOp.V_DUAL_ADD_F32: ir3.VOP2Op.V_ADD_F32_E32, ir3.VOPDOp.V_DUAL_SUB_F32: ir3.VOP2Op.V_SUB_F32_E32, ir3.VOPDOp.V_DUAL_SUBREV_F32: ir3.VOP2Op.V_SUBREV_F32_E32, ir3.VOPDOp.V_DUAL_MAX_F32: ir3.VOP2Op.V_MAX_F32_E32, ir3.VOPDOp.V_DUAL_MIN_F32: ir3.VOP2Op.V_MIN_F32_E32, ir3.VOPDOp.V_DUAL_ADD_NC_U32: ir3.VOP2Op.V_ADD_NC_U32_E32, ir3.VOPDOp.V_DUAL_LSHLREV_B32: ir3.VOP2Op.V_LSHLREV_B32_E32, ir3.VOPDOp.V_DUAL_AND_B32: ir3.VOP2Op.V_AND_B32_E32, ir3.VOPDOp.V_DUAL_MOV_B32: ir3.VOP1Op.V_MOV_B32_E32, ir3.VOPDOp.V_DUAL_CNDMASK_B32: ir3.VOP2Op.V_CNDMASK_B32_E32, ir3.VOPDOp.V_DUAL_FMAAK_F32: ir3.VOP2Op.V_FMAAK_F32_E32, ir3.VOPDOp.V_DUAL_FMAMK_F32: ir3.VOP2Op.V_FMAMK_F32_E32, ir3.VOPDOp.V_DUAL_DOT2ACC_F32_F16: ir3.VOP2Op.V_DOT2ACC_F32_F16_E32, # RDNA4 mappings (same VOP1/VOP2 targets, RDNA4 uses _NUM_ suffix for min/max) ir4.VOPDOp.V_DUAL_FMAC_F32: ir3.VOP2Op.V_FMAC_F32_E32, ir4.VOPDOp.V_DUAL_MUL_F32: ir3.VOP2Op.V_MUL_F32_E32, ir4.VOPDOp.V_DUAL_ADD_F32: ir3.VOP2Op.V_ADD_F32_E32, ir4.VOPDOp.V_DUAL_SUB_F32: ir3.VOP2Op.V_SUB_F32_E32, ir4.VOPDOp.V_DUAL_SUBREV_F32: ir3.VOP2Op.V_SUBREV_F32_E32, ir4.VOPDOp.V_DUAL_MAX_NUM_F32: ir3.VOP2Op.V_MAX_F32_E32, ir4.VOPDOp.V_DUAL_MIN_NUM_F32: ir3.VOP2Op.V_MIN_F32_E32, ir4.VOPDOp.V_DUAL_ADD_NC_U32: ir3.VOP2Op.V_ADD_NC_U32_E32, ir4.VOPDOp.V_DUAL_LSHLREV_B32: ir3.VOP2Op.V_LSHLREV_B32_E32, ir4.VOPDOp.V_DUAL_AND_B32: ir3.VOP2Op.V_AND_B32_E32, ir4.VOPDOp.V_DUAL_MOV_B32: ir3.VOP1Op.V_MOV_B32_E32, ir4.VOPDOp.V_DUAL_CNDMASK_B32: ir3.VOP2Op.V_CNDMASK_B32_E32, ir4.VOPDOp.V_DUAL_FMAAK_F32: ir3.VOP2Op.V_FMAAK_F32_E32, ir4.VOPDOp.V_DUAL_FMAMK_F32: ir3.VOP2Op.V_FMAMK_F32_E32, ir4.VOPDOp.V_DUAL_DOT2ACC_F32_F16: ir3.VOP2Op.V_DOT2ACC_F32_F16_E32, } def _wave_size(arch: str) -> int: return 64 if arch.startswith("cdna") else 32 # Special registers stored after inline constants (256-259) PC_LO_IDX, PC_HI_IDX, SCRATCH_STRIDE_IDX = 256, 257, 259 # SGPR buffer: 0-127 = SGPRs, 128-255 = inline constants, 256-259 = special registers SGPR_COUNT = 260 # Sentinel PC value for s_endpgm ENDPGM_PC = 0xFFFFFFFFFFFFFFFF def _op_name(inst) -> str: if hasattr(inst, 'opx'): return f"{inst.opx.name}_{inst.opy.name}" # VOPD has opx/opy not op return inst.op.name if hasattr(inst.op, 'name') else str(inst.op) def _to_u32(val: UOp) -> UOp: if val.dtype == dtypes.uint32: return val if val.dtype.itemsize == 4: return val.bitcast(dtypes.uint32) # same size: bitcast (float32->uint32) return val.cast(dtypes.uint32) # different size: cast (bool, int16, etc) def _lane_active(exec_mask: UOp, lane: UOp) -> UOp: if exec_mask.dtype == dtypes.uint64: return ((exec_mask >> lane.cast(dtypes.uint64)) & UOp.const(dtypes.uint64, 1)).ne(UOp.const(dtypes.uint64, 0)) return ((exec_mask >> lane.cast(dtypes.uint32)) & _c(1)).ne(_c(0)) def _hi16(v: UOp) -> UOp: return (v >> _c(16)) & _c(0xFFFF) def _cond(cond, if_true, if_false): """Select between values based on condition (works with UOp or bool).""" return cond.where(if_true, if_false) if isinstance(cond, UOp) else if_true if cond else if_false def _cond_hi16(cond, val: UOp) -> UOp: return _cond(cond, _hi16(val), val) def _apply_opsel(val: UOp, sel_bit: int, opsel: int) -> UOp: return _hi16(val) if opsel & (1 << sel_bit) else val def _set_lane_bit(old: UOp, lane: UOp, val: UOp, exec_mask: UOp) -> UOp: """Set/clear a single bit in a mask based on lane index, respecting exec mask.""" if old.dtype in (dtypes.uint64, dtypes.int64): dt = dtypes.uint64 mask = UOp.const(dt, 1) << lane.cast(dt) new_bit = _to_u32(val).cast(dt) << lane.cast(dt) cleared = old.cast(dt) & (mask ^ UOp.const(dt, 0xFFFFFFFFFFFFFFFF)) return _lane_active(exec_mask, lane).where(cleared | new_bit, old.cast(dt)) mask = _c(1) << lane.cast(dtypes.uint32) new_bit = _to_u32(val) << lane.cast(dtypes.uint32) cleared = old & (mask ^ _c(MASK32)) return _lane_active(exec_mask, lane).where(cleared | new_bit, old) def _val_to_u32(val: UOp) -> UOp: """Convert any value to uint32 for storage (bitcast floats, cast ints).""" if val.dtype == dtypes.uint32: return val if val.dtype == dtypes.float32: return val.bitcast(dtypes.uint32) if val.dtype == dtypes.half: return val.bitcast(dtypes.uint16).cast(dtypes.uint32) if val.dtype in (dtypes.uint16, dtypes.int16): return val.cast(dtypes.uint32) return val.cast(dtypes.uint32) _pcode_fixes = { 'V_DIV_FMAS_F32': ('D0.f32 = 2.0F ** 32 * fma(S0.f32, S1.f32, S2.f32)', 'D0.f32 = (exponent(S2.f32) > 127) ? (2.0F ** 64 * fma(S0.f32, S1.f32, S2.f32)) : (2.0F ** -64 * fma(S0.f32, S1.f32, S2.f32))'), 'V_DIV_FMAS_F64': ('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)', 'D0.f64 = (exponent(S2.f64) > 1023) ? (2.0 ** 128 * fma(S0.f64, S1.f64, S2.f64)) : (2.0 ** -128 * fma(S0.f64, S1.f64, S2.f64))'), 'V_DIV_FIXUP_F32': ('D0.f32 = sign_out ? -abs(S0.f32) : abs(S0.f32)', 'D0.f32 = isNAN(S0.f32) ? (sign_out ? -INF.f32 : +INF.f32) : (sign_out ? -abs(S0.f32) : abs(S0.f32))'), 'V_DIV_FIXUP_F64': ('D0.f64 = sign_out ? -abs(S0.f64) : abs(S0.f64)', 'D0.f64 = isNAN(S0.f64) ? (sign_out ? -INF : +INF) : (sign_out ? -abs(S0.f64) : abs(S0.f64))'), 'V_TRIG_PREOP_F64': ("result = 64'F((1201'B(2.0 / PI)[1200 : 0] << shift.u32) & 1201'0x1fffffffffffff)", "result = trig_preop_result(shift)"), } def _get_pcode_dict(op) -> dict: """Return the PCODE dictionary for the given opcode based on its architecture.""" return PCODE_CDNA if 'cdna' in type(op).__module__ else PCODE_RDNA4 if 'rdna4' in type(op).__module__ else PCODE_RDNA3 # Pcode parser @functools.cache def get_pcode(op) -> str: op_name = op.name pcode_dict = _get_pcode_dict(op) if op not in pcode_dict and op_name.endswith('_E64'): # VOP3 ops ending in _E64 may share pcode with VOP1 _E32 equivalents import importlib enum_mod = importlib.import_module(type(op).__module__) vop1_cls = getattr(enum_mod, 'VOP1Op', None) e32_name = op_name.replace('_E64', '_E32') if vop1_cls and hasattr(vop1_cls, e32_name): op = vop1_cls[e32_name] pcode = pcode_dict[op] fix_name = op_name.replace('_E64', '').replace('_E32', '') if fix_name in _pcode_fixes: pcode = pcode.replace(*_pcode_fixes[fix_name]) if 'V_DIV_SCALE' in op_name: dt, exp_lim, ldexp_val = ('f32', '23', '64') if 'F32' in op_name else ('f64', '52', '128') for old, new in [(f'S2.{dt} / S1.{dt} == DENORM.{dt}', f'divWouldBeDenorm(S2.{dt}, S1.{dt})'), (f"1.0 / 64'F(S1.{dt}) == DENORM.f64", '0'), (f'1.0 / S1.{dt} == DENORM.{dt}', '0'), (f'S1.{dt} == DENORM.{dt}', f'isDENORM(S1.{dt})'), (f'D0.{dt} = NAN.{dt}', f'VCC = 0x1LL;\nD0.{dt} = NAN.{dt}'), (f'elsif isDENORM(S1.{dt}) then\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})', f'elsif 1 == 0 then\nD0.{dt} = S0.{dt}'), (f'elsif exponent(S2.{dt}) <= {exp_lim} then\n// Numerator is tiny\n' f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})', f'elsif exponent(S2.{dt}) <= {exp_lim} then\nVCC = 0x1LL;\n' f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})'), (f'elsif divWouldBeDenorm(S2.{dt}, S1.{dt}) then\nVCC = 0x1LL;\n' f'if S0.{dt} == S2.{dt} then\n// Only scale the numerator\n' f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif', f'elsif divWouldBeDenorm(S2.{dt}, S1.{dt}) then\n' f'VCC = 0x1LL;\nD0.{dt} = S0.{dt}'), (f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif\nelsif', f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nelse\n' f'D0.{dt} = S0.{dt}\nendif\nelsif')]: pcode = pcode.replace(old, new) lines = pcode.rstrip().split('\n') for i in range(len(lines) - 1, -1, -1): if lines[i].strip() == 'endif': lines.insert(i, f'else\nD0.{dt} = S0.{dt}') break pcode = '\n'.join(lines) + f';\nif isDENORM(S1.{dt}) then\nD0.{dt} = NAN.{dt}\nendif' pcode = pcode.replace('VCC = 0x0LL', 'VCC.u64[laneId] = 0').replace('VCC = 0x1LL', 'VCC.u64[laneId] = 1') return pcode def parse_pcode(pcode: str, srcs: dict[str, UOp | int] | None = None) -> tuple[dict, list]: env: dict = srcs.copy() if srcs else {} assigns: list[tuple[str, UOp]] = [] raw_lines = [l.strip().rstrip(';') for l in pcode.split('\n') if l.strip() and not l.strip().startswith('//')] # TODO: pcode.py should tokenize full pcode string instead of line-by-line, then this hack can be removed lines: list[str] = [] for l in raw_lines: if lines and re.search(r'(&&|\|\||[&|+\-*/^])\s*$', lines[-1]): lines[-1] = lines[-1] + ' ' + l else: lines.append(l) _, final, _ = parse_block(lines, 0, env, assigns=assigns) sliced = set(d.split('[')[0] for d, _ in assigns if '[' in d) for var, val in final.items(): if var in ['D0', 'S0', 'SCC', 'VCC', 'EXEC', 'PC', 'RETURN_DATA', 'VDATA'] and isinstance(val, UOp): if var in sliced and not any(re.match(rf'{var}\.\w+\s*=', l) for l in lines): continue for l in lines: if (m := re.match(rf'{var}\.(\w+(?:\[\w+\])?)', l)): assigns.append((f'{var}.{m.group(1)}', val)) break else: assigns.append((var, val)) return env, assigns def _write_64bit(val: UOp, wfn, reg_or_addr, is_mem: bool, *args) -> list[UOp]: """Write a 64-bit value as two 32-bit writes. args passed to wfn after reg/addr and lo/hi value.""" lo, hi = _split64(val) incr = 4 if is_mem else 1 # 4 bytes for memory addresses, 1 for register indices return [wfn(reg_or_addr, lo, *args), wfn(reg_or_addr + (UOp.const(reg_or_addr.dtype, incr) if isinstance(reg_or_addr, UOp) else incr), hi, *args)] def _write_val(bits: int, val: UOp, wfn, reg_or_addr, *args, is_mem: bool = False) -> list[UOp]: """Write value, splitting 64-bit if needed. bits=64 for 64-bit writes, otherwise 32-bit.""" return _write_64bit(val, wfn, reg_or_addr, is_mem, *args) if bits == 64 else [wfn(reg_or_addr, _to_u32(val), *args)] def _mem_store(mem: UOp, addr: UOp, val: UOp, active: UOp, addr_bits: int = 32, data_bits: int = 32) -> list[UOp]: """Conditional memory store with sub-word support. Returns list of store UOps.""" adt = dtypes.uint64 if addr_bits == 64 else dtypes.uint32 word_addr = addr >> UOp.const(adt, 2) idx = mem.index(word_addr.cast(dtypes.int).valid(active)) if data_bits == 32: return [idx.store(active.where(_to_u32(val), idx))] # Sub-word store: read-modify-write with mask byte_pos = addr.cast(dtypes.uint32) & _c(3) byte_shift = byte_pos * _c(8) val_u32, size_mask = val.cast(dtypes.uint32), _c(0xFF if data_bits == 8 else 0xFFFF) mask = size_mask << byte_shift new_word = (idx & (mask ^ _c(0xFFFFFFFF))) | ((val_u32 & size_mask) << byte_shift) if data_bits == 8: return [idx.store(active.where(new_word, idx))] # 16-bit cross-word case: byte_pos == 3 means value spans two words is_cross = byte_pos.eq(_c(3)) cross_word0 = (idx & _c(0x00FFFFFF)) | ((val_u32 & _c(0xFF)) << _c(24)) store0 = idx.store(active.where(is_cross.where(cross_word0, new_word), idx)) next_idx = mem.index((word_addr + UOp.const(adt, 1)).cast(dtypes.int).valid(active & is_cross)) cross_word1 = (next_idx & _c(0xFFFFFF00)) | ((val_u32 >> _c(8)) & _c(0xFF)) return [store0, next_idx.store((active & is_cross).where(cross_word1, next_idx))] def _mem_store_bytes(mem: UOp, addr: UOp, val: UOp, active: UOp, data_bits: int = 32) -> list[UOp]: """Store to byte-addressable memory (scratch). addr is byte offset, mem is uint8 buffer.""" stores = [] val_u32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val for i in range(data_bits // 8): byte_val = (val_u32 >> UOp.const(dtypes.uint32, i * 8)) & UOp.const(dtypes.uint32, 0xFF) stores.append(mem.index((addr + UOp.const(dtypes.uint64, i)).cast(dtypes.int).valid(active)).store(byte_val.cast(dtypes.uint8))) return stores def _collect_data_slices(assigns: list[tuple[str, UOp]], data_prefix: str, pcode_vars: dict | None = None, op_name: str = "") -> dict[int, UOp]: """Collect bit slices from assigns into {dword_idx: value} dict.""" slices = {} for dest, val in assigns: if dest.startswith(f'{data_prefix}['): if (m := re.match(rf'{data_prefix}\[(\d+)\s*:\s*(\d+)\]', dest)): hi_bit, low_bit = int(m.group(1)), int(m.group(2)) dword_idx = low_bit // 32 # D16 loads preserve bits - use final value from pcode_vars which has hi bits preserved if pcode_vars and 'D16' in op_name and dword_idx == 0 and hi_bit < 32: slices[0] = _to_u32(pcode_vars.get(data_prefix, val)) else: slices[dword_idx] = _to_u32(val) elif dest.startswith(data_prefix): slices[0] = _to_u32(val) return slices # ═══════════════════════════════════════════════════════════════════════════════ # INSTRUCTION COMPILER - converts decoded instruction to UOp SINK # ═══════════════════════════════════════════════════════════════════════════════ class _Ctx: """Context for instruction compilation - holds buffers and helpers.""" __slots__ = ('inst_size', 'dyn_fields', '_axis_id', 'wave_size', 'vgpr', 'accvgpr') sgpr = UOp.param(0, dtypes.uint32.ptr(SGPR_COUNT)) vmem = UOp.param(2, dtypes.uint32.ptr(1 << 46)) lds = UOp.param(3, dtypes.uint32.ptr(16384)) scratch = UOp.param(4, dtypes.uint8.ptr(1 << 30)) # Cache PARAM UOps by wave_size so all _Ctx instances with same wave_size share identical UOp references _vgpr_cache: dict[int, UOp] = {} _accvgpr_cache: dict[int, UOp] = {} def __init__(self, inst_size: int, wave_size: int = 32): self.inst_size, self._axis_id, self.wave_size = inst_size, 0, wave_size self.dyn_fields: list[tuple[int, int]] = [] # (lo, hi) of fields read dynamically if wave_size not in _Ctx._vgpr_cache: _Ctx._vgpr_cache[wave_size] = UOp.param(1, dtypes.uint32.ptr(256 * wave_size)) self.vgpr = _Ctx._vgpr_cache[wave_size] if wave_size == 64: if wave_size not in _Ctx._accvgpr_cache: _Ctx._accvgpr_cache[wave_size] = UOp.param(5, dtypes.uint32.ptr(256 * wave_size)) self.accvgpr = _Ctx._accvgpr_cache[wave_size] else: self.accvgpr = self.vgpr def range(self, n: int | None = None) -> UOp: """Create a lane range UOp with unique axis ID.""" if n is None: n = self.wave_size self._axis_id += 1 return UOp.range(n, self._axis_id, AxisType.LOOP, dtype=dtypes.int) def unroll_lanes(self, get_lane_bit, exec_mask: UOp, apply_exec: bool = True) -> UOp: """Combine lane bits into a mask using RANGE+REDUCE (32-bit for RDNA, 64-bit for CDNA).""" lane = self.range() if self.wave_size <= 32: bit = get_lane_bit(lane).cast(dtypes.uint32) << lane.cast(dtypes.uint32) result = bit.reduce(lane, arg=Ops.ADD) else: bit = get_lane_bit(lane).cast(dtypes.uint64) << lane.cast(dtypes.uint64) result = bit.reduce(lane, arg=Ops.ADD) return result & exec_mask if apply_exec else result def inst_word(self, dword_idx: int) -> UOp: """Read instruction dword from vmem at PC + dword_idx*4.""" pc = self.rpc() addr = pc if dword_idx == 0 else pc + UOp.const(dtypes.uint64, dword_idx * 4) return self.vmem.index((addr >> UOp.const(dtypes.uint64, 2)).cast(dtypes.int), ptr=True).load() def inst_field(self, field) -> UOp: """Extract field bits from instruction encoding. Tracks field for canonical key computation.""" lo, hi = field.lo, field.hi self.dyn_fields.append((lo, hi)) dword_idx = lo // 32 lo_in_dword = lo % 32 hi_in_dword = hi % 32 word = self.inst_word(dword_idx) if lo // 32 == hi // 32: # Same dword mask = (1 << (hi - lo + 1)) - 1 shifted = word if lo_in_dword == 0 else word >> UOp.const(dtypes.uint32, lo_in_dword) return shifted & UOp.const(dtypes.uint32, mask) else: # Spans two dwords lo_bits = 32 - lo_in_dword lo_mask = (1 << lo_bits) - 1 hi_mask = (1 << (hi_in_dword + 1)) - 1 lo_part = (word >> UOp.const(dtypes.uint32, lo_in_dword)) & UOp.const(dtypes.uint32, lo_mask) hi_part = self.inst_word(dword_idx + 1) & UOp.const(dtypes.uint32, hi_mask) return lo_part | (hi_part << UOp.const(dtypes.uint32, lo_bits)) def inst_field_signed(self, field) -> UOp: """Extract field and sign-extend based on field width.""" val = self.inst_field(field) width = field.hi - field.lo + 1 sign_bit = 1 << (width - 1) return (val.cast(dtypes.int) ^ _c(sign_bit, dtypes.int)) - _c(sign_bit, dtypes.int) def canonical_mask(self, inst_bytes: bytes) -> tuple[int, int, int]: """Compute canonical (base, mask, size) for cache lookup. base = instruction bits with dynamic fields zeroed mask = bitmask with 1s for static bits, 0s for dynamic bits size = instruction size in bytes""" size = self.inst_size base = int.from_bytes(inst_bytes[:size], 'little') mask = (1 << (size * 8)) - 1 # all 1s initially for lo, hi in self.dyn_fields: field_mask = ((1 << (hi - lo + 1)) - 1) << lo base &= ~field_mask # zero dynamic bits in base mask &= ~field_mask # zero dynamic bits in mask return base, mask, size def rexec(self) -> UOp: """Read full EXEC mask (32-bit for RDNA, 64-bit for CDNA).""" lo = self.rsgpr_dyn(_c(EXEC_LO.offset)) if self.wave_size <= 32: return lo hi = self.rsgpr_dyn(_c(EXEC_LO.offset + 1)) return _u64(lo, hi) # Dynamic register access (takes UOp index instead of int) def rsgpr_dyn(self, reg: UOp, valid: UOp | None = None) -> UOp: """Read SGPR with dynamic register index.""" if valid is not None: return self.sgpr.index(reg.cast(dtypes.int).valid(valid), ptr=True).load() return self.sgpr.index(reg.cast(dtypes.int), ptr=True).load() def wsgpr_dyn(self, reg: UOp, val: UOp) -> UOp: """Write SGPR with dynamic register index. On RDNA, index 124 = NULL (writes discarded). On CDNA, index 124 = M0 (read/write).""" # RDNA: NULL (124) discards writes. CDNA: M0 (124) is writable. valid = None if self.wave_size == 64 else reg.ne(_c(124)) return self.sgpr.index(reg.cast(dtypes.int).valid(valid) if valid is not None else reg.cast(dtypes.int), ptr=True).store(val.cast(dtypes.uint32)) def wmask(self, reg: UOp, val: UOp) -> list[UOp]: """Write a lane mask (VCC/EXEC). Splits into lo/hi for wave64.""" if self.wave_size > 32: lo, hi = _split64(val) return [self.wsgpr_dyn(reg, lo), self.wsgpr_dyn(reg + _c(1), hi)] return [self.wsgpr_dyn(reg, val)] def rmask(self, reg: UOp) -> UOp: """Read a lane mask (VCC/EXEC). Combines lo/hi for wave64.""" if self.wave_size > 32: return _u64(self.rsgpr_dyn(reg), self.rsgpr_dyn(reg + _c(1))) return self.rsgpr_dyn(reg) def rvgpr_dyn(self, reg: UOp, lane: UOp, valid: UOp | None = None) -> UOp: """Read VGPR with dynamic register index.""" idx = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int) return self.vgpr.index(idx.valid(valid), ptr=True).load() if valid is not None else self.vgpr.index(idx, ptr=True).load() def wvgpr_dyn(self, reg: UOp, lane: UOp, val: UOp, exec_mask: UOp, after: UOp | None = None) -> UOp: """Write VGPR with dynamic register index.""" buf = self.vgpr.after(after) if after is not None else self.vgpr offset = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int) return buf.index(offset.valid(_lane_active(exec_mask, lane))).store(val.cast(dtypes.uint32)) def raccvgpr_dyn(self, reg: UOp, lane: UOp, valid: UOp | None = None) -> UOp: """Read ACCVGPR with dynamic register index (CDNA only).""" idx = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int) return self.accvgpr.index(idx.valid(valid), ptr=True).load() if valid is not None else self.accvgpr.index(idx, ptr=True).load() def waccvgpr_dyn(self, reg: UOp, lane: UOp, val: UOp, exec_mask: UOp, after: UOp | None = None) -> UOp: """Write ACCVGPR with dynamic register index (CDNA only).""" buf = self.accvgpr.after(after) if after is not None else self.accvgpr offset = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int) return buf.index(offset.valid(_lane_active(exec_mask, lane))).store(val.cast(dtypes.uint32)) def rsrc_dyn(self, off: UOp, lane: UOp | None, bits: int = 32, literal: UOp | None = None, is_f64: bool = False, do_cast: bool = True) -> UOp: """Read source operand with dynamic offset. Handles SGPR/inline constants (<256), VGPR (>=256). If lane is None, only scalar access is supported (off must be < 256). is_f64: True for F64 operations where 64-bit literals go in high 32 bits.""" is_float_const = (off >= _c(240)) & (off <= _c(248)) is_vgpr = off >= _c(256) is_sgpr = is_vgpr.ne(True) sgpr_lo = self.rsgpr_dyn(off, is_sgpr) if lane is not None: vgpr_reg = off - _c(256) vgpr_lo = self.rvgpr_dyn(vgpr_reg, lane, is_vgpr) vgpr_val = _u64(vgpr_lo, self.rvgpr_dyn(vgpr_reg + _c(1), lane, is_vgpr)) if bits == 64 else vgpr_lo if bits == 64: sgpr_hi = self.rsgpr_dyn(off + _c(1), is_sgpr) sgpr_val = _u64(sgpr_lo, sgpr_hi) # Integer inline constants: sign-extend 32-bit value from buffer to 64-bit # Float constants: cast F32 to F64 int_inline = sgpr_lo.cast(dtypes.int32).cast(dtypes.int64) float_inline = sgpr_lo.bitcast(dtypes.float32).cast(dtypes.float64) # compute inline inline = is_float_const.where(float_inline.bitcast(dtypes.uint64), int_inline.bitcast(dtypes.uint64)) # Literal handling: F64 VOP puts literal in high 32 bits; B64/I64/U64 VOP and SOP zero-extend if literal is not None: lit_val = literal.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32) if is_f64 else literal.cast(dtypes.uint64) inline = off.eq(_c(255)).where(lit_val, inline) scalar_val = (off < _c(128)).where(sgpr_val, inline) else: scalar_val = sgpr_lo if literal is not None: scalar_val = off.eq(_c(255)).where(literal, scalar_val) if bits == 16 and do_cast: # Float constants: cast F32 to F16 scalar_val = is_float_const.where(scalar_val.bitcast(dtypes.float32).cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32), scalar_val) return is_vgpr.where(vgpr_val, scalar_val) if lane is not None else scalar_val def rpc(self) -> UOp: """Read PC as 64-bit byte address.""" # Index at PC_LO, then cast to uint64 ptr and load return self.sgpr.index(_c(PC_LO_IDX, dtypes.int), ptr=True).cast(dtypes.uint64.ptr(SGPR_COUNT // 2)).load() def inc_pc(self) -> list[UOp]: """Increment PC by instruction size in bytes. Returns [store].""" new_pc = self.rpc() + UOp.const(dtypes.uint64, self.inst_size) return [self.sgpr.index(_c(PC_LO_IDX, dtypes.int), ptr=True).cast(dtypes.uint64.ptr(SGPR_COUNT // 2)).store(new_pc)] def scalar_stores(self, assigns: list[tuple[str, UOp]], sdst_reg: UOp, sdst_size: int = 1) -> list[UOp]: """Generate stores for scalar assigns with dynamic destination register (D0, SCC, EXEC, VCC).""" stores: list[UOp] = [] for dest, val in assigns: if dest.startswith('D0'): if sdst_size == 2: lo, hi = _split64(val) stores.extend([self.wsgpr_dyn(sdst_reg, lo), self.wsgpr_dyn(sdst_reg + _c(1), hi)]) else: stores.append(self.wsgpr_dyn(sdst_reg, _val_to_u32(val))) elif dest.startswith('SCC'): stores.append(self.wsgpr_dyn(_c(SCC.offset), _to_u32(val))) elif dest.startswith('EXEC'): if self.wave_size > 32 and val.dtype in (dtypes.uint64, dtypes.int64): lo, hi = _split64(val) stores.extend([self.wsgpr_dyn(_c(EXEC_LO.offset), lo), self.wsgpr_dyn(_c(EXEC_LO.offset + 1), hi)]) else: stores.append(self.wsgpr_dyn(_c(EXEC_LO.offset), _to_u32(val))) elif dest.startswith('VCC'): stores.extend(self.wmask(_c(VCC_LO.offset), val)) return stores def compile_sop_pcode(self, op, srcs: dict[str, UOp | int], sdst_reg: UOp, sdst_size: int) -> UOp: """Compile a scalar instruction with dynamic destination register.""" pcode = get_pcode(op) srcs.update({'VCC': self.rmask(_c(VCC_LO.offset)), 'EXEC': self.rexec(), 'SCC': self.rsgpr_dyn(_c(SCC.offset)), '_wave_size': self.wave_size}) if 'D0' not in srcs: srcs['D0'] = self.rsgpr_dyn(sdst_reg) # D0 is current dest value for read-modify-write ops _, assigns = parse_pcode(pcode, srcs) return UOp.sink(*self.scalar_stores(assigns, sdst_reg, sdst_size), *self.inc_pc()) def compile_lane_pcode(self, op, inst) -> UOp: """Compile cross-lane ops (READLANE/WRITELANE/PERMLANE) using pcode parser.""" pcode = get_pcode(op) op_name = op.name if hasattr(op, 'name') else str(op) src0_off, vdst_off = self.inst_field(type(inst).src0), self.inst_field(type(inst).vdst) src0_reg = (src0_off >= _c(256)).where(src0_off - _c(256), _c(0)) # VGPR index or 0 src1_off = self.inst_field(type(inst).src1) if hasattr(type(inst), 'src1') else None src2_off = self.inst_field(type(inst).src2) if hasattr(type(inst), 'src2') else None src1_reg = (src1_off >= _c(256)).where(src1_off - _c(256), src1_off) if src1_off is not None else _c(0) src2_reg = (src2_off >= _c(256)).where(src2_off - _c(256), src2_off) if src2_off is not None else _c(0) exec_val = self.rexec() exec_lo = exec_val.cast(dtypes.uint32) if exec_val.dtype == dtypes.uint64 else exec_val srcs = { 'SRC0': src0_reg, 'VDST': vdst_off, 'EXEC_LO': exec_lo, 'EXEC': exec_val if exec_val.dtype == dtypes.uint64 else exec_val.cast(dtypes.uint64), '_vgpr': self.vgpr, '_wave_size': self.wave_size, 'SRC1': src1_reg, 'SRC2': src2_reg, 'S0': self.rsrc_dyn(src0_off, _c(0, dtypes.int)) if 'WRITELANE' in op_name else src0_reg, 'S1': self.rsrc_dyn(src1_off, _c(0, dtypes.int)) if src1_off is not None else _c(0), 'S2': self.rsrc_dyn(src2_off, _c(0, dtypes.int)) if src2_off is not None else _c(0), } _, assigns = parse_pcode(pcode, srcs) stores = [] for dest, val in assigns: if dest.startswith('D0'): stores.append(self.wsgpr_dyn(vdst_off, val.cast(dtypes.uint32))) elif dest.startswith('VGPR['): stores.append(self.vgpr.index(val[0].cast(dtypes.int)).store(val[1].cast(dtypes.uint32))) return UOp.sink(*stores, *self.inc_pc()) def compile_vop_pcode(self, op, srcs: dict[str, UOp | int], lane: UOp, vdst_reg: UOp, exec_mask: UOp, opsel_dst_hi: bool | UOp = False, sdst_reg: int | None = None, clmp: int = 0, src0_off: UOp | None = None) -> UOp: """Compile VOP instruction. Returns sink with stores and inc_pc.""" pcode = get_pcode(op) vcc_reg = sdst_reg if sdst_reg is not None else VCC_LO.offset if 'VCC' not in srcs: srcs['VCC'] = self.rmask(_c(vcc_reg)) srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane, 'VDST': vdst_reg, 'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0), 'ROUND_NEAREST_EVEN': _c(0), '_vgpr': self.vgpr, '_wave_size': self.wave_size, 'MAX_FLOAT_F32': UOp.const(dtypes.float32, 3.4028234663852886e38), # CDNA SDWA byte/word select constants (E32 always uses BYTE0/WORD0 defaults) 'SDWA_SRC0_SEL': _c(0), 'BYTE0': _c(0), 'BYTE1': _c(1), 'BYTE2': _c(2), 'BYTE3': _c(3), 'WORD0': _c(0), 'WORD1': _c(1)}) # rounding mode and SDWA constants _, assigns = parse_pcode(pcode, srcs) # For integer ops with clamp, compute overflow using wide arithmetic # NOTE: MUL_LO ops don't saturate - they always return the low bits int_saturate = None if clmp and any(p in op.name for p in ('_NC_U', '_MAD_U', '_NC_I', '_MAD_I')): is_signed, is_16bit = '_I' in op.name and '_U' not in op.name, '16' in op.name if not (is_16bit and is_signed): # Skip 16-bit signed ops due to codegen issues s0, s1, s2 = srcs.get('S0'), srcs.get('S1'), srcs.get('S2') if s0 is not None and s1 is not None: narrow_dt = dtypes.uint16 if is_16bit else (dtypes.int32 if is_signed else dtypes.uint32) wide_dt = dtypes.int32 if is_16bit else dtypes.int64 narrow_max, narrow_min = (0xFFFF, 0) if is_16bit else ((0x7FFFFFFF, -0x80000000) if is_signed else (0xFFFFFFFF, 0)) def to_wide(x): return (x.bitcast(narrow_dt) if x.dtype.itemsize == narrow_dt.itemsize else x.cast(narrow_dt)).cast(wide_dt) is_sub, is_mad = 'SUB' in op.name, 'MAD' in op.name full = (to_wide(s0) * to_wide(s1) + to_wide(s2)) if is_mad and s2 is not None else \ (to_wide(s1) - to_wide(s0)) if is_sub and 'SUBREV' in op.name else \ (to_wide(s0) - to_wide(s1)) if is_sub else (to_wide(s0) + to_wide(s1)) int_saturate = full.clamp(narrow_min, narrow_max).cast(narrow_dt) # V_SUB_U32 / V_ADD_U32 with clamp: unsigned saturate (SUB underflow->0, ADD overflow->0xFFFFFFFF) if clmp and int_saturate is None and any(p in op.name for p in ('_SUB_U32', '_ADD_U32', '_SUB_U16', '_ADD_U16')): s0, s1 = srcs.get('S0'), srcs.get('S1') if s0 is not None and s1 is not None: assert isinstance(s0, UOp) and isinstance(s1, UOp) a, b = (s1.cast(dtypes.uint32), s0.cast(dtypes.uint32)) if 'SUBREV' in op.name else (s0.cast(dtypes.uint32), s1.cast(dtypes.uint32)) if 'SUB' in op.name: int_saturate = (a < b).where(_c(0), a - b) # underflow -> 0 else: raw_sum = a + b int_saturate = (raw_sum < a).where(_c(0xFFFFFFFF), raw_sum) # overflow -> MAX raw_stores: list = [] vcc_val, exec_val = None, None for dest, val in assigns: # VGPR bit-slice assignment: VGPR[lane][reg][hi:lo] = (vgpr_idx, rhs_val, hi, lo[, cond]) -> read-modify-write if dest.startswith('VGPR[') and re.search(r'\[\d+:\d+\]', dest): # VGPR bit-slice: (vgpr_idx, rhs_val, hi_bit, lo_bit) - hi/lo are UOp constants hi_bit, lo_bit = int(val[2].arg), int(val[3].arg) width = hi_bit - lo_bit + 1 old = self.vgpr.index(val[0].cast(dtypes.int), ptr=True).load() new_val = _set_bits(old, _val_to_bits(val[1]), width, lo_bit).cast(dtypes.uint32) active = _lane_active(exec_mask, lane) if len(val) > 4: active = active & _to_bool(val[4]) raw_stores.append(('vgpr_direct', self.vgpr.index(val[0].cast(dtypes.int).valid(active)).store(new_val))) continue if 'D0' in dest and '[laneId]' in dest: old_vcc = self.rmask(_c(VCC_LO.offset)) new_vcc = _set_lane_bit(old_vcc, lane, val, exec_mask) raw_stores.extend([('vcc', s) for s in self.wmask(_c(VCC_LO.offset), new_vcc)]) elif dest.startswith('D0'): dest_suffix = re.match(r'D0\.(\w+)', dest) if dest_suffix is not None: target_dt = {'u16': dtypes.uint16, 'i16': dtypes.int16, 'f16': dtypes.half}.get(dest_suffix.group(1)) if target_dt is not None and val.dtype != target_dt: val = val.cast(target_dt) if (slice_match := re.match(r'D0\[(\d+)\s*:\s*(\d+)\]', dest)): d0_hi_bit, d0_lo_bit = int(slice_match.group(1)), int(slice_match.group(2)) if d0_hi_bit != 31 or d0_lo_bit != 0: d0_width, slice_mask = d0_hi_bit - d0_lo_bit + 1, (1 << (d0_hi_bit - d0_lo_bit + 1)) - 1 val_bits = val.bitcast(dtypes.uint16).cast(dtypes.uint32) if val.dtype == dtypes.half else \ val.cast(dtypes.uint32) if val.dtype in (dtypes.uint16, dtypes.int16) else \ val.cast(dtypes.uint32) & UOp.const(dtypes.uint32, slice_mask) raw_stores.append(('vgpr_slice', (d0_lo_bit, d0_width, val_bits))) continue # For integer ops with clamp, use pre-computed saturated value; for floats, clamp to [0,1] if int_saturate is not None: val = int_saturate elif clmp and val.dtype in (dtypes.float32, dtypes.half, dtypes.float64): clamped = val.maximum(UOp.const(val.dtype, 0.0)).minimum(UOp.const(val.dtype, 1.0)) val = _FUNCS['isNAN'](val).where(UOp.const(val.dtype, 0.0), clamped) if val.dtype in (dtypes.uint64, dtypes.int64, dtypes.float64): lo, hi = _split64(val) raw_stores.extend([('vgpr', self.wvgpr_dyn(vdst_reg, lane, lo, exec_mask)), ('vgpr', self.wvgpr_dyn(vdst_reg + _c(1), lane, hi, exec_mask))]) elif val.dtype in (dtypes.half, dtypes.uint16, dtypes.int16): result, old_val = _val_to_u32(val), self.rvgpr_dyn(vdst_reg, lane) hi_result = (old_val & UOp.const(dtypes.uint32, 0xFFFF)) | (result << UOp.const(dtypes.uint32, 16)) # GFX9/CDNA zeroes upper 16 bits on lo-half write; RDNA preserves them lo_result = (result & UOp.const(dtypes.uint32, 0xFFFF)) if self.wave_size == 64 else \ (old_val & UOp.const(dtypes.uint32, 0xFFFF0000)) | (result & UOp.const(dtypes.uint32, 0xFFFF)) result = opsel_dst_hi.where(hi_result, lo_result) if isinstance(opsel_dst_hi, UOp) else hi_result if opsel_dst_hi else lo_result raw_stores.append(('vgpr', self.wvgpr_dyn(vdst_reg, lane, result, exec_mask))) else: raw_stores.append(('vgpr', self.wvgpr_dyn(vdst_reg, lane, _val_to_u32(val), exec_mask))) elif dest.startswith('S0') and src0_off is not None: # Write back to src0 VGPR (e.g. v_swap_b32). src0_off is raw encoding (256+ = VGPR) src0_vgpr = src0_off - _c(256) raw_stores.append(('vgpr_s0', self.wvgpr_dyn(src0_vgpr, lane, _val_to_u32(val), exec_mask))) elif dest.startswith('VCC'): vcc_val = val elif dest.startswith('EXEC'): exec_val = val elif dest.startswith('SCC'): raw_stores.append(('scc', self.wsgpr_dyn(_c(SCC.offset), _to_u32(val)))) lane_stores = [s for t, s in raw_stores if t in ('vgpr', 'vgpr_s0', 'vgpr_direct')] stores, scalar_stores = [], [s for t, s in raw_stores if t == 'scc'] slice_stores = [s for t, s in raw_stores if t == 'vgpr_slice'] if slice_stores: result = self.rvgpr_dyn(vdst_reg, lane) for lo_bit, width, val_bits in slice_stores: mask = UOp.const(dtypes.uint32, ((1 << width) - 1) << lo_bit) result = (result & (mask ^ UOp.const(dtypes.uint32, 0xFFFFFFFF))) | (val_bits << UOp.const(dtypes.uint32, lo_bit)) lane_stores.append(self.wvgpr_dyn(vdst_reg, lane, result, exec_mask)) # VCC/EXEC mask writes must be computed BEFORE VGPR stores to avoid reading modified VGPRs. # When vdst overlaps with src operands (e.g. v_add_co_u32 v[0], vcc, s[8], v[0]), the carry # computation reads the original source values only if its range loop runs before the VGPR write loop. mask_stores: list[UOp] = [] for mask_val, reg in [(vcc_val, vcc_reg), (exec_val, EXEC_LO.offset)]: if mask_val is None: continue def get_bit(l, v=mask_val): return (_to_u32(v.substitute({lane: l})) & _c(1)).cast(dtypes.uint32) mask_stores.extend(self.wmask(_c(reg), self.unroll_lanes(get_bit, exec_mask, apply_exec=False))) stores.extend(mask_stores) if lane_stores: stores.append(UOp.sink(*lane_stores).end(lane)) stores.extend(scalar_stores) return UOp.sink(*stores, *self.inc_pc()) # ═══════════════════════════════════════════════════════════════════════════════ # INSTRUCTION HANDLERS # ═══════════════════════════════════════════════════════════════════════════════ def _compile_sopp(inst: ir3.SOPP | ir4.SOPP, ctx: _Ctx) -> UOp: simm16 = ctx.inst_field_signed(type(inst).simm16).cast(dtypes.int16) if inst.op in (ir3.SOPPOp.S_ENDPGM, ir4.SOPPOp.S_ENDPGM, irc.SOPPOp.S_ENDPGM): return UOp.sink(ctx.wsgpr_dyn(_c(PC_LO_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)), ctx.wsgpr_dyn(_c(PC_HI_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF))) # S_BARRIER: advance PC past the barrier instruction. The execution loop detects barriers before executing and handles synchronization. barrier_ops = {ir3.SOPPOp.S_BARRIER, irc.SOPPOp.S_BARRIER} if hasattr(ir4.SOPPOp, 'S_BARRIER_WAIT'): barrier_ops.add(ir4.SOPPOp.S_BARRIER_WAIT) if inst.op in barrier_ops: return UOp.sink(*ctx.inc_pc()) # S_NOP and S_WAITCNT are no-ops in emulator (no pipeline/cache to wait on) if inst.op in (ir3.SOPPOp.S_NOP, ir4.SOPPOp.S_NOP, irc.SOPPOp.S_NOP, irc.SOPPOp.S_WAITCNT): return UOp.sink(*ctx.inc_pc()) # NOTE: we ignore SOPPs without PCODE if inst.op in _get_pcode_dict(inst.op): pcode = get_pcode(inst.op) pc_bytes = ctx.rpc() # PC is already 64-bit byte address vcc, exec_val = ctx.rmask(_c(VCC_LO.offset)), ctx.rexec() srcs = {'PC': pc_bytes.cast(dtypes.int64), 'SIMM16': simm16, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'VCC': vcc, 'VCCZ': vcc.eq(UOp.const(vcc.dtype, 0)).cast(dtypes.uint32), 'EXECZ': exec_val.eq(UOp.const(exec_val.dtype, 0)).cast(dtypes.uint32)} for dest, val in parse_pcode(pcode, srcs)[1]: if dest == 'PC' or dest.startswith('PC.'): lo, hi = _split64(val.cast(dtypes.uint64)) return UOp.sink(ctx.wsgpr_dyn(_c(PC_LO_IDX), lo), ctx.wsgpr_dyn(_c(PC_HI_IDX), hi)) return UOp.sink(*ctx.inc_pc()) def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp: # Cache invalidation instructions are no-ops in the emulator (we don't model caches) if '_INV' in inst.op.name: return UOp.sink(*ctx.inc_pc()) # Dynamic sbase field (bits 5:0) - SGPR pair, field value * 2 = register offset sbase = ctx.inst_field(type(inst).sbase) * _c(2) # Dynamic sdata field (bits 12:6) - destination SGPR sdata_reg = ctx.inst_field(type(inst).sdata) # RDNA4 uses 'ioffset', RDNA3 uses 'offset' - use type(inst) to get correct field offset_field = type(inst).ioffset if hasattr(type(inst), 'ioffset') else type(inst).offset # type: ignore[union-attr] offset = ctx.inst_field_signed(offset_field) # signed immediate # Dynamic soffset field - SGPR for additional offset (NULL=124 reads as 0, CDNA soffset_en=0 means no soffset) soffset_val = _c(0).cast(dtypes.uint64) if not (isinstance(inst, irc.SMEM) and not inst.soffset_en): soffset_val = ctx.rsgpr_dyn(ctx.inst_field(type(inst).soffset)).cast(dtypes.uint64) addr = _u64(ctx.rsgpr_dyn(sbase), ctx.rsgpr_dyn(sbase + _c(1))) + offset.cast(dtypes.uint64) + soffset_val # S_LOAD_(DTYPE) series: B32/DWORD=1, B64/DWORDX2=2, U8=0.25, I8=-0.25, etc. op_name = _op_name(inst) assert (op_name).startswith('S_LOAD_'), f"unexpected SMEM op: {op_name}" part = op_name.rsplit('_', 1)[1] # B32, DWORD, DWORDX2, U8, I8, etc. nval = int(part.removeprefix('DWORD').removeprefix('X') or '1') if 'DWORD' in part else int(part[1:]) / 32 * (-1 if part[0] == 'I' else 1) ndwords = max(1, int(abs(nval))) dword_base = addr >> UOp.const(dtypes.uint64, 2) vals = [ctx.vmem.index((dword_base + UOp.const(dtypes.uint64, i)).cast(dtypes.int)) for i in range(ndwords)] if abs(nval) < 1: nbits = int(abs(nval) * 32) byte_off = (addr & UOp.const(dtypes.uint64, 3)).cast(dtypes.uint32) * UOp.const(dtypes.uint32, 8) extracted = (vals[0] >> byte_off) & UOp.const(dtypes.uint32, (1 << nbits) - 1) vals[0] = extracted.cast({8: dtypes.int8, 16: dtypes.int16}[nbits]).cast(dtypes.int32).bitcast(dtypes.uint32) if nval < 0 else extracted stores = [ctx.wsgpr_dyn(sdata_reg + _c(i), vals[i]) for i in range(ndwords)] return UOp.sink(*stores, *ctx.inc_pc()) def _compile_sop(inst: ir3.SOP1|ir3.SOP2|ir3.SOPC|ir3.SOPK|ir4.SOP1|ir4.SOP2|ir4.SOPC|ir4.SOPK|irc.SOP1|irc.SOP2|irc.SOPC|irc.SOPK, ctx: _Ctx) -> UOp: bits = inst.canonical_op_bits literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr] if isinstance(inst, (ir3.SOPK, ir4.SOPK, irc.SOPK)): sdst_off = ctx.inst_field(type(inst).sdst) simm16 = ctx.inst_field(type(inst).simm16) # Sign-extend simm16 simm16_sext = simm16.cast(dtypes.int16).cast(dtypes.int32) # RDNA4 pcodes use S0.i16 for the immediate (e.g., S_MULK_I32), RDNA3 uses S0 for the register (e.g., S_CMPK_*) # CDNA pcode uses S0 for the immediate in MOVK/MULK/ADDK/CMOVK, but S0 = register for CMPK/SETREG op_name = _op_name(inst) if isinstance(inst, ir4.SOPK): s0 = simm16 elif isinstance(inst, irc.SOPK) and 'CMPK' not in op_name and 'SETREG' not in op_name: s0 = simm16_sext else: s0 = ctx.rsgpr_dyn(sdst_off) srcs = {'S0': s0, 'S1': simm16_sext, 'SIMM16': simm16_sext, 'D0': ctx.rsgpr_dyn(sdst_off)} dst_off, dst_size = sdst_off, 1 # S_GETREG_B32: extract bits from HW register. Handle as special case since HW_REGISTERS is not a normal variable. # HW register values are stored at SGPR[SGPR_COUNT-16 + hwRegId] by _init_wave. if 'GETREG' in op_name: hw_reg_id = simm16.cast(dtypes.uint32) & _c(0x3F) offset = (simm16.cast(dtypes.uint32) >> _c(6)) & _c(0x1F) size = ((simm16.cast(dtypes.uint32) >> _c(11)) & _c(0x1F)) + _c(1) hw_val = ctx.rsgpr_dyn(_c(SGPR_COUNT - 16) + hw_reg_id) mask = (_c(1) << size) - _c(1) result = (hw_val >> offset) & mask return UOp.sink(ctx.wsgpr_dyn(sdst_off, result), *ctx.inc_pc()) elif isinstance(inst, (ir3.SOP1, ir4.SOP1, irc.SOP1)): # S_BARRIER_SIGNAL: no-op in emulator, barrier sync handled by execution loop if isinstance(inst, ir4.SOP1) and inst.op in _BARRIER_SOP1_OPS: return UOp.sink(*ctx.inc_pc()) sdst_off = ctx.inst_field(type(inst).sdst) ssrc0_off = ctx.inst_field(type(inst).ssrc0) srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal)} dst_off, dst_size = sdst_off, bits['d'] // 32 elif isinstance(inst, (ir3.SOP2, ir4.SOP2, irc.SOP2)): sdst_off = ctx.inst_field(type(inst).sdst) ssrc0_off = ctx.inst_field(type(inst).ssrc0) ssrc1_off = ctx.inst_field(type(inst).ssrc1) srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal), 'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)} if literal is not None: srcs['SIMM32'] = literal dst_off, dst_size = sdst_off, bits['d'] // 32 elif isinstance(inst, (ir3.SOPC, ir4.SOPC, irc.SOPC)): ssrc0_off = ctx.inst_field(type(inst).ssrc0) ssrc1_off = ctx.inst_field(type(inst).ssrc1) srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal), 'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)} dst_off, dst_size = _c(0), 0 # SOPC writes to SCC, not sdst else: raise RuntimeError(f"unknown SOP type: {type(inst).__name__}") return ctx.compile_sop_pcode(inst.op, srcs, dst_off, dst_size) def _sdwa_select(val: UOp, sel: UOp, sext: UOp) -> UOp: """Apply SDWA byte/word selection and optional sign extension to a 32-bit value.""" # sel: 0-3=BYTE_0..3, 4=WORD_0, 5=WORD_1, 6=DWORD b0 = val & _c(0xFF) b1 = (val >> _c(8)) & _c(0xFF) b2 = (val >> _c(16)) & _c(0xFF) b3 = (val >> _c(24)) & _c(0xFF) w0 = val & _c(0xFFFF) w1 = (val >> _c(16)) & _c(0xFFFF) selected = sel.eq(_c(1)).where(b1, sel.eq(_c(2)).where(b2, sel.eq(_c(3)).where(b3, sel.eq(_c(4)).where(w0, sel.eq(_c(5)).where(w1, sel.eq(_c(6)).where(val, b0)))))) # Sign extend when sext=1 is_byte = sel < _c(4) byte_sext = (selected & _c(0x80)).ne(_c(0)).where(selected | _c(0xFFFFFF00), selected) word_sext = (selected & _c(0x8000)).ne(_c(0)).where(selected | _c(0xFFFF0000), selected) return sext.ne(_c(0)).where(is_byte.where(byte_sext, word_sext), selected) def _sdwa_write(old: UOp, val: UOp, dst_sel: UOp, dst_unused: UOp) -> UOp: """Apply SDWA destination selection: write selected byte/word, handle unused bits.""" # dst_unused: 0=PAD(zero), 1=SEXT, 2=PRESERVE # dst_sel: 0-3=BYTE, 4=WORD_0, 5=WORD_1, 6=DWORD is_byte = dst_sel < _c(4) is_word = (dst_sel >= _c(4)) & (dst_sel < _c(6)) shift = is_byte.where(dst_sel * _c(8), (dst_sel - _c(4)) * _c(16)) mask = is_byte.where(_c(0xFF), is_word.where(_c(0xFFFF), _c(0xFFFFFFFF))) placed = (val & mask) << shift preserve_mask = (mask << shift) ^ _c(0xFFFFFFFF) preserved = (old & preserve_mask) | placed # For PAD and SEXT, unused bits are zero (PAD) or sign-extended (SEXT). For DWORD, just return val. return dst_sel.eq(_c(6)).where(val, dst_unused.eq(_c(2)).where(preserved, placed)) def _dpp_quad_sel(quad_lane: UOp, sels: tuple[int, int, int, int]) -> UOp: sel = _c(sels[0], dtypes.int) for i, src in enumerate(sels[1:], start=1): sel = quad_lane.eq(_c(i, dtypes.int)).where(_c(src, dtypes.int), sel) return sel def _dpp16_ctrl(lane: UOp, dpp: int, row_mask: int, bank_mask: int, wave_size: int) -> tuple[UOp, UOp, UOp]: """Return (src_lane, row/bank enabled, in-bounds) for a DPP16 swizzle.""" lane_i = lane.cast(dtypes.int) row_base, lane_in_row = lane_i & _c(~15, dtypes.int), lane_i & _c(15, dtypes.int) row = lane_i // _c(16, dtypes.int) bank = lane_in_row >> _c(2, dtypes.int) enabled = (((_c(row_mask) >> row.cast(dtypes.uint32)) & _c(1)).ne(_c(0)) & (((_c(bank_mask) >> bank.cast(dtypes.uint32)) & _c(1)).ne(_c(0)))) op, arg = decode_dpp16(dpp) src_lane, valid = lane_i, UOp.const(dtypes.bool, True) if op == 'quad_perm': assert isinstance(arg, tuple) src_lane = (lane_i & _c(~3, dtypes.int)) + _dpp_quad_sel(lane_i & _c(3, dtypes.int), arg) else: assert isinstance(arg, int) if op == 'row_shl': src_lane, valid = row_base + lane_in_row + _c(arg, dtypes.int), lane_in_row <= _c(15 - arg, dtypes.int) elif op == 'row_shr': src_lane, valid = row_base + lane_in_row - _c(arg, dtypes.int), lane_in_row >= _c(arg, dtypes.int) elif op == 'row_ror': src_lane = row_base + ((lane_in_row - _c(arg, dtypes.int)) & _c(15, dtypes.int)) elif op == 'row_mirror': src_lane = row_base + (_c(15, dtypes.int) - lane_in_row) elif op == 'row_half_mirror': src_lane = row_base + ((lane_in_row & _c(8, dtypes.int)) | (_c(7, dtypes.int) - (lane_in_row & _c(7, dtypes.int)))) elif op == 'row_bcast': src_lane = row_base elif op == 'wave_shl': src_lane, valid = lane_i + _c(arg, dtypes.int), lane_i < _c(wave_size - arg, dtypes.int) elif op == 'wave_rol': src_lane = (lane_i + _c(arg, dtypes.int)) % _c(wave_size, dtypes.int) elif op == 'wave_shr': src_lane, valid = lane_i - _c(arg, dtypes.int), lane_i >= _c(arg, dtypes.int) elif op == 'wave_ror': src_lane = (lane_i - _c(arg, dtypes.int)) % _c(wave_size, dtypes.int) else: raise NotImplementedError(f"DPP16 control {dpp:#x} ({op}:{arg}) not implemented in emulator") return src_lane, enabled, valid def _load_dpp16_src0(ctx: _Ctx, inst, lane: UOp, fallback: UOp) -> UOp: """Load a DPP16-swizzled src0 value from vsrc0.""" src_lane, enabled, valid = _dpp16_ctrl(lane, getattr(inst, 'dpp', 0) or 0, getattr(inst, 'row_mask', 0xf) or 0xf, getattr(inst, 'bank_mask', 0xf) or 0xf, ctx.wave_size) safe_src_lane = (enabled & valid).where(src_lane, _c(0, dtypes.int)) swizzled = ctx.rvgpr_dyn(ctx.inst_field(type(inst).vsrc0), safe_src_lane) invalid = UOp.const(fallback.dtype, 0) if getattr(inst, 'bc', 0) else fallback return enabled.where(valid.where(swizzled, invalid), fallback) def _compile_sdwa(inst: irc.VOP1_SDWA | irc.VOP2_SDWA | irc.VOP2_SDWA_SDST | irc.VOPC_SDWA_SDST, ctx: _Ctx) -> UOp: """Compile CDNA SDWA (Sub-Dword Access) VOP1/VOP2/VOPC instructions.""" is_vopc = isinstance(inst, irc.VOPC_SDWA_SDST) exec_mask = ctx.rexec() # sd=1 means use sdst register, sd=0 means use VCC (for VOPC_SDWA_SDST and VOP2_SDWA_SDST) if isinstance(inst, (irc.VOP2_SDWA_SDST, irc.VOPC_SDWA_SDST)): sdst_off = _c(inst.sdst.offset) if getattr(inst, 'sd', False) else _c(VCC_LO.offset) else: sdst_off = _c(VCC_LO.offset) # Read SDWA fields (these are dynamic but shared across lanes) src0_sel = ctx.inst_field(type(inst).src0_sel) src0_sext = ctx.inst_field(type(inst).src0_sext) vsrc0_reg = ctx.inst_field(type(inst).vsrc0) pcode = get_pcode(inst.op) if isinstance(inst, (irc.VOP2_SDWA, irc.VOP2_SDWA_SDST, irc.VOPC_SDWA_SDST)): src1_sel = ctx.inst_field(type(inst).src1_sel) src1_sext = ctx.inst_field(type(inst).src1_sext) vsrc1_reg = ctx.inst_field(type(inst).vsrc1) # For VOPC: use unroll_lanes to build the bitmask from scratch (no read-modify-write on stale data) if is_vopc: def get_cmp_bit(lane) -> UOp: lc = lane.cast(dtypes.int) if isinstance(lane, UOp) else _c(lane, dtypes.int) s0_raw = ctx.rsgpr_dyn(vsrc0_reg) if inst.s0 else ctx.rvgpr_dyn(vsrc0_reg, lc) s0 = _sdwa_select(s0_raw, src0_sel, src0_sext) s1_raw = ctx.rsgpr_dyn(vsrc1_reg) if inst.s1 else ctx.rvgpr_dyn(vsrc1_reg, lc) s1 = _sdwa_select(s1_raw, src1_sel, src1_sext) srcs = {'S0': s0, 'S1': s1, 'laneId': lc} for dest, val in parse_pcode(pcode, srcs)[1]: if '[laneId]' in dest and ('D0' in dest or 'EXEC' in dest): return val.cast(dtypes.uint32) return _c(0) new_result = ctx.unroll_lanes(get_cmp_bit, exec_mask, apply_exec=False) & exec_mask stores = ctx.wmask(sdst_off, new_result) return UOp.sink(*stores, *ctx.inc_pc()) # Non-VOPC path: VOP1_SDWA, VOP2_SDWA, VOP2_SDWA_SDST — uses lane loop lane = ctx.range() vdst_reg = ctx.inst_field(type(inst).vdst) # type: ignore[union-attr] s0_raw = ctx.rsgpr_dyn(vsrc0_reg) if inst.s0 else ctx.rvgpr_dyn(vsrc0_reg, lane) s0 = _sdwa_select(s0_raw, src0_sel, src0_sext) if isinstance(inst, (irc.VOP2_SDWA, irc.VOP2_SDWA_SDST)): s1_raw = ctx.rsgpr_dyn(vsrc1_reg) if inst.s1 else ctx.rvgpr_dyn(vsrc1_reg, lane) s1 = _sdwa_select(s1_raw, src1_sel, src1_sext) srcs:dict[str, UOp | int] = {'S0': s0, 'S1': s1, 'D0': ctx.rvgpr_dyn(vdst_reg, lane)} else: srcs = {'S0': s0} # dst_sel and dst_unused has_dst_sel = hasattr(type(inst), 'dst_sel') if has_dst_sel: dst_sel = ctx.inst_field(type(inst).dst_sel) # type: ignore[union-attr] dst_unused = ctx.inst_field(type(inst).dst_unused) # type: ignore[union-attr] srcs.update({'VCC': ctx.rmask(_c(VCC_LO.offset)), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane, 'VDST': vdst_reg, 'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0), 'ROUND_NEAREST_EVEN': _c(0), '_vgpr': ctx.vgpr, '_wave_size': ctx.wave_size, 'SDWA_SRC0_SEL': _c(0), 'BYTE0': _c(0), 'BYTE1': _c(1), 'BYTE2': _c(2), 'BYTE3': _c(3), 'WORD0': _c(0), 'WORD1': _c(1)}) _, assigns = parse_pcode(pcode, srcs) stores = [] vcc_val = None for dest, val in assigns: if 'D0' in dest and '[laneId]' in dest: vcc_val = val elif dest.startswith('D0'): result = _val_to_u32(val) if has_dst_sel: old = ctx.rvgpr_dyn(vdst_reg, lane) result = _sdwa_write(old, result, dst_sel, dst_unused) stores.append(ctx.wvgpr_dyn(vdst_reg, lane, result, exec_mask)) elif dest.startswith('VCC'): old_vcc = ctx.rmask(_c(VCC_LO.offset)) stores.extend(ctx.wmask(_c(VCC_LO.offset), _set_lane_bit(old_vcc, lane, val, exec_mask))) if vcc_val is not None: # Initialize sdst to 0 before lane loop (old value may be unrelated data), then set lane bits in loop init_stores = [ctx.wsgpr_dyn(sdst_off, _c(0)), ctx.wsgpr_dyn(sdst_off + _c(1), _c(0))] old_sdst = ctx.rmask(sdst_off) stores.extend(ctx.wmask(sdst_off, _set_lane_bit(old_sdst, lane, vcc_val, exec_mask))) if stores: return UOp.sink(*init_stores, UOp.sink(*stores).end(lane), *ctx.inc_pc()) return UOp.sink(*init_stores, *ctx.inc_pc()) if stores: return UOp.sink(UOp.sink(*stores).end(lane), *ctx.inc_pc()) return UOp.sink(*ctx.inc_pc()) def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP1_DPP16 | ir3.VOP2 | ir3.VOP2_DPP16 | ir4.VOP1 | ir4.VOP1_SDST | ir4.VOP1_DPP16 | ir4.VOP2 | ir4.VOP2_DPP16 | irc.VOP1 | irc.VOP1_DPP16 | irc.VOP2 | irc.VOP2_DPP16, ctx: _Ctx) -> UOp: op_name = _op_name(inst) if op_name in ('V_READFIRSTLANE_B32_E32', 'V_PERMLANE64_B32_E32'): return ctx.compile_lane_pcode(inst.op, inst) # v_accvgpr_mov_b32: ACCVGPR[vdst] = ACCVGPR[src0] (VOP1 encoding, no pcode) if 'ACCVGPR_MOV' in op_name: lane, exec_mask = ctx.range(), ctx.rexec() vdst_reg = ctx.inst_field(type(inst).vdst) # VGPRField: raw ACCVGPR index (0-255) acc_src0_off = ctx.inst_field(type(inst).src0) # SrcField: raw 256 + ACCVGPR index val = ctx.raccvgpr_dyn(acc_src0_off - _c(256), lane) return UOp.sink(ctx.waccvgpr_dyn(vdst_reg, lane, val, exec_mask).end(lane), *ctx.inc_pc()) lane, exec_mask, bits = ctx.range(), ctx.rexec(), inst.canonical_op_bits literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr] is_f64 = 'F64' in op_name and 'B64' not in op_name is_float = any(x in op_name for x in ('F16', 'F32', 'F64')) is_dpp16 = hasattr(type(inst), 'dpp') and hasattr(type(inst), 'vsrc0') vdst_reg = ctx.inst_field(type(inst).vdst) write_hi_half = bits['d'] == 16 and (vdst_reg >= _c(128)) if isinstance(write_hi_half, UOp): vdst_reg = write_hi_half.where(vdst_reg - _c(128), vdst_reg) elif write_hi_half: vdst_reg -= 128 src0_off: UOp | None = None if isinstance(inst, (ir3.VOP1, ir4.VOP1, irc.VOP1)): # Handle VOP1 hi-half source operand (src0 >= v[128] for 16-bit ops) d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane)) if is_dpp16: s0 = _load_dpp16_src0(ctx, inst, lane, d0) else: src0_off = ctx.inst_field(type(inst).src0) s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal, is_f64) if bits['s0'] == 16 and not is_dpp16: src0_hi = src0_off >= _c(384) # Only compute hi-half when src0_off >= 384, use guarded index to prevent OOB access src0_reg = src0_hi.where(src0_off - _c(384), _c(0)) s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0) if is_dpp16 and is_float: s0 = _apply_src_mods(s0, 0, 1 if getattr(inst, 'src0_abs', 0) else 0, 1 if getattr(inst, 'src0_neg', 0) else 0, bits['s0']) srcs:dict[str, UOp | int] = {'S0': s0, 'D0': d0} else: vsrc1_reg = ctx.inst_field(type(inst).vsrc1) vsrc1_hi = bits['s0'] == 16 and (vsrc1_reg >= _c(128)) vsrc1_actual = _cond(vsrc1_hi, vsrc1_reg - _c(128), vsrc1_reg) if bits['s1'] == 64: s1 = _u64(ctx.rvgpr_dyn(vsrc1_reg, lane), ctx.rvgpr_dyn(vsrc1_reg + _c(1), lane)) d0 = _u64(ctx.rvgpr_dyn(vdst_reg, lane), ctx.rvgpr_dyn(vdst_reg + _c(1), lane)) else: s1 = _cond_hi16(vsrc1_hi, ctx.rvgpr_dyn(vsrc1_actual, lane)) d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane)) # FMAC/FMAMK hi-half dest needs hi-half accumulator # Handle VOP2 hi-half src0 operand (src0 >= v[128] for 16-bit ops) if is_dpp16: s0 = _load_dpp16_src0(ctx, inst, lane, d0) else: src0_off = ctx.inst_field(type(inst).src0) s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal, is_f64) if bits['s0'] == 16 and not is_dpp16: src0_hi = src0_off >= _c(384) # Only compute hi-half when src0_off >= 384, use guarded index to prevent OOB access src0_reg = src0_hi.where(src0_off - _c(384), _c(0)) s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0) if is_dpp16 and is_float: s0 = _apply_src_mods(s0, 0, 1 if getattr(inst, 'src0_abs', 0) else 0, 1 if getattr(inst, 'src0_neg', 0) else 0, bits['s0']) s1 = _apply_src_mods(s1, 0, 1 if getattr(inst, 'src1_abs', 0) else 0, 1 if getattr(inst, 'src1_neg', 0) else 0, bits['s1']) srcs = {'S0': s0, 'S1': s1, 'D0': d0} # FMAAK_(DTYPE)_E32 series if 'V_FMAA' in _op_name(inst) or 'V_FMAM' in _op_name(inst): assert literal is not None srcs['SIMM32'] = literal return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=write_hi_half, src0_off=src0_off) def _compile_vopc(inst: ir3.VOPC|ir3.VOPC_DPP16|ir3.VOP3|ir4.VOPC|ir4.VOPC_DPP16|ir4.VOP3|irc.VOPC|irc.VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp: exec_mask, op_name, bits = ctx.rexec(), _op_name(inst), inst.canonical_op_bits is_cmpx, is_vopc = 'CMPX' in op_name, hasattr(inst, 'vsrc1') # is_vopc: e32 vs e64 is_dpp16 = hasattr(type(inst), 'dpp') and hasattr(type(inst), 'vsrc0') # Handle both VOPC (vsrc1) and VOP3 (src1) instruction formats - read operands dynamically if is_vopc: src0_off = ctx.inst_field(type(inst).src0) vsrc1_off = ctx.inst_field(type(inst).vsrc1) # type: ignore[union-attr] # For 16-bit ops, vsrc1 >= 128 means hi-half of v[vsrc1-128] if bits['s0'] == 16: vsrc1_hi = vsrc1_off >= _c(128) src1_off = _c(256) + vsrc1_hi.where(vsrc1_off - _c(128), vsrc1_off) else: vsrc1_hi = False src1_off = _c(256) + vsrc1_off else: src0_off = ctx.inst_field(type(inst).src0) src1_off = ctx.inst_field(type(inst).src1) # type: ignore[union-attr] dst_off = ctx.inst_field(type(inst).vdst) # type: ignore[union-attr] vsrc1_hi = False literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr] is_float, is_f64, pcode = any(x in op_name for x in ('_F32', '_F64', '_F16')), '_F64' in op_name, get_pcode(inst.op) def get_cmp_bit(lane) -> UOp: lc = lane.cast(dtypes.int) if isinstance(lane, UOp) else _c(lane, dtypes.int) s0 = _load_dpp16_src0(ctx, inst, lc, _c(0)) if is_dpp16 else ctx.rsrc_dyn(src0_off, lc, bits['s0'], literal, is_f64) s1 = _cond_hi16(vsrc1_hi, ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)) if bits['s0'] == 16 \ else ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64) if bits['s0'] == 16 and opsel: s0, s1 = _apply_opsel(s0, 0, opsel), _apply_opsel(s1, 1, opsel) if is_float: if is_dpp16: s0 = _apply_src_mods(s0, 0, 1 if getattr(inst, 'src0_abs', 0) else 0, 1 if getattr(inst, 'src0_neg', 0) else 0, bits['s0']) s1 = _apply_src_mods(s1, 0, 1 if getattr(inst, 'src1_abs', 0) else 0, 1 if getattr(inst, 'src1_neg', 0) else 0, bits['s1']) s0 = _apply_src_mods(s0, 0, abs_bits, neg_bits, bits['s0']) s1 = _apply_src_mods(s1, 1, abs_bits, neg_bits, bits['s1']) for dest, val in parse_pcode(pcode, {'S0': s0, 'S1': s1, 'laneId': lc, 'D0': UOp.const(dtypes.uint64, 0)})[1]: if '[laneId]' in dest and ('D0' in dest or 'EXEC' in dest): return val.cast(dtypes.uint32) return _c(0) new_bits = ctx.unroll_lanes(get_cmp_bit, exec_mask, apply_exec=False) # Both VOPC and VOP3 clear inactive lane bits (hardware verified) new_result = new_bits & exec_mask # CMPX e32: writes EXEC only; CMPX e64: writes both EXEC and SDST; non-CMPX: writes dst only if is_cmpx: stores = ctx.wmask(_c(EXEC_LO.offset), new_result) if not is_vopc: stores.extend(ctx.wmask(dst_off, new_result)) else: stores = ctx.wmask(dst_off, new_result) if not is_vopc else ctx.wmask(_c(VCC_LO.offset), new_result) return UOp.sink(*stores, *ctx.inc_pc()) def _compile_bitop3(inst, ctx: _Ctx, exec_mask: UOp, bits: dict, op_name: str) -> UOp: """BITOP3: 3-input truth table. abs/neg/omod encode the truth table, not source modifiers.""" lane = ctx.range() vdst_reg = ctx.inst_field(type(inst).vdst) ops = inst.canonical_operands src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, bits['s0'], None, 's0' in ops and ops['s0'][0] == Fmt.FMT_NUM_F64) src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, bits['s1'], None, 's1' in ops and ops['s1'][0] == Fmt.FMT_NUM_F64) src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, bits['s2'], None, 's2' in ops and ops['s2'][0] == Fmt.FMT_NUM_F64) # Truth table: TTBL = { omod[1:0], abs[2:0], neg[2:0] } = 8-bit LUT ttbl = ((getattr(inst, 'omod', 0) or 0) << 6) | ((getattr(inst, 'abs', 0) or 0) << 3) | (getattr(inst, 'neg', 0) or 0) is_16 = 'B16' in op_name dt, mask = (dtypes.uint16, 0xFFFF) if is_16 else (dtypes.uint32, 0xFFFFFFFF) s0, s1, s2 = src0.cast(dt), src1.cast(dt), src2.cast(dt) def bnot(v): return v ^ UOp.const(dt, mask) result = UOp.const(dt, 0) for i in range(8): if not (ttbl & (1 << i)): continue result = result | ((s0 if i & 4 else bnot(s0)) & (s1 if i & 2 else bnot(s1)) & (s2 if i & 1 else bnot(s2))) return UOp.sink(ctx.wvgpr_dyn(vdst_reg, lane, result.cast(dtypes.uint32), exec_mask).end(lane), *ctx.inc_pc()) def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3 | irc.VOP3, ctx: _Ctx) -> UOp: exec_mask = ctx.rexec() bits = inst.canonical_op_bits opsel, op_name = getattr(inst, 'opsel', 0) or 0, _op_name(inst) # Lane operations if op_name in ('V_READLANE_B32', 'V_READFIRSTLANE_B32', 'V_READFIRSTLANE_B32_E64', 'V_WRITELANE_B32'): return ctx.compile_lane_pcode(inst.op, inst) # V_PERMLANE16_B32 / V_PERMLANEX16_B32: cross-lane swizzle via pcode if 'PERMLANE16' in op_name or 'PERMLANEX16' in op_name: return ctx.compile_lane_pcode(inst.op, inst) # VOP3 VOPC (v_cmp_*_e64) - delegate to unified VOPC handler if 'V_CMP' in op_name or 'V_CMPX' in op_name: return _compile_vopc(inst, ctx, opsel=opsel, abs_bits=getattr(inst, 'abs', 0) or 0, neg_bits=getattr(inst, 'neg', 0) or 0) # BITOP3: abs/neg/omod encode truth table, not source modifiers if 'BITOP3' in op_name: return _compile_bitop3(inst, ctx, exec_mask, bits, op_name) # VOP3 specific fields vdst_reg = ctx.inst_field(type(inst).vdst) literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr] abs_bits, neg_bits = getattr(inst, 'abs', 0) or 0, getattr(inst, 'neg', 0) or 0 # VOP3_SDST: v_s_* instructions goes to SGPR if 'V_S_' in op_name: src0 = _apply_src_mods(ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), _c(0, dtypes.int), bits['s0'], literal), 0, abs_bits, neg_bits, bits['s0']) srcs = {'S0': src0, 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': _c(0, dtypes.int), 'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0)} _, assigns = parse_pcode(get_pcode(inst.op), srcs) stores = [ctx.wsgpr_dyn(vdst_reg, _val_to_u32(val)) for dest, val in assigns if dest.startswith('D0')] return UOp.sink(*stores, *ctx.inc_pc()) # Regular VOP3 - read operands dynamically lane = ctx.range() ops = inst.canonical_operands src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, bits['s0'], literal, 's0' in ops and ops['s0'][0] == Fmt.FMT_NUM_F64) src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, bits['s1'], literal, 's1' in ops and ops['s1'][0] == Fmt.FMT_NUM_F64) src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, bits['s2'], literal, 's2' in ops and ops['s2'][0] == Fmt.FMT_NUM_F64) if bits['s0'] == 16: src0 = _apply_opsel(src0, 0, opsel) src1 = _apply_opsel(src1, 1, opsel) src2 = _apply_opsel(src2, 2, opsel) src0 = _apply_src_mods(src0, 0, abs_bits, neg_bits, bits['s0']) src1 = _apply_src_mods(src1, 1, abs_bits, neg_bits, bits['s1']) src2 = _apply_src_mods(src2, 2, abs_bits, neg_bits, bits['s2']) srcs = {'S0': src0, 'S1': src1, 'S2': src2, 'OPSEL': UOp.const(dtypes.uint32, opsel)} if 'CNDMASK' in op_name and src2 is not None: srcs['VCC'] = src2 # FMAC instructions need D0 (accumulator) from destination register if 'FMAC' in op_name: srcs['D0'] = ctx.rvgpr_dyn(vdst_reg, lane) opsel_dst_hi = bool(opsel & 0b1000) and bits['d'] == 16 return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=opsel_dst_hi, clmp=getattr(inst, 'clmp', 0)) def _compile_vinterp(inst: ir3.VINTERP | ir4.VINTERP, ctx: _Ctx) -> UOp: lane, exec_mask = ctx.range(), ctx.rexec() inst_type = type(inst) vdst_reg = ctx.inst_field(inst_type.vdst) src0_off, src1_off, src2_off = ctx.inst_field(inst_type.src0), ctx.inst_field(inst_type.src1), ctx.inst_field(inst_type.src2) src0_reg = (src0_off >= _c(256)).where(src0_off - _c(256), src0_off) src2_reg = (src2_off >= _c(256)).where(src2_off - _c(256), src2_off) srcs = { 'SRC0': src0_reg, 'SRC2': src2_reg, 'S0': ctx.rsrc_dyn(src0_off, lane), 'S1': ctx.rsrc_dyn(src1_off, lane), 'S2': ctx.rsrc_dyn(src2_off, lane), } return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask) def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD | irc.VOP3SD, ctx: _Ctx) -> UOp: exec_mask = ctx.rexec() bits, pcode, ops = inst.canonical_op_bits, get_pcode(inst.op), inst.canonical_operands # Read operands dynamically from instruction encoding vdst_reg, sdst_off = ctx.inst_field(type(inst).vdst), ctx.inst_field(type(inst).sdst) src0_off, src1_off, src2_off = ctx.inst_field(type(inst).src0), ctx.inst_field(type(inst).src1), ctx.inst_field(type(inst).src2) literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr] has_carry_in = 's2' in ops and ops['s2'][2] == OpType.OPR_SREG vcc_in_off = src2_off if has_carry_in else sdst_off def load_srcs(lane_uop): ret = {'VCC': ctx.rmask(vcc_in_off), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane_uop} ret['S0'] = ctx.rsrc_dyn(src0_off, lane_uop, bits['s0'], literal, ops['s0'][0] == Fmt.FMT_NUM_F64) ret['S1'] = ctx.rsrc_dyn(src1_off, lane_uop, bits['s1'], literal, ops['s1'][0] == Fmt.FMT_NUM_F64) if 's2' in ops: ret['S2'] = ctx.rsrc_dyn(src2_off, lane_uop, bits['s2'], literal, ops['s2'][0] == Fmt.FMT_NUM_F64) return ret lane = ctx.range() srcs = load_srcs(lane) _, assigns = parse_pcode(pcode, srcs) has_per_lane_vcc = any('[laneId]' in dest for dest, _ in assigns if dest.startswith('VCC') or dest.startswith('D0.u64')) clmp = getattr(inst, 'clmp', 0) if has_per_lane_vcc: # VCC computation: RANGE+REDUCE gets axis ID first (lower ID = runs first) # This ensures VCC reads source values BEFORE VGPR stores modify them def get_vcc_bit(lane_uop) -> UOp: vcc_bit = _c(0) for dest, val in parse_pcode(pcode, load_srcs(lane_uop))[1]: if dest.startswith('VCC') or (dest.startswith('D0.u64') and '[laneId]' in dest): vcc_bit = val.cast(dtypes.uint32) return vcc_bit final_vcc = ctx.unroll_lanes(get_vcc_bit, exec_mask) # VGPR stores: RANGE gets axis ID second (higher ID = runs after VCC loop) lane3 = ctx.range() d0_val, vcc_per_lane = None, None for dest, val in parse_pcode(pcode, load_srcs(lane3))[1]: if dest.startswith('D0') and '[laneId]' not in dest: d0_val = val if dest.startswith('VCC') or (dest.startswith('D0.u64') and '[laneId]' in dest): vcc_per_lane = val vgpr_stores = [] if d0_val is not None: # Apply clamp using carry/borrow bit: ADD overflow->0xFFFFFFFF, SUB underflow->0 if clmp and vcc_per_lane is not None: is_sub = 'SUB' in inst.op.name sat_val = _c(0) if is_sub else _c(0xFFFFFFFF) d0_val = vcc_per_lane.cast(dtypes.bool).where(sat_val, d0_val.cast(dtypes.uint32)) if d0_val.dtype in (dtypes.uint64, dtypes.int64, dtypes.float64): lo, hi = _split64(d0_val) vgpr_stores.extend([ctx.wvgpr_dyn(vdst_reg, lane3, lo, exec_mask), ctx.wvgpr_dyn(vdst_reg + _c(1), lane3, hi, exec_mask)]) else: d0_u32 = d0_val.bitcast(dtypes.uint32) if d0_val.dtype in (dtypes.float32, dtypes.half) else d0_val.cast(dtypes.uint32) vgpr_stores.append(ctx.wvgpr_dyn(vdst_reg, lane3, d0_u32, exec_mask)) # Write carry output (wmask handles lo/hi split for wave64) vcc_writes = ctx.wmask(sdst_off, final_vcc) return UOp.sink(*vcc_writes, UOp.group(*vgpr_stores).end(lane3), *ctx.inc_pc()) else: return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, sdst_reg=inst.sdst.offset) def _compile_mfma(inst: irc.VOP3P, ctx: _Ctx) -> UOp: """CDNA MFMA matrix multiply-accumulate emulation. Uses local temp arrays to cache inputs, avoiding aliasing issues when vdst overlaps src0/src1. Phase 1: Read all input f32 values from VGPRs into temp arrays (range loop over 64 lanes). Phase 2: Compute 256 output values using temp arrays and write to VGPRs (range loop over 64 lanes) Register layout (wave64): - 16x16: 4 groups of 16 lanes. Each lane in group holds k_per_grp elements. 4 output ACCVGPRs per lane. - 32x32: 2 groups of 32 lanes. lanes%16 = M/N index within block, lanes//16 selects block. 16 output ACCVGPRs per lane. - 4x4: 16 groups of 4 lanes. 4 output ACCVGPRs per lane. """ import re as _re op_name = _op_name(inst) exec_mask = ctx.rexec() vdst_reg = ctx.inst_field(type(inst).vdst) src0_off = ctx.inst_field(type(inst).src0) src1_off = ctx.inst_field(type(inst).src1) src0_r = src0_off - _c(256) # VGPR-relative index (only valid when src is VGPR) src1_r = src1_off - _c(256) src2_off = ctx.inst_field(type(inst).src2) # Check if sources are VGPRs (offset >= 256) vs inline constants/SGPRs src0_is_vgpr = src0_off >= _c(256) src1_is_vgpr = src1_off >= _c(256) m = _re.search(r'(\d+)X(\d+)X(\d+)', op_name) if m is None: raise ValueError(f"could not parse MFMA dimensions from {op_name}") M, N, K = int(m.group(1)), int(m.group(2)), int(m.group(3)) is_bf16 = 'BF16' in op_name is_fp8 = 'FP8' in op_name or 'F8' in op_name is_i8 = 'I8' in op_name # Source type is the LAST type in the name: V_MFMA_F32_16X16X32_**F16** -> source is F16, not F32 src_type = op_name.rsplit('_', 1)[-1] # e.g. "F16", "BF16", "F32", "I8" is_f32_src = src_type == 'F32' is_int_out = 'I32' in op_name.split('_')[2] # V_MFMA_I32_... # Determine elements per VGPR and conversion function if is_i8: vpg = 4 elif is_f32_src: vpg = 1 elif is_fp8: vpg = 4 else: vpg = 2 # For 16x16: grp_size=16, n_grps=4, out_per_lane=4 # For 32x32: grp_size=32, n_grps=2, out_per_lane=16 # For 4x4: grp_size=4, n_grps=16, out_per_lane=4 if M == 16 and N == 16: grp_size, n_grps, out_per_lane = 16, 4, 4 elif M == 32 and N == 32: grp_size, n_grps, out_per_lane = 32, 2, 16 elif M == 4 and N == 4: grp_size, n_grps, out_per_lane = 4, 16, 4 else: raise RuntimeError(f"unsupported MFMA shape {M}x{N}x{K}") # For 4x4: each group independently computes a 4x4 block. K is NOT split across groups. # For 16x16/32x32: K IS split across groups (each group has K/n_grps elements). k_per_grp = K if M == 4 else K // n_grps # Temp array size: for 4x4, store all 16 independent blocks; for others, store shared MxK/NxK n_a_elems = n_grps * M * K if M == 4 else M * K n_b_elems = n_grps * N * K if M == 4 else N * K # src2 can be VGPR (>=256) or inline constant/SGPR (<256) src2_is_vgpr = src2_off >= _c(256) src2_r = src2_off - _c(256) if is_int_out: acc_scalar = ctx.rsgpr_dyn(src2_off, src2_is_vgpr.ne(True)).cast(dtypes.int32) else: acc_scalar = ctx.rsgpr_dyn(src2_off, src2_is_vgpr.ne(True)).bitcast(dtypes.float32) # Phase 1: Read all A and B values from VGPRs into temp arrays. # Layout: tmp[0..n_a_elems-1] = A[m][k], tmp[n_a_elems..n_a_elems+n_b_elems-1] = B[n][k] # Within each group of lanes, lane%grp_sub gives M/N index, lane//grp_sub gives sub-block grp_sub = min(M, 16) # lanes within group mapped to M/N dimension b_off = UOp.const(dtypes.int, n_a_elems) acc_dt = dtypes.int32 if is_int_out else dtypes.float32 # Use uint32 temp array to prevent optimizer from eliminating f16→f32 bitcast chains. # The optimizer folds bitcast(uint32→float32) stores to float32 arrays, losing the conversion. tmp = UOp(Ops.DEFINE_LOCAL, dtypes.uint32.ptr(n_a_elems + n_b_elems, addrspace=AddrSpace.LOCAL), arg=(n_a_elems + n_b_elems,)) def cvt_elem(raw: UOp, sub_idx: int) -> UOp: if is_i8: # Extract i8, sign-extend to i32 byte_val = (raw >> UOp.const(dtypes.uint32, sub_idx * 8)) & UOp.const(dtypes.uint32, 0xFF) return (byte_val.cast(dtypes.int32) ^ UOp.const(dtypes.int32, 0x80)) - UOp.const(dtypes.int32, 0x80) elif is_f32_src: return raw # already uint32 (f32 bit pattern) elif is_fp8: return ((raw >> UOp.const(dtypes.uint32, sub_idx * 8)) & UOp.const(dtypes.uint32, 0xFF)).cast(dtypes.uint32) elif is_bf16: # bf16→f32 bits: just shift left by 16 (bf16 is upper 16 bits of f32) return ((raw >> UOp.const(dtypes.uint32, sub_idx * 16)) & UOp.const(dtypes.uint32, 0xFFFF)) << UOp.const(dtypes.uint32, 16) else: # f16→f32 conversion using float arithmetic to avoid UOp optimizer eliminating the conversion. # The optimizer folds bitcast(uint32→float32) chains, so we compute the float value directly. h = (raw >> UOp.const(dtypes.uint32, sub_idx * 16)) & UOp.const(dtypes.uint32, 0xFFFF) sign = (h >> UOp.const(dtypes.uint32, 15)) & UOp.const(dtypes.uint32, 1) exp = (h >> UOp.const(dtypes.uint32, 10)) & UOp.const(dtypes.uint32, 0x1F) mant = h & UOp.const(dtypes.uint32, 0x3FF) # Use bf16 path: shift left by 16 to create bf16 bits, then shift mantissa and adjust exponent in float domain # bf16 bits = (sign << 15) | (exp_bf16 << 7) | mant_bf16 -- but f16 and bf16 have different formats # Instead: construct f32 bits properly, use a DEFINE_LOCAL uint32 array to force materialization f32_bits = (sign << UOp.const(dtypes.uint32, 31)) | \ ((exp + UOp.const(dtypes.uint32, 112)) << UOp.const(dtypes.uint32, 23)) | \ (mant << UOp.const(dtypes.uint32, 13)) is_zero = exp.eq(UOp.const(dtypes.uint32, 0)) # Return uint32 (f32 bit pattern) — stored directly to uint32 temp array, bitcast to float on read return is_zero.where(UOp.const(dtypes.uint32, 0), f32_bits) read_lane = ctx.range() # For 32x32: lane%16 = M/N index within 16-wide block, lane//16 = which of 4 quarter-waves # Groups: lanes 0-31 = group 0, lanes 32-63 = group 1 # Within group: (lane%32)%16 = M/N[0-15], (lane%32)//16 selects M/N[0-15] or [16-31] lane_in_grp = read_lane % UOp.const(dtypes.int, grp_size) grp_idx = read_lane // UOp.const(dtypes.int, grp_size) if M == 32: # 32x32: lane_in_grp%16 = sub-row/col (0-15), lane_in_grp//16 = block (0=rows 0-15, 1=rows 16-31) sub_mn = lane_in_grp % UOp.const(dtypes.int, 16) block_mn = lane_in_grp // UOp.const(dtypes.int, 16) mn_idx = block_mn * UOp.const(dtypes.int, 16) + sub_mn # actual M/N index (0-31) else: mn_idx = lane_in_grp # for 16x16 and 4x4 read_stores = [] for kl in range(k_per_grp): reg_idx, sub_idx = kl // vpg, kl % vpg # Read A/B sources. Use rsrc_dyn for inline constants/SGPRs (src_off < 256), rvgpr_dyn for VGPRs (src_off >= 256). a_raw = src0_is_vgpr.where(ctx.rvgpr_dyn(src0_r + _c(reg_idx), read_lane), ctx.rsrc_dyn(src0_off, _c(0, dtypes.int), 32)) a_val = cvt_elem(a_raw, sub_idx) if M == 4: a_idx = grp_idx * UOp.const(dtypes.int, M * K) + mn_idx * UOp.const(dtypes.int, K) + UOp.const(dtypes.int, kl) else: a_idx = mn_idx * UOp.const(dtypes.int, K) + grp_idx * UOp.const(dtypes.int, k_per_grp) + UOp.const(dtypes.int, kl) read_stores.append(tmp.index(a_idx).store(a_val)) b_raw = src1_is_vgpr.where(ctx.rvgpr_dyn(src1_r + _c(reg_idx), read_lane), ctx.rsrc_dyn(src1_off, _c(0, dtypes.int), 32)) b_val = cvt_elem(b_raw, sub_idx) if M == 4: b_idx = b_off + grp_idx * UOp.const(dtypes.int, N * K) + mn_idx * UOp.const(dtypes.int, K) + UOp.const(dtypes.int, kl) else: b_idx = b_off + mn_idx * UOp.const(dtypes.int, K) + grp_idx * UOp.const(dtypes.int, k_per_grp) + UOp.const(dtypes.int, kl) read_stores.append(tmp.index(b_idx).store(b_val)) read_phase = UOp.group(*read_stores).end(read_lane) # Phase 2: Compute dot products and write outputs. # For 16x16: each lane computes 4 outputs. n_idx = lane%16, grp selects which 4 rows. # For 32x32: each lane computes 16 outputs. Layout: lane%16 selects n within block, lane//16 selects column block. # Output mapping: out_reg r at lane l -> D[m][n] where # n = (l%32)%16 + ((l%32)//16)*16, m = (l//32)*4 + r (for r in 0..3), with 4 groups of 4 rows -> 16 outputs total # Actually: 16 ACCVGPRs per lane, organized as 4 groups (l//32 gives half, each half has 2 sub-groups) of 4 rows tmp2 = tmp.after(read_phase) compute_lane = ctx.range() compute_stores = [] if M == 32 and N == 32: # 32x32: each lane has 16 output ACCVGPRs # Lane mapping: n = (lane%32)%16 + ((lane%32)//16)*16, gives column 0-31 # Row groups: 4 groups of 4, covering rows 0-31. Group g (0-3): rows g*4 .. g*4+3 # group assignment: lane//16 gives quarter (0-3), each quarter maps to 4 rows c_lane_in_32 = compute_lane % UOp.const(dtypes.int, 32) c_sub = c_lane_in_32 % UOp.const(dtypes.int, 16) c_block = c_lane_in_32 // UOp.const(dtypes.int, 16) n_idx = c_block * UOp.const(dtypes.int, 16) + c_sub c_half = compute_lane // UOp.const(dtypes.int, 32) # 0 or 1 for out_reg in range(16): # Each half covers 8 rows. out_reg 0-3: rows 0-3 (half0) or 16-19 (half1) # out_reg 4-7: rows 4-7 (half0) or 20-23 (half1), etc. # Actually: for 32x32, the output layout per lane is: # acc[0:3] -> rows 0-3 (half 0) or rows 0-3 (half 1)? # Let me use the ISA doc: for 32x32, D has 16 dwords per lane. The mapping is: # acc[r] at lane l -> D[m][n] where n = (l%32)%16 + ((l%32)//16)*16 # m = (l//32)*16 + (r//4)*4 + (r%4) ... giving rows in blocks of 4 # So: m_base = half * 16 + (out_reg // 4) * 4 + (out_reg % 4) m_base = c_half * UOp.const(dtypes.int, 16) + UOp.const(dtypes.int, (out_reg // 4) * 4 + (out_reg % 4)) acc_v = ctx.raccvgpr_dyn(src2_r + _c(out_reg), compute_lane, src2_is_vgpr) if is_int_out: acc_v = acc_v.cast(dtypes.int32) else: acc_v = acc_v.bitcast(dtypes.float32) acc = src2_is_vgpr.where(acc_v, acc_scalar) for k in range(K): a_val = tmp2.index(m_base * UOp.const(dtypes.int, K) + UOp.const(dtypes.int, k)).bitcast(acc_dt) b_val = tmp2.index(b_off + n_idx * UOp.const(dtypes.int, K) + UOp.const(dtypes.int, k)).bitcast(acc_dt) acc = acc + a_val * b_val if is_int_out: compute_stores.append(ctx.waccvgpr_dyn(vdst_reg + _c(out_reg), compute_lane, acc.cast(dtypes.uint32), exec_mask)) else: compute_stores.append(ctx.waccvgpr_dyn(vdst_reg + _c(out_reg), compute_lane, acc.bitcast(dtypes.uint32), exec_mask)) else: # 16x16 and 4x4: each lane computes out_per_lane outputs n_idx = compute_lane % UOp.const(dtypes.int, grp_sub) c_grp = compute_lane // UOp.const(dtypes.int, grp_sub) for out_reg in range(out_per_lane): acc_v = ctx.raccvgpr_dyn(src2_r + _c(out_reg), compute_lane, src2_is_vgpr) if is_int_out: acc_v = acc_v.cast(dtypes.int32) else: acc_v = acc_v.bitcast(dtypes.float32) acc = src2_is_vgpr.where(acc_v, acc_scalar) if M == 4: # 4x4: each group is independent. A/B indexed per-group. m_base = c_grp * UOp.const(dtypes.int, M * K) + UOp.const(dtypes.int, out_reg * K) for k in range(K): a_val = tmp2.index(m_base + UOp.const(dtypes.int, k)).bitcast(acc_dt) b_val = tmp2.index(b_off + c_grp * UOp.const(dtypes.int, N*K) + n_idx * UOp.const(dtypes.int, K)+UOp.const(dtypes.int, k)).bitcast(acc_dt) acc = acc + a_val * b_val else: # 16x16: K is split across groups. Shared MxK/NxK arrays. m_base = c_grp * UOp.const(dtypes.int, out_per_lane) + UOp.const(dtypes.int, out_reg) for k in range(K): a_val = tmp2.index(m_base * UOp.const(dtypes.int, K) + UOp.const(dtypes.int, k)).bitcast(acc_dt) b_val = tmp2.index(b_off + n_idx * UOp.const(dtypes.int, K) + UOp.const(dtypes.int, k)).bitcast(acc_dt) acc = acc + a_val * b_val if is_int_out: compute_stores.append(ctx.waccvgpr_dyn(vdst_reg + _c(out_reg), compute_lane, acc.cast(dtypes.uint32), exec_mask)) else: compute_stores.append(ctx.waccvgpr_dyn(vdst_reg + _c(out_reg), compute_lane, acc.bitcast(dtypes.uint32), exec_mask)) compute_phase = UOp.group(*compute_stores).end(compute_lane) return UOp.sink(read_phase, compute_phase, *ctx.inc_pc()) def _compile_wmma(inst: ir3.VOP3P | ir4.VOP3P | irc.VOP3P, ctx: _Ctx) -> UOp: op_name = _op_name(inst) exec_mask = ctx.rexec() vdst_reg = ctx.inst_field(type(inst).vdst) src0_r = ctx.inst_field(type(inst).src0) - _c(256) src1_r = ctx.inst_field(type(inst).src1) - _c(256) src2_r = ctx.inst_field(type(inst).src2) - _c(256) is_f16_output = 'F16_16X16X16_F16' in op_name or 'BF16_16X16X16_BF16' in op_name # F16/BF16 output vs F32 output is_bf16 = 'BF16' in op_name cvt = _FUNCS['bf16_to_f32'] if is_bf16 else _FUNCS['f16_to_f32'] is_rdna4 = isinstance(inst, ir4.VOP3P) # read 16x16 F16/BF16 matrix from VGPRs → flat f32 array[row*16+k] def read_f16_val(src, lane, vgpr, half): v = ctx.rvgpr_dyn(src + _c(vgpr), UOp.const(dtypes.int, lane)) return cvt((v >> UOp.const(dtypes.uint32, 16)) if half else (v & UOp.const(dtypes.uint32, 0xFFFF))) # RDNA3: 16 lanes × 8 VGPRs × 2 halves, k maps linearly # RDNA4: 32 lanes × 4 VGPRs × 2 halves, k bits are scrambled (k[2] goes to lane bit 4) def read_f16_mat(src): # (row, k) → (lane, vgpr, half) def ab_map(i, k): elem, lane = ((k & 3) | ((k >> 1) & 4), i + ((k >> 2) & 1) * 16) if is_rdna4 else (k, i) return lane, elem // 2, elem % 2 return [read_f16_val(src, *ab_map(row, k)) for row in range(16) for k in range(16)] mat_a, mat_b = read_f16_mat(src0_r), read_f16_mat(src1_r) # (row, col) -> (lane, vgpr) def d_map(m, n): lane_bit, vgpr = (m >> 3, m & 7) if is_rdna4 else (m & 1, m >> 1) return n + lane_bit * 16, vgpr if is_f16_output: # read accumulator C with f16 layout: for RDNA4, pairs of f32 vgprs pack into one f16 vgpr # for RDNA3, same layout as f32 but only lo 16 bits used mat_c = [read_f16_val(src2_r, *((lane, vgpr // 2, vgpr % 2) if is_rdna4 else (lane, vgpr, 0))) for m in range(16) for n in range(16) for lane, vgpr in [d_map(m, n)]] mat_d = [sum(mat_a[r*16+k] * mat_b[c*16+k] for k in range(16)) + mat_c[r*16+c] for r in range(16) for c in range(16)] def f32_to_f16_bits(v: UOp) -> UOp: return v.cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32) def f32_to_bf16_bits(v: UOp) -> UOp: return (v.bitcast(dtypes.uint32) >> UOp.const(dtypes.uint32, 16)) & UOp.const(dtypes.uint32, 0xFFFF) out_cvt = f32_to_bf16_bits if is_bf16 else f32_to_f16_bits if is_rdna4: # pack 2 f16 per VGPR: adjacent m values share (lane, vgpr) since vgpr=m&7, half=m&1 stores = [ctx.wvgpr_dyn(vdst_reg + _c(d_map(m, n)[1] // 2), UOp.const(dtypes.int, d_map(m, n)[0]), out_cvt(mat_d[m*16+n]) | (out_cvt(mat_d[(m+1)*16+n]) << UOp.const(dtypes.uint32, 16)), exec_mask) for n in range(16) for m in range(0, 16, 2)] else: # (rdna3) 1 f16 per VGPR (lo half only) stores = [ctx.wvgpr_dyn(vdst_reg + _c(d_map(m, n)[1]), UOp.const(dtypes.int, d_map(m, n)[0]), out_cvt(mat_d[m*16+n]), exec_mask) for m in range(16) for n in range(16)] else: # f32 mat_c = [ctx.rvgpr_dyn(src2_r + _c(d_map(m, n)[1]), UOp.const(dtypes.int, d_map(m, n)[0])).bitcast(dtypes.float32) for m in range(16) for n in range(16)] mat_d = [sum(mat_a[r*16+k] * mat_b[c*16+k] for k in range(16)) + mat_c[r*16+c] for r in range(16) for c in range(16)] stores = [ctx.wvgpr_dyn(vdst_reg + _c(d_map(m, n)[1]), UOp.const(dtypes.int, d_map(m, n)[0]), mat_d[m*16+n].bitcast(dtypes.uint32), exec_mask) for m in range(16) for n in range(16)] return UOp.sink(*stores, *ctx.inc_pc()) def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P | irc.VOP3P, ctx: _Ctx) -> UOp: op_name = _op_name(inst) if 'WMMA' in op_name and ('16X16X16_F16' in op_name or '16X16X16_BF16' in op_name): return _compile_wmma(inst, ctx) if 'MFMA' in op_name and any(f'{s}X{s}X' in op_name for s in ('4', '16', '32')) and isinstance(inst, irc.VOP3P): return _compile_mfma(inst, ctx) # ACCVGPR_WRITE/READ/MOV: copies between VGPR and ACCVGPR register files # Detect by checking operand types for ACCVGPR involvement ops = inst.operands src0_is_acc = ops.get('src0', (None, None, None))[2] in (OpType.OPR_SRC_ACCVGPR, OpType.OPR_ACCVGPR) vdst_is_acc = ops.get('vdst', (None, None, None))[2] in (OpType.OPR_ACCVGPR,) if src0_is_acc or vdst_is_acc: lane = ctx.range() exec_mask = ctx.rexec() vdst_reg = ctx.inst_field(type(inst).vdst) src0_off = ctx.inst_field(type(inst).src0) if src0_is_acc and not vdst_is_acc: # v_accvgpr_read: VGPR[vdst] = ACCVGPR[src0] val = ctx.raccvgpr_dyn(src0_off - _c(256), lane) return UOp.sink(ctx.wvgpr_dyn(vdst_reg, lane, val, exec_mask).end(lane), *ctx.inc_pc()) elif vdst_is_acc and not src0_is_acc: # v_accvgpr_write: ACCVGPR[vdst] = src0 (src0 can be VGPR or SGPR/const) src0 = ctx.rsrc_dyn(src0_off, lane, 32) return UOp.sink(ctx.waccvgpr_dyn(vdst_reg, lane, src0, exec_mask).end(lane), *ctx.inc_pc()) else: # v_accvgpr_mov: ACCVGPR[vdst] = ACCVGPR[src0] val = ctx.raccvgpr_dyn(src0_off - _c(256), lane) return UOp.sink(ctx.waccvgpr_dyn(vdst_reg, lane, val, exec_mask).end(lane), *ctx.inc_pc()) lane = ctx.range() exec_mask = ctx.rexec() vdst_reg = ctx.inst_field(type(inst).vdst) is_pk_f32 = 'PK' in op_name and 'F32' in op_name and 'MOV' not in op_name # CDNA packed F32 ops is_pk_mov_b32 = 'PK_MOV_B32' in op_name # CDNA packed MOV needs special handling do_cast = any(x in op_name for x in ('F16', 'F32', 'BF16')) and 'IU' not in op_name and not is_pk_f32 literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr] src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, 16, literal=literal, do_cast=do_cast) src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, 16, literal=literal, do_cast=do_cast) src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, 16, literal=literal, do_cast=do_cast) opsel, opsel_hi = getattr(inst, 'opsel', 0) or 0, getattr(inst, 'opsel_hi', 3) if getattr(inst, 'opsel_hi', 3) is not None else 3 opsel_hi2 = getattr(inst, 'opsel_hi2', 1) if getattr(inst, 'opsel_hi2', 1) is not None else 1 neg, neg_hi = getattr(inst, 'neg', 0) or 0, getattr(inst, 'neg_hi', 0) or 0 if is_pk_mov_b32: # v_pk_mov_b32: D[lo] = src0[opsel_bit0 ? hi : lo], D[hi] = src1[opsel_bit1 ? hi : lo] src_offs = [ctx.inst_field(type(inst).src0), ctx.inst_field(type(inst).src1)] def _pk_mov_sel(src_lo: UOp, src_off: UOp, sel_bit: int) -> UOp: is_vgpr = src_off >= _c(256) vgpr_lo = ctx.rvgpr_dyn(src_off - _c(256), lane) if lane is not None else _c(0) vgpr_hi = ctx.rvgpr_dyn(src_off - _c(256) + _c(1), lane) if lane is not None else _c(0) is_sgpr_pair = src_off < _c(128) sgpr_hi = ctx.rsgpr_dyn(src_off + _c(1), is_sgpr_pair) scalar_sel = is_sgpr_pair.where(sgpr_hi, src_lo) if sel_bit else src_lo return is_vgpr.where(vgpr_hi if sel_bit else vgpr_lo, scalar_sel) lo_val = _pk_mov_sel(src0, src_offs[0], opsel & 1) hi_val = _pk_mov_sel(src1, src_offs[1], opsel & 2) result = _u64(lo_val, hi_val) lo_out, hi_out = _split64(result) stores = [ctx.wvgpr_dyn(vdst_reg, lane, lo_out, exec_mask), ctx.wvgpr_dyn(vdst_reg + _c(1), lane, hi_out, exec_mask)] return UOp.sink(UOp.group(*stores).end(lane), *ctx.inc_pc()) srcs: dict[str, UOp | int] = {} if is_pk_f32: # CDNA packed F32: read 32-bit sources, build 64-bit packed values using opsel. # For VGPRs: opsel selects between v[reg] (0) and v[reg+1] (1) for each half. # For SGPR pairs (off < 128): s[N] = lo float32, s[N+1] = hi float32. # For inline constants (128 <= off < 256): broadcast same value to both halves. src_offs = [ctx.inst_field(type(inst).src0), ctx.inst_field(type(inst).src1), ctx.inst_field(type(inst).src2)] def build_pk_f32(src_lo: UOp, src_off: UOp, opsel_lo: int, opsel_hi_bit: int, neg_lo: int, neg_hi_bit: int) -> UOp: is_vgpr = src_off >= _c(256) vgpr_lo = ctx.rvgpr_dyn(src_off - _c(256), lane) if lane is not None else _c(0) vgpr_hi = ctx.rvgpr_dyn(src_off - _c(256) + _c(1), lane) if lane is not None else _c(0) # For SGPR pairs, opsel selects between s[N] (0) and s[N+1] (1); inline constants always broadcast. is_sgpr_pair = src_off < _c(128) sgpr_hi = ctx.rsgpr_dyn(src_off + _c(1), is_sgpr_pair) scalar_lo_sel = src_lo if not opsel_lo else is_sgpr_pair.where(sgpr_hi, src_lo) scalar_hi_sel = src_lo if not opsel_hi_bit else is_sgpr_pair.where(sgpr_hi, src_lo) lo = is_vgpr.where(vgpr_hi if opsel_lo else vgpr_lo, scalar_lo_sel) hi = is_vgpr.where(vgpr_hi if opsel_hi_bit else vgpr_lo, scalar_hi_sel) if neg_lo: lo = lo ^ UOp.const(dtypes.uint32, 0x80000000) if neg_hi_bit: hi = hi ^ UOp.const(dtypes.uint32, 0x80000000) return _u64(lo, hi) srcs = {'S0': build_pk_f32(src0, src_offs[0], opsel & 1, opsel_hi & 1, neg & 1, neg_hi & 1), 'S1': build_pk_f32(src1, src_offs[1], opsel & 2, opsel_hi & 2, neg & 2, neg_hi & 2), 'S2': build_pk_f32(src2, src_offs[2], opsel & 4, 1 if opsel_hi2 else 0, neg & 4, neg_hi & 4)} elif 'FMA_MIX' in op_name or 'MAD_MIX' in op_name: combined_opsel_hi = (opsel_hi & 0x3) | ((opsel_hi2 & 0x1) << 2) # For FMA_MIX: neg_hi is ABS (not neg!), neg is actual negation def apply_abs(v, bit, opsel_hi_bit, opsel_bit): if not (neg_hi & bit): return v # Apply abs based on whether source is f32 or f16 if not (combined_opsel_hi & opsel_hi_bit): return v & UOp.const(dtypes.uint32, 0x7FFFFFFF) # f32 abs if opsel & opsel_bit: return v & UOp.const(dtypes.uint32, 0x7FFF0000) # f16 hi abs (preserve lo) return v & UOp.const(dtypes.uint32, 0xFFFF7FFF) # f16 lo abs (preserve hi) def apply_neg_mix(v, bit, opsel_hi_bit, opsel_bit): if not (neg & bit): return v if not (combined_opsel_hi & opsel_hi_bit): return v ^ UOp.const(dtypes.uint32, 0x80000000) # f32 neg if opsel & opsel_bit: return v ^ UOp.const(dtypes.uint32, 0x80000000) # f16 hi neg return v ^ UOp.const(dtypes.uint32, 0x00008000) # f16 lo neg s0_mod = apply_neg_mix(apply_abs(src0, 1, 1, 1), 1, 1, 1) s1_mod = apply_neg_mix(apply_abs(src1, 2, 2, 2), 2, 2, 2) s2_mod = apply_neg_mix(apply_abs(src2, 4, 4, 4), 4, 4, 4) srcs = {'S@0': s0_mod, 'S@1': s1_mod, 'S@2': s2_mod, 'OPSEL_HI': UOp.const(dtypes.uint32, combined_opsel_hi), 'OPSEL': UOp.const(dtypes.uint32, opsel)} else: def get_half_bits(val: UOp, use_hi: bool, apply_neg: bool = False) -> UOp: bits = ((val >> UOp.const(dtypes.uint32, 16)) if use_hi else val) & UOp.const(dtypes.uint32, 0xFFFF) if apply_neg: bits = bits.cast(dtypes.uint16).bitcast(dtypes.half).neg().bitcast(dtypes.uint16).cast(dtypes.uint32) return bits def build_remapped_src(src: UOp, opsel_lo_bit: int, opsel_hi_bit: int, neg_lo_bit: int, neg_hi_bit: int) -> UOp: lo = get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit)) hi = get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit)) return lo | (hi << UOp.const(dtypes.uint32, 16)) # DOT IU instructions use NEG bits for signed/unsigned selection, not fp16 negation is_dot_iu = 'DOT' in op_name and 'IU' in op_name n0, n1, n2, nh0, nh1, nh2 = (0, 0, 0, 0, 0, 0) if is_dot_iu else (neg & 1, neg & 2, neg & 4, neg_hi & 1, neg_hi & 2, neg_hi & 4) srcs = {'S0': build_remapped_src(src0, opsel & 1, opsel_hi & 1, n0, nh0), 'S1': build_remapped_src(src1, opsel & 2, opsel_hi & 2, n1, nh1), 'S2': build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, n2, nh2)} if is_dot_iu: srcs['NEG'] = UOp.const(dtypes.uint32, neg) return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask) def _compile_vopd(inst: ir3.VOPD | ir4.VOPD, ctx: _Ctx) -> UOp: exec_mask = ctx.rexec() # Read operands dynamically - use type(inst) to get correct field descriptors inst_type = type(inst) vdstx_reg = ctx.inst_field(inst_type.vdstx) # vdsty has complex encoding: actual = (raw << 1) | ((vdstx & 1) ^ 1) vdsty_raw = ctx.inst_field(inst_type.vdsty) vdsty_reg = (vdsty_raw << _c(1)) | ((vdstx_reg & _c(1)) ^ _c(1)) srcx0_off = ctx.inst_field(inst_type.srcx0) srcy0_off = ctx.inst_field(inst_type.srcy0) vsrcx1_reg = ctx.inst_field(inst_type.vsrcx1) vsrcy1_reg = ctx.inst_field(inst_type.vsrcy1) literal = ctx.inst_field(inst_type.literal) if hasattr(inst_type, 'literal') else None lane = ctx.range() srcy0, srcy1 = ctx.rsrc_dyn(srcy0_off, lane, literal=literal), ctx.rvgpr_dyn(vsrcy1_reg, lane) all_stores = [] srcs:dict[str, UOp | int] = {} for op, src0_off, vsrc1_reg, vdst_reg, label in [(inst.opx, srcx0_off, vsrcx1_reg, vdstx_reg, 'X'), (inst.opy, srcy0_off, vsrcy1_reg, vdsty_reg, 'Y')]: vop = VOPD_TO_VOP2.get(op) assert vop is not None, f"no VOP mapping for VOPD {label}: {op}" if label == 'Y': srcs = {'S0': srcy0, 'S1': srcy1, 'D0': ctx.rvgpr_dyn(vdst_reg, lane)} else: srcs = {'S0': ctx.rsrc_dyn(src0_off, lane, literal=literal), 'S1': ctx.rvgpr_dyn(vsrc1_reg, lane), 'D0': ctx.rvgpr_dyn(vdst_reg, lane)} # VOP2_FMAAK/FMAMK_(DTYPE)_E32 if vop in (ir3.VOP2Op.V_FMAAK_F32_E32, ir3.VOP2Op.V_FMAMK_F32_E32, ir3.VOP2Op.V_FMAAK_F32_E32, ir3.VOP2Op.V_FMAMK_F32_E32): assert literal is not None srcs['SIMM32'] = literal if op in (ir3.VOPDOp.V_DUAL_CNDMASK_B32, ir4.VOPDOp.V_DUAL_CNDMASK_B32): srcs['VCC'] = ctx.rmask(_c(VCC_LO.offset)) pcode = get_pcode(vop) srcs.update({'VCC': ctx.rmask(_c(VCC_LO.offset)), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane}) for dest, val in parse_pcode(pcode, srcs)[1]: if dest.startswith('D0'): all_stores.append(ctx.wvgpr_dyn(vdst_reg, lane, _val_to_u32(val), exec_mask, after=srcy1)) return UOp.sink(UOp.group(*all_stores).end(lane), *ctx.inc_pc()) def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLAT|ir4.VGLOBAL|ir4.VSCRATCH |irc.DS|irc.FLAT|irc.GLOBAL|irc.SCRATCH, ctx: _Ctx) -> UOp: """Unified memory operation compiler for DS, FLAT, GLOBAL, SCRATCH.""" exec_mask, op_name = ctx.rexec(), _op_name(inst) pcode = get_pcode(inst.op) # CDNA pcode uses CalcGlobalAddr/CalcDsAddr to compute address from raw components, but make_addr already handles this. # Strip the addr computation line and use pre-computed ADDR directly (rename 'addr' -> 'ADDR' in remaining pcode). if isinstance(inst, (irc.GLOBAL, irc.FLAT, irc.SCRATCH, irc.DS, ir4.VSCRATCH)) and 'Calc' in pcode and 'Addr' in pcode: pcode = re.sub(r'addr\s*=\s*Calc\w+Addr\([^)]*\)\s*;?\n?', '', pcode).replace('MEM[addr', 'MEM[ADDR') is_lds = isinstance(inst, (ir3.DS, ir4.DS, irc.DS)) is_scratch = isinstance(inst, (ir3.SCRATCH, ir4.VSCRATCH, irc.SCRATCH)) # CDNA acc bit: when set, VGPR operands (vdst/vdata) target ACCVGPR file instead of VGPR use_acc = bool(getattr(inst, 'acc', 0)) mem = ctx.lds if is_lds else ctx.scratch if is_scratch else ctx.vmem addr_shift = UOp.const(dtypes.uint32 if is_lds else dtypes.uint64, 2) # Extract register info - all dynamic for deduplication if is_lds: addr_reg = ctx.inst_field(type(inst).addr) # type: ignore[union-attr] vdata_reg = ctx.inst_field(type(inst).data0) # type: ignore[union-attr] vdst_reg = ctx.inst_field(type(inst).vdst) offset0 = ctx.inst_field(type(inst).offset0) # type: ignore[union-attr] offset1 = ctx.inst_field(type(inst).offset1) # type: ignore[union-attr] offset = (offset1 << _c(8)) | offset0 # DS offset is 16-bit: (offset1 << 8) | offset0 saddr_reg = None elif isinstance(inst, (ir4.VGLOBAL, ir4.VSCRATCH, ir4.VFLAT)): # RDNA4: vaddr, vsrc, ioffset addr_reg = ctx.inst_field(type(inst).vaddr) vdata_reg = ctx.inst_field(type(inst).vsrc) vdst_reg = ctx.inst_field(type(inst).vdst) offset = ctx.inst_field_signed(type(inst).ioffset) offset0, offset1 = _c(0), _c(0) saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(type(inst), 'saddr') else None else: # RDNA3: addr, data, offset addr_reg = ctx.inst_field(type(inst).addr) # type: ignore[union-attr] vdata_reg = ctx.inst_field(type(inst).data) # type: ignore[union-attr] vdst_reg = ctx.inst_field(type(inst).vdst) offset = ctx.inst_field_signed(type(inst).offset) # type: ignore[union-attr] offset0, offset1 = _c(0), _c(0) saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(type(inst), 'saddr') else None # type: ignore[union-attr] # Data width from canonical_op_bits (32/64/96/128), default to 32 for untyped ops data_bits_mem = inst.canonical_op_bits.get('data', 32) is_atomic, glc = 'ATOMIC' in op_name, getattr(inst, 'glc', 0) has_data1 = is_lds and hasattr(inst, 'data1') and inst.data1 is not None data1_reg = ctx.inst_field(type(inst).data1) if is_lds else _c(0) # type: ignore[union-attr] # DS_PERMUTE/DS_BPERMUTE: cross-lane VGPR access via pcode if is_lds and 'PERMUTE' in op_name: pcode = get_pcode(inst.op) srcs = {'ADDR': addr_reg, 'DATA0': vdata_reg, 'VDST': vdst_reg, 'OFFSET': offset, 'EXEC': exec_mask.cast(dtypes.uint64), '_vgpr': ctx.vgpr, '_wave_size': ctx.wave_size} _, assigns = parse_pcode(pcode, srcs) stores = [ctx.vgpr.index(val[0].cast(dtypes.int)).store(val[1].cast(dtypes.uint32)) for dest, val in assigns if dest.startswith('VGPR[')] return UOp.sink(*stores, *ctx.inc_pc()) def make_addr(lane: UOp) -> UOp: if is_lds: addr = ctx.rvgpr_dyn(addr_reg, lane) # Some DS pcode (e.g. DS_STORE_B16) uses MEM[ADDR] without adding OFFSET explicitly. # In those cases, add the instruction offset to ADDR here. if 'OFFSET' not in pcode: addr = addr + offset return addr offset64 = offset.cast(dtypes.uint64) # Dynamic saddr check: saddr < 124 means valid SGPR, otherwise use VGPR pair for address use_saddr = (saddr_reg < _c(124)) if saddr_reg is not None else UOp.const(dtypes.bool, False) if is_scratch: scratch_stride = ctx.rsgpr_dyn(_c(SCRATCH_STRIDE_IDX)).cast(dtypes.uint64) base = lane.cast(dtypes.uint64) * scratch_stride # SVE (Scratch VGPR Enable): when SVE=1, VADDR is used as offset; when SVE=0, VADDR is ignored sve = getattr(inst, 'sve', 0) vaddr = ctx.rvgpr_dyn(addr_reg, lane).cast(dtypes.uint64) addr_offset = vaddr if sve == 1 else UOp.const(dtypes.uint64, 0) # Add saddr value only if use_saddr is true (saddr < 124) saddr_contrib = use_saddr.where(ctx.rsgpr_dyn(saddr_reg).cast(dtypes.uint64), UOp.const(dtypes.uint64, 0)) \ if saddr_reg is not None else UOp.const(dtypes.uint64, 0) return base + addr_offset + saddr_contrib + offset64 # FLAT/GLOBAL: choose between SGPR base (saddr) or VGPR pair (addr) based on saddr validity saddr_base = _u64(ctx.rsgpr_dyn(saddr_reg), ctx.rsgpr_dyn(saddr_reg + _c(1))) if saddr_reg is not None else UOp.const(dtypes.uint64, 0) vaddr_base = _u64(ctx.rvgpr_dyn(addr_reg, lane), ctx.rvgpr_dyn(addr_reg + _c(1), lane)) # When saddr is valid: base = saddr pair, vaddr is 32-bit offset; otherwise: base = 0, vaddr is 64-bit address base_addr = use_saddr.where(saddr_base + ctx.rvgpr_dyn(addr_reg, lane).cast(dtypes.uint64), vaddr_base) return base_addr + offset64 def wmem(addr: UOp, val: UOp, active: UOp, data_bits: int = 32) -> UOp: if data_bits < 32: # Sub-dword LDS write: read-modify-write within the uint32 slot word_addr = (addr >> addr_shift).cast(dtypes.int) idx = mem.index(word_addr.valid(active)) byte_pos = addr.cast(dtypes.uint32) & _c(3) byte_shift = byte_pos * _c(8) size_mask = _c(0xFF if data_bits == 8 else 0xFFFF) mask = size_mask << byte_shift new_word = (idx & (mask ^ _c(0xFFFFFFFF))) | ((val.cast(dtypes.uint32) & size_mask) << byte_shift) return idx.store(active.where(new_word, idx)) idx = mem.index((addr >> addr_shift).cast(dtypes.int)) return idx.store(active.where(val, idx.load())) def make_srcs(lane: UOp) -> dict: addr = make_addr(lane) if is_lds: if data_bits_mem == 128: data = {'DATA': ctx.rvgpr_dyn(vdata_reg, lane), 'DATA1': ctx.rvgpr_dyn(vdata_reg + _c(1), lane), 'DATA2': ctx.rvgpr_dyn(vdata_reg + _c(2), lane), 'DATA3': ctx.rvgpr_dyn(vdata_reg + _c(3), lane)} elif data_bits_mem == 96: data = {'DATA': ctx.rvgpr_dyn(vdata_reg, lane), 'DATA1': ctx.rvgpr_dyn(vdata_reg + _c(1), lane), 'DATA2': ctx.rvgpr_dyn(vdata_reg + _c(2), lane)} elif data_bits_mem <= 32: data = {'DATA': ctx.rvgpr_dyn(vdata_reg, lane), 'DATA2': ctx.rvgpr_dyn(data1_reg, lane) if has_data1 else UOp.const(dtypes.uint32, 0)} else: data = {'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)), 'DATA2': _u64(ctx.rvgpr_dyn(data1_reg, lane), ctx.rvgpr_dyn(data1_reg + _c(1), lane)) if has_data1 else UOp.const(dtypes.uint64, 0)} # RDNA3 uses ADDR/OFFSET, RDNA4 uses vgpr_a/offset (lowercase) + CalcDsAddr function return {'ADDR': addr, 'ADDR_BASE': addr, 'OFFSET': offset, 'OFFSET0': offset0, 'OFFSET1': offset1, '_lds': mem, 'laneId': lane, 'vgpr_a': ctx.rvgpr_dyn(addr_reg, lane), 'offset': offset, 'offset0': offset0, 'offset1': offset1, **data} active = _lane_active(exec_mask, lane) # saddr < 124 means valid SGPR pair, otherwise use 0 (NULL means no saddr contribution) use_saddr = (saddr_reg < _c(124)) if saddr_reg is not None else UOp.const(dtypes.bool, False) saddr_raw = _u64(ctx.rsgpr_dyn(saddr_reg), ctx.rsgpr_dyn(saddr_reg + _c(1))) if saddr_reg is not None else UOp.const(dtypes.uint64, 0) saddr_base = use_saddr.where(saddr_raw, UOp.const(dtypes.uint64, 0)) # Sign-extend offset to 64-bit for the final address calculation ioffset64 = offset.cast(dtypes.int64).cast(dtypes.uint64) # v_addr for CalcGlobalAddr: when saddr valid, use low 32 bits as offset; otherwise full 64-bit address. Include ioffset. vaddr_full = _u64(ctx.rvgpr_dyn(addr_reg, lane), ctx.rvgpr_dyn(addr_reg + _c(1), lane)) vaddr_lo = ctx.rvgpr_dyn(addr_reg, lane).cast(dtypes.uint64) vaddr_base = use_saddr.where(vaddr_lo + ioffset64, vaddr_full + ioffset64) if is_atomic: atomic_data = _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)) \ if data_bits_mem == 64 else ctx.rvgpr_dyn(vdata_reg, lane) return {'ADDR': addr, 'DATA': atomic_data, '_vmem': mem, '_active': active, 'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base} # acc bit: read/write ACCVGPR instead of VGPR for data operands _rvdata = (lambda r, l, *a: ctx.raccvgpr_dyn(r, l)) if use_acc else ctx.rvgpr_dyn vdata = _rvdata(vdata_reg, lane).cast(dtypes.uint64) if 'STORE' in op_name \ else _rvdata(vdst_reg, lane) if 'D16' in op_name else UOp.const(dtypes.uint32, 0) if 'STORE' in op_name and data_bits_mem >= 64: vdata = vdata | (_rvdata(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32)) srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active, 'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base, 'SADDR': saddr_base, 'OFFSET': offset} for i in range(data_bits_mem // 32): srcs[f'VDATA{i}'] = _rvdata(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0) return srcs def make_stores(dest: str, val: UOp, lane: UOp, active: UOp, writes_return_data: bool) -> list[UOp]: # Parse bit width from dest format: MEM[...].b32 or RETURN_DATA[63:32].b64 parts = dest.rsplit('.', 1) data_bits = int(parts[1][1:]) if len(parts) == 2 else 32 if dest.startswith('MEM['): if is_lds or is_atomic: if data_bits < 32 and is_lds: return [wmem(val[0], val[1], active, data_bits)] return _write_val(data_bits, val[1], wmem, val[0], active, is_mem=True) if is_scratch: return _mem_store_bytes(mem, val[0], val[1], active, data_bits) return _mem_store(mem, val[0], val[1], active, 64, data_bits) if dest.startswith('RETURN_DATA') and writes_return_data: _wdata = (lambda r, v, l, e: ctx.waccvgpr_dyn(r, l, v, e)) if use_acc else (lambda r, v, l, e: ctx.wvgpr_dyn(r, l, v, e)) if (m := re.match(r'RETURN_DATA\[(\d+)\s*:\s*(\d+)\]', dest)): bit_width, dword_idx = int(m.group(1)) - int(m.group(2)) + 1, int(m.group(2)) // 32 return _write_val(bit_width, val, _wdata, vdst_reg + _c(dword_idx), lane, exec_mask) return _write_val(data_bits, val, _wdata, vdst_reg, lane, exec_mask) return [] # DS-specific: check for 2ADDR pattern needing separate ranges if is_lds: dummy_lane = ctx.range() _, assigns = parse_pcode(pcode, make_srcs(dummy_lane)) mem_assigns = [d for d, _ in assigns if d.startswith('MEM[')] mem_addrs = set(m.group(1) if (m := re.match(r'MEM\[([^\]]+)\]', d)) else d for d in mem_assigns) use_separate_ranges = (len(mem_addrs) > 1 or '2ADDR' in op_name) and 'STOREXCHG' not in op_name if use_separate_ranges: # Split assigns into MEM writes (stores) and RETURN_DATA writes (loads). # Stores to different addresses need separate lane ranges. Loads must share a single lane range so the # addr vgpr is read before any vdst write (hardware reads addr once, then writes all results). store_assigns = [(i, d) for i, (d, _) in enumerate(assigns) if d.startswith('MEM[')] load_assigns = [(i, d) for i, (d, _) in enumerate(assigns) if d.startswith('RETURN_DATA')] ended: list[UOp] = [] for i, dest in store_assigns: lane = ctx.range() active = _lane_active(exec_mask, lane) _, lane_assigns = parse_pcode(pcode, make_srcs(lane)) ended.extend(s.end(lane) for s in make_stores(dest, lane_assigns[i][1], lane, active, True)) if load_assigns: lane = ctx.range() active = _lane_active(exec_mask, lane) _, lane_assigns = parse_pcode(pcode, make_srcs(lane)) load_stores: list[UOp] = [] for i, dest in load_assigns: load_stores.extend(make_stores(dest, lane_assigns[i][1], lane, active, True)) if load_stores: ended.append(UOp.group(*load_stores).end(lane)) return UOp.sink(*ended, *ctx.inc_pc()) # Standard path: single lane range writes_return_data = '_RTN' in op_name or (is_lds and (op_name.startswith('DS_LOAD') or op_name.startswith('DS_READ'))) or bool(is_atomic and glc) lane = ctx.range() active = _lane_active(exec_mask, lane) pcode_vars, assigns = parse_pcode(pcode, make_srcs(lane)) stores = [s for dest, val in assigns for s in make_stores(dest, val, lane, active, writes_return_data)] # FLAT/GLOBAL/SCRATCH: collect VDATA slices for loads if not is_lds and not is_atomic: _wdst = ctx.waccvgpr_dyn if use_acc else ctx.wvgpr_dyn for dword_idx, val in sorted(_collect_data_slices(assigns, 'VDATA', pcode_vars, op_name).items()): stores.append(_wdst(vdst_reg + _c(dword_idx), lane, val, exec_mask)) return UOp.sink(UOp.group(*stores).end(lane), *ctx.inc_pc()) def _compile_mubuf(inst: irc.MUBUF, ctx: _Ctx) -> UOp: """CDNA MUBUF: linear buffer address = base + soffset + (stride * index) + vgpr_offset + inst_offset""" exec_mask, op_name = ctx.rexec(), _op_name(inst) use_acc, is_store, is_lds = bool(getattr(inst, 'acc', 0)), 'STORE' in op_name, bool(getattr(inst, 'lds', 0)) n_dwords = 4 if 'X4' in op_name else 2 if 'X2' in op_name else 1 # instruction fields vdata, vaddr = ctx.inst_field(type(inst).vdata), ctx.inst_field(type(inst).vaddr) srsrc, soffset = ctx.inst_field(type(inst).srsrc) * _c(4), ctx.inst_field(type(inst).soffset) offset, offen, idxen = ctx.inst_field(type(inst).offset), ctx.inst_field(type(inst).offen), ctx.inst_field(type(inst).idxen) # V# descriptor: base[0:1], num_records[2], stride=word3[13:0] base = _u64(ctx.rsgpr_dyn(srsrc), ctx.rsgpr_dyn(srsrc + _c(1))) & UOp.const(dtypes.uint64, 0xFFFFFFFFFFFF) num_records = ctx.rsgpr_dyn(srsrc + _c(2)) stride = (ctx.rsgpr_dyn(srsrc + _c(3)) & _c(0x3FFF)).cast(dtypes.uint64) lane = ctx.range() active = _lane_active(exec_mask, lane) # soffset: sgpr if < 128, else inline constant soff = (soffset < _c(128)).where(ctx.rsgpr_dyn(soffset), soffset - _c(128)).cast(dtypes.uint64) # vaddr: index (if idxen) in vaddr, offset (if offen) in vaddr or vaddr+1 index = idxen.ne(_c(0)).where(ctx.rvgpr_dyn(vaddr, lane), _c(0)).cast(dtypes.uint64) voff = offen.ne(_c(0)).where(ctx.rvgpr_dyn(idxen.ne(_c(0)).where(vaddr + _c(1), vaddr), lane), _c(0)).cast(dtypes.uint64) # buffer_offset for bounds check, final address buffer_offset = (stride * index + voff + offset.cast(dtypes.uint64)).cast(dtypes.uint32) in_bounds = active & buffer_offset.__lt__(num_records) addr = base + soff + buffer_offset.cast(dtypes.uint64) addr = in_bounds.where(addr, UOp.const(dtypes.uint64, 0)) # safe address when OOB mem = ctx.vmem stores: list[UOp] = [] if is_lds and not is_store: # LDS load: buffer -> LDS (bypass VGPRs), LDS addr = M0[17:0] + lane * elem_size lds_base = ctx.rsgpr_dyn(_c(124)) & _c(0x3FFFF) lds_addr = lds_base + lane.cast(dtypes.uint32) * _c(n_dwords * 4) for i in range(n_dwords): word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2) val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64), ptr=True).load(), _c(0)) lds_idx = ((lds_addr + _c(i * 4)) >> _c(2)).cast(dtypes.int) lds_slot = ctx.lds.index(lds_idx.valid(active)) stores.append(lds_slot.store(active.where(val, lds_slot))) elif is_store: for i in range(n_dwords): word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2) idx = mem.index(word_addr.cast(dtypes.int64).valid(in_bounds)) val = (ctx.raccvgpr_dyn if use_acc else ctx.rvgpr_dyn)(vdata + _c(i), lane) stores.append(idx.store(in_bounds.where(_to_u32(val), idx))) else: for i in range(n_dwords): word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2) val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64).valid(in_bounds), ptr=True).load(), _c(0)) stores.append((ctx.waccvgpr_dyn if use_acc else ctx.wvgpr_dyn)(vdata + _c(i), lane, val, exec_mask)) return UOp.sink(UOp.group(*stores).end(lane), *ctx.inc_pc()) # Dispatch table: instruction type -> handler function _INST_HANDLERS: dict[type, Callable[..., UOp]] = { ir3.SOPP: _compile_sopp, ir3.SMEM: _compile_smem, ir3.SOP1: _compile_sop, ir3.SOP2: _compile_sop, ir3.SOPC: _compile_sop, ir3.SOPK: _compile_sop, ir3.VOP1: _compile_vop12, ir3.VOP1_SDST: _compile_vop12, ir3.VOP1_DPP16: _compile_vop12, ir3.VOP2: _compile_vop12, ir3.VOP2_DPP16: _compile_vop12, ir3.VOPC: _compile_vopc, ir3.VOPC_DPP16: _compile_vopc, ir3.VOP3: _compile_vop3, ir3.VINTERP: _compile_vinterp, ir3.VOP3_SDST: _compile_vop3, ir3.VOP3SD: _compile_vop3sd, ir3.VOP3P: _compile_vop3p, ir3.VOPD: _compile_vopd, ir3.DS: _compile_mem_op, ir3.FLAT: _compile_mem_op, ir3.GLOBAL: _compile_mem_op, ir3.SCRATCH: _compile_mem_op, # RDNA4 instruction classes ir4.SOPP: _compile_sopp, ir4.SMEM: _compile_smem, ir4.SOP1: _compile_sop, ir4.SOP2: _compile_sop, ir4.SOPC: _compile_sop, ir4.SOPK: _compile_sop, ir4.VOP1: _compile_vop12, ir4.VOP1_SDST: _compile_vop12, ir4.VOP1_DPP16: _compile_vop12, ir4.VOP2: _compile_vop12, ir4.VOP2_DPP16: _compile_vop12, ir4.VOPC: _compile_vopc, ir4.VOPC_DPP16: _compile_vopc, ir4.VOP3: _compile_vop3, ir4.VINTERP: _compile_vinterp, ir4.VOP3_SDST: _compile_vop3, ir4.VOP3SD: _compile_vop3sd, ir4.VOP3P: _compile_vop3p, ir4.VOPD: _compile_vopd, ir4.DS: _compile_mem_op, ir4.VFLAT: _compile_mem_op, ir4.VGLOBAL: _compile_mem_op, ir4.VSCRATCH: _compile_mem_op, # CDNA instruction classes irc.SOPP: _compile_sopp, irc.SMEM: _compile_smem, irc.SOP1: _compile_sop, irc.SOP2: _compile_sop, irc.SOPC: _compile_sop, irc.SOPK: _compile_sop, irc.VOP1: _compile_vop12, irc.VOP1_DPP16: _compile_vop12, irc.VOP2: _compile_vop12, irc.VOP2_DPP16: _compile_vop12, irc.VOPC: _compile_vopc, irc.VOP3: _compile_vop3, irc.VOP3_SDST: _compile_vop3, irc.VOP3SD: _compile_vop3sd, irc.VOP3P: _compile_vop3p, irc.VOP1_SDWA: _compile_sdwa, irc.VOP2_SDWA: _compile_sdwa, irc.VOP2_SDWA_SDST: _compile_sdwa, irc.VOPC_SDWA_SDST: _compile_sdwa, irc.DS: _compile_mem_op, irc.FLAT: _compile_mem_op, irc.GLOBAL: _compile_mem_op, irc.SCRATCH: _compile_mem_op, irc.MUBUF: _compile_mubuf, } # ═══════════════════════════════════════════════════════════════════════════════ # PROGRAM DECODE AND COMPILATION # ═══════════════════════════════════════════════════════════════════════════════ _canonical_runner_cache: list[tuple[type, int, int, int, tuple[UOp, object]]] = [] # [(inst_type, base, mask, size, (prg, runtime)), ...] @functools.cache def _get_runner(inst_bytes: bytes, arch: str = "rdna3"): """Build and compile instruction to (prg, runtime). Cached by instruction bytes, with canonical dedup.""" inst = decode_inst(inst_bytes, arch) inst_size = inst.size() inst_int = int.from_bytes(inst_bytes[:inst_size], 'little') # Check if instruction matches any cached canonical pattern (must also match instruction type to avoid variant conflicts) for inst_type, base, mask, size, entry in _canonical_runner_cache: if type(inst) is inst_type and inst_size == size and (inst_int & mask) == base: return entry # Look up handler by type, falling back to base classes for _LIT variants handler = _INST_HANDLERS.get(type(inst)) if handler is None: for cls in type(inst).__mro__: if cls in _INST_HANDLERS: handler = _INST_HANDLERS[cls] break if handler is None: raise RuntimeError(f"[emu] unimplemented instruction type: {type(inst).__name__} {_op_name(inst)}") ctx = _Ctx(inst_size, _wave_size(arch)) sink = handler(inst, ctx) base, mask, size = ctx.canonical_mask(inst_bytes) canonical_name = f"{_op_name(inst).lower()}_{base.to_bytes(size, 'little').hex()}" sink = sink.replace(arg=KernelInfo(name=canonical_name)).rtag(1) # NOTE: renderer output is not reproducible because of _MXCSRContext. PROFILE=0 prevents emulator instruction runners from polluting profiling. with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES="", CAPTURE_PROCESS_REPLAY=0, PROFILE=0): prg = to_program(sink, Device['CPU'].renderer) runtime = get_runtime('CPU', prg) _canonical_runner_cache.append((type(inst), base, mask, size, (prg, runtime))) return prg, runtime _BARRIER_OPS = {ir3.SOPPOp.S_BARRIER, irc.SOPPOp.S_BARRIER} if hasattr(ir4.SOPPOp, 'S_BARRIER_WAIT'): _BARRIER_OPS.add(ir4.SOPPOp.S_BARRIER_WAIT) _BARRIER_SOP1_OPS: set = set() if hasattr(ir4.SOP1Op, 'S_BARRIER_SIGNAL'): _BARRIER_SOP1_OPS.add(ir4.SOP1Op.S_BARRIER_SIGNAL) _BRANCH_OPS: set[int] = {op.value for op in (ir3.SOPPOp.S_BRANCH, ir3.SOPPOp.S_CBRANCH_SCC0, ir3.SOPPOp.S_CBRANCH_SCC1, ir3.SOPPOp.S_CBRANCH_VCCZ, ir3.SOPPOp.S_CBRANCH_VCCNZ, ir3.SOPPOp.S_CBRANCH_EXECZ, ir3.SOPPOp.S_CBRANCH_EXECNZ)} def _decode_at(pc: int, arch: str): """Decode and compile instruction at absolute address pc. Returns (runner, decoded_inst).""" inst_bytes = bytes((ctypes.c_char * 16).from_address(pc).raw) inst = decode_inst(inst_bytes, arch) try: return _get_runner(bytes(inst_bytes[:inst.size() + 4]), arch), inst except Exception as e: try: inst_str = repr(inst) except Exception: inst_str = f"<{type(inst).__name__}>" raise RuntimeError(f"[emu] Failed to compile {inst_str}: {type(e).__name__}: {e}") from e # ═══════════════════════════════════════════════════════════════════════════════ # WAVE STATE # ═══════════════════════════════════════════════════════════════════════════════ # Inline float constants (as bit patterns) for GPU instructions F32_INLINE = {240: 0x3f000000, 241: 0xbf000000, 242: 0x3f800000, 243: 0xbf800000, # 0.5, -0.5, 1.0, -1.0 244: 0x40000000, 245: 0xc0000000, 246: 0x40800000, 247: 0xc0800000, 248: 0x3e22f983} # 2.0, -2.0, 4.0, -4.0, 1/(2*pi) class WaveState: __slots__ = ('vgpr_buf', 'sgpr_buf', 'accvgpr_buf', '_vgpr_mv', '_sgpr_mv', 'n_lanes', 'wave_size') def __init__(self, n_lanes: int, wave_size: int = 32): self.n_lanes, self.wave_size = n_lanes, wave_size vgpr_size = 256 * wave_size self.vgpr_buf = Buffer('CPU', vgpr_size, dtypes.uint32).ensure_allocated() self.sgpr_buf = Buffer('CPU', SGPR_COUNT, dtypes.uint32).ensure_allocated() # CDNA (wave64) has separate ACCVGPR file; RDNA shares with VGPR if wave_size == 64: self.accvgpr_buf = Buffer('CPU', vgpr_size, dtypes.uint32).ensure_allocated() ctypes.memset(self.accvgpr_buf._buf.va_addr, 0, vgpr_size * 4) else: self.accvgpr_buf = self.vgpr_buf self._vgpr_mv = self.vgpr_buf.as_memoryview(force_zero_copy=True).cast('I') self._sgpr_mv = self.sgpr_buf.as_memoryview(force_zero_copy=True).cast('I') # Zero memory using ctypes memset (much faster than Python loops) ctypes.memset(self.vgpr_buf._buf.va_addr, 0, vgpr_size * 4) ctypes.memset(self.sgpr_buf._buf.va_addr, 0, SGPR_COUNT * 4) # Pre-populate inline constants at indices 128-255 for i in range(65): self._write_sgpr(128 + i, i) # 128-192: integers 0-64 for i in range(16): self._write_sgpr(193 + i, (-(i + 1)) & MASK32) # 193-208: -1 to -16 for off, val in F32_INLINE.items(): self._write_sgpr(off, val) # 240-248: float constants # EXEC mask: for 64-lane waves, set both EXEC_LO and EXEC_HI if wave_size == 64: self._write_sgpr(EXEC_LO.offset, (1 << min(n_lanes, 32)) - 1) self._write_sgpr(EXEC_LO.offset + 1, (1 << max(n_lanes - 32, 0)) - 1 if n_lanes > 32 else 0) else: self._write_sgpr(EXEC_LO.offset, (1 << n_lanes) - 1) self._write_sgpr(PC_LO_IDX, 0) self._write_sgpr(PC_HI_IDX, 0) def _write_sgpr(self, idx: int, val: int): self._sgpr_mv[idx] = val & MASK32 def _read_sgpr(self, idx: int) -> int: return self._sgpr_mv[idx] def _write_vgpr(self, reg: int, lane: int, val: int): self._vgpr_mv[reg * self.wave_size + lane] = val & MASK32 def _read_vgpr(self, reg: int, lane: int) -> int: return self._vgpr_mv[reg * self.wave_size + lane] @property def pc(self) -> int: return self._read_sgpr(PC_LO_IDX) | (self._read_sgpr(PC_HI_IDX) << 32) @pc.setter def pc(self, val: int): self._write_sgpr(PC_LO_IDX, val & MASK32) self._write_sgpr(PC_HI_IDX, (val >> 32) & MASK32) # ═══════════════════════════════════════════════════════════════════════════════ # EXECUTION # ═══════════════════════════════════════════════════════════════════════════════ def _init_wave(lib: int, wave_start: int, total_threads: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int, scratch_size: int, arch: str, gidx: int, gidy: int, gidz: int, user_data: list[int]|None, wave_size: int = 32) -> WaveState: """Initialize a single wavefront and return WaveState.""" n_lanes = min(wave_size, total_threads - wave_start) st = WaveState(n_lanes, wave_size) st.pc = lib if user_data: for i, val in enumerate(user_data): st._write_sgpr(i, val) else: st._write_sgpr(0, args_ptr & MASK32) st._write_sgpr(1, (args_ptr >> 32) & MASK32) if arch == "rdna4": # workgroup IDs only exist in ttmp registers, not normal SGPRs st._write_sgpr(ttmp[7].offset, (gidy & 0xFFFF) | ((gidz & 0xFFFF) << 16)) st._write_sgpr(ttmp[9].offset, gidx) else: sgpr_idx = (rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT_SHIFT for enabled, gid in [(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_X, gidx), (hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Y, gidy), (hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Z, gidz)]: if rsrc2 & enabled: st._write_sgpr(sgpr_idx, gid) sgpr_idx += 1 for lane in range(n_lanes): tid = wave_start + lane st._write_vgpr(0, lane, ((tid // (lx * ly)) << 20) | (((tid // lx) % ly) << 10) | (tid % lx)) st._write_sgpr(SCRATCH_STRIDE_IDX, scratch_size) # Store HW register values at SGPR[SGPR_COUNT-16 .. SGPR_COUNT-1] for s_getreg_b32 emulation. # HW_ID (hwRegId=4): WAVE_ID[3:0], SIMD_ID[5:4], PIPE_ID[7:6], CU_ID[11:8], ... wave_idx = wave_start // wave_size # wave index within this workgroup (0, 1, 2, 3 for 256 threads / 64 wave_size) hw_id = (wave_idx & 0xF) | ((wave_idx & 0x3) << 4) # WAVE_ID = wave_idx, SIMD_ID = wave_idx % 4 st._write_sgpr(SGPR_COUNT - 16 + 4, hw_id) # HW_REGISTERS[4] = HW_ID return st def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c, scratch_size: int = 0, arch: str = "rdna3", user_data: list[int]|None = None) -> int: """Execute AMD assembly program. scratch_size is private_segment_fixed_size from kernel descriptor (per-lane).""" from tinygrad.renderer.amd.dsl import Inst program: dict[int, tuple[Callable, list[int], bool, Inst]] = {} # pc -> (fxn, globals, is_barrier, inst) lds_size = ((rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE_SHIFT) * 512 total_threads = lx * ly * lz wave_size = _wave_size(arch) # Use Buffer objects with external_ptr=0 for vmem vmem_buf = Buffer('CPU', 1 << 40, dtypes.uint32, options=BufferSpec(external_ptr=0)).ensure_allocated() lds_buf = Buffer('CPU', max(lds_size // 4, 1), dtypes.uint32).ensure_allocated() scratch_buf = Buffer('CPU', scratch_size * wave_size, dtypes.uint8).ensure_allocated() if scratch_size else None # Initialize SQTT encoder — emits packets inline as instructions execute (only when profiling) if PROFILE: sqtt_emit, sqtt_finish, sqtt_finalize = _init_sqtt_encoder() def _ensure_compiled(pc: int) -> tuple[Callable, list[int], bool, Inst]: if pc not in program: prev_len = len(_canonical_runner_cache) (prg, runtime), inst = _decode_at(pc, arch) is_barrier = (isinstance(inst, (ir3.SOPP, ir4.SOPP, irc.SOPP)) and inst.op in _BARRIER_OPS) or \ (isinstance(inst, (ir4.SOP1,)) and inst.op in _BARRIER_SOP1_OPS) program[pc] = (runtime.fxn, prg.arg.globals, is_barrier, inst) if DEBUG >= 3: msg = f"[emu] PC={pc - lib}: {inst!r}" print(colored(msg, 'green') if len(_canonical_runner_cache) > prev_len else msg) return program[pc] # Set DAZ+FTZ during emulator execution, restore afterward to avoid breaking hypothesis tests # Only trace the first workgroup (like real HW traces one CU/SIMD), subsequent workgroups run but don't add to trace tracing = bool(PROFILE) with _MXCSRContext(): for gidz in range(gz): for gidy in range(gy): for gidx in range(gx): # Initialize all wavefronts for this workgroup waves: list[tuple[WaveState, list]] = [] for wave_start in range(0, total_threads, wave_size): st = _init_wave(lib, wave_start, total_threads, lx, ly, lz, args_ptr, rsrc2, scratch_size, arch, gidx, gidy, gidz, user_data, wave_size) c_bufs = [ctypes.c_uint64(st.sgpr_buf._buf.va_addr), ctypes.c_uint64(st.vgpr_buf._buf.va_addr), ctypes.c_uint64(vmem_buf._buf.va_addr), ctypes.c_uint64(lds_buf._buf.va_addr), ctypes.c_uint64(scratch_buf._buf.va_addr if scratch_buf else 0), ctypes.c_uint64(st.accvgpr_buf._buf.va_addr)] waves.append((st, c_bufs)) # Execute wavefronts with barrier synchronization # Each wave runs until it hits s_barrier or s_endpgm. When all waves have stopped, release barrier waves. done = [False] * len(waves) for total_inst in range(10_000_000): if all(done): break for wi, (st, c_bufs) in enumerate(waves): if done[wi]: continue # Run this wave until barrier or endpgm for _ in range(1_000_000): pc = st.pc if pc == ENDPGM_PC: done[wi] = True if tracing: sqtt_finish(wi) break fxn, globals_list, is_barrier, inst = _ensure_compiled(pc) if DEBUG >= 5: print(f" exec gid=({gidx},{gidy},{gidz}) w={wi} PC={pc - lib}: {inst!r}", flush=True) fxn(*[c_bufs[g] for g in globals_list]) if tracing: inst_op = inst.op.value if hasattr(inst, 'op') else 0 sqtt_emit(wi, inst, (st.pc != ENDPGM_PC and st.pc != pc + inst.size()) if inst_op in _BRANCH_OPS else None) if is_barrier: break # s_barrier hit: PC already advanced past it, pause this wave else: raise RuntimeError("exceeded 1M instructions in single wave, likely infinite loop") # All waves have either hit barrier or endpgm — release barrier waves for next round else: raise RuntimeError("exceeded 10M total scheduling rounds") tracing = False # only trace the first workgroup # Reset LDS for next workgroup if lds_size > 0: ctypes.memset(lds_buf._buf.va_addr, 0, max(lds_size, 4)) if PROFILE: sqtt_traces.append(sqtt_finalize()) return 0