assembly/amd: remove IMG instruction support and asm.py (#14163)

* assembly/amd: return IMG instruction supports

* remove asm.py

* op2dsl
This commit is contained in:
George Hotz
2026-01-17 06:21:50 +09:00
committed by GitHub
parent dc4ae7dd08
commit 7d1d9d4568
10 changed files with 117 additions and 916 deletions

View File

@@ -1,478 +0,0 @@
# RDNA3/RDNA4/CDNA assembler
from __future__ import annotations
import re
from extra.assembly.amd.dsl import Reg, s, v, ttmp
from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL
# Assembler-specific types (not part of clean DSL)
class RawImm(Reg):
"""Raw immediate value - bypasses normal encoding, used for special register encodings."""
def __init__(self, val: int): super().__init__(val, 1)
class SrcMod(Reg):
"""Source with modifiers - wraps a value with neg/abs flags."""
def __init__(self, val: int, neg: bool = False, abs_: bool = False):
super().__init__(255 if not (-16 <= val <= 64) else (128 + val if val >= 0 else 192 - val), 1)
self.val, self.neg, self.abs_ = val, neg, abs_
# Type aliases for register factories
_RegFactory = type(s)
SGPR, VGPR, TTMP = s, v, ttmp
OFF = NULL # OFF is alias for NULL (encoding 124)
# Float encoding constants
FLOAT_ENC = {0.5: 240, -0.5: 241, 1.0: 242, -1.0: 243, 2.0: 244, -2.0: 245, 4.0: 246, -4.0: 247}
from extra.assembly.amd.autogen.rdna3 import ins
from extra.assembly.amd.autogen.rdna3.ins import VOP2Op, VOPDOp, SOPKOp
from extra.assembly.amd.autogen.rdna4 import ins as rdna4_ins
# Re-export disasm for backwards compatibility
from extra.assembly.amd.disasm import disasm, HWREG, HWREG_RDNA4
# ═══════════════════════════════════════════════════════════════════════════════
# CONSTANTS
# ═══════════════════════════════════════════════════════════════════════════════
# RDNA unified buffer format (not in XML, hardcoded from AMD documentation)
BUF_FMT = {'BUF_FMT_8_UNORM': 1, 'BUF_FMT_8_SNORM': 2, 'BUF_FMT_8_USCALED': 3, 'BUF_FMT_8_SSCALED': 4,
'BUF_FMT_8_UINT': 5, 'BUF_FMT_8_SINT': 6, 'BUF_FMT_16_UNORM': 7, 'BUF_FMT_16_SNORM': 8,
'BUF_FMT_16_USCALED': 9, 'BUF_FMT_16_SSCALED': 10, 'BUF_FMT_16_UINT': 11, 'BUF_FMT_16_SINT': 12,
'BUF_FMT_16_FLOAT': 13, 'BUF_FMT_8_8_UNORM': 14, 'BUF_FMT_8_8_SNORM': 15, 'BUF_FMT_8_8_USCALED': 16,
'BUF_FMT_8_8_SSCALED': 17, 'BUF_FMT_8_8_UINT': 18, 'BUF_FMT_8_8_SINT': 19, 'BUF_FMT_32_UINT': 20,
'BUF_FMT_32_SINT': 21, 'BUF_FMT_32_FLOAT': 22, 'BUF_FMT_16_16_UNORM': 23, 'BUF_FMT_16_16_SNORM': 24,
'BUF_FMT_16_16_USCALED': 25, 'BUF_FMT_16_16_SSCALED': 26, 'BUF_FMT_16_16_UINT': 27, 'BUF_FMT_16_16_SINT': 28,
'BUF_FMT_16_16_FLOAT': 29, 'BUF_FMT_10_11_11_FLOAT': 30, 'BUF_FMT_11_11_10_FLOAT': 31,
'BUF_FMT_10_10_10_2_UNORM': 32, 'BUF_FMT_10_10_10_2_SNORM': 33, 'BUF_FMT_10_10_10_2_UINT': 34,
'BUF_FMT_10_10_10_2_SINT': 35, 'BUF_FMT_2_10_10_10_UNORM': 36, 'BUF_FMT_2_10_10_10_SNORM': 37,
'BUF_FMT_2_10_10_10_USCALED': 38, 'BUF_FMT_2_10_10_10_SSCALED': 39, 'BUF_FMT_2_10_10_10_UINT': 40,
'BUF_FMT_2_10_10_10_SINT': 41, 'BUF_FMT_8_8_8_8_UNORM': 42, 'BUF_FMT_8_8_8_8_SNORM': 43,
'BUF_FMT_8_8_8_8_USCALED': 44, 'BUF_FMT_8_8_8_8_SSCALED': 45, 'BUF_FMT_8_8_8_8_UINT': 46,
'BUF_FMT_8_8_8_8_SINT': 47, 'BUF_FMT_32_32_UINT': 48, 'BUF_FMT_32_32_SINT': 49, 'BUF_FMT_32_32_FLOAT': 50,
'BUF_FMT_16_16_16_16_UNORM': 51, 'BUF_FMT_16_16_16_16_SNORM': 52, 'BUF_FMT_16_16_16_16_USCALED': 53,
'BUF_FMT_16_16_16_16_SSCALED': 54, 'BUF_FMT_16_16_16_16_UINT': 55, 'BUF_FMT_16_16_16_16_SINT': 56,
'BUF_FMT_16_16_16_16_FLOAT': 57, 'BUF_FMT_32_32_32_UINT': 58, 'BUF_FMT_32_32_32_SINT': 59,
'BUF_FMT_32_32_32_FLOAT': 60, 'BUF_FMT_32_32_32_32_UINT': 61, 'BUF_FMT_32_32_32_32_SINT': 62,
'BUF_FMT_32_32_32_32_FLOAT': 63, 'BUF_FMT_8_FLOAT': 108}
def _parse_buf_fmt_combo(s: str) -> int:
parts = [p.strip().replace('BUF_DATA_FORMAT_', '').replace('BUF_NUM_FORMAT_', '') for p in s.split(',')]
return BUF_FMT.get(f'BUF_FMT_{parts[0]}_{parts[1]}') if len(parts) == 2 else None
def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int:
return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
# ═══════════════════════════════════════════════════════════════════════════════
# ASSEMBLER
# ═══════════════════════════════════════════════════════════════════════════════
SPEC_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), 'null': RawImm(124), 'off': RawImm(124), 'm0': RawImm(125),
'exec_lo': RawImm(126), 'exec_hi': RawImm(127), 'exec': RawImm(126), 'scc': RawImm(253), 'src_scc': RawImm(253)}
FLOATS = {str(k): k for k in FLOAT_ENC} # Valid float literal strings: '0.5', '-0.5', '1.0', etc.
REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp}
SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b96', 's_load_b128', 's_load_b256', 's_load_b512',
's_load_i8', 's_load_u8', 's_load_i16', 's_load_u16',
's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b96', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512',
's_buffer_load_i8', 's_buffer_load_u8', 's_buffer_load_i16', 's_buffer_load_u16',
's_atc_probe', 's_atc_probe_buffer'}
SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', 'm0': 'M0',
'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', 'src_scc': 'SCC'}
def _op2dsl(op: str) -> str:
op = op.strip()
neg = op.startswith('-') and not (op[1:2].isdigit() or (len(op) > 2 and op[1] == '0' and op[2] in 'xX'))
if neg: op = op[1:]
if op.startswith('neg(') and op.endswith(')'): neg = True; op = op[4:-1]
abs_ = (op.startswith('|') and op.endswith('|')) or (op.startswith('abs(') and op.endswith(')'))
if abs_: op = op[1:-1] if op.startswith('|') else op[4:-1]
hi = ".h" if op.endswith('.h') else ".l" if op.endswith('.l') else ""
if hi: op = op[:-2]
lo = op.lower()
def wrap(b): return f"{'-' if neg else ''}abs({b}){hi}" if abs_ else f"-{b}{hi}" if neg else f"{b}{hi}"
if lo in SPEC_DSL: return wrap(SPEC_DSL[lo])
if op in FLOATS: return wrap(op)
rp = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'}
if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return wrap(f"{rp[m.group(1)]}[{m.group(2)}:{m.group(3)}]")
if m := re.match(r'^([svt](?:tmp)?)(\d+)$', lo): return wrap(f"{rp[m.group(1)]}[{m.group(2)}]")
if re.match(r'^-?\d+$|^-?0x[0-9a-fA-F]+$', op): return f"SrcMod({op}, neg={neg}, abs_={abs_})" if neg or abs_ else op
return wrap(op)
def _parse_ops(s: str) -> list[str]:
ops, cur, depth, pipe = [], "", 0, False
for c in s:
if c in '[(': depth += 1
elif c in '])': depth -= 1
elif c == '|': pipe = not pipe
if c == ',' and depth == 0 and not pipe: ops.append(cur.strip()); cur = ""
else: cur += c
if cur.strip(): ops.append(cur.strip())
return ops
def _extract(text: str, pat: str, flags=re.I):
if m := re.search(pat, text, flags): return m, text[:m.start()] + ' ' + text[m.end():]
return None, text
# Instruction aliases: LLVM uses different names for some instructions
_ALIASES = {
'v_cmp_tru_f16': 'v_cmp_t_f16', 'v_cmp_tru_f32': 'v_cmp_t_f32', 'v_cmp_tru_f64': 'v_cmp_t_f64',
'v_cmpx_tru_f16': 'v_cmpx_t_f16', 'v_cmpx_tru_f32': 'v_cmpx_t_f32', 'v_cmpx_tru_f64': 'v_cmpx_t_f64',
'v_cvt_flr_i32_f32': 'v_cvt_floor_i32_f32', 'v_cvt_rpi_i32_f32': 'v_cvt_nearest_i32_f32',
'v_ffbh_i32': 'v_cls_i32', 'v_ffbh_u32': 'v_clz_i32_u32', 'v_ffbl_b32': 'v_ctz_i32_b32',
'v_cvt_pkrtz_f16_f32': 'v_cvt_pk_rtz_f16_f32', 'v_fmac_legacy_f32': 'v_fmac_dx9_zero_f32', 'v_mul_legacy_f32': 'v_mul_dx9_zero_f32',
's_load_dword': 's_load_b32', 's_load_dwordx2': 's_load_b64', 's_load_dwordx4': 's_load_b128',
's_load_dwordx8': 's_load_b256', 's_load_dwordx16': 's_load_b512',
's_buffer_load_dword': 's_buffer_load_b32', 's_buffer_load_dwordx2': 's_buffer_load_b64',
's_buffer_load_dwordx4': 's_buffer_load_b128', 's_buffer_load_dwordx8': 's_buffer_load_b256',
's_buffer_load_dwordx16': 's_buffer_load_b512',
'v_cvt_pknorm_i16_f16': 'v_cvt_pk_norm_i16_f16', 'v_cvt_pknorm_u16_f16': 'v_cvt_pk_norm_u16_f16',
'v_add3_nc_u32': 'v_add3_u32', 'v_xor_add_u32': 'v_xad_u32',
'v_interp_p2_new_f32': 'v_interp_p2_f32',
's_ff1_i32_b32': 's_ctz_i32_b32', 's_ff1_i32_b64': 's_ctz_i32_b64',
's_flbit_i32_b32': 's_clz_i32_u32', 's_flbit_i32_b64': 's_clz_i32_u64', 's_flbit_i32': 's_cls_i32', 's_flbit_i32_i64': 's_cls_i32_i64',
's_andn1_saveexec_b32': 's_and_not0_saveexec_b32', 's_andn1_saveexec_b64': 's_and_not0_saveexec_b64',
's_andn1_wrexec_b32': 's_and_not0_wrexec_b32', 's_andn1_wrexec_b64': 's_and_not0_wrexec_b64',
's_andn2_saveexec_b32': 's_and_not1_saveexec_b32', 's_andn2_saveexec_b64': 's_and_not1_saveexec_b64',
's_andn2_wrexec_b32': 's_and_not1_wrexec_b32', 's_andn2_wrexec_b64': 's_and_not1_wrexec_b64',
's_orn1_saveexec_b32': 's_or_not0_saveexec_b32', 's_orn1_saveexec_b64': 's_or_not0_saveexec_b64',
's_orn2_saveexec_b32': 's_or_not1_saveexec_b32', 's_orn2_saveexec_b64': 's_or_not1_saveexec_b64',
's_andn2_b32': 's_and_not1_b32', 's_andn2_b64': 's_and_not1_b64',
's_orn2_b32': 's_or_not1_b32', 's_orn2_b64': 's_or_not1_b64',
'v_dot2c_f32_f16': 'v_dot2acc_f32_f16',
'v_fma_legacy_f32': 'v_fma_dx9_zero_f32',
'ds_read_b32': 'ds_load_b32', 'ds_read_b64': 'ds_load_b64', 'ds_read_b96': 'ds_load_b96', 'ds_read_b128': 'ds_load_b128',
'ds_read_i8': 'ds_load_i8', 'ds_read_u8': 'ds_load_u8', 'ds_read_i16': 'ds_load_i16', 'ds_read_u16': 'ds_load_u16',
'ds_read_i8_d16': 'ds_load_i8_d16', 'ds_read_u8_d16': 'ds_load_u8_d16', 'ds_read_i8_d16_hi': 'ds_load_i8_d16_hi', 'ds_read_u8_d16_hi': 'ds_load_u8_d16_hi',
'ds_read_u16_d16': 'ds_load_u16_d16', 'ds_read_u16_d16_hi': 'ds_load_u16_d16_hi',
'ds_read2_b32': 'ds_load_2addr_b32', 'ds_read2_b64': 'ds_load_2addr_b64',
'ds_read2st64_b32': 'ds_load_2addr_stride64_b32', 'ds_read2st64_b64': 'ds_load_2addr_stride64_b64',
'ds_read_addtid_b32': 'ds_load_addtid_b32', 'ds_write_addtid_b32': 'ds_store_addtid_b32',
'ds_write_b32': 'ds_store_b32', 'ds_write_b64': 'ds_store_b64', 'ds_write_b96': 'ds_store_b96', 'ds_write_b128': 'ds_store_b128',
'ds_write_b8': 'ds_store_b8', 'ds_write_b16': 'ds_store_b16',
'ds_write_b8_d16_hi': 'ds_store_b8_d16_hi', 'ds_write_b16_d16_hi': 'ds_store_b16_d16_hi',
'ds_write2_b32': 'ds_store_2addr_b32', 'ds_write2_b64': 'ds_store_2addr_b64',
'ds_write2st64_b32': 'ds_store_2addr_stride64_b32', 'ds_write2st64_b64': 'ds_store_2addr_stride64_b64',
'ds_wrxchg_rtn_b32': 'ds_storexchg_rtn_b32', 'ds_wrxchg_rtn_b64': 'ds_storexchg_rtn_b64',
'ds_wrxchg2_rtn_b32': 'ds_storexchg_2addr_rtn_b32', 'ds_wrxchg2_rtn_b64': 'ds_storexchg_2addr_rtn_b64',
'ds_wrxchg2st64_rtn_b32': 'ds_storexchg_2addr_stride64_rtn_b32', 'ds_wrxchg2st64_rtn_b64': 'ds_storexchg_2addr_stride64_rtn_b64',
}
def _apply_alias(text: str) -> str:
mn = text.split()[0].lower() if ' ' in text else text.lower().rstrip('_')
for m in (mn, mn.removesuffix('_e32'), mn.removesuffix('_e64')):
if m in _ALIASES: return _ALIASES[m] + text[len(m):]
return text
def _has(op: str, *subs) -> bool: return any(s in op for s in subs)
def get_dsl(text: str, arch: str = "rdna3") -> str:
text, kw = _apply_alias(text.strip()), []
# Extract modifiers
for pat, val in [(r'\s+mul:2(?:\s|$)', 1), (r'\s+mul:4(?:\s|$)', 2), (r'\s+div:2(?:\s|$)', 3)]:
if (m := _extract(text, pat))[0]: kw.append(f'omod={val}'); text = m[1]; break
clamp_found = False
if (m := _extract(text, r'\s+clamp(?:\s|$)'))[0]: clamp_found = True; text = m[1]
opsel, m, text = None, *_extract(text, r'\s+op_sel:\[([^\]]+)\]')
if m:
bits, mn = [int(x.strip()) for x in m.group(1).split(',')], text.split()[0].lower()
is3p = mn.startswith(('v_pk_', 'v_wmma_', 'v_dot', 'v_fma_mix'))
opsel = (bits[0] | (bits[1] << 1) | (bits[2] << 2)) if len(bits) == 3 and is3p else \
(bits[0] | (bits[1] << 1) | (bits[2] << 3)) if len(bits) == 3 else sum(b << i for i, b in enumerate(bits))
opsel_hi_val, m, text = None, *_extract(text, r'\s+op_sel_hi:\[([^\]]+)\]')
if m: opsel_hi_val = [int(x.strip()) for x in m.group(1).split(',')]
m, text = _extract(text, r'\s+wait_exp:(\d+)'); waitexp = m.group(1) if m else None
m, text = _extract(text, r'\s+offset:(0x[0-9a-fA-F]+|-?\d+)'); off_val = m.group(1) if m else None
m, text = _extract(text, r'\s+dlc(?:\s|$)'); dlc = 1 if m else None
m, text = _extract(text, r'\s+glc(?:\s|$)'); glc = 1 if m else None
m, text = _extract(text, r'\s+slc(?:\s|$)'); slc = 1 if m else None
m, text = _extract(text, r'\s+tfe(?:\s|$)'); tfe = 1 if m else None
m, text = _extract(text, r'\s+offen(?:\s|$)'); offen = 1 if m else None
m, text = _extract(text, r'\s+idxen(?:\s|$)'); idxen = 1 if m else None
m, text = _extract(text, r'\s+format:\[([^\]]+)\]'); fmt_val = m.group(1) if m else None
m, text = _extract(text, r'\s+format:(\d+)'); fmt_val = m.group(1) if m and not fmt_val else fmt_val
m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
m, text = _extract(text, r'\s+byte_sel:(\d+)'); byte_sel = int(m.group(1)) if m else None
m, text = _extract(text, r'\s+offset0:(\d+)'); ds_off0 = int(m.group(1)) if m else None
m, text = _extract(text, r'\s+offset1:(\d+)'); ds_off1 = int(m.group(1)) if m else None
m, text = _extract(text, r'\s+index_key:(\d+)'); index_key = int(m.group(1)) if m else None
if waitexp: kw.append(f'waitexp={waitexp}')
if byte_sel is not None:
if opsel is None: opsel = 0
opsel |= (byte_sel << 2)
if ds_off0 is not None: kw.append(f'offset0={ds_off0}')
if ds_off1 is not None: kw.append(f'offset1={ds_off1}')
if index_key is not None: kw.append(f'opsel={index_key}')
parts = text.replace(',', ' ').split()
if not parts: raise ValueError("empty instruction")
mn, op_str = parts[0].lower(), text[len(parts[0]):].strip()
ops, args = _parse_ops(op_str), [_op2dsl(o) for o in _parse_ops(op_str)]
# s_waitcnt
if mn == 's_waitcnt':
vm, exp, lgkm = 0x3f, 0x7, 0x3f
for p in op_str.replace(',', ' ').split():
if m := re.match(r'vmcnt\((\d+)\)', p): vm = int(m.group(1))
elif m := re.match(r'expcnt\((\d+)\)', p): exp = int(m.group(1))
elif m := re.match(r'lgkmcnt\((\d+)\)', p): lgkm = int(m.group(1))
elif re.match(r'^0x[0-9a-f]+$|^\d+$', p): return f"s_waitcnt(simm16={int(p, 0)})"
return f"s_waitcnt(simm16={waitcnt(vm, exp, lgkm)})"
# VOPD
if '::' in text:
xp, yp = text.split('::')
xps, yps = xp.strip().replace(',', ' ').split(), yp.strip().replace(',', ' ').split()
xo, yo = [_op2dsl(p) for p in xps[1:]], [_op2dsl(p) for p in yps[1:]]
vdx, sx0, vsx1 = xo[0], xo[1] if len(xo) > 1 else '0', xo[2] if len(xo) > 2 else 'v[0]'
vdy, sy0, vsy1 = yo[0], yo[1] if len(yo) > 1 else '0', yo[2] if len(yo) > 2 else 'v[0]'
lit = xo[3] if 'fmaak' in xps[0].lower() and len(xo) > 3 else yo[3] if 'fmaak' in yps[0].lower() and len(yo) > 3 else None
if 'fmamk' in xps[0].lower() and len(xo) > 3: lit, vsx1 = xo[2], xo[3]
elif 'fmamk' in yps[0].lower() and len(yo) > 3: lit, vsy1 = yo[2], yo[3]
return f"VOPD(VOPDOp.{xps[0].upper()}, VOPDOp.{yps[0].upper()}, vdstx={vdx}, vdsty={vdy}, srcx0={sx0}, vsrcx1={vsx1}, srcy0={sy0}, vsrcy1={vsy1}{f', literal={lit}' if lit else ''})"
# Special instructions
if mn == 's_setreg_imm32_b32': raise ValueError(f"unsupported: {mn}")
sop1_no_dest = ('s_alloc_vgpr', 's_barrier_init', 's_barrier_join', 's_barrier_signal', 's_barrier_signal_isfirst', 's_sleep_var')
if mn in sop1_no_dest:
return f"{mn}(sdst=RawImm(128), ssrc0={args[0]})"
if mn in ('s_setpc_b64', 's_rfe_b64'): return f"{mn}(ssrc0={args[0]})"
if mn in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'): return f"{mn}(sdst={args[0]}, ssrc0=RawImm({args[1].strip()}))"
if mn == 's_version': return f"{mn}(simm16={args[0]})"
if mn == 's_setreg_b32': return f"{mn}(simm16={args[0]}, sdst={args[1]})"
# Export instructions (RDNA4 VEXPORT)
if mn == 'export':
target_map = {**{f'mrt{i}': i for i in range(8)}, 'mrtz': 8, **{f'pos{i}': 12+i for i in range(4)}}
m, exp_str = _extract(op_str, r'\s+done(?:\s|$)')
done_val = 1 if m else 0
exp_parts = exp_str.replace(',', ' ').split()
target_name = exp_parts[0].lower().strip()
target = target_map.get(target_name, 0)
vsrcs, en = [], 0
for i, o in enumerate(exp_parts[1:5]):
o = o.strip().lower()
if o == 'off': vsrcs.append('v[0]')
else: vsrcs.append(_op2dsl(o)); en |= (1 << i)
return f"VEXPORT(target={target}, en={en}, vsrc0={vsrcs[0]}, vsrc1={vsrcs[1]}, vsrc2={vsrcs[2]}, vsrc3={vsrcs[3]}, done={done_val})"
# SMEM
if mn in SMEM_OPS:
gs, ds = ", glc=1" if glc else "", ", dlc=1" if dlc else ""
off_field = "ioffset" if arch == "rdna4" else "offset"
th_s, scope_s, smem_str = "", "", op_str
if arch == "rdna4":
m, smem_str = _extract(op_str, r'\s+th:TH_(\w+)')
th_val = {'LOAD_RT': 0, 'LOAD_NT': 1, 'LOAD_HT': 2, 'LOAD_LU': 3, 'STORE_RT': 0, 'STORE_NT': 1, 'STORE_HT': 2, 'STORE_LU': 3}.get(m.group(1), 0) if m else None
m, smem_str = _extract(smem_str, r'\s+scope:SCOPE_(\w+)')
scope_val = {'CU': 0, 'SE': 1, 'DEV': 2, 'SYS': 3}.get(m.group(1), 0) if m else None
if scope_val is None:
m, smem_str = _extract(smem_str, r'\s+scope:(0?x?[0-9a-fA-F]+)')
scope_val = int(m.group(1), 0) if m else None
th_s = f", th={th_val}" if th_val else ""
scope_s = f", scope={scope_val}" if scope_val else ""
smem_ops = _parse_ops(smem_str)
smem_args = [_op2dsl(o) for o in smem_ops]
if len(smem_ops) >= 3 and re.match(r'^-?[0-9]|^-?0x', smem_ops[2].strip().lower()):
return f"{mn}(sdata={smem_args[0]}, sbase={smem_args[1]}, {off_field}={smem_ops[2].strip()}, soffset=RawImm(124){gs}{ds}{th_s}{scope_s})"
if off_val and len(smem_ops) >= 3: return f"{mn}(sdata={smem_args[0]}, sbase={smem_args[1]}, {off_field}={off_val}, soffset={smem_args[2]}{gs}{ds}{th_s}{scope_s})"
if len(smem_ops) >= 3: return f"{mn}(sdata={smem_args[0]}, sbase={smem_args[1]}, soffset={smem_args[2]}{gs}{ds}{th_s}{scope_s})"
# Buffer (MUBUF/MTBUF/VBUFFER) instructions
if mn.startswith(('buffer_', 'tbuffer_')):
is_tbuf = mn.startswith('tbuffer_')
fmt_num = None
if fmt_val is not None:
if fmt_val.isdigit(): fmt_num = int(fmt_val)
else: fmt_num = BUF_FMT.get(fmt_val.replace(' ', '')) or _parse_buf_fmt_combo(fmt_val)
if mn in ('buffer_gl0_inv', 'buffer_gl1_inv', 'buffer_wbl2', 'buffer_inv'): return f"{mn}()"
if arch == "rdna4":
m, buf_text = _extract(op_str, r'\s+th:TH_(\w+)')
th_val = {'LOAD_RT': 0, 'LOAD_NT': 1, 'LOAD_HT': 2, 'LOAD_BYPASS': 3, 'LOAD_LU': 4, 'LOAD_RT_NT': 5, 'LOAD_NT_HT': 6, 'LOAD_RT_WB': 7,
'STORE_RT': 0, 'STORE_NT': 1, 'STORE_HT': 2, 'STORE_BYPASS': 3, 'STORE_LU': 4, 'STORE_RT_NT': 5, 'STORE_NT_HT': 6,
'ATOMIC_RT': 0, 'ATOMIC_NT': 1, 'ATOMIC_RETURN': 1, 'ATOMIC_RT_RETURN': 1, 'ATOMIC_NT_RETURN': 3, 'ATOMIC_CASCADE_RT': 6, 'ATOMIC_CASCADE_NT': 6}.get(m.group(1), 0) if m else 0
m, buf_text = _extract(buf_text, r'\s+scope:SCOPE_(\w+)')
scope_val = {'CU': 0, 'SE': 1, 'DEV': 2, 'SYS': 3}.get(m.group(1), 0) if m else 0
buf_ops = _parse_ops(buf_text)
buf_args = [_op2dsl(o) for o in buf_ops]
vbuf_mods = "".join([f", ioffset={off_val}" if off_val else "", ", offen=1" if offen else "", ", idxen=1" if idxen else "",
f", th={th_val}" if th_val else "", f", scope={scope_val}" if scope_val else "",
", tfe=1" if tfe else ""])
if is_tbuf and fmt_num is not None: vbuf_mods = f", format={fmt_num}" + vbuf_mods
elif is_tbuf: vbuf_mods = ", format=1" + vbuf_mods
else: vbuf_mods = ", format=1" + vbuf_mods
vaddr_idx = 1
if len(buf_ops) > vaddr_idx and buf_ops[vaddr_idx].strip().lower() == 'off': vaddr_val = "v[0]"
else: vaddr_val = buf_args[vaddr_idx] if len(buf_args) > vaddr_idx else "v[0]"
rsrc_idx, soff_idx = (2, 3) if len(buf_ops) > 1 else (1, 2)
rsrc_raw = buf_ops[rsrc_idx].strip() if len(buf_ops) > rsrc_idx else "s[0:3]"
if m := re.match(r's\[(\d+):\d+\]', rsrc_raw.lower()): rsrc_val = m.group(1)
elif m := re.match(r's(\d+)', rsrc_raw.lower()): rsrc_val = m.group(1)
elif m := re.match(r'ttmp\[(\d+):\d+\]', rsrc_raw.lower()): rsrc_val = str(108 + int(m.group(1)))
elif m := re.match(r'ttmp(\d+)', rsrc_raw.lower()): rsrc_val = str(108 + int(m.group(1)))
else: rsrc_val = "0"
soff_raw = buf_ops[soff_idx].strip() if len(buf_ops) > soff_idx else "0"
soff_lower = soff_raw.lower()
if soff_lower == 'm0': soff_val = "RawImm(125)"
elif soff_lower in ('null', 'off'): soff_val = "RawImm(124)"
elif m := re.match(r's(\d+)', soff_lower): soff_val = f"RawImm({m.group(1)})"
else: soff_val = f"RawImm({soff_raw})"
return f"{mn}(vdata={buf_args[0]}, vaddr={vaddr_val}, rsrc={rsrc_val}, soffset={soff_val}{vbuf_mods})"
buf_mods = "".join([f", offset={off_val}" if off_val else "", ", glc=1" if glc else "", ", dlc=1" if dlc else "",
", slc=1" if slc else "", ", tfe=1" if tfe else "", ", offen=1" if offen else "", ", idxen=1" if idxen else ""])
if is_tbuf and fmt_num is not None: buf_mods = f", format={fmt_num}" + buf_mods
vaddr_idx = 1
if len(ops) > vaddr_idx and ops[vaddr_idx].strip().lower() == 'off': vaddr_val = "v[0]"
else: vaddr_val = args[vaddr_idx] if len(args) > vaddr_idx else "v[0]"
srsrc_idx, soff_idx = (2, 3) if len(ops) > 1 else (1, 2)
srsrc_val = args[srsrc_idx] if len(args) > srsrc_idx else "s[0:3]"
soff_val = args[soff_idx] if len(args) > soff_idx else "0"
return f"{mn}(vdata={args[0]}, vaddr={vaddr_val}, srsrc={srsrc_val}, soffset={soff_val}{buf_mods})"
# FLAT/GLOBAL/SCRATCH load/store/atomic
def _saddr(a): return 'RawImm(124)' if a in ('OFF', 'NULL') else a
flat_mods = f"{f', offset={off_val}' if off_val else ''}{', glc=1' if glc else ''}{', slc=1' if slc else ''}{', dlc=1' if dlc else ''}"
for pre, flds in [('flat_load','vdst,addr,saddr'), ('global_load','vdst,addr,saddr'), ('scratch_load','vdst,addr,saddr'),
('flat_store','addr,data,saddr'), ('global_store','addr,data,saddr'), ('scratch_store','addr,data,saddr')]:
if mn.startswith(pre) and len(args) >= 2:
f0, f1, f2 = flds.split(',')
return f"{mn}({f0}={args[0]}, {f1}={args[1]}{f', {f2}={_saddr(args[2])}' if len(args) >= 3 else ', saddr=RawImm(124)'}{flat_mods})"
for pre in ('flat_atomic', 'global_atomic', 'scratch_atomic'):
if mn.startswith(pre):
if glc and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3])}' if len(args) >= 4 else ', saddr=RawImm(124)'}{flat_mods})"
if len(args) >= 2: return f"{mn}(addr={args[0]}, data={args[1]}{f', saddr={_saddr(args[2])}' if len(args) >= 3 else ', saddr=RawImm(124)'}{flat_mods})"
# DS instructions
if mn.startswith('ds_'):
if ds_off0 is not None or ds_off1 is not None:
off0, off1 = str(ds_off0 or 0), str(ds_off1 or 0)
elif off_val:
off0, off1 = str(int(off_val, 0) & 0xff), str((int(off_val, 0) >> 8) & 0xff)
else:
off0, off1 = "0", "0"
gds_s = ", gds=1" if 'gds' in text.lower().split()[-1:] else ""
off_kw = f", offset0={off0}, offset1={off1}{gds_s}"
if mn == 'ds_nop' or mn in ('ds_gws_sema_v', 'ds_gws_sema_p', 'ds_gws_sema_release_all'): return f"{mn}({off_kw.lstrip(', ')})"
if 'gws_' in mn: return f"{mn}(addr={args[0]}{off_kw})"
if 'consume' in mn or 'append' in mn: return f"{mn}(vdst={args[0]}{off_kw})"
if 'gs_reg' in mn: return f"{mn}(vdst={args[0]}, data0={args[1]}{off_kw})"
if '2addr' in mn:
if 'load' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
if 'store' in mn and 'xchg' not in mn: return f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})"
if 'load' in mn: return f"{mn}(vdst={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
if 'store' in mn and not _has(mn, 'cmp', 'xchg'):
return f"{mn}(data0={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})"
if 'swizzle' in mn or 'ordered_count' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
if 'permute' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})"
if 'bvh' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})"
if 'condxchg' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})"
if _has(mn, 'cmpstore', 'mskor', 'wrap'):
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})"
# v_fmaak/v_fmamk literal extraction
lit_s = ""
mn_base = mn.replace('_e32', '').replace('_e64', '')
if mn_base in ('v_fmaak_f32', 'v_fmaak_f16') and len(args) == 4: lit_s, args = f", literal={args[3].strip()}", args[:3]
elif mn_base in ('v_fmamk_f32', 'v_fmamk_f16') and len(args) == 4: lit_s, args = f", literal={args[2].strip()}", [args[0], args[1], args[3]]
elif mn_base in ('s_fmaak_f32',) and len(args) == 4: lit_s, args = f", literal={args[3].strip()}", args[:3]
elif mn_base in ('s_fmamk_f32',) and len(args) == 4: lit_s, args = f", literal={args[2].strip()}", [args[0], args[1], args[3]]
elif mn in ('v_cndmask_b32', 'v_cndmask_b32_e32') and len(args) == 4 and ops[3].strip().lower() in ('vcc_lo', 'vcc'):
mn, args = 'v_cndmask_b32_e32', args[:3]
_SGPR_NAMES = {'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'null': 124, 'm0': 125, 'exec_lo': 126, 'exec_hi': 127}
vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'}
if mn.replace('_e32', '') in vcc_ops and len(args) >= 5:
carry_in = ops[4].strip().lower() if len(ops) > 4 else 'vcc_lo'
carry_out = ops[1].strip().lower() if len(ops) > 1 else 'vcc_lo'
if carry_in in ('vcc_lo', 'vcc') and carry_out in ('vcc_lo', 'vcc'):
mn, args = mn.replace('_e32', '') + '_e32', [args[0], args[2], args[3]]
else:
mn_base = mn.replace('_e32', '').replace('_e64', '')
sdst = _SGPR_NAMES.get(carry_out, 124) if carry_out in _SGPR_NAMES else (int(carry_out[1:]) if carry_out.startswith('s') and carry_out[1:].isdigit() else 124)
src2 = _SGPR_NAMES.get(carry_in, 0) if carry_in in _SGPR_NAMES else (int(carry_in[1:]) if carry_in.startswith('s') and carry_in[1:].isdigit() else 0)
return f"{mn_base}(vdst={args[0]}, sdst=RawImm({sdst}), src0={args[2]}, src1={args[3]}, src2=RawImm({src2}))"
if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '')
if mn.startswith('v_cmp') and not mn.endswith('_e64') and len(args) >= 3 and ops[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): args = args[1:]
if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2: args = ['RawImm(126)'] + args
if ((mn.startswith('v_cmp') and 'cmpx' not in mn and mn.endswith('_e64')) or mn.startswith('v_s_') or mn in ('v_readlane_b32', 'v_readfirstlane_b32')) and len(args) >= 1:
dst = ops[0].strip().lower()
if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})'
elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})'
elif dst.startswith('ttmp') and dst[4:].isdigit(): args[0] = f'RawImm({108 + int(dst[4:])})'
elif dst.startswith('ttmp[') and ':' in dst: args[0] = f'RawImm({108 + int(dst[5:].split(":")[0])})'
elif dst in _SGPR_NAMES: args[0] = f'RawImm({_SGPR_NAMES[dst]})'
fn = mn.replace('.', '_')
if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args]
# v_fma_mix*: extract inline neg/abs modifiers
if 'fma_mix' in mn and neg_lo is None and neg_hi is None:
inline_neg, inline_abs, clean_args = 0, 0, [args[0]]
for i, op in enumerate(ops[1:4]):
op = op.strip()
neg = op.startswith('-') and not (op[1:2].isdigit() or (len(op) > 2 and op[1] == '0' and op[2] in 'xX'))
if neg: op = op[1:]
abs_ = op.startswith('|') and op.endswith('|')
if abs_: op = op[1:-1]
if neg: inline_neg |= (1 << i)
if abs_: inline_abs |= (1 << i)
clean_args.append(_op2dsl(op))
args = clean_args + args[4:]
if inline_neg: neg_lo = inline_neg
if inline_abs: neg_hi = inline_abs
all_kw = list(kw)
if lit_s: all_kw.append(lit_s.lstrip(', '))
if opsel is not None: all_kw.append(f'opsel={opsel}')
if neg_lo is not None: all_kw.append(f'neg={neg_lo}')
if neg_hi is not None: all_kw.append(f'neg_hi={neg_hi}')
if 'bvh' in mn and 'intersect_ray' in mn: all_kw.extend(['dmask=15', 'unrm=1', 'r128=1'])
vop3p_ops = {'v_pk_', 'v_dot2', 'v_dot4', 'v_dot8', 'v_wmma', 'v_swmmac'}
is_vop3p = any(mn.startswith(p) for p in vop3p_ops)
is_fma_mix = 'fma_mix' in mn
if opsel_hi_val is not None:
opsel_hi_enc = opsel_hi_val[0] | (opsel_hi_val[1] << 1) if len(opsel_hi_val) >= 2 else opsel_hi_val[0]
opsel_hi2_enc = opsel_hi_val[2] if len(opsel_hi_val) >= 3 else (0 if is_fma_mix else 1)
all_kw.extend([f'opsel_hi={opsel_hi_enc}', f'opsel_hi2={opsel_hi2_enc}'])
elif is_vop3p and not is_fma_mix:
all_kw.extend(['opsel_hi=3', 'opsel_hi2=1'])
if clamp_found:
if arch == 'rdna4': all_kw.append('cm=1')
else: all_kw.append('clmp=1')
a_str, kw_str = ', '.join(args), ', '.join(all_kw)
return f"{fn}({a_str}, {kw_str})" if kw_str and a_str else f"{fn}({kw_str})" if kw_str else f"{fn}({a_str})"
def _hwreg(id_, offset=0, size=32): return id_ | (offset << 6) | ((size - 1) << 11)
def _sendmsg(id_, op=0, stream=0): return id_ | (op << 4) | (stream << 8)
_HWREG_NAMES = {'HW_REG_MODE': 1, 'HW_REG_STATUS': 2, 'HW_REG_TRAPSTS': 3, 'HW_REG_HW_ID': 4, 'HW_REG_GPR_ALLOC': 5,
'HW_REG_LDS_ALLOC': 6, 'HW_REG_IB_STS': 7, 'HW_REG_PC_LO': 8, 'HW_REG_PC_HI': 9, 'HW_REG_INST_DW0': 10, 'HW_REG_INST_DW1': 11,
'HW_REG_IB_DBG0': 12, 'HW_REG_IB_DBG1': 13, 'HW_REG_FLUSH_IB': 14, 'HW_REG_SH_MEM_BASES': 15, 'HW_REG_SQ_SHADER_TBA_LO': 16,
'HW_REG_SQ_SHADER_TBA_HI': 17, 'HW_REG_SQ_SHADER_TMA_LO': 18, 'HW_REG_SQ_SHADER_TMA_HI': 19, 'HW_REG_FLAT_SCR_LO': 20,
'HW_REG_FLAT_SCR_HI': 21, 'HW_REG_XNACK_MASK': 22, 'HW_REG_HW_ID1': 23, 'HW_REG_HW_ID2': 24, 'HW_REG_POPS_PACKER': 25,
'HW_REG_PERF_SNAPSHOT_DATA': 26, 'HW_REG_PERF_SNAPSHOT_PC_LO': 27, 'HW_REG_PERF_SNAPSHOT_PC_HI': 28, 'HW_REG_SHADER_CYCLES': 29,
'HW_REG_SHADER_CYCLES_HI': 30, 'HW_REG_WAVE_MODE': 31, 'HW_REG_WAVE_SCRATCH_BASE': 32}
_HWREG_NAMES_RDNA4 = {v: k for k, v in HWREG_RDNA4.items()}
_SENDMSG_NAMES = {'MSG_INTERRUPT': 1, 'MSG_GS': 2, 'MSG_GS_DONE': 3, 'MSG_SAVEWAVE': 4, 'MSG_STALL_WAVE_GEN': 5,
'MSG_HALT_WAVES': 6, 'MSG_ORDERED_PS_DONE': 7, 'MSG_EARLY_PRIM_DEALLOC': 8, 'MSG_GS_ALLOC_REQ': 9, 'MSG_GET_DOORBELL': 10,
'MSG_GET_DDID': 11, 'MSG_HS_TESSFACTOR': 2, 'MSG_DEALLOC_VGPRS': 10, 'MSG_RTN_GET_DOORBELL': 128, 'MSG_RTN_GET_DDID': 129,
'MSG_RTN_GET_TMA': 130, 'MSG_RTN_GET_REALTIME': 131, 'MSG_RTN_SAVE_WAVE': 132, 'MSG_RTN_GET_TBA': 133,
'MSG_RTN_GET_TBA_TO_PC': 134, 'MSG_RTN_GET_SE_AID_ID': 135}
def asm(text: str, arch: str = "rdna3") -> Inst:
dsl = get_dsl(text, arch)
if arch == "rdna4":
ns = {n: getattr(rdna4_ins, n) for n in dir(rdna4_ins) if not n.startswith('_')}
hwreg_names = _HWREG_NAMES_RDNA4
else:
ns = {n: getattr(ins, n) for n in dir(ins) if not n.startswith('_')}
hwreg_names = _HWREG_NAMES
def hwreg(id_, offset=0, size=32): return _hwreg(hwreg_names.get(id_, id_) if isinstance(id_, str) else id_, offset, size)
def sendmsg(id_, op=0, stream=0): return _sendmsg(_SENDMSG_NAMES.get(id_, id_) if isinstance(id_, str) else id_, op, stream)
ns.update({'s': s, 'v': v, 'ttmp': ttmp, 'abs': abs, 'RawImm': RawImm, 'SrcMod': SrcMod, 'VGPR': VGPR, 'SGPR': SGPR, 'TTMP': TTMP,
'VCC_LO': VCC_LO, 'VCC_HI': VCC_HI, 'VCC': VCC, 'EXEC_LO': EXEC_LO, 'EXEC_HI': EXEC_HI, 'EXEC': EXEC, 'SCC': SCC, 'M0': M0, 'NULL': NULL, 'OFF': OFF,
'hwreg': hwreg, 'sendmsg': sendmsg, **{k: k for k in hwreg_names}, **{k: k for k in _SENDMSG_NAMES}})
try: return eval(dsl, ns)
except NameError:
if m := re.match(r'^(v_\w+)(\(.*\))$', dsl): return eval(f"{m.group(1)}_e32{m.group(2)}", ns)
raise

View File

@@ -2,14 +2,13 @@
from __future__ import annotations
from extra.assembly.amd.dsl import Inst, FixedBitField
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP1_SDST, VOP2, VOP3, VOP3_SDST, VOP3SD, VOP3P, VOPC, VOPD, VINTERP,
SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP)
SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT)
from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as R4_VOP1_SDST, VOP2 as R4_VOP2,
VOP3 as R4_VOP3, VOP3_SDST as R4_VOP3_SDST, VOP3SD as R4_VOP3SD, VOP3P as R4_VOP3P,
VOPC as R4_VOPC, VOPD as R4_VOPD, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP2 as R4_SOP2, SOPC as R4_SOPC, SOPK as R4_SOPK, SOPP as R4_SOPP,
SMEM as R4_SMEM, DS as R4_DS, VBUFFER as R4_VBUFFER, VEXPORT as R4_VEXPORT)
SMEM as R4_SMEM, DS as R4_DS)
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as C_VOP1, VOP2 as C_VOP2, VOPC as C_VOPC, VOP3A, VOP3B, VOP3P as C_VOP3P,
SOP1 as C_SOP1, SOP2 as C_SOP2, SOPC as C_SOPC, SOPK as C_SOPK, SOPP as C_SOPP, SMEM as C_SMEM, DS as C_DS,
FLAT as C_FLAT, MUBUF as C_MUBUF, MTBUF as C_MTBUF)
SOP1 as C_SOP1, SOP2 as C_SOP2, SOPC as C_SOPC, SOPK as C_SOPK, SOPP as C_SOPP, SMEM as C_SMEM, DS as C_DS, FLAT as C_FLAT)
def _matches_encoding(word: int, cls: type[Inst]) -> bool:
"""Check if word matches the encoding pattern of an instruction class."""
@@ -19,12 +18,12 @@ def _matches_encoding(word: int, cls: type[Inst]) -> bool:
return ((word >> bf.lo) & bf.mask) == bf.default
# Order matters: more specific encodings first, VOP2 last (it's a catch-all for bit31=0)
_RDNA_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP]
_RDNA_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, SMEM]
_RDNA_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls
_CDNA_FORMATS_64 = [C_VOP3P, VOP3A, C_DS, C_FLAT, C_MUBUF, C_MTBUF, C_SMEM]
_CDNA_FORMATS_64 = [C_VOP3P, VOP3A, C_DS, C_FLAT, C_SMEM]
_CDNA_FORMATS_32 = [C_SOP1, C_SOPC, C_SOPP, C_SOPK, C_VOPC, C_VOP1, C_SOP2, C_VOP2]
_CDNA_VOP3B_OPS = {281, 282, 283, 284, 285, 286, 480, 481, 488, 489} # VOP3B opcodes
_RDNA4_FORMATS_64 = [R4_VOPD, R4_VOP3P, R4_VINTERP, R4_VOP3, R4_DS, R4_VBUFFER, R4_SMEM, R4_VEXPORT]
_RDNA4_FORMATS_64 = [R4_VOPD, R4_VOP3P, R4_VINTERP, R4_VOP3, R4_DS, R4_SMEM]
_RDNA4_FORMATS_32 = [R4_SOP1, R4_SOPC, R4_SOPP, R4_SOPK, R4_VOPC, R4_VOP1, R4_SOP2, R4_VOP2]
_RDNA4_VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
_RDNA3_VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}

View File

@@ -84,12 +84,12 @@ def _swmmac_regs(name: str) -> tuple[int, int, int, int]:
# IMPORTS
# ═══════════════════════════════════════════════════════════════════════════════
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP1_SDST, VOP2, VOP3, VOP3_SDST, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, GLOBAL, SCRATCH, MUBUF, MTBUF, MIMG, EXP,
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp)
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP1_SDST, VOP2, VOP3, VOP3_SDST, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, GLOBAL, SCRATCH,
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp)
from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as R4_VOP1_SDST, VOP2 as R4_VOP2, VOP3 as R4_VOP3, VOP3_SDST as R4_VOP3_SDST, VOP3SD as R4_VOP3SD, VOP3P as R4_VOP3P,
VOPC as R4_VOPC, VOPD as R4_VOPD, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP2 as R4_SOP2, SOPC as R4_SOPC, SOPK as R4_SOPK, SOPP as R4_SOPP,
SMEM as R4_SMEM, DS as R4_DS, VBUFFER as R4_VBUFFER, VEXPORT as R4_VEXPORT, VOPDOp as R4_VOPDOp)
from extra.assembly.amd.autogen.cdna.ins import FLAT as C_FLAT, MUBUF as C_MUBUF, MTBUF as C_MTBUF
SMEM as R4_SMEM, DS as R4_DS, VOPDOp as R4_VOPDOp)
from extra.assembly.amd.autogen.cdna.ins import FLAT as C_FLAT
def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__
@@ -491,72 +491,6 @@ def _disasm_vop3p(inst: VOP3P) -> str:
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if clamp else [])
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
def _disasm_buf(inst: MUBUF | MTBUF) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
acc = getattr(inst, 'acc', 0)
reg_fn = _areg if acc else _vreg
if cdna and name in ('buffer_wbl2', 'buffer_inv'): return name
if not cdna and inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \
{'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1)
if hasattr(inst, 'tfe') and inst.tfe: w += 1
vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else _vreg(inst.vaddr) if inst.offen or inst.idxen else "off"
srsrc = _sreg_or_ttmp(_unwrap(inst.srsrc), 4)
is_mtbuf = isinstance(inst, MTBUF) or isinstance(inst, C_MTBUF)
if is_mtbuf:
dfmt, nfmt = inst.format & 0xf, (inst.format >> 4) & 0x7
if acc: fmt_s = f" dfmt:{dfmt}, nfmt:{nfmt},"
elif not cdna: fmt_s = f" format:{inst.format}" if inst.format else ""
else:
dfmt_names = ['INVALID', '8', '16', '8_8', '32', '16_16', '10_11_11', '11_11_10', '10_10_10_2', '2_10_10_10', '8_8_8_8', '32_32', '16_16_16_16', '32_32_32', '32_32_32_32', 'RESERVED_15']
nfmt_names = ['UNORM', 'SNORM', 'USCALED', 'SSCALED', 'UINT', 'SINT', 'RESERVED_6', 'FLOAT']
if dfmt == 1 and nfmt == 0: fmt_s = ""
elif nfmt == 0: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]}]"
elif dfmt == 1: fmt_s = f" format:[BUF_NUM_FORMAT_{nfmt_names[nfmt]}]"
else: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]},BUF_NUM_FORMAT_{nfmt_names[nfmt]}]"
else: fmt_s = ""
if cdna: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"glc"),(inst.nt,"slc"),(inst.sc1,"sc1")] if c]
else: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
soffset_s = decode_src(inst.soffset, cdna)
if cdna and not acc and is_mtbuf: return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc}, {soffset_s}{fmt_s}{' ' + ' '.join(mods) if mods else ''}"
return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc},{fmt_s} {soffset_s}{' ' + ' '.join(mods) if mods else ''}"
def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
base = [1, 2, 3, 3, 2, 3, 3, 4][dim]
grad = [1, 2, 3, 2, 1, 2, 2, 2][dim]
if 'get_resinfo' in name: return 1
packed, unpacked = 0, 0
if '_mip' in name: packed += 1
elif 'sample' in name or 'gather' in name:
if '_o' in name: unpacked += 1
if re.search(r'_c(_|$)', name): unpacked += 1
if '_d' in name: unpacked += (grad + 1) & ~1 if '_g16' in name else grad*2
if '_b' in name: unpacked += 1
if '_l' in name and '_cl' not in name and '_lz' not in name: packed += 1
if '_cl' in name: packed += 1
return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked
def _disasm_mimg(inst: MIMG) -> str:
name = inst.op_name.lower()
srsrc_str = inst.srsrc.fmt()
if 'bvh' in name:
vaddr = (9 if '64' in name else 8) if inst.a16 else (12 if '64' in name else 11)
return f"{name} {inst.vdata.fmt(sz=4)}, {inst.vaddr.fmt(sz=vaddr)}, {inst.srsrc.fmt(sz=4)}{' a16' if inst.a16 else ''}"
vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1)
if inst.d16: vdata = (vdata + 1) // 2
if inst.tfe: vdata += 1
dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"dim_{inst.dim}"
vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16)
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask and (inst.dmask != 15 or 'atomic' in name) else []
mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}")
for flag, mod in [(inst.unrm,"unorm"),(inst.glc,"glc"),(inst.slc,"slc"),(inst.dlc,"dlc"),(inst.r128,"r128"),
(inst.a16,"a16"),(inst.tfe,"tfe"),(inst.lwe,"lwe"),(inst.d16,"d16")]:
if flag: mods.append(mod)
ssamp_str = f", {inst.ssamp.fmt()}" if 'sample' in name or 'gather' in name or 'get_lod' in name else ""
return f"{name} {inst.vdata.fmt(sz=vdata)}, {inst.vaddr.fmt(sz=vaddr)}, {srsrc_str}{ssamp_str} {' '.join(mods)}"
def _disasm_sop1(inst: SOP1) -> str:
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
# Use get_field_bits for register sizes
@@ -618,46 +552,14 @@ def _disasm_vinterp(inst: VINTERP) -> str:
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
return f"{inst.op_name.lower()} {inst.vdst.fmt()}, {_lit(inst, inst.src0, inst.neg & 1)}, {_lit(inst, inst.src1, inst.neg & 2)}, {_lit(inst, inst.src2, inst.neg & 4)}" + (" " + mods if mods else "")
EXP_TARGETS = {0: 'mrt0', 1: 'mrt1', 2: 'mrt2', 3: 'mrt3', 4: 'mrt4', 5: 'mrt5', 6: 'mrt6', 7: 'mrt7',
8: 'mrtz', 9: 'null', 12: 'pos0', 13: 'pos1', 14: 'pos2', 15: 'pos3', 16: 'pos4',
32: 'param0', 33: 'param1', 34: 'param2', 35: 'param3', 36: 'param4', 37: 'param5'}
def _disasm_vexport(inst) -> str:
tgt = EXP_TARGETS.get(inst.target, f'{inst.target}')
srcs = [getattr(inst, f"vsrc{i}").fmt() if inst.en & (1 << i) else 'off' for i in range(4)]
mods = _mods((inst.done, "done"), (inst.row, "row_en"))
return f"export {tgt} {', '.join(srcs)}" + (" " + mods if mods else "")
def _disasm_vbuffer(inst) -> str:
name = inst.op_name.lower().replace('buffer_', 'buffer_').replace('tbuffer_', 'tbuffer_')
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \
{'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], inst.canonical_op_regs['d'])
if getattr(inst, 'tfe', 0): w += 1
vdata = inst.vdata.fmt(sz=w)
vaddr = inst.vaddr.fmt(sz=2) if inst.offen and inst.idxen else (inst.vaddr.fmt() if inst.offen or inst.idxen else 'off')
srsrc = f'ttmp[{inst.rsrc - 108}:{inst.rsrc - 108 + 3}]' if inst.rsrc >= 108 else f's[{inst.rsrc}:{inst.rsrc + 3}]'
soff = decode_src(inst.soffset.offset) if inst.soffset.offset >= 106 else f's{inst.soffset.offset}'
fmt = getattr(inst, 'format', 0)
from extra.assembly.amd.asm import BUF_FMT
fmt_names = {v: k for k, v in BUF_FMT.items()}
fmt_s = f" format:[{fmt_names[fmt]}]" if fmt > 1 and fmt in fmt_names else (f" format:{fmt}" if fmt > 1 else "")
if 'atomic' in name: th_names = {1: 'TH_ATOMIC_RETURN', 6: 'TH_ATOMIC_CASCADE_NT'}
elif 'store' in name: th_names = {3: 'TH_STORE_BYPASS', 6: 'TH_STORE_NT_HT'}
else: th_names = {3: 'TH_LOAD_BYPASS', 6: 'TH_LOAD_NT_HT'}
scope_names = {1: 'SCOPE_SE', 2: 'SCOPE_DEV', 3: 'SCOPE_SYS'}
mods = _mods((inst.idxen, "idxen"), (inst.offen, "offen"), (inst.ioffset, f"offset:{inst.ioffset}"),
(inst.th in th_names, f"th:{th_names.get(inst.th, '')}"), (inst.scope in scope_names, f"scope:{scope_names.get(inst.scope, '')}"))
return f"{name} {vdata}, {vaddr}, {srsrc}, {soff}{fmt_s}" + (" " + mods if mods else "")
DISASM_HANDLERS: dict[type, callable] = {
VOP1: _disasm_vop1, VOP1_SDST: _disasm_vop1, VOP2: _disasm_vop2, VOPC: _disasm_vopc, VOP3: _disasm_vop3, VOP3_SDST: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOPD: _disasm_vopd, VOP3P: _disasm_vop3p,
VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, GLOBAL: _disasm_flat, SCRATCH: _disasm_flat,
MUBUF: _disasm_buf, MTBUF: _disasm_buf, MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk,
SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk,
# RDNA4
R4_VOP1: _disasm_vop1, R4_VOP1_SDST: _disasm_vop1, R4_VOP2: _disasm_vop2, R4_VOPC: _disasm_vopc, R4_VOP3: _disasm_vop3, R4_VOP3_SDST: _disasm_vop3, R4_VOP3SD: _disasm_vop3sd,
R4_VOPD: _disasm_vopd, R4_VOP3P: _disasm_vop3p, R4_VINTERP: _disasm_vinterp, R4_SOPP: _disasm_sopp, R4_SMEM: _disasm_smem,
R4_DS: _disasm_ds, R4_SOP1: _disasm_sop1, R4_SOP2: _disasm_sop2, R4_SOPC: _disasm_sopc, R4_SOPK: _disasm_sopk,
R4_VEXPORT: _disasm_vexport, R4_VBUFFER: _disasm_vbuffer}
R4_DS: _disasm_ds, R4_SOP1: _disasm_sop1, R4_SOP2: _disasm_sop2, R4_SOPC: _disasm_sopc, R4_SOPK: _disasm_sopk}
def disasm(inst: Inst) -> str: return DISASM_HANDLERS[type(inst)](inst)
@@ -668,7 +570,7 @@ def disasm(inst: Inst) -> str: return DISASM_HANDLERS[type(inst)](inst)
try:
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as CDNA_VOP1, VOP2 as CDNA_VOP2, VOPC as CDNA_VOPC, VOP3A, VOP3B, VOP3P as CDNA_VOP3P,
SOP1 as CDNA_SOP1, SOP2 as CDNA_SOP2, SOPC as CDNA_SOPC, SOPK as CDNA_SOPK, SOPP as CDNA_SOPP, SMEM as CDNA_SMEM, DS as CDNA_DS,
FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op, VOPCOp as CDNA_VOPCOp)
FLAT as CDNA_FLAT, VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op, VOPCOp as CDNA_VOPCOp)
def _cdna_src(inst, v, neg, abs_=0, n=1):
s = _lit(inst, v) if v == 255 else _fmt_src(v, n, cdna=True)
@@ -678,8 +580,7 @@ try:
_CDNA_VOP3_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32'}
def _disasm_vop3a(inst) -> str:
op_val = inst._values.get('op', 0)
if hasattr(op_val, 'value'): op_val = op_val.value
op_val = inst.op.value if hasattr(inst.op, 'value') else inst.op
name = inst.op_name.lower() or f'vop3a_op_{op_val}'
n = inst.num_srcs() or _num_srcs(inst)
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
@@ -711,8 +612,7 @@ try:
return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}"
def _disasm_vop3b(inst) -> str:
op_val = inst._values.get('op', 0)
if hasattr(op_val, 'value'): op_val = op_val.value
op_val = inst.op.value if hasattr(inst.op, 'value') else inst.op
name = inst.op_name.lower() or f'vop3b_op_{op_val}'
n = inst.num_srcs() or _num_srcs(inst)
regs = inst.canonical_op_regs
@@ -728,18 +628,44 @@ try:
return f"{name}{suf} {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {sdst}, {s0}, {s1}{cl}{om}"
def _disasm_cdna_vop3p(inst) -> str:
name, n, is_mfma = inst.op_name.lower(), inst.num_srcs() or 2, 'mfma' in inst.op_name.lower() or 'smfmac' in inst.op_name.lower()
name, n = inst.op_name.lower(), inst.num_srcs() or 2
is_mfma = 'mfma' in name or 'smfmac' in name
is_accvgpr = 'accvgpr' in name
get_src = lambda v, sc: _lit(inst, v) if v == 255 else _fmt_src(v, sc, cdna=True)
if is_mfma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4; src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16), _vreg(inst.vdst, 16)
else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), _vreg(inst.vdst)
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \
# Handle accvgpr read/write (accumulator register operations)
if is_accvgpr:
src0_off = _unwrap(inst.src0)
vdst_off = _vi(inst.vdst)
if 'read' in name:
# v_accvgpr_read_b32 vN, aM - reads from accumulator to VGPR
return f"{name}_b32 v{vdst_off}, a{src0_off - 256 if src0_off >= 256 else src0_off}"
if 'write' in name:
# v_accvgpr_write_b32 aM, src - writes to accumulator from source
src = _lit(inst, inst.src0) if src0_off == 255 else (f"v{src0_off - 256}" if src0_off >= 256 else decode_src(src0_off, cdna=True))
return f"{name}_b32 a{vdst_off}, {src}"
# Handle MFMA instructions with accumulator destinations
if is_mfma:
sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4
src0, src1, src2 = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16)
dst = _areg(inst.vdst, 16) # MFMA uses accumulator registers
opsel_hi = inst.opsel_hi
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != 3 else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}"
# Standard VOP3P instructions
src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), _vreg(inst.vdst)
opsel_hi = inst.opsel_hi # CDNA VOP3P only has 2 bits for opsel_hi (no opsel_hi2)
opsel_hi_default = 3 # CDNA default is 0b11 (2 bits), not 0b111 like RDNA
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != opsel_hi_default else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP2: _disasm_vop2, CDNA_VOPC: _disasm_vopc,
CDNA_SOP1: _disasm_sop1, CDNA_SOP2: _disasm_sop2, CDNA_SOPC: _disasm_sopc, CDNA_SOPK: _disasm_sopk, CDNA_SOPP: _disasm_sopp,
CDNA_SMEM: _disasm_smem, CDNA_DS: _disasm_ds, CDNA_FLAT: _disasm_flat, CDNA_MUBUF: _disasm_buf, CDNA_MTBUF: _disasm_buf,
CDNA_SMEM: _disasm_smem, CDNA_DS: _disasm_ds, CDNA_FLAT: _disasm_flat,
VOP3A: _disasm_vop3a, VOP3B: _disasm_vop3b, CDNA_VOP3P: _disasm_cdna_vop3p})
except ImportError:
pass

View File

@@ -335,7 +335,7 @@ class Inst:
def _size(cls) -> int: return cls._base_size
def size(self) -> int: return self._base_size + (4 if self._literal is not None else 0)
def disasm(self) -> str:
from extra.assembly.amd.asm import disasm
from extra.assembly.amd.disasm import disasm
return disasm(self)
def to_bytes(self) -> bytes:

View File

@@ -1,178 +1,14 @@
#!/usr/bin/env python3
"""Test MUBUF, MTBUF, MIMG, EXP, DS formats against LLVM."""
"""Test DS and other compute-relevant instruction formats.
Note: Graphics-only formats (EXP, MUBUF, MTBUF, MIMG) are not supported - use GLOBAL/FLAT for memory access in compute.
"""
import unittest
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.dsl import VCC_HI, EXEC_LO, NULL
OFF = NULL # OFF is alias for NULL
from extra.assembly.amd.decode import detect_format
class TestMUBUF(unittest.TestCase):
"""Test MUBUF (buffer) instructions."""
def test_buffer_load_b32_basic(self):
# buffer_load_b32 v5, off, s[8:11], s3 offset:4095
# GFX11: encoding: [0xff,0x0f,0x50,0xe0,0x00,0x05,0x02,0x03]
inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095)
self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0x50,0xe0,0x00,0x05,0x02,0x03]))
def test_buffer_load_b32_idxen(self):
# buffer_load_b32 v5, v0, s[8:11], s3 idxen offset:4095
# GFX11: encoding: [0xff,0x0f,0x50,0xe0,0x00,0x05,0x82,0x03]
inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, idxen=1)
self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0x50,0xe0,0x00,0x05,0x82,0x03]))
def test_buffer_load_b32_offen(self):
# buffer_load_b32 v5, v0, s[8:11], s3 offen offset:4095
# GFX11: encoding: [0xff,0x0f,0x50,0xe0,0x00,0x05,0x42,0x03]
inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, offen=1)
self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0x50,0xe0,0x00,0x05,0x42,0x03]))
def test_buffer_load_b32_glc(self):
# buffer_load_b32 v5, off, s[8:11], s3 offset:4095 glc
# GFX11: encoding: [0xff,0x4f,0x50,0xe0,0x00,0x05,0x02,0x03]
inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, glc=1)
self.assertEqual(inst.to_bytes(), bytes([0xff,0x4f,0x50,0xe0,0x00,0x05,0x02,0x03]))
def test_buffer_load_b32_slc(self):
# buffer_load_b32 v5, off, s[8:11], s3 offset:4095 slc
# GFX11: encoding: [0xff,0x1f,0x50,0xe0,0x00,0x05,0x02,0x03]
inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, slc=1)
self.assertEqual(inst.to_bytes(), bytes([0xff,0x1f,0x50,0xe0,0x00,0x05,0x02,0x03]))
def test_buffer_load_b32_dlc(self):
# buffer_load_b32 v5, off, s[8:11], s3 offset:4095 dlc
# GFX11: encoding: [0xff,0x2f,0x50,0xe0,0x00,0x05,0x02,0x03]
inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, dlc=1)
self.assertEqual(inst.to_bytes(), bytes([0xff,0x2f,0x50,0xe0,0x00,0x05,0x02,0x03]))
def test_buffer_load_b32_all_flags(self):
# buffer_load_b32 v5, off, s[8:11], s3 offset:4095 glc slc dlc
# GFX11: encoding: [0xff,0x7f,0x50,0xe0,0x00,0x05,0x02,0x03]
inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, glc=1, slc=1, dlc=1)
self.assertEqual(inst.to_bytes(), bytes([0xff,0x7f,0x50,0xe0,0x00,0x05,0x02,0x03]))
def test_buffer_store_b32(self):
# buffer_store_b32 v1, off, s[12:15], s4 offset:4095
# GFX11: encoding: [0xff,0x0f,0x68,0xe0,0x00,0x01,0x03,0x04]
inst = buffer_store_b32(vdata=v[1], vaddr=v[0], srsrc=s[12:15], soffset=s[4], offset=4095)
self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0x68,0xe0,0x00,0x01,0x03,0x04]))
def test_buffer_load_b64(self):
# buffer_load_b64 v[5:6], off, s[8:11], s3 offset:4095
# GFX11: encoding: [0xff,0x0f,0x54,0xe0,0x00,0x05,0x02,0x03]
inst = buffer_load_b64(vdata=v[5:6], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095)
self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0x54,0xe0,0x00,0x05,0x02,0x03]))
def test_buffer_load_soffset_m0(self):
# buffer_load_b32 v5, off, s[8:11], m0 offset:4095
# GFX11: encoding: [0xff,0x0f,0x50,0xe0,0x00,0x05,0x02,0x7d]
inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=M0, offset=4095)
self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0x50,0xe0,0x00,0x05,0x02,0x7d]))
def test_buffer_load_soffset_inline_const(self):
# buffer_load_b32 v5, off, s[8:11], 0 offset:4095
# GFX11: encoding: [0xff,0x0f,0x50,0xe0,0x00,0x05,0x02,0x80]
inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=0, offset=4095)
self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0x50,0xe0,0x00,0x05,0x02,0x80]))
def test_buffer_disasm_roundtrip(self):
inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, glc=1)
decoded = MUBUF.from_bytes(inst.to_bytes())
self.assertEqual(decoded.to_bytes(), inst.to_bytes())
class TestMTBUF(unittest.TestCase):
"""Test MTBUF (typed buffer) instructions."""
def test_tbuffer_load_format_x(self):
# tbuffer_load_format_x v5, off, s[8:11], s3 format:[BUF_FMT_32_FLOAT] offset:4095
# BUF_FMT_32_FLOAT = 22
# GFX11: encoding: [0xff,0x0f,0xb0,0xe8,0x00,0x05,0x02,0x03]
inst = tbuffer_load_format_x(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, format=22)
self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0xb0,0xe8,0x00,0x05,0x02,0x03]))
def test_tbuffer_store_format_x(self):
# tbuffer_store_format_x v5, off, s[8:11], s3 format:[BUF_FMT_32_FLOAT] offset:4095
# BUF_FMT_32_FLOAT = 22
# GFX11: encoding: [0xff,0x0f,0xb2,0xe8,0x00,0x05,0x02,0x03]
inst = tbuffer_store_format_x(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, format=22)
self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0xb2,0xe8,0x00,0x05,0x02,0x03]))
def test_tbuffer_load_format_xy(self):
# tbuffer_load_format_xy v[5:6], off, s[8:11], s3 format:[BUF_FMT_32_32_FLOAT] offset:4095
# BUF_FMT_32_32_FLOAT = 50
# GFX11: encoding: [0xff,0x8f,0x90,0xe9,0x00,0x05,0x02,0x03]
inst = tbuffer_load_format_xy(vdata=v[5:6], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, format=50)
self.assertEqual(inst.to_bytes(), bytes([0xff,0x8f,0x90,0xe9,0x00,0x05,0x02,0x03]))
class TestMIMG(unittest.TestCase):
"""Test MIMG (image) instructions."""
def test_image_load_2d(self):
# image_load v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D
# GFX11: encoding: [0x04,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00]
inst = image_load(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=1) # dim=1 is SQ_RSRC_IMG_2D
self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00]))
def test_image_store_2d(self):
# image_store v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D
# GFX11: encoding: [0x04,0x0f,0x18,0xf0,0x04,0x00,0x00,0x00]
inst = image_store(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=1)
self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x18,0xf0,0x04,0x00,0x00,0x00]))
def test_image_load_1d(self):
# image_load v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_1D
# GFX11: encoding: [0x00,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00]
inst = image_load(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=0) # dim=0 is SQ_RSRC_IMG_1D
self.assertEqual(inst.to_bytes(), bytes([0x00,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00]))
def test_image_sample(self):
# image_sample v[0:3], v[4:6], s[0:7], s[8:11] dmask:0xf dim:SQ_RSRC_IMG_2D
# GFX11: encoding: [0x04,0x0f,0x6c,0xf0,0x04,0x00,0x00,0x08]
inst = image_sample(vdata=v[0:3], vaddr=v[4:6], srsrc=s[0:7], ssamp=s[8:11], dmask=0xf, dim=1)
self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x6c,0xf0,0x04,0x00,0x00,0x08]))
def test_image_load_d16(self):
# image_load v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D d16
# GFX11: encoding: [0x04,0x0f,0x02,0xf0,0x04,0x00,0x00,0x00]
inst = image_load(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=1, d16=1)
self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x02,0xf0,0x04,0x00,0x00,0x00]))
class TestEXP(unittest.TestCase):
"""Test EXP (export) instructions."""
def test_exp_mrt0(self):
# exp mrt0 v0, v1, v2, v3
# GFX11: encoding: [0x0f,0x00,0x00,0xf8,0x00,0x01,0x02,0x03]
inst = EXP(en=0xf, target=0, vsrc0=v[0], vsrc1=v[1], vsrc2=v[2], vsrc3=v[3])
self.assertEqual(inst.to_bytes(), bytes([0x0f,0x00,0x00,0xf8,0x00,0x01,0x02,0x03]))
def test_exp_mrtz(self):
# exp mrtz v4, v3, v2, v1
# GFX11: encoding: [0x8f,0x00,0x00,0xf8,0x04,0x03,0x02,0x01]
inst = EXP(en=0xf, target=8, vsrc0=v[4], vsrc1=v[3], vsrc2=v[2], vsrc3=v[1])
self.assertEqual(inst.to_bytes(), bytes([0x8f,0x00,0x00,0xf8,0x04,0x03,0x02,0x01]))
def test_exp_mrtz_done(self):
# exp mrtz v4, v3, v2, v1 done
# GFX11: encoding: [0x8f,0x08,0x00,0xf8,0x04,0x03,0x02,0x01]
inst = EXP(en=0xf, target=8, vsrc0=v[4], vsrc1=v[3], vsrc2=v[2], vsrc3=v[3], done=1)
self.assertEqual(inst.to_bytes(), bytes([0x8f,0x08,0x00,0xf8,0x04,0x03,0x02,0x03]))
def test_exp_partial_mask(self):
# exp mrt0 v0, v1, off, off (en=0x3, only first two components)
# GFX11: encoding: [0x03,0x00,0x00,0xf8,0x00,0x01,0x00,0x00]
inst = EXP(en=0x3, target=0, vsrc0=v[0], vsrc1=v[1], vsrc2=v[0], vsrc3=v[0])
self.assertEqual(inst.to_bytes(), bytes([0x03,0x00,0x00,0xf8,0x00,0x01,0x00,0x00]))
def test_exp_row_en(self):
# exp mrtz v4, v3, v2, v1 row_en
# GFX11: encoding: [0x8f,0x20,0x00,0xf8,0x04,0x03,0x02,0x01]
inst = EXP(en=0xf, target=8, vsrc0=v[4], vsrc1=v[3], vsrc2=v[2], vsrc3=v[1], row=1)
self.assertEqual(inst.to_bytes(), bytes([0x8f,0x20,0x00,0xf8,0x04,0x03,0x02,0x01]))
class TestDS(unittest.TestCase):
"""Test DS (data share / LDS) instructions."""
@@ -380,18 +216,6 @@ class TestDetectFormat(unittest.TestCase):
self.assertEqual(detect_format(global_load_b32(vdst=v[0], addr=v[1:2], saddr=NULL).to_bytes()), FLAT)
self.assertEqual(detect_format(global_store_b32(addr=v[0:1], data=v[2], saddr=NULL).to_bytes()), FLAT)
def test_detect_mubuf(self):
self.assertEqual(detect_format(buffer_load_b32(v[0], v[1], s[0:3], s[5]).to_bytes()), MUBUF)
def test_detect_mtbuf(self):
self.assertEqual(detect_format(tbuffer_load_format_x(v[0], v[1], s[0:3], s[5], format=22).to_bytes()), MTBUF)
def test_detect_mimg(self):
self.assertEqual(detect_format(image_load(v[0:3], v[4:7], s[0:7], dmask=0xf, dim=1).to_bytes()), MIMG)
def test_detect_exp(self):
self.assertEqual(detect_format(EXP(en=0xf, target=0, vsrc0=v[0], vsrc1=v[1], vsrc2=v[2], vsrc3=v[3]).to_bytes()), EXP)
def test_detect_vopd(self):
inst = VOPD(VOPDOp.V_DUAL_MOV_B32, VOPDOp.V_DUAL_MOV_B32, vdstx=v[0], vdsty=v[1], srcx0=0, srcy0=0)
self.assertEqual(detect_format(inst.to_bytes()), VOPD)

View File

@@ -4,7 +4,6 @@
import unittest, struct
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.dsl import Inst
from extra.assembly.amd.asm import asm
from extra.assembly.amd.test.test_roundtrip import compile_asm
class TestIntegration(unittest.TestCase):
@@ -13,12 +12,9 @@ class TestIntegration(unittest.TestCase):
if not hasattr(self, 'inst'): return
b = self.inst.to_bytes()
st = self.inst.disasm()
reasm = asm(st)
desc = f"{st:25s} {self.inst} {b!r} {reasm}"
# Test that the instruction can be compiled by LLVM and produces the same bytes
desc = f"{st:25s} {self.inst} {b!r}"
self.assertEqual(b, compile_asm(st), desc)
# TODO: this compare should work for valid things
#self.assertEqual(self.inst, reasm)
self.assertEqual(repr(self.inst), repr(reasm))
print(desc)
def test_wmma(self):

View File

@@ -1,9 +1,10 @@
#!/usr/bin/env python3
"""Integration test: round-trip RDNA3 assembly through AMD toolchain."""
import unittest, re, io, sys, subprocess
import unittest, io, sys
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.asm import waitcnt, asm
from extra.assembly.amd.test.helpers import get_llvm_mc
def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int:
return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
def disassemble(lib: bytes, arch: str = "gfx1100") -> str:
"""Disassemble ELF binary using tinygrad's compiler, return raw output."""
@@ -40,7 +41,7 @@ def assemble_and_disassemble(instructions: list, arch: str = "gfx1100") -> list[
return parse_disassembly(disassemble(lib, arch))
class TestIntegration(unittest.TestCase):
"""Test our assembler output matches LLVM disassembly."""
"""Test our DSL output matches LLVM disassembly."""
def test_simple_sop1(self):
"""Test SOP1 instructions round-trip."""
@@ -148,79 +149,6 @@ class TestIntegration(unittest.TestCase):
return
self.fail("Could not find s_mov_b32 in disassembly")
class TestAsm(unittest.TestCase):
"""Test asm() string parsing."""
def test_asm_basic(self):
"""Test basic instruction parsing."""
inst = asm('s_mov_b32 s0, s1')
self.assertEqual(inst.to_bytes(), s_mov_b32(s[0], s[1]).to_bytes())
def test_asm_with_immediates(self):
"""Test parsing with immediate values."""
inst = asm('s_add_u32 s0, s1, 10')
self.assertEqual(inst.to_bytes(), s_add_u32(s[0], s[1], 10).to_bytes())
def test_asm_float_const(self):
"""Test parsing float constants."""
inst = asm('v_mul_f32_e32 v0, 1.0, v1')
self.assertEqual(inst.to_bytes(), v_mul_f32_e32(v[0], 1.0, v[1]).to_bytes())
def test_asm_hex_immediate(self):
"""Test parsing hex immediates."""
inst = asm('s_waitcnt 0xfc07')
self.assertEqual(inst.to_bytes(), s_waitcnt(simm16=0xfc07).to_bytes())
def test_asm_special_regs(self):
"""Test parsing special registers."""
inst = asm('s_mov_b32 s0, vcc_lo')
self.assertEqual(inst.to_bytes(), s_mov_b32(s[0], VCC_LO).to_bytes())
def test_asm_register_range(self):
"""Test parsing register ranges."""
inst = asm('s_load_b128 s[4:7], s[0:1], null')
self.assertEqual(inst.to_bytes(), s_load_b128(s[4:7], s[0:1], NULL).to_bytes())
def test_asm_matches_llvm(self):
"""Test asm() output matches LLVM assembler."""
from tinygrad.runtime.support.compiler_amd import HIPCompiler
compiler = HIPCompiler('gfx1100')
def get_llvm_bytes(instr: str) -> bytes:
src = f'.text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n{instr}\n'
lib = compiler.compile(src)
raw = disassemble(lib)
for line in raw.splitlines():
if instr.split()[0] in line and '//' in line:
hex_str = line.split('//')[1].strip().split(':')[1].strip()
return bytes.fromhex(hex_str)[::-1]
return b''
tests = ['s_mov_b32 s0, s1', 's_endpgm', 'v_add_f32_e32 v0, v1, v2']
for t in tests:
self.assertEqual(asm(t).to_bytes(), get_llvm_bytes(t), f"mismatch for: {t}")
def test_asm_vop3_modifiers(self):
"""Test asm() with VOP3 modifiers (neg, abs, clamp)."""
def get_llvm_encoding(instr: str) -> str:
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-show-encoding'],
input=instr, capture_output=True, text=True)
if m := re.search(r'encoding:\s*\[(.*?)\]', result.stdout):
return m.group(1).replace('0x','').replace(',','').replace(' ','')
return ''
tests = [
'v_fma_f32 v0, -v1, v2, v3', # neg on src0
'v_fma_f32 v0, v1, |v2|, v3', # abs on src1
'v_fma_f32 v0, v1, v2, v3 clamp', # clamp
'v_fma_f32 v0, -v1, |v2|, v3 clamp', # all modifiers
'v_fma_f32 v0, -|v1|, v2, v3', # neg+abs on same operand
]
for t in tests:
our_hex = asm(t).to_bytes().hex()
llvm_hex = get_llvm_encoding(t)
self.assertEqual(our_hex, llvm_hex, f"mismatch for: {t}")
class TestTinygradIntegration(unittest.TestCase):
"""Test that we can parse disassembled tinygrad kernels."""

View File

@@ -1,37 +1,45 @@
#!/usr/bin/env python3
"""Test AMD assembler/disassembler against LLVM test vectors."""
"""Test AMD assembler/disassembler against LLVM test vectors.
Only compute-relevant instruction formats are tested. Graphics-only formats not supported:
- MUBUF/MTBUF: buffer instructions with resource descriptors (use GLOBAL/FLAT instead)
- MIMG: image/texture instructions
- EXP/VEXPORT: export instructions for pixel/vertex output
- VIMAGE/VSAMPLE: image sampling instructions (RDNA4)
- VBUFFER: buffer instructions (RDNA4)
"""
import unittest, re, subprocess, functools
from tinygrad.helpers import fetch
from extra.assembly.amd.asm import asm, disasm
from extra.assembly.amd.disasm import disasm
from extra.assembly.amd.decode import decode_inst, detect_format
from extra.assembly.amd.test.helpers import get_llvm_mc
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/llvmorg-21.1.0/llvm/test/MC/AMDGPU"
# RDNA3 (gfx11) test files for compute instructions
# Excluded: gfx11_asm_mubuf.s, gfx11_asm_mtbuf.s, gfx11_asm_mimg.s, gfx11_asm_mubuf_alias.s, gfx11_asm_mtbuf_alias.s (graphics-only)
RDNA_FILES = ['gfx11_asm_sop1.s', 'gfx11_asm_sop2.s', 'gfx11_asm_sopp.s', 'gfx11_asm_sopk.s', 'gfx11_asm_sopc.s',
'gfx11_asm_vop1.s', 'gfx11_asm_vop2.s', 'gfx11_asm_vopc.s', 'gfx11_asm_vop3.s', 'gfx11_asm_vop3p.s', 'gfx11_asm_vinterp.s',
'gfx11_asm_vopd.s', 'gfx11_asm_vopcx.s', 'gfx11_asm_vop3_from_vop1.s', 'gfx11_asm_vop3_from_vop2.s', 'gfx11_asm_vop3_from_vopc.s',
'gfx11_asm_vop3_from_vopcx.s', 'gfx11_asm_ds.s', 'gfx11_asm_smem.s', 'gfx11_asm_flat.s', 'gfx11_asm_mubuf.s', 'gfx11_asm_mtbuf.s',
'gfx11_asm_mimg.s', 'gfx11_asm_wmma.s', 'gfx11_asm_vop3_features.s', 'gfx11_asm_vop3p_features.s', 'gfx11_asm_vopd_features.s',
'gfx11_asm_vop3_from_vopcx.s', 'gfx11_asm_ds.s', 'gfx11_asm_smem.s', 'gfx11_asm_flat.s',
'gfx11_asm_wmma.s', 'gfx11_asm_vop3_features.s', 'gfx11_asm_vop3p_features.s', 'gfx11_asm_vopd_features.s',
'gfx11_asm_vop3_alias.s', 'gfx11_asm_vop3p_alias.s', 'gfx11_asm_vopc_alias.s', 'gfx11_asm_vopcx_alias.s', 'gfx11_asm_vinterp_alias.s',
'gfx11_asm_smem_alias.s', 'gfx11_asm_mubuf_alias.s', 'gfx11_asm_mtbuf_alias.s']
# CDNA test files - includes gfx9 files for shared instructions, plus gfx90a/gfx942 specific files
# gfx90a_ldst_acc.s has MIMG mixed in, filtered via is_mimg check
'gfx11_asm_smem_alias.s']
# CDNA (gfx9/gfx90a/gfx942) test files for compute instructions
# Excluded: gfx9_asm_mubuf.s, gfx9_asm_mtbuf.s, gfx90a_ldst_acc.s (has MIMG mixed in)
CDNA_FILES = ['gfx9_asm_sop1.s', 'gfx9_asm_sop2.s', 'gfx9_asm_sopp.s', 'gfx9_asm_sopk.s', 'gfx9_asm_sopc.s',
'gfx9_asm_vop1.s', 'gfx9_asm_vop2.s', 'gfx9_asm_vopc.s', 'gfx9_asm_vop3.s', 'gfx9_asm_vop3p.s',
'gfx9_asm_ds.s', 'gfx9_asm_flat.s', 'gfx9_asm_smem.s', 'gfx9_asm_mubuf.s', 'gfx9_asm_mtbuf.s',
'gfx90a_ldst_acc.s', 'gfx90a_asm_features.s', 'flat-scratch-gfx942.s', 'gfx942_asm_features.s',
'gfx9_asm_ds.s', 'gfx9_asm_flat.s', 'gfx9_asm_smem.s',
'gfx90a_asm_features.s', 'flat-scratch-gfx942.s', 'gfx942_asm_features.s',
'mai-gfx90a.s', 'mai-gfx942.s']
# RDNA4 (gfx12) test files - excludes alias/err/fake16/dpp files, and vimage/vsample (not supported)
# NOTE: vflat/vdsdir excluded - not implemented; features.s has mixed formats
# RDNA4 (gfx12) test files for compute instructions
# Excluded: gfx12_asm_vbuffer_mubuf.s, gfx12_asm_vbuffer_mtbuf.s, gfx12_asm_exp.s (graphics-only)
RDNA4_FILES = ['gfx12_asm_sop1.s', 'gfx12_asm_sop2.s', 'gfx12_asm_sopp.s', 'gfx12_asm_sopk.s', 'gfx12_asm_sopc.s',
'gfx12_asm_vop1.s', 'gfx12_asm_vop2.s', 'gfx12_asm_vopc.s', 'gfx12_asm_vopcx.s', 'gfx12_asm_vop3.s', 'gfx12_asm_vop3c.s',
'gfx12_asm_vop3cx.s', 'gfx12_asm_vop3p.s', 'gfx12_asm_vop3_from_vop1.s', 'gfx12_asm_vop3_from_vop2.s',
'gfx12_asm_vop3p_features.s', 'gfx12_asm_vopd.s', 'gfx12_asm_vopd_features.s',
'gfx12_asm_ds.s', 'gfx12_asm_smem.s',
'gfx12_asm_vbuffer_mubuf.s', 'gfx12_asm_vbuffer_mtbuf.s', 'gfx12_asm_wmma_w32.s', 'gfx12_asm_exp.s']
def _is_mimg(data: bytes) -> bool: return (int.from_bytes(data[:4], 'little') >> 26) & 0x3f == 0b111100
'gfx12_asm_wmma_w32.s']
def _parse_llvm_tests(text: str, pattern: str) -> list[tuple[str, bytes]]:
tests = []
@@ -53,15 +61,16 @@ def _parse_llvm_tests(text: str, pattern: str) -> list[tuple[str, bytes]]:
def _get_tests(f: str, arch: str) -> list[tuple[str, bytes]]:
text = fetch(f"{LLVM_BASE}/{f}").read_bytes().decode('utf-8', errors='ignore')
if arch == "rdna3":
tests = _parse_llvm_tests(text, r'(?:GFX11|W32|W64)')
# Match GFX11 and W32 only (wavefront32 mode)
tests = _parse_llvm_tests(text, r'(?:GFX11|W32)')
elif arch == "rdna4":
# Match GFX12 but not GFX1250 (which has different lit64 encoding)
tests = _parse_llvm_tests(text, r'(?:GFX12(?!50)|W32|W64)')
# Match GFX12 (but not GFX1250) and W32 only (wavefront32 mode)
tests = _parse_llvm_tests(text, r'(?:GFX12(?!50)|W32)')
elif 'gfx90a' in f or 'gfx942' in f:
tests = _parse_llvm_tests(text, r'(?:GFX90A|GFX942)')
else:
tests = _parse_llvm_tests(text, r'(?:VI9|GFX9|CHECK)')
return [(a, d) for a, d in tests if not _is_mimg(d)] if arch == "cdna" else tests
return tests
def _compile_asm_batch(instrs: list[str], arch: str = "rdna3") -> list[bytes]:
if not instrs: return []
@@ -87,14 +96,6 @@ def _make_test(f: str, arch: str, test_type: str):
print(f"{name}: {passed} passed, {skipped} skipped")
if arch in ("rdna3", "rdna4"):
self.assertEqual(skipped, 0, f"{name}: {skipped} tests skipped, expected 0")
elif test_type == "asm":
passed, skipped = 0, 0
for asm_text, expected in tests:
try:
self.assertEqual(asm(asm_text, arch).to_bytes(), expected)
passed += 1
except: skipped += 1
print(f"{name}: {passed} passed, {skipped} skipped")
elif test_type == "disasm":
to_test = []
for _, data in tests:
@@ -114,14 +115,12 @@ class TestLLVM(unittest.TestCase): pass
for f in RDNA_FILES:
setattr(TestLLVM, f"test_rdna3_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "roundtrip"))
setattr(TestLLVM, f"test_rdna3_asm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "asm"))
setattr(TestLLVM, f"test_rdna3_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "disasm"))
for f in CDNA_FILES:
setattr(TestLLVM, f"test_cdna_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "cdna", "roundtrip"))
setattr(TestLLVM, f"test_cdna_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "cdna", "disasm"))
for f in RDNA4_FILES:
setattr(TestLLVM, f"test_rdna4_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna4", "roundtrip"))
setattr(TestLLVM, f"test_rdna4_asm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna4", "asm"))
setattr(TestLLVM, f"test_rdna4_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna4", "disasm"))
if __name__ == "__main__":

View File

@@ -2,7 +2,6 @@
"""Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match."""
import unittest, io, sys, re, subprocess, os
from extra.assembly.amd.dsl import Inst
from extra.assembly.amd.asm import asm
from extra.assembly.amd.decode import decode_inst, detect_format
from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump
@@ -86,8 +85,8 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
def _test_kernel_roundtrip(self, op_fn):
"""Generate kernel from op_fn, test:
1. decode -> reencode matches original bytes
2. asm(disasm()) matches LLVM output
3. our disasm() matches LLVM's disassembly string exactly
2. disasm() -> LLVM asm -> bytes matches original (validates disasm correctness)
3. our disasm() matches LLVM's disassembly string (informational)
"""
arch = self.arch
mcpu, mattr = ARCH_CONFIG[arch]
@@ -130,19 +129,19 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
offset += size
# Collect disasm strings for batched LLVM calls - skip unknown opcodes (op_X) that LLVM can't compile
asm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for asm test
asm_test_instrs: list[tuple[int, str, bytes]] = [] # (idx, our_disasm, orig_bytes) for asm test
disasm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for disasm comparison test
for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs):
if our_disasm is None: continue
# Skip unknown opcodes and malformed instructions for both tests
# Skip unknown opcodes and malformed instructions
if our_disasm.startswith('op_') or re.search(r', \d+, \d+, \d+,', our_disasm): continue
asm_test_instrs.append((idx, our_disasm))
asm_test_instrs.append((idx, our_disasm, orig_bytes))
disasm_test_instrs.append((idx, our_disasm))
# Batch compile for asm test
asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs], arch)
asm_llvm_map = {idx: result for (idx, _), result in zip(asm_test_instrs, asm_llvm_results)}
# Batch compile for asm test (our disasm -> LLVM asm -> bytes)
asm_llvm_results = compile_asm_batch([d for _, d, _ in asm_test_instrs], arch)
asm_llvm_map = {idx: (result, orig) for (idx, _, orig), result in zip(asm_test_instrs, asm_llvm_results)}
# Batch compile+disasm for disasm comparison test
disasm_llvm_results = compile_and_disasm_batch([d for _, d in disasm_test_instrs], arch)
@@ -166,20 +165,16 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
decode_failed += 1
decode_failures.append(f"K{ki}@{offset}: {our_disasm}: {decode_err}")
# Asm test
# Asm test: our disasm -> LLVM asm -> compare bytes with original
if our_disasm is None:
asm_skipped += 1
elif idx in asm_llvm_map:
llvm_bytes = asm_llvm_map[idx]
try:
our_bytes = asm(our_disasm).to_bytes()
if our_bytes[:len(llvm_bytes)] == llvm_bytes:
asm_passed += 1
else:
asm_failed += 1
asm_failures.append(f"K{ki}@{offset}: '{our_disasm}': ours={our_bytes[:len(llvm_bytes)].hex()} llvm={llvm_bytes.hex()}")
except Exception:
asm_skipped += 1
llvm_bytes, orig = asm_llvm_map[idx]
if llvm_bytes == orig[:len(llvm_bytes)]:
asm_passed += 1
else:
asm_failed += 1
asm_failures.append(f"K{ki}@{offset}: '{our_disasm}': llvm={llvm_bytes.hex()} orig={orig[:len(llvm_bytes)].hex()}")
else:
asm_skipped += 1
@@ -197,7 +192,7 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
disasm_skipped += 1
print(f"[{arch}] decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped")
print(f"[{arch}] asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
print(f"[{arch}] asm via llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
print(f"[{arch}] disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped")
self.assertEqual(decode_failed, 0, f"Decode failures:\n" + "\n".join(decode_failures[:20]))
self.assertEqual(asm_failed, 0, f"Asm failures:\n" + "\n".join(asm_failures[:20]))
@@ -248,10 +243,10 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
# Fused ops
def test_fma(self): self._test_kernel_roundtrip(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0])))
@unittest.skip("no asm support for RDNA4")
@unittest.skip("RDNA4 decode roundtrip not yet supported")
class TestTinygradKernelRoundtripRDNA4(TestTinygradKernelRoundtrip): arch = 'rdna4'
@unittest.skip("no asm support for CDNA")
@unittest.skip("CDNA decode roundtrip not yet supported")
class TestTinygradKernelRoundtripCDNA(TestTinygradKernelRoundtrip): arch = 'cdna'
if __name__ == "__main__":

View File

@@ -425,10 +425,22 @@ def parse_branch(asm:str) -> int|None:
return (x - 0x10000 if x & 0x8000 else x)*4
return None
def _op2dsl(op: str) -> str:
"""Convert LLVM asm operand (s0, s[0:1], v0) to DSL format (s[0], s[0:1], v[0])."""
import re
op = op.strip()
lo = op.lower()
SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC', 'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC',
'scc': 'SCC', 'm0': 'M0', 'null': 'NULL', 'off': 'OFF'}
if lo in SPEC_DSL: return SPEC_DSL[lo]
rp = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'}
if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return f"{rp[m.group(1)]}[{m.group(2)}:{m.group(3)}]"
if m := re.match(r'^([svt](?:tmp)?)(\d+)$', lo): return f"{rp[m.group(1)]}[{m.group(2)}]"
return op
def amdgpu_tokenize(st:str) -> list[str]:
try:
from extra.assembly.amd.dsl import s, v, Reg, VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF
from extra.assembly.amd.asm import _op2dsl
dsl = eval(_op2dsl(st), {'s':s, 'v':v, 'VCC_LO':VCC_LO, 'VCC_HI':VCC_HI, 'VCC':VCC, 'EXEC_LO':EXEC_LO, 'EXEC_HI':EXEC_HI, 'EXEC':EXEC,
'SCC':SCC, 'M0':M0, 'NULL':NULL, 'OFF':OFF})
return [f"{type(dsl).__name__[0].lower()}{dsl.offset + i}" for i in range(dsl.sz)] if isinstance(dsl, Reg) else [st]