diff --git a/extra/assembly/amd/asm.py b/extra/assembly/amd/asm.py deleted file mode 100644 index 849e844440..0000000000 --- a/extra/assembly/amd/asm.py +++ /dev/null @@ -1,478 +0,0 @@ -# RDNA3/RDNA4/CDNA assembler -from __future__ import annotations -import re -from extra.assembly.amd.dsl import Reg, s, v, ttmp -from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL - -# Assembler-specific types (not part of clean DSL) -class RawImm(Reg): - """Raw immediate value - bypasses normal encoding, used for special register encodings.""" - def __init__(self, val: int): super().__init__(val, 1) - -class SrcMod(Reg): - """Source with modifiers - wraps a value with neg/abs flags.""" - def __init__(self, val: int, neg: bool = False, abs_: bool = False): - super().__init__(255 if not (-16 <= val <= 64) else (128 + val if val >= 0 else 192 - val), 1) - self.val, self.neg, self.abs_ = val, neg, abs_ - -# Type aliases for register factories -_RegFactory = type(s) -SGPR, VGPR, TTMP = s, v, ttmp -OFF = NULL # OFF is alias for NULL (encoding 124) - -# Float encoding constants -FLOAT_ENC = {0.5: 240, -0.5: 241, 1.0: 242, -1.0: 243, 2.0: 244, -2.0: 245, 4.0: 246, -4.0: 247} -from extra.assembly.amd.autogen.rdna3 import ins -from extra.assembly.amd.autogen.rdna3.ins import VOP2Op, VOPDOp, SOPKOp -from extra.assembly.amd.autogen.rdna4 import ins as rdna4_ins - -# Re-export disasm for backwards compatibility -from extra.assembly.amd.disasm import disasm, HWREG, HWREG_RDNA4 - -# ═══════════════════════════════════════════════════════════════════════════════ -# CONSTANTS -# ═══════════════════════════════════════════════════════════════════════════════ - -# RDNA unified buffer format (not in XML, hardcoded from AMD documentation) -BUF_FMT = {'BUF_FMT_8_UNORM': 1, 'BUF_FMT_8_SNORM': 2, 'BUF_FMT_8_USCALED': 3, 'BUF_FMT_8_SSCALED': 4, - 'BUF_FMT_8_UINT': 5, 'BUF_FMT_8_SINT': 6, 'BUF_FMT_16_UNORM': 7, 'BUF_FMT_16_SNORM': 8, - 'BUF_FMT_16_USCALED': 9, 'BUF_FMT_16_SSCALED': 10, 'BUF_FMT_16_UINT': 11, 'BUF_FMT_16_SINT': 12, - 'BUF_FMT_16_FLOAT': 13, 'BUF_FMT_8_8_UNORM': 14, 'BUF_FMT_8_8_SNORM': 15, 'BUF_FMT_8_8_USCALED': 16, - 'BUF_FMT_8_8_SSCALED': 17, 'BUF_FMT_8_8_UINT': 18, 'BUF_FMT_8_8_SINT': 19, 'BUF_FMT_32_UINT': 20, - 'BUF_FMT_32_SINT': 21, 'BUF_FMT_32_FLOAT': 22, 'BUF_FMT_16_16_UNORM': 23, 'BUF_FMT_16_16_SNORM': 24, - 'BUF_FMT_16_16_USCALED': 25, 'BUF_FMT_16_16_SSCALED': 26, 'BUF_FMT_16_16_UINT': 27, 'BUF_FMT_16_16_SINT': 28, - 'BUF_FMT_16_16_FLOAT': 29, 'BUF_FMT_10_11_11_FLOAT': 30, 'BUF_FMT_11_11_10_FLOAT': 31, - 'BUF_FMT_10_10_10_2_UNORM': 32, 'BUF_FMT_10_10_10_2_SNORM': 33, 'BUF_FMT_10_10_10_2_UINT': 34, - 'BUF_FMT_10_10_10_2_SINT': 35, 'BUF_FMT_2_10_10_10_UNORM': 36, 'BUF_FMT_2_10_10_10_SNORM': 37, - 'BUF_FMT_2_10_10_10_USCALED': 38, 'BUF_FMT_2_10_10_10_SSCALED': 39, 'BUF_FMT_2_10_10_10_UINT': 40, - 'BUF_FMT_2_10_10_10_SINT': 41, 'BUF_FMT_8_8_8_8_UNORM': 42, 'BUF_FMT_8_8_8_8_SNORM': 43, - 'BUF_FMT_8_8_8_8_USCALED': 44, 'BUF_FMT_8_8_8_8_SSCALED': 45, 'BUF_FMT_8_8_8_8_UINT': 46, - 'BUF_FMT_8_8_8_8_SINT': 47, 'BUF_FMT_32_32_UINT': 48, 'BUF_FMT_32_32_SINT': 49, 'BUF_FMT_32_32_FLOAT': 50, - 'BUF_FMT_16_16_16_16_UNORM': 51, 'BUF_FMT_16_16_16_16_SNORM': 52, 'BUF_FMT_16_16_16_16_USCALED': 53, - 'BUF_FMT_16_16_16_16_SSCALED': 54, 'BUF_FMT_16_16_16_16_UINT': 55, 'BUF_FMT_16_16_16_16_SINT': 56, - 'BUF_FMT_16_16_16_16_FLOAT': 57, 'BUF_FMT_32_32_32_UINT': 58, 'BUF_FMT_32_32_32_SINT': 59, - 'BUF_FMT_32_32_32_FLOAT': 60, 'BUF_FMT_32_32_32_32_UINT': 61, 'BUF_FMT_32_32_32_32_SINT': 62, - 'BUF_FMT_32_32_32_32_FLOAT': 63, 'BUF_FMT_8_FLOAT': 108} -def _parse_buf_fmt_combo(s: str) -> int: - parts = [p.strip().replace('BUF_DATA_FORMAT_', '').replace('BUF_NUM_FORMAT_', '') for p in s.split(',')] - return BUF_FMT.get(f'BUF_FMT_{parts[0]}_{parts[1]}') if len(parts) == 2 else None - -def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int: - return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10) - -# ═══════════════════════════════════════════════════════════════════════════════ -# ASSEMBLER -# ═══════════════════════════════════════════════════════════════════════════════ - -SPEC_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), 'null': RawImm(124), 'off': RawImm(124), 'm0': RawImm(125), - 'exec_lo': RawImm(126), 'exec_hi': RawImm(127), 'exec': RawImm(126), 'scc': RawImm(253), 'src_scc': RawImm(253)} -FLOATS = {str(k): k for k in FLOAT_ENC} # Valid float literal strings: '0.5', '-0.5', '1.0', etc. -REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp} -SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b96', 's_load_b128', 's_load_b256', 's_load_b512', - 's_load_i8', 's_load_u8', 's_load_i16', 's_load_u16', - 's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b96', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512', - 's_buffer_load_i8', 's_buffer_load_u8', 's_buffer_load_i16', 's_buffer_load_u16', - 's_atc_probe', 's_atc_probe_buffer'} -SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', 'm0': 'M0', - 'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', 'src_scc': 'SCC'} - -def _op2dsl(op: str) -> str: - op = op.strip() - neg = op.startswith('-') and not (op[1:2].isdigit() or (len(op) > 2 and op[1] == '0' and op[2] in 'xX')) - if neg: op = op[1:] - if op.startswith('neg(') and op.endswith(')'): neg = True; op = op[4:-1] - abs_ = (op.startswith('|') and op.endswith('|')) or (op.startswith('abs(') and op.endswith(')')) - if abs_: op = op[1:-1] if op.startswith('|') else op[4:-1] - hi = ".h" if op.endswith('.h') else ".l" if op.endswith('.l') else "" - if hi: op = op[:-2] - lo = op.lower() - def wrap(b): return f"{'-' if neg else ''}abs({b}){hi}" if abs_ else f"-{b}{hi}" if neg else f"{b}{hi}" - if lo in SPEC_DSL: return wrap(SPEC_DSL[lo]) - if op in FLOATS: return wrap(op) - rp = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'} - if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return wrap(f"{rp[m.group(1)]}[{m.group(2)}:{m.group(3)}]") - if m := re.match(r'^([svt](?:tmp)?)(\d+)$', lo): return wrap(f"{rp[m.group(1)]}[{m.group(2)}]") - if re.match(r'^-?\d+$|^-?0x[0-9a-fA-F]+$', op): return f"SrcMod({op}, neg={neg}, abs_={abs_})" if neg or abs_ else op - return wrap(op) - -def _parse_ops(s: str) -> list[str]: - ops, cur, depth, pipe = [], "", 0, False - for c in s: - if c in '[(': depth += 1 - elif c in '])': depth -= 1 - elif c == '|': pipe = not pipe - if c == ',' and depth == 0 and not pipe: ops.append(cur.strip()); cur = "" - else: cur += c - if cur.strip(): ops.append(cur.strip()) - return ops - -def _extract(text: str, pat: str, flags=re.I): - if m := re.search(pat, text, flags): return m, text[:m.start()] + ' ' + text[m.end():] - return None, text - -# Instruction aliases: LLVM uses different names for some instructions -_ALIASES = { - 'v_cmp_tru_f16': 'v_cmp_t_f16', 'v_cmp_tru_f32': 'v_cmp_t_f32', 'v_cmp_tru_f64': 'v_cmp_t_f64', - 'v_cmpx_tru_f16': 'v_cmpx_t_f16', 'v_cmpx_tru_f32': 'v_cmpx_t_f32', 'v_cmpx_tru_f64': 'v_cmpx_t_f64', - 'v_cvt_flr_i32_f32': 'v_cvt_floor_i32_f32', 'v_cvt_rpi_i32_f32': 'v_cvt_nearest_i32_f32', - 'v_ffbh_i32': 'v_cls_i32', 'v_ffbh_u32': 'v_clz_i32_u32', 'v_ffbl_b32': 'v_ctz_i32_b32', - 'v_cvt_pkrtz_f16_f32': 'v_cvt_pk_rtz_f16_f32', 'v_fmac_legacy_f32': 'v_fmac_dx9_zero_f32', 'v_mul_legacy_f32': 'v_mul_dx9_zero_f32', - 's_load_dword': 's_load_b32', 's_load_dwordx2': 's_load_b64', 's_load_dwordx4': 's_load_b128', - 's_load_dwordx8': 's_load_b256', 's_load_dwordx16': 's_load_b512', - 's_buffer_load_dword': 's_buffer_load_b32', 's_buffer_load_dwordx2': 's_buffer_load_b64', - 's_buffer_load_dwordx4': 's_buffer_load_b128', 's_buffer_load_dwordx8': 's_buffer_load_b256', - 's_buffer_load_dwordx16': 's_buffer_load_b512', - 'v_cvt_pknorm_i16_f16': 'v_cvt_pk_norm_i16_f16', 'v_cvt_pknorm_u16_f16': 'v_cvt_pk_norm_u16_f16', - 'v_add3_nc_u32': 'v_add3_u32', 'v_xor_add_u32': 'v_xad_u32', - 'v_interp_p2_new_f32': 'v_interp_p2_f32', - 's_ff1_i32_b32': 's_ctz_i32_b32', 's_ff1_i32_b64': 's_ctz_i32_b64', - 's_flbit_i32_b32': 's_clz_i32_u32', 's_flbit_i32_b64': 's_clz_i32_u64', 's_flbit_i32': 's_cls_i32', 's_flbit_i32_i64': 's_cls_i32_i64', - 's_andn1_saveexec_b32': 's_and_not0_saveexec_b32', 's_andn1_saveexec_b64': 's_and_not0_saveexec_b64', - 's_andn1_wrexec_b32': 's_and_not0_wrexec_b32', 's_andn1_wrexec_b64': 's_and_not0_wrexec_b64', - 's_andn2_saveexec_b32': 's_and_not1_saveexec_b32', 's_andn2_saveexec_b64': 's_and_not1_saveexec_b64', - 's_andn2_wrexec_b32': 's_and_not1_wrexec_b32', 's_andn2_wrexec_b64': 's_and_not1_wrexec_b64', - 's_orn1_saveexec_b32': 's_or_not0_saveexec_b32', 's_orn1_saveexec_b64': 's_or_not0_saveexec_b64', - 's_orn2_saveexec_b32': 's_or_not1_saveexec_b32', 's_orn2_saveexec_b64': 's_or_not1_saveexec_b64', - 's_andn2_b32': 's_and_not1_b32', 's_andn2_b64': 's_and_not1_b64', - 's_orn2_b32': 's_or_not1_b32', 's_orn2_b64': 's_or_not1_b64', - 'v_dot2c_f32_f16': 'v_dot2acc_f32_f16', - 'v_fma_legacy_f32': 'v_fma_dx9_zero_f32', - 'ds_read_b32': 'ds_load_b32', 'ds_read_b64': 'ds_load_b64', 'ds_read_b96': 'ds_load_b96', 'ds_read_b128': 'ds_load_b128', - 'ds_read_i8': 'ds_load_i8', 'ds_read_u8': 'ds_load_u8', 'ds_read_i16': 'ds_load_i16', 'ds_read_u16': 'ds_load_u16', - 'ds_read_i8_d16': 'ds_load_i8_d16', 'ds_read_u8_d16': 'ds_load_u8_d16', 'ds_read_i8_d16_hi': 'ds_load_i8_d16_hi', 'ds_read_u8_d16_hi': 'ds_load_u8_d16_hi', - 'ds_read_u16_d16': 'ds_load_u16_d16', 'ds_read_u16_d16_hi': 'ds_load_u16_d16_hi', - 'ds_read2_b32': 'ds_load_2addr_b32', 'ds_read2_b64': 'ds_load_2addr_b64', - 'ds_read2st64_b32': 'ds_load_2addr_stride64_b32', 'ds_read2st64_b64': 'ds_load_2addr_stride64_b64', - 'ds_read_addtid_b32': 'ds_load_addtid_b32', 'ds_write_addtid_b32': 'ds_store_addtid_b32', - 'ds_write_b32': 'ds_store_b32', 'ds_write_b64': 'ds_store_b64', 'ds_write_b96': 'ds_store_b96', 'ds_write_b128': 'ds_store_b128', - 'ds_write_b8': 'ds_store_b8', 'ds_write_b16': 'ds_store_b16', - 'ds_write_b8_d16_hi': 'ds_store_b8_d16_hi', 'ds_write_b16_d16_hi': 'ds_store_b16_d16_hi', - 'ds_write2_b32': 'ds_store_2addr_b32', 'ds_write2_b64': 'ds_store_2addr_b64', - 'ds_write2st64_b32': 'ds_store_2addr_stride64_b32', 'ds_write2st64_b64': 'ds_store_2addr_stride64_b64', - 'ds_wrxchg_rtn_b32': 'ds_storexchg_rtn_b32', 'ds_wrxchg_rtn_b64': 'ds_storexchg_rtn_b64', - 'ds_wrxchg2_rtn_b32': 'ds_storexchg_2addr_rtn_b32', 'ds_wrxchg2_rtn_b64': 'ds_storexchg_2addr_rtn_b64', - 'ds_wrxchg2st64_rtn_b32': 'ds_storexchg_2addr_stride64_rtn_b32', 'ds_wrxchg2st64_rtn_b64': 'ds_storexchg_2addr_stride64_rtn_b64', -} - -def _apply_alias(text: str) -> str: - mn = text.split()[0].lower() if ' ' in text else text.lower().rstrip('_') - for m in (mn, mn.removesuffix('_e32'), mn.removesuffix('_e64')): - if m in _ALIASES: return _ALIASES[m] + text[len(m):] - return text - -def _has(op: str, *subs) -> bool: return any(s in op for s in subs) - -def get_dsl(text: str, arch: str = "rdna3") -> str: - text, kw = _apply_alias(text.strip()), [] - # Extract modifiers - for pat, val in [(r'\s+mul:2(?:\s|$)', 1), (r'\s+mul:4(?:\s|$)', 2), (r'\s+div:2(?:\s|$)', 3)]: - if (m := _extract(text, pat))[0]: kw.append(f'omod={val}'); text = m[1]; break - clamp_found = False - if (m := _extract(text, r'\s+clamp(?:\s|$)'))[0]: clamp_found = True; text = m[1] - opsel, m, text = None, *_extract(text, r'\s+op_sel:\[([^\]]+)\]') - if m: - bits, mn = [int(x.strip()) for x in m.group(1).split(',')], text.split()[0].lower() - is3p = mn.startswith(('v_pk_', 'v_wmma_', 'v_dot', 'v_fma_mix')) - opsel = (bits[0] | (bits[1] << 1) | (bits[2] << 2)) if len(bits) == 3 and is3p else \ - (bits[0] | (bits[1] << 1) | (bits[2] << 3)) if len(bits) == 3 else sum(b << i for i, b in enumerate(bits)) - opsel_hi_val, m, text = None, *_extract(text, r'\s+op_sel_hi:\[([^\]]+)\]') - if m: opsel_hi_val = [int(x.strip()) for x in m.group(1).split(',')] - m, text = _extract(text, r'\s+wait_exp:(\d+)'); waitexp = m.group(1) if m else None - m, text = _extract(text, r'\s+offset:(0x[0-9a-fA-F]+|-?\d+)'); off_val = m.group(1) if m else None - m, text = _extract(text, r'\s+dlc(?:\s|$)'); dlc = 1 if m else None - m, text = _extract(text, r'\s+glc(?:\s|$)'); glc = 1 if m else None - m, text = _extract(text, r'\s+slc(?:\s|$)'); slc = 1 if m else None - m, text = _extract(text, r'\s+tfe(?:\s|$)'); tfe = 1 if m else None - m, text = _extract(text, r'\s+offen(?:\s|$)'); offen = 1 if m else None - m, text = _extract(text, r'\s+idxen(?:\s|$)'); idxen = 1 if m else None - m, text = _extract(text, r'\s+format:\[([^\]]+)\]'); fmt_val = m.group(1) if m else None - m, text = _extract(text, r'\s+format:(\d+)'); fmt_val = m.group(1) if m and not fmt_val else fmt_val - m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None - m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None - m, text = _extract(text, r'\s+byte_sel:(\d+)'); byte_sel = int(m.group(1)) if m else None - m, text = _extract(text, r'\s+offset0:(\d+)'); ds_off0 = int(m.group(1)) if m else None - m, text = _extract(text, r'\s+offset1:(\d+)'); ds_off1 = int(m.group(1)) if m else None - m, text = _extract(text, r'\s+index_key:(\d+)'); index_key = int(m.group(1)) if m else None - if waitexp: kw.append(f'waitexp={waitexp}') - if byte_sel is not None: - if opsel is None: opsel = 0 - opsel |= (byte_sel << 2) - if ds_off0 is not None: kw.append(f'offset0={ds_off0}') - if ds_off1 is not None: kw.append(f'offset1={ds_off1}') - if index_key is not None: kw.append(f'opsel={index_key}') - - parts = text.replace(',', ' ').split() - if not parts: raise ValueError("empty instruction") - mn, op_str = parts[0].lower(), text[len(parts[0]):].strip() - ops, args = _parse_ops(op_str), [_op2dsl(o) for o in _parse_ops(op_str)] - - # s_waitcnt - if mn == 's_waitcnt': - vm, exp, lgkm = 0x3f, 0x7, 0x3f - for p in op_str.replace(',', ' ').split(): - if m := re.match(r'vmcnt\((\d+)\)', p): vm = int(m.group(1)) - elif m := re.match(r'expcnt\((\d+)\)', p): exp = int(m.group(1)) - elif m := re.match(r'lgkmcnt\((\d+)\)', p): lgkm = int(m.group(1)) - elif re.match(r'^0x[0-9a-f]+$|^\d+$', p): return f"s_waitcnt(simm16={int(p, 0)})" - return f"s_waitcnt(simm16={waitcnt(vm, exp, lgkm)})" - - # VOPD - if '::' in text: - xp, yp = text.split('::') - xps, yps = xp.strip().replace(',', ' ').split(), yp.strip().replace(',', ' ').split() - xo, yo = [_op2dsl(p) for p in xps[1:]], [_op2dsl(p) for p in yps[1:]] - vdx, sx0, vsx1 = xo[0], xo[1] if len(xo) > 1 else '0', xo[2] if len(xo) > 2 else 'v[0]' - vdy, sy0, vsy1 = yo[0], yo[1] if len(yo) > 1 else '0', yo[2] if len(yo) > 2 else 'v[0]' - lit = xo[3] if 'fmaak' in xps[0].lower() and len(xo) > 3 else yo[3] if 'fmaak' in yps[0].lower() and len(yo) > 3 else None - if 'fmamk' in xps[0].lower() and len(xo) > 3: lit, vsx1 = xo[2], xo[3] - elif 'fmamk' in yps[0].lower() and len(yo) > 3: lit, vsy1 = yo[2], yo[3] - return f"VOPD(VOPDOp.{xps[0].upper()}, VOPDOp.{yps[0].upper()}, vdstx={vdx}, vdsty={vdy}, srcx0={sx0}, vsrcx1={vsx1}, srcy0={sy0}, vsrcy1={vsy1}{f', literal={lit}' if lit else ''})" - - # Special instructions - if mn == 's_setreg_imm32_b32': raise ValueError(f"unsupported: {mn}") - sop1_no_dest = ('s_alloc_vgpr', 's_barrier_init', 's_barrier_join', 's_barrier_signal', 's_barrier_signal_isfirst', 's_sleep_var') - if mn in sop1_no_dest: - return f"{mn}(sdst=RawImm(128), ssrc0={args[0]})" - if mn in ('s_setpc_b64', 's_rfe_b64'): return f"{mn}(ssrc0={args[0]})" - if mn in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'): return f"{mn}(sdst={args[0]}, ssrc0=RawImm({args[1].strip()}))" - if mn == 's_version': return f"{mn}(simm16={args[0]})" - if mn == 's_setreg_b32': return f"{mn}(simm16={args[0]}, sdst={args[1]})" - - # Export instructions (RDNA4 VEXPORT) - if mn == 'export': - target_map = {**{f'mrt{i}': i for i in range(8)}, 'mrtz': 8, **{f'pos{i}': 12+i for i in range(4)}} - m, exp_str = _extract(op_str, r'\s+done(?:\s|$)') - done_val = 1 if m else 0 - exp_parts = exp_str.replace(',', ' ').split() - target_name = exp_parts[0].lower().strip() - target = target_map.get(target_name, 0) - vsrcs, en = [], 0 - for i, o in enumerate(exp_parts[1:5]): - o = o.strip().lower() - if o == 'off': vsrcs.append('v[0]') - else: vsrcs.append(_op2dsl(o)); en |= (1 << i) - return f"VEXPORT(target={target}, en={en}, vsrc0={vsrcs[0]}, vsrc1={vsrcs[1]}, vsrc2={vsrcs[2]}, vsrc3={vsrcs[3]}, done={done_val})" - - # SMEM - if mn in SMEM_OPS: - gs, ds = ", glc=1" if glc else "", ", dlc=1" if dlc else "" - off_field = "ioffset" if arch == "rdna4" else "offset" - th_s, scope_s, smem_str = "", "", op_str - if arch == "rdna4": - m, smem_str = _extract(op_str, r'\s+th:TH_(\w+)') - th_val = {'LOAD_RT': 0, 'LOAD_NT': 1, 'LOAD_HT': 2, 'LOAD_LU': 3, 'STORE_RT': 0, 'STORE_NT': 1, 'STORE_HT': 2, 'STORE_LU': 3}.get(m.group(1), 0) if m else None - m, smem_str = _extract(smem_str, r'\s+scope:SCOPE_(\w+)') - scope_val = {'CU': 0, 'SE': 1, 'DEV': 2, 'SYS': 3}.get(m.group(1), 0) if m else None - if scope_val is None: - m, smem_str = _extract(smem_str, r'\s+scope:(0?x?[0-9a-fA-F]+)') - scope_val = int(m.group(1), 0) if m else None - th_s = f", th={th_val}" if th_val else "" - scope_s = f", scope={scope_val}" if scope_val else "" - smem_ops = _parse_ops(smem_str) - smem_args = [_op2dsl(o) for o in smem_ops] - if len(smem_ops) >= 3 and re.match(r'^-?[0-9]|^-?0x', smem_ops[2].strip().lower()): - return f"{mn}(sdata={smem_args[0]}, sbase={smem_args[1]}, {off_field}={smem_ops[2].strip()}, soffset=RawImm(124){gs}{ds}{th_s}{scope_s})" - if off_val and len(smem_ops) >= 3: return f"{mn}(sdata={smem_args[0]}, sbase={smem_args[1]}, {off_field}={off_val}, soffset={smem_args[2]}{gs}{ds}{th_s}{scope_s})" - if len(smem_ops) >= 3: return f"{mn}(sdata={smem_args[0]}, sbase={smem_args[1]}, soffset={smem_args[2]}{gs}{ds}{th_s}{scope_s})" - - # Buffer (MUBUF/MTBUF/VBUFFER) instructions - if mn.startswith(('buffer_', 'tbuffer_')): - is_tbuf = mn.startswith('tbuffer_') - fmt_num = None - if fmt_val is not None: - if fmt_val.isdigit(): fmt_num = int(fmt_val) - else: fmt_num = BUF_FMT.get(fmt_val.replace(' ', '')) or _parse_buf_fmt_combo(fmt_val) - if mn in ('buffer_gl0_inv', 'buffer_gl1_inv', 'buffer_wbl2', 'buffer_inv'): return f"{mn}()" - if arch == "rdna4": - m, buf_text = _extract(op_str, r'\s+th:TH_(\w+)') - th_val = {'LOAD_RT': 0, 'LOAD_NT': 1, 'LOAD_HT': 2, 'LOAD_BYPASS': 3, 'LOAD_LU': 4, 'LOAD_RT_NT': 5, 'LOAD_NT_HT': 6, 'LOAD_RT_WB': 7, - 'STORE_RT': 0, 'STORE_NT': 1, 'STORE_HT': 2, 'STORE_BYPASS': 3, 'STORE_LU': 4, 'STORE_RT_NT': 5, 'STORE_NT_HT': 6, - 'ATOMIC_RT': 0, 'ATOMIC_NT': 1, 'ATOMIC_RETURN': 1, 'ATOMIC_RT_RETURN': 1, 'ATOMIC_NT_RETURN': 3, 'ATOMIC_CASCADE_RT': 6, 'ATOMIC_CASCADE_NT': 6}.get(m.group(1), 0) if m else 0 - m, buf_text = _extract(buf_text, r'\s+scope:SCOPE_(\w+)') - scope_val = {'CU': 0, 'SE': 1, 'DEV': 2, 'SYS': 3}.get(m.group(1), 0) if m else 0 - buf_ops = _parse_ops(buf_text) - buf_args = [_op2dsl(o) for o in buf_ops] - vbuf_mods = "".join([f", ioffset={off_val}" if off_val else "", ", offen=1" if offen else "", ", idxen=1" if idxen else "", - f", th={th_val}" if th_val else "", f", scope={scope_val}" if scope_val else "", - ", tfe=1" if tfe else ""]) - if is_tbuf and fmt_num is not None: vbuf_mods = f", format={fmt_num}" + vbuf_mods - elif is_tbuf: vbuf_mods = ", format=1" + vbuf_mods - else: vbuf_mods = ", format=1" + vbuf_mods - vaddr_idx = 1 - if len(buf_ops) > vaddr_idx and buf_ops[vaddr_idx].strip().lower() == 'off': vaddr_val = "v[0]" - else: vaddr_val = buf_args[vaddr_idx] if len(buf_args) > vaddr_idx else "v[0]" - rsrc_idx, soff_idx = (2, 3) if len(buf_ops) > 1 else (1, 2) - rsrc_raw = buf_ops[rsrc_idx].strip() if len(buf_ops) > rsrc_idx else "s[0:3]" - if m := re.match(r's\[(\d+):\d+\]', rsrc_raw.lower()): rsrc_val = m.group(1) - elif m := re.match(r's(\d+)', rsrc_raw.lower()): rsrc_val = m.group(1) - elif m := re.match(r'ttmp\[(\d+):\d+\]', rsrc_raw.lower()): rsrc_val = str(108 + int(m.group(1))) - elif m := re.match(r'ttmp(\d+)', rsrc_raw.lower()): rsrc_val = str(108 + int(m.group(1))) - else: rsrc_val = "0" - soff_raw = buf_ops[soff_idx].strip() if len(buf_ops) > soff_idx else "0" - soff_lower = soff_raw.lower() - if soff_lower == 'm0': soff_val = "RawImm(125)" - elif soff_lower in ('null', 'off'): soff_val = "RawImm(124)" - elif m := re.match(r's(\d+)', soff_lower): soff_val = f"RawImm({m.group(1)})" - else: soff_val = f"RawImm({soff_raw})" - return f"{mn}(vdata={buf_args[0]}, vaddr={vaddr_val}, rsrc={rsrc_val}, soffset={soff_val}{vbuf_mods})" - buf_mods = "".join([f", offset={off_val}" if off_val else "", ", glc=1" if glc else "", ", dlc=1" if dlc else "", - ", slc=1" if slc else "", ", tfe=1" if tfe else "", ", offen=1" if offen else "", ", idxen=1" if idxen else ""]) - if is_tbuf and fmt_num is not None: buf_mods = f", format={fmt_num}" + buf_mods - vaddr_idx = 1 - if len(ops) > vaddr_idx and ops[vaddr_idx].strip().lower() == 'off': vaddr_val = "v[0]" - else: vaddr_val = args[vaddr_idx] if len(args) > vaddr_idx else "v[0]" - srsrc_idx, soff_idx = (2, 3) if len(ops) > 1 else (1, 2) - srsrc_val = args[srsrc_idx] if len(args) > srsrc_idx else "s[0:3]" - soff_val = args[soff_idx] if len(args) > soff_idx else "0" - return f"{mn}(vdata={args[0]}, vaddr={vaddr_val}, srsrc={srsrc_val}, soffset={soff_val}{buf_mods})" - - # FLAT/GLOBAL/SCRATCH load/store/atomic - def _saddr(a): return 'RawImm(124)' if a in ('OFF', 'NULL') else a - flat_mods = f"{f', offset={off_val}' if off_val else ''}{', glc=1' if glc else ''}{', slc=1' if slc else ''}{', dlc=1' if dlc else ''}" - for pre, flds in [('flat_load','vdst,addr,saddr'), ('global_load','vdst,addr,saddr'), ('scratch_load','vdst,addr,saddr'), - ('flat_store','addr,data,saddr'), ('global_store','addr,data,saddr'), ('scratch_store','addr,data,saddr')]: - if mn.startswith(pre) and len(args) >= 2: - f0, f1, f2 = flds.split(',') - return f"{mn}({f0}={args[0]}, {f1}={args[1]}{f', {f2}={_saddr(args[2])}' if len(args) >= 3 else ', saddr=RawImm(124)'}{flat_mods})" - for pre in ('flat_atomic', 'global_atomic', 'scratch_atomic'): - if mn.startswith(pre): - if glc and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3])}' if len(args) >= 4 else ', saddr=RawImm(124)'}{flat_mods})" - if len(args) >= 2: return f"{mn}(addr={args[0]}, data={args[1]}{f', saddr={_saddr(args[2])}' if len(args) >= 3 else ', saddr=RawImm(124)'}{flat_mods})" - - # DS instructions - if mn.startswith('ds_'): - if ds_off0 is not None or ds_off1 is not None: - off0, off1 = str(ds_off0 or 0), str(ds_off1 or 0) - elif off_val: - off0, off1 = str(int(off_val, 0) & 0xff), str((int(off_val, 0) >> 8) & 0xff) - else: - off0, off1 = "0", "0" - gds_s = ", gds=1" if 'gds' in text.lower().split()[-1:] else "" - off_kw = f", offset0={off0}, offset1={off1}{gds_s}" - if mn == 'ds_nop' or mn in ('ds_gws_sema_v', 'ds_gws_sema_p', 'ds_gws_sema_release_all'): return f"{mn}({off_kw.lstrip(', ')})" - if 'gws_' in mn: return f"{mn}(addr={args[0]}{off_kw})" - if 'consume' in mn or 'append' in mn: return f"{mn}(vdst={args[0]}{off_kw})" - if 'gs_reg' in mn: return f"{mn}(vdst={args[0]}, data0={args[1]}{off_kw})" - if '2addr' in mn: - if 'load' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})" - if 'store' in mn and 'xchg' not in mn: return f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})" - return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" - if 'load' in mn: return f"{mn}(vdst={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})" - if 'store' in mn and not _has(mn, 'cmp', 'xchg'): - return f"{mn}(data0={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})" - if 'swizzle' in mn or 'ordered_count' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})" - if 'permute' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" - if 'bvh' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" - if 'condxchg' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" - if _has(mn, 'cmpstore', 'mskor', 'wrap'): - return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})" - return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})" - - # v_fmaak/v_fmamk literal extraction - lit_s = "" - mn_base = mn.replace('_e32', '').replace('_e64', '') - if mn_base in ('v_fmaak_f32', 'v_fmaak_f16') and len(args) == 4: lit_s, args = f", literal={args[3].strip()}", args[:3] - elif mn_base in ('v_fmamk_f32', 'v_fmamk_f16') and len(args) == 4: lit_s, args = f", literal={args[2].strip()}", [args[0], args[1], args[3]] - elif mn_base in ('s_fmaak_f32',) and len(args) == 4: lit_s, args = f", literal={args[3].strip()}", args[:3] - elif mn_base in ('s_fmamk_f32',) and len(args) == 4: lit_s, args = f", literal={args[2].strip()}", [args[0], args[1], args[3]] - elif mn in ('v_cndmask_b32', 'v_cndmask_b32_e32') and len(args) == 4 and ops[3].strip().lower() in ('vcc_lo', 'vcc'): - mn, args = 'v_cndmask_b32_e32', args[:3] - - _SGPR_NAMES = {'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'null': 124, 'm0': 125, 'exec_lo': 126, 'exec_hi': 127} - vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'} - if mn.replace('_e32', '') in vcc_ops and len(args) >= 5: - carry_in = ops[4].strip().lower() if len(ops) > 4 else 'vcc_lo' - carry_out = ops[1].strip().lower() if len(ops) > 1 else 'vcc_lo' - if carry_in in ('vcc_lo', 'vcc') and carry_out in ('vcc_lo', 'vcc'): - mn, args = mn.replace('_e32', '') + '_e32', [args[0], args[2], args[3]] - else: - mn_base = mn.replace('_e32', '').replace('_e64', '') - sdst = _SGPR_NAMES.get(carry_out, 124) if carry_out in _SGPR_NAMES else (int(carry_out[1:]) if carry_out.startswith('s') and carry_out[1:].isdigit() else 124) - src2 = _SGPR_NAMES.get(carry_in, 0) if carry_in in _SGPR_NAMES else (int(carry_in[1:]) if carry_in.startswith('s') and carry_in[1:].isdigit() else 0) - return f"{mn_base}(vdst={args[0]}, sdst=RawImm({sdst}), src0={args[2]}, src1={args[3]}, src2=RawImm({src2}))" - if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '') - if mn.startswith('v_cmp') and not mn.endswith('_e64') and len(args) >= 3 and ops[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): args = args[1:] - if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2: args = ['RawImm(126)'] + args - if ((mn.startswith('v_cmp') and 'cmpx' not in mn and mn.endswith('_e64')) or mn.startswith('v_s_') or mn in ('v_readlane_b32', 'v_readfirstlane_b32')) and len(args) >= 1: - dst = ops[0].strip().lower() - if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})' - elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})' - elif dst.startswith('ttmp') and dst[4:].isdigit(): args[0] = f'RawImm({108 + int(dst[4:])})' - elif dst.startswith('ttmp[') and ':' in dst: args[0] = f'RawImm({108 + int(dst[5:].split(":")[0])})' - elif dst in _SGPR_NAMES: args[0] = f'RawImm({_SGPR_NAMES[dst]})' - - fn = mn.replace('.', '_') - if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args] - - # v_fma_mix*: extract inline neg/abs modifiers - if 'fma_mix' in mn and neg_lo is None and neg_hi is None: - inline_neg, inline_abs, clean_args = 0, 0, [args[0]] - for i, op in enumerate(ops[1:4]): - op = op.strip() - neg = op.startswith('-') and not (op[1:2].isdigit() or (len(op) > 2 and op[1] == '0' and op[2] in 'xX')) - if neg: op = op[1:] - abs_ = op.startswith('|') and op.endswith('|') - if abs_: op = op[1:-1] - if neg: inline_neg |= (1 << i) - if abs_: inline_abs |= (1 << i) - clean_args.append(_op2dsl(op)) - args = clean_args + args[4:] - if inline_neg: neg_lo = inline_neg - if inline_abs: neg_hi = inline_abs - - all_kw = list(kw) - if lit_s: all_kw.append(lit_s.lstrip(', ')) - if opsel is not None: all_kw.append(f'opsel={opsel}') - if neg_lo is not None: all_kw.append(f'neg={neg_lo}') - if neg_hi is not None: all_kw.append(f'neg_hi={neg_hi}') - if 'bvh' in mn and 'intersect_ray' in mn: all_kw.extend(['dmask=15', 'unrm=1', 'r128=1']) - vop3p_ops = {'v_pk_', 'v_dot2', 'v_dot4', 'v_dot8', 'v_wmma', 'v_swmmac'} - is_vop3p = any(mn.startswith(p) for p in vop3p_ops) - is_fma_mix = 'fma_mix' in mn - if opsel_hi_val is not None: - opsel_hi_enc = opsel_hi_val[0] | (opsel_hi_val[1] << 1) if len(opsel_hi_val) >= 2 else opsel_hi_val[0] - opsel_hi2_enc = opsel_hi_val[2] if len(opsel_hi_val) >= 3 else (0 if is_fma_mix else 1) - all_kw.extend([f'opsel_hi={opsel_hi_enc}', f'opsel_hi2={opsel_hi2_enc}']) - elif is_vop3p and not is_fma_mix: - all_kw.extend(['opsel_hi=3', 'opsel_hi2=1']) - if clamp_found: - if arch == 'rdna4': all_kw.append('cm=1') - else: all_kw.append('clmp=1') - - a_str, kw_str = ', '.join(args), ', '.join(all_kw) - return f"{fn}({a_str}, {kw_str})" if kw_str and a_str else f"{fn}({kw_str})" if kw_str else f"{fn}({a_str})" - -def _hwreg(id_, offset=0, size=32): return id_ | (offset << 6) | ((size - 1) << 11) -def _sendmsg(id_, op=0, stream=0): return id_ | (op << 4) | (stream << 8) - -_HWREG_NAMES = {'HW_REG_MODE': 1, 'HW_REG_STATUS': 2, 'HW_REG_TRAPSTS': 3, 'HW_REG_HW_ID': 4, 'HW_REG_GPR_ALLOC': 5, - 'HW_REG_LDS_ALLOC': 6, 'HW_REG_IB_STS': 7, 'HW_REG_PC_LO': 8, 'HW_REG_PC_HI': 9, 'HW_REG_INST_DW0': 10, 'HW_REG_INST_DW1': 11, - 'HW_REG_IB_DBG0': 12, 'HW_REG_IB_DBG1': 13, 'HW_REG_FLUSH_IB': 14, 'HW_REG_SH_MEM_BASES': 15, 'HW_REG_SQ_SHADER_TBA_LO': 16, - 'HW_REG_SQ_SHADER_TBA_HI': 17, 'HW_REG_SQ_SHADER_TMA_LO': 18, 'HW_REG_SQ_SHADER_TMA_HI': 19, 'HW_REG_FLAT_SCR_LO': 20, - 'HW_REG_FLAT_SCR_HI': 21, 'HW_REG_XNACK_MASK': 22, 'HW_REG_HW_ID1': 23, 'HW_REG_HW_ID2': 24, 'HW_REG_POPS_PACKER': 25, - 'HW_REG_PERF_SNAPSHOT_DATA': 26, 'HW_REG_PERF_SNAPSHOT_PC_LO': 27, 'HW_REG_PERF_SNAPSHOT_PC_HI': 28, 'HW_REG_SHADER_CYCLES': 29, - 'HW_REG_SHADER_CYCLES_HI': 30, 'HW_REG_WAVE_MODE': 31, 'HW_REG_WAVE_SCRATCH_BASE': 32} -_HWREG_NAMES_RDNA4 = {v: k for k, v in HWREG_RDNA4.items()} -_SENDMSG_NAMES = {'MSG_INTERRUPT': 1, 'MSG_GS': 2, 'MSG_GS_DONE': 3, 'MSG_SAVEWAVE': 4, 'MSG_STALL_WAVE_GEN': 5, - 'MSG_HALT_WAVES': 6, 'MSG_ORDERED_PS_DONE': 7, 'MSG_EARLY_PRIM_DEALLOC': 8, 'MSG_GS_ALLOC_REQ': 9, 'MSG_GET_DOORBELL': 10, - 'MSG_GET_DDID': 11, 'MSG_HS_TESSFACTOR': 2, 'MSG_DEALLOC_VGPRS': 10, 'MSG_RTN_GET_DOORBELL': 128, 'MSG_RTN_GET_DDID': 129, - 'MSG_RTN_GET_TMA': 130, 'MSG_RTN_GET_REALTIME': 131, 'MSG_RTN_SAVE_WAVE': 132, 'MSG_RTN_GET_TBA': 133, - 'MSG_RTN_GET_TBA_TO_PC': 134, 'MSG_RTN_GET_SE_AID_ID': 135} - -def asm(text: str, arch: str = "rdna3") -> Inst: - dsl = get_dsl(text, arch) - if arch == "rdna4": - ns = {n: getattr(rdna4_ins, n) for n in dir(rdna4_ins) if not n.startswith('_')} - hwreg_names = _HWREG_NAMES_RDNA4 - else: - ns = {n: getattr(ins, n) for n in dir(ins) if not n.startswith('_')} - hwreg_names = _HWREG_NAMES - def hwreg(id_, offset=0, size=32): return _hwreg(hwreg_names.get(id_, id_) if isinstance(id_, str) else id_, offset, size) - def sendmsg(id_, op=0, stream=0): return _sendmsg(_SENDMSG_NAMES.get(id_, id_) if isinstance(id_, str) else id_, op, stream) - ns.update({'s': s, 'v': v, 'ttmp': ttmp, 'abs': abs, 'RawImm': RawImm, 'SrcMod': SrcMod, 'VGPR': VGPR, 'SGPR': SGPR, 'TTMP': TTMP, - 'VCC_LO': VCC_LO, 'VCC_HI': VCC_HI, 'VCC': VCC, 'EXEC_LO': EXEC_LO, 'EXEC_HI': EXEC_HI, 'EXEC': EXEC, 'SCC': SCC, 'M0': M0, 'NULL': NULL, 'OFF': OFF, - 'hwreg': hwreg, 'sendmsg': sendmsg, **{k: k for k in hwreg_names}, **{k: k for k in _SENDMSG_NAMES}}) - try: return eval(dsl, ns) - except NameError: - if m := re.match(r'^(v_\w+)(\(.*\))$', dsl): return eval(f"{m.group(1)}_e32{m.group(2)}", ns) - raise diff --git a/extra/assembly/amd/decode.py b/extra/assembly/amd/decode.py index 0a9c9b468a..a2efe35433 100644 --- a/extra/assembly/amd/decode.py +++ b/extra/assembly/amd/decode.py @@ -2,14 +2,13 @@ from __future__ import annotations from extra.assembly.amd.dsl import Inst, FixedBitField from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP1_SDST, VOP2, VOP3, VOP3_SDST, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, - SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP) + SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT) from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as R4_VOP1_SDST, VOP2 as R4_VOP2, VOP3 as R4_VOP3, VOP3_SDST as R4_VOP3_SDST, VOP3SD as R4_VOP3SD, VOP3P as R4_VOP3P, VOPC as R4_VOPC, VOPD as R4_VOPD, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP2 as R4_SOP2, SOPC as R4_SOPC, SOPK as R4_SOPK, SOPP as R4_SOPP, - SMEM as R4_SMEM, DS as R4_DS, VBUFFER as R4_VBUFFER, VEXPORT as R4_VEXPORT) + SMEM as R4_SMEM, DS as R4_DS) from extra.assembly.amd.autogen.cdna.ins import (VOP1 as C_VOP1, VOP2 as C_VOP2, VOPC as C_VOPC, VOP3A, VOP3B, VOP3P as C_VOP3P, - SOP1 as C_SOP1, SOP2 as C_SOP2, SOPC as C_SOPC, SOPK as C_SOPK, SOPP as C_SOPP, SMEM as C_SMEM, DS as C_DS, - FLAT as C_FLAT, MUBUF as C_MUBUF, MTBUF as C_MTBUF) + SOP1 as C_SOP1, SOP2 as C_SOP2, SOPC as C_SOPC, SOPK as C_SOPK, SOPP as C_SOPP, SMEM as C_SMEM, DS as C_DS, FLAT as C_FLAT) def _matches_encoding(word: int, cls: type[Inst]) -> bool: """Check if word matches the encoding pattern of an instruction class.""" @@ -19,12 +18,12 @@ def _matches_encoding(word: int, cls: type[Inst]) -> bool: return ((word >> bf.lo) & bf.mask) == bf.default # Order matters: more specific encodings first, VOP2 last (it's a catch-all for bit31=0) -_RDNA_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP] +_RDNA_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, SMEM] _RDNA_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls -_CDNA_FORMATS_64 = [C_VOP3P, VOP3A, C_DS, C_FLAT, C_MUBUF, C_MTBUF, C_SMEM] +_CDNA_FORMATS_64 = [C_VOP3P, VOP3A, C_DS, C_FLAT, C_SMEM] _CDNA_FORMATS_32 = [C_SOP1, C_SOPC, C_SOPP, C_SOPK, C_VOPC, C_VOP1, C_SOP2, C_VOP2] _CDNA_VOP3B_OPS = {281, 282, 283, 284, 285, 286, 480, 481, 488, 489} # VOP3B opcodes -_RDNA4_FORMATS_64 = [R4_VOPD, R4_VOP3P, R4_VINTERP, R4_VOP3, R4_DS, R4_VBUFFER, R4_SMEM, R4_VEXPORT] +_RDNA4_FORMATS_64 = [R4_VOPD, R4_VOP3P, R4_VINTERP, R4_VOP3, R4_DS, R4_SMEM] _RDNA4_FORMATS_32 = [R4_SOP1, R4_SOPC, R4_SOPP, R4_SOPK, R4_VOPC, R4_VOP1, R4_SOP2, R4_VOP2] _RDNA4_VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770} _RDNA3_VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770} diff --git a/extra/assembly/amd/disasm.py b/extra/assembly/amd/disasm.py index ac057d03b7..760d1b7ef9 100644 --- a/extra/assembly/amd/disasm.py +++ b/extra/assembly/amd/disasm.py @@ -84,12 +84,12 @@ def _swmmac_regs(name: str) -> tuple[int, int, int, int]: # IMPORTS # ═══════════════════════════════════════════════════════════════════════════════ -from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP1_SDST, VOP2, VOP3, VOP3_SDST, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, GLOBAL, SCRATCH, MUBUF, MTBUF, MIMG, EXP, - VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp) +from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP1_SDST, VOP2, VOP3, VOP3_SDST, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, GLOBAL, SCRATCH, + VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp) from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as R4_VOP1_SDST, VOP2 as R4_VOP2, VOP3 as R4_VOP3, VOP3_SDST as R4_VOP3_SDST, VOP3SD as R4_VOP3SD, VOP3P as R4_VOP3P, VOPC as R4_VOPC, VOPD as R4_VOPD, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP2 as R4_SOP2, SOPC as R4_SOPC, SOPK as R4_SOPK, SOPP as R4_SOPP, - SMEM as R4_SMEM, DS as R4_DS, VBUFFER as R4_VBUFFER, VEXPORT as R4_VEXPORT, VOPDOp as R4_VOPDOp) -from extra.assembly.amd.autogen.cdna.ins import FLAT as C_FLAT, MUBUF as C_MUBUF, MTBUF as C_MTBUF + SMEM as R4_SMEM, DS as R4_DS, VOPDOp as R4_VOPDOp) +from extra.assembly.amd.autogen.cdna.ins import FLAT as C_FLAT def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__ @@ -491,72 +491,6 @@ def _disasm_vop3p(inst: VOP3P) -> str: ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if clamp else []) return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}" -def _disasm_buf(inst: MUBUF | MTBUF) -> str: - name, cdna = inst.op_name.lower(), _is_cdna(inst) - acc = getattr(inst, 'acc', 0) - reg_fn = _areg if acc else _vreg - if cdna and name in ('buffer_wbl2', 'buffer_inv'): return name - if not cdna and inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name - w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \ - ((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \ - {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1) - if hasattr(inst, 'tfe') and inst.tfe: w += 1 - vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else _vreg(inst.vaddr) if inst.offen or inst.idxen else "off" - srsrc = _sreg_or_ttmp(_unwrap(inst.srsrc), 4) - is_mtbuf = isinstance(inst, MTBUF) or isinstance(inst, C_MTBUF) - if is_mtbuf: - dfmt, nfmt = inst.format & 0xf, (inst.format >> 4) & 0x7 - if acc: fmt_s = f" dfmt:{dfmt}, nfmt:{nfmt}," - elif not cdna: fmt_s = f" format:{inst.format}" if inst.format else "" - else: - dfmt_names = ['INVALID', '8', '16', '8_8', '32', '16_16', '10_11_11', '11_11_10', '10_10_10_2', '2_10_10_10', '8_8_8_8', '32_32', '16_16_16_16', '32_32_32', '32_32_32_32', 'RESERVED_15'] - nfmt_names = ['UNORM', 'SNORM', 'USCALED', 'SSCALED', 'UINT', 'SINT', 'RESERVED_6', 'FLOAT'] - if dfmt == 1 and nfmt == 0: fmt_s = "" - elif nfmt == 0: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]}]" - elif dfmt == 1: fmt_s = f" format:[BUF_NUM_FORMAT_{nfmt_names[nfmt]}]" - else: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]},BUF_NUM_FORMAT_{nfmt_names[nfmt]}]" - else: fmt_s = "" - if cdna: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"glc"),(inst.nt,"slc"),(inst.sc1,"sc1")] if c] - else: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c] - soffset_s = decode_src(inst.soffset, cdna) - if cdna and not acc and is_mtbuf: return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc}, {soffset_s}{fmt_s}{' ' + ' '.join(mods) if mods else ''}" - return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc},{fmt_s} {soffset_s}{' ' + ' '.join(mods) if mods else ''}" - -def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int: - base = [1, 2, 3, 3, 2, 3, 3, 4][dim] - grad = [1, 2, 3, 2, 1, 2, 2, 2][dim] - if 'get_resinfo' in name: return 1 - packed, unpacked = 0, 0 - if '_mip' in name: packed += 1 - elif 'sample' in name or 'gather' in name: - if '_o' in name: unpacked += 1 - if re.search(r'_c(_|$)', name): unpacked += 1 - if '_d' in name: unpacked += (grad + 1) & ~1 if '_g16' in name else grad*2 - if '_b' in name: unpacked += 1 - if '_l' in name and '_cl' not in name and '_lz' not in name: packed += 1 - if '_cl' in name: packed += 1 - return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked - -def _disasm_mimg(inst: MIMG) -> str: - name = inst.op_name.lower() - srsrc_str = inst.srsrc.fmt() - if 'bvh' in name: - vaddr = (9 if '64' in name else 8) if inst.a16 else (12 if '64' in name else 11) - return f"{name} {inst.vdata.fmt(sz=4)}, {inst.vaddr.fmt(sz=vaddr)}, {inst.srsrc.fmt(sz=4)}{' a16' if inst.a16 else ''}" - vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1) - if inst.d16: vdata = (vdata + 1) // 2 - if inst.tfe: vdata += 1 - dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array'] - dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"dim_{inst.dim}" - vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16) - mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask and (inst.dmask != 15 or 'atomic' in name) else [] - mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}") - for flag, mod in [(inst.unrm,"unorm"),(inst.glc,"glc"),(inst.slc,"slc"),(inst.dlc,"dlc"),(inst.r128,"r128"), - (inst.a16,"a16"),(inst.tfe,"tfe"),(inst.lwe,"lwe"),(inst.d16,"d16")]: - if flag: mods.append(mod) - ssamp_str = f", {inst.ssamp.fmt()}" if 'sample' in name or 'gather' in name or 'get_lod' in name else "" - return f"{name} {inst.vdata.fmt(sz=vdata)}, {inst.vaddr.fmt(sz=vaddr)}, {srsrc_str}{ssamp_str} {' '.join(mods)}" - def _disasm_sop1(inst: SOP1) -> str: op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst) # Use get_field_bits for register sizes @@ -618,46 +552,14 @@ def _disasm_vinterp(inst: VINTERP) -> str: mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp")) return f"{inst.op_name.lower()} {inst.vdst.fmt()}, {_lit(inst, inst.src0, inst.neg & 1)}, {_lit(inst, inst.src1, inst.neg & 2)}, {_lit(inst, inst.src2, inst.neg & 4)}" + (" " + mods if mods else "") -EXP_TARGETS = {0: 'mrt0', 1: 'mrt1', 2: 'mrt2', 3: 'mrt3', 4: 'mrt4', 5: 'mrt5', 6: 'mrt6', 7: 'mrt7', - 8: 'mrtz', 9: 'null', 12: 'pos0', 13: 'pos1', 14: 'pos2', 15: 'pos3', 16: 'pos4', - 32: 'param0', 33: 'param1', 34: 'param2', 35: 'param3', 36: 'param4', 37: 'param5'} -def _disasm_vexport(inst) -> str: - tgt = EXP_TARGETS.get(inst.target, f'{inst.target}') - srcs = [getattr(inst, f"vsrc{i}").fmt() if inst.en & (1 << i) else 'off' for i in range(4)] - mods = _mods((inst.done, "done"), (inst.row, "row_en")) - return f"export {tgt} {', '.join(srcs)}" + (" " + mods if mods else "") - -def _disasm_vbuffer(inst) -> str: - name = inst.op_name.lower().replace('buffer_', 'buffer_').replace('tbuffer_', 'tbuffer_') - w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \ - ((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \ - {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], inst.canonical_op_regs['d']) - if getattr(inst, 'tfe', 0): w += 1 - vdata = inst.vdata.fmt(sz=w) - vaddr = inst.vaddr.fmt(sz=2) if inst.offen and inst.idxen else (inst.vaddr.fmt() if inst.offen or inst.idxen else 'off') - srsrc = f'ttmp[{inst.rsrc - 108}:{inst.rsrc - 108 + 3}]' if inst.rsrc >= 108 else f's[{inst.rsrc}:{inst.rsrc + 3}]' - soff = decode_src(inst.soffset.offset) if inst.soffset.offset >= 106 else f's{inst.soffset.offset}' - fmt = getattr(inst, 'format', 0) - from extra.assembly.amd.asm import BUF_FMT - fmt_names = {v: k for k, v in BUF_FMT.items()} - fmt_s = f" format:[{fmt_names[fmt]}]" if fmt > 1 and fmt in fmt_names else (f" format:{fmt}" if fmt > 1 else "") - if 'atomic' in name: th_names = {1: 'TH_ATOMIC_RETURN', 6: 'TH_ATOMIC_CASCADE_NT'} - elif 'store' in name: th_names = {3: 'TH_STORE_BYPASS', 6: 'TH_STORE_NT_HT'} - else: th_names = {3: 'TH_LOAD_BYPASS', 6: 'TH_LOAD_NT_HT'} - scope_names = {1: 'SCOPE_SE', 2: 'SCOPE_DEV', 3: 'SCOPE_SYS'} - mods = _mods((inst.idxen, "idxen"), (inst.offen, "offen"), (inst.ioffset, f"offset:{inst.ioffset}"), - (inst.th in th_names, f"th:{th_names.get(inst.th, '')}"), (inst.scope in scope_names, f"scope:{scope_names.get(inst.scope, '')}")) - return f"{name} {vdata}, {vaddr}, {srsrc}, {soff}{fmt_s}" + (" " + mods if mods else "") - DISASM_HANDLERS: dict[type, callable] = { VOP1: _disasm_vop1, VOP1_SDST: _disasm_vop1, VOP2: _disasm_vop2, VOPC: _disasm_vopc, VOP3: _disasm_vop3, VOP3_SDST: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOPD: _disasm_vopd, VOP3P: _disasm_vop3p, VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, GLOBAL: _disasm_flat, SCRATCH: _disasm_flat, - MUBUF: _disasm_buf, MTBUF: _disasm_buf, MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk, + SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk, # RDNA4 R4_VOP1: _disasm_vop1, R4_VOP1_SDST: _disasm_vop1, R4_VOP2: _disasm_vop2, R4_VOPC: _disasm_vopc, R4_VOP3: _disasm_vop3, R4_VOP3_SDST: _disasm_vop3, R4_VOP3SD: _disasm_vop3sd, R4_VOPD: _disasm_vopd, R4_VOP3P: _disasm_vop3p, R4_VINTERP: _disasm_vinterp, R4_SOPP: _disasm_sopp, R4_SMEM: _disasm_smem, - R4_DS: _disasm_ds, R4_SOP1: _disasm_sop1, R4_SOP2: _disasm_sop2, R4_SOPC: _disasm_sopc, R4_SOPK: _disasm_sopk, - R4_VEXPORT: _disasm_vexport, R4_VBUFFER: _disasm_vbuffer} + R4_DS: _disasm_ds, R4_SOP1: _disasm_sop1, R4_SOP2: _disasm_sop2, R4_SOPC: _disasm_sopc, R4_SOPK: _disasm_sopk} def disasm(inst: Inst) -> str: return DISASM_HANDLERS[type(inst)](inst) @@ -668,7 +570,7 @@ def disasm(inst: Inst) -> str: return DISASM_HANDLERS[type(inst)](inst) try: from extra.assembly.amd.autogen.cdna.ins import (VOP1 as CDNA_VOP1, VOP2 as CDNA_VOP2, VOPC as CDNA_VOPC, VOP3A, VOP3B, VOP3P as CDNA_VOP3P, SOP1 as CDNA_SOP1, SOP2 as CDNA_SOP2, SOPC as CDNA_SOPC, SOPK as CDNA_SOPK, SOPP as CDNA_SOPP, SMEM as CDNA_SMEM, DS as CDNA_DS, - FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op, VOPCOp as CDNA_VOPCOp) + FLAT as CDNA_FLAT, VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op, VOPCOp as CDNA_VOPCOp) def _cdna_src(inst, v, neg, abs_=0, n=1): s = _lit(inst, v) if v == 255 else _fmt_src(v, n, cdna=True) @@ -678,8 +580,7 @@ try: _CDNA_VOP3_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32'} def _disasm_vop3a(inst) -> str: - op_val = inst._values.get('op', 0) - if hasattr(op_val, 'value'): op_val = op_val.value + op_val = inst.op.value if hasattr(inst.op, 'value') else inst.op name = inst.op_name.lower() or f'vop3a_op_{op_val}' n = inst.num_srcs() or _num_srcs(inst) cl, om = " clamp" if inst.clmp else "", _omod(inst.omod) @@ -711,8 +612,7 @@ try: return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}" def _disasm_vop3b(inst) -> str: - op_val = inst._values.get('op', 0) - if hasattr(op_val, 'value'): op_val = op_val.value + op_val = inst.op.value if hasattr(inst.op, 'value') else inst.op name = inst.op_name.lower() or f'vop3b_op_{op_val}' n = inst.num_srcs() or _num_srcs(inst) regs = inst.canonical_op_regs @@ -728,18 +628,44 @@ try: return f"{name}{suf} {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {sdst}, {s0}, {s1}{cl}{om}" def _disasm_cdna_vop3p(inst) -> str: - name, n, is_mfma = inst.op_name.lower(), inst.num_srcs() or 2, 'mfma' in inst.op_name.lower() or 'smfmac' in inst.op_name.lower() + name, n = inst.op_name.lower(), inst.num_srcs() or 2 + is_mfma = 'mfma' in name or 'smfmac' in name + is_accvgpr = 'accvgpr' in name get_src = lambda v, sc: _lit(inst, v) if v == 255 else _fmt_src(v, sc, cdna=True) - if is_mfma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4; src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16), _vreg(inst.vdst, 16) - else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), _vreg(inst.vdst) - opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2) - mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \ + + # Handle accvgpr read/write (accumulator register operations) + if is_accvgpr: + src0_off = _unwrap(inst.src0) + vdst_off = _vi(inst.vdst) + if 'read' in name: + # v_accvgpr_read_b32 vN, aM - reads from accumulator to VGPR + return f"{name}_b32 v{vdst_off}, a{src0_off - 256 if src0_off >= 256 else src0_off}" + if 'write' in name: + # v_accvgpr_write_b32 aM, src - writes to accumulator from source + src = _lit(inst, inst.src0) if src0_off == 255 else (f"v{src0_off - 256}" if src0_off >= 256 else decode_src(src0_off, cdna=True)) + return f"{name}_b32 a{vdst_off}, {src}" + + # Handle MFMA instructions with accumulator destinations + if is_mfma: + sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4 + src0, src1, src2 = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16) + dst = _areg(inst.vdst, 16) # MFMA uses accumulator registers + opsel_hi = inst.opsel_hi + mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != 3 else []) + \ + ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else []) + return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" + + # Standard VOP3P instructions + src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), _vreg(inst.vdst) + opsel_hi = inst.opsel_hi # CDNA VOP3P only has 2 bits for opsel_hi (no opsel_hi2) + opsel_hi_default = 3 # CDNA default is 0b11 (2 bits), not 0b111 like RDNA + mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != opsel_hi_default else []) + \ ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else []) return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}" DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP2: _disasm_vop2, CDNA_VOPC: _disasm_vopc, CDNA_SOP1: _disasm_sop1, CDNA_SOP2: _disasm_sop2, CDNA_SOPC: _disasm_sopc, CDNA_SOPK: _disasm_sopk, CDNA_SOPP: _disasm_sopp, - CDNA_SMEM: _disasm_smem, CDNA_DS: _disasm_ds, CDNA_FLAT: _disasm_flat, CDNA_MUBUF: _disasm_buf, CDNA_MTBUF: _disasm_buf, + CDNA_SMEM: _disasm_smem, CDNA_DS: _disasm_ds, CDNA_FLAT: _disasm_flat, VOP3A: _disasm_vop3a, VOP3B: _disasm_vop3b, CDNA_VOP3P: _disasm_cdna_vop3p}) except ImportError: pass diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index 9ea4920451..3d908efb3c 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -335,7 +335,7 @@ class Inst: def _size(cls) -> int: return cls._base_size def size(self) -> int: return self._base_size + (4 if self._literal is not None else 0) def disasm(self) -> str: - from extra.assembly.amd.asm import disasm + from extra.assembly.amd.disasm import disasm return disasm(self) def to_bytes(self) -> bytes: diff --git a/extra/assembly/amd/test/test_formats.py b/extra/assembly/amd/test/test_formats.py index 334f045f85..200bae7140 100644 --- a/extra/assembly/amd/test/test_formats.py +++ b/extra/assembly/amd/test/test_formats.py @@ -1,178 +1,14 @@ #!/usr/bin/env python3 -"""Test MUBUF, MTBUF, MIMG, EXP, DS formats against LLVM.""" +"""Test DS and other compute-relevant instruction formats. + +Note: Graphics-only formats (EXP, MUBUF, MTBUF, MIMG) are not supported - use GLOBAL/FLAT for memory access in compute. +""" import unittest from extra.assembly.amd.autogen.rdna3.ins import * from extra.assembly.amd.dsl import VCC_HI, EXEC_LO, NULL OFF = NULL # OFF is alias for NULL from extra.assembly.amd.decode import detect_format -class TestMUBUF(unittest.TestCase): - """Test MUBUF (buffer) instructions.""" - - def test_buffer_load_b32_basic(self): - # buffer_load_b32 v5, off, s[8:11], s3 offset:4095 - # GFX11: encoding: [0xff,0x0f,0x50,0xe0,0x00,0x05,0x02,0x03] - inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095) - self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0x50,0xe0,0x00,0x05,0x02,0x03])) - - def test_buffer_load_b32_idxen(self): - # buffer_load_b32 v5, v0, s[8:11], s3 idxen offset:4095 - # GFX11: encoding: [0xff,0x0f,0x50,0xe0,0x00,0x05,0x82,0x03] - inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, idxen=1) - self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0x50,0xe0,0x00,0x05,0x82,0x03])) - - def test_buffer_load_b32_offen(self): - # buffer_load_b32 v5, v0, s[8:11], s3 offen offset:4095 - # GFX11: encoding: [0xff,0x0f,0x50,0xe0,0x00,0x05,0x42,0x03] - inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, offen=1) - self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0x50,0xe0,0x00,0x05,0x42,0x03])) - - def test_buffer_load_b32_glc(self): - # buffer_load_b32 v5, off, s[8:11], s3 offset:4095 glc - # GFX11: encoding: [0xff,0x4f,0x50,0xe0,0x00,0x05,0x02,0x03] - inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, glc=1) - self.assertEqual(inst.to_bytes(), bytes([0xff,0x4f,0x50,0xe0,0x00,0x05,0x02,0x03])) - - def test_buffer_load_b32_slc(self): - # buffer_load_b32 v5, off, s[8:11], s3 offset:4095 slc - # GFX11: encoding: [0xff,0x1f,0x50,0xe0,0x00,0x05,0x02,0x03] - inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, slc=1) - self.assertEqual(inst.to_bytes(), bytes([0xff,0x1f,0x50,0xe0,0x00,0x05,0x02,0x03])) - - def test_buffer_load_b32_dlc(self): - # buffer_load_b32 v5, off, s[8:11], s3 offset:4095 dlc - # GFX11: encoding: [0xff,0x2f,0x50,0xe0,0x00,0x05,0x02,0x03] - inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, dlc=1) - self.assertEqual(inst.to_bytes(), bytes([0xff,0x2f,0x50,0xe0,0x00,0x05,0x02,0x03])) - - def test_buffer_load_b32_all_flags(self): - # buffer_load_b32 v5, off, s[8:11], s3 offset:4095 glc slc dlc - # GFX11: encoding: [0xff,0x7f,0x50,0xe0,0x00,0x05,0x02,0x03] - inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, glc=1, slc=1, dlc=1) - self.assertEqual(inst.to_bytes(), bytes([0xff,0x7f,0x50,0xe0,0x00,0x05,0x02,0x03])) - - def test_buffer_store_b32(self): - # buffer_store_b32 v1, off, s[12:15], s4 offset:4095 - # GFX11: encoding: [0xff,0x0f,0x68,0xe0,0x00,0x01,0x03,0x04] - inst = buffer_store_b32(vdata=v[1], vaddr=v[0], srsrc=s[12:15], soffset=s[4], offset=4095) - self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0x68,0xe0,0x00,0x01,0x03,0x04])) - - def test_buffer_load_b64(self): - # buffer_load_b64 v[5:6], off, s[8:11], s3 offset:4095 - # GFX11: encoding: [0xff,0x0f,0x54,0xe0,0x00,0x05,0x02,0x03] - inst = buffer_load_b64(vdata=v[5:6], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095) - self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0x54,0xe0,0x00,0x05,0x02,0x03])) - - def test_buffer_load_soffset_m0(self): - # buffer_load_b32 v5, off, s[8:11], m0 offset:4095 - # GFX11: encoding: [0xff,0x0f,0x50,0xe0,0x00,0x05,0x02,0x7d] - inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=M0, offset=4095) - self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0x50,0xe0,0x00,0x05,0x02,0x7d])) - - def test_buffer_load_soffset_inline_const(self): - # buffer_load_b32 v5, off, s[8:11], 0 offset:4095 - # GFX11: encoding: [0xff,0x0f,0x50,0xe0,0x00,0x05,0x02,0x80] - inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=0, offset=4095) - self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0x50,0xe0,0x00,0x05,0x02,0x80])) - - def test_buffer_disasm_roundtrip(self): - inst = buffer_load_b32(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, glc=1) - decoded = MUBUF.from_bytes(inst.to_bytes()) - self.assertEqual(decoded.to_bytes(), inst.to_bytes()) - - -class TestMTBUF(unittest.TestCase): - """Test MTBUF (typed buffer) instructions.""" - - def test_tbuffer_load_format_x(self): - # tbuffer_load_format_x v5, off, s[8:11], s3 format:[BUF_FMT_32_FLOAT] offset:4095 - # BUF_FMT_32_FLOAT = 22 - # GFX11: encoding: [0xff,0x0f,0xb0,0xe8,0x00,0x05,0x02,0x03] - inst = tbuffer_load_format_x(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, format=22) - self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0xb0,0xe8,0x00,0x05,0x02,0x03])) - - def test_tbuffer_store_format_x(self): - # tbuffer_store_format_x v5, off, s[8:11], s3 format:[BUF_FMT_32_FLOAT] offset:4095 - # BUF_FMT_32_FLOAT = 22 - # GFX11: encoding: [0xff,0x0f,0xb2,0xe8,0x00,0x05,0x02,0x03] - inst = tbuffer_store_format_x(vdata=v[5], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, format=22) - self.assertEqual(inst.to_bytes(), bytes([0xff,0x0f,0xb2,0xe8,0x00,0x05,0x02,0x03])) - - def test_tbuffer_load_format_xy(self): - # tbuffer_load_format_xy v[5:6], off, s[8:11], s3 format:[BUF_FMT_32_32_FLOAT] offset:4095 - # BUF_FMT_32_32_FLOAT = 50 - # GFX11: encoding: [0xff,0x8f,0x90,0xe9,0x00,0x05,0x02,0x03] - inst = tbuffer_load_format_xy(vdata=v[5:6], vaddr=v[0], srsrc=s[8:11], soffset=s[3], offset=4095, format=50) - self.assertEqual(inst.to_bytes(), bytes([0xff,0x8f,0x90,0xe9,0x00,0x05,0x02,0x03])) - - -class TestMIMG(unittest.TestCase): - """Test MIMG (image) instructions.""" - - def test_image_load_2d(self): - # image_load v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D - # GFX11: encoding: [0x04,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00] - inst = image_load(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=1) # dim=1 is SQ_RSRC_IMG_2D - self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00])) - - def test_image_store_2d(self): - # image_store v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D - # GFX11: encoding: [0x04,0x0f,0x18,0xf0,0x04,0x00,0x00,0x00] - inst = image_store(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=1) - self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x18,0xf0,0x04,0x00,0x00,0x00])) - - def test_image_load_1d(self): - # image_load v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_1D - # GFX11: encoding: [0x00,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00] - inst = image_load(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=0) # dim=0 is SQ_RSRC_IMG_1D - self.assertEqual(inst.to_bytes(), bytes([0x00,0x0f,0x00,0xf0,0x04,0x00,0x00,0x00])) - - def test_image_sample(self): - # image_sample v[0:3], v[4:6], s[0:7], s[8:11] dmask:0xf dim:SQ_RSRC_IMG_2D - # GFX11: encoding: [0x04,0x0f,0x6c,0xf0,0x04,0x00,0x00,0x08] - inst = image_sample(vdata=v[0:3], vaddr=v[4:6], srsrc=s[0:7], ssamp=s[8:11], dmask=0xf, dim=1) - self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x6c,0xf0,0x04,0x00,0x00,0x08])) - - def test_image_load_d16(self): - # image_load v[0:3], v[4:7], s[0:7] dmask:0xf dim:SQ_RSRC_IMG_2D d16 - # GFX11: encoding: [0x04,0x0f,0x02,0xf0,0x04,0x00,0x00,0x00] - inst = image_load(vdata=v[0:3], vaddr=v[4:7], srsrc=s[0:7], dmask=0xf, dim=1, d16=1) - self.assertEqual(inst.to_bytes(), bytes([0x04,0x0f,0x02,0xf0,0x04,0x00,0x00,0x00])) - - -class TestEXP(unittest.TestCase): - """Test EXP (export) instructions.""" - - def test_exp_mrt0(self): - # exp mrt0 v0, v1, v2, v3 - # GFX11: encoding: [0x0f,0x00,0x00,0xf8,0x00,0x01,0x02,0x03] - inst = EXP(en=0xf, target=0, vsrc0=v[0], vsrc1=v[1], vsrc2=v[2], vsrc3=v[3]) - self.assertEqual(inst.to_bytes(), bytes([0x0f,0x00,0x00,0xf8,0x00,0x01,0x02,0x03])) - - def test_exp_mrtz(self): - # exp mrtz v4, v3, v2, v1 - # GFX11: encoding: [0x8f,0x00,0x00,0xf8,0x04,0x03,0x02,0x01] - inst = EXP(en=0xf, target=8, vsrc0=v[4], vsrc1=v[3], vsrc2=v[2], vsrc3=v[1]) - self.assertEqual(inst.to_bytes(), bytes([0x8f,0x00,0x00,0xf8,0x04,0x03,0x02,0x01])) - - def test_exp_mrtz_done(self): - # exp mrtz v4, v3, v2, v1 done - # GFX11: encoding: [0x8f,0x08,0x00,0xf8,0x04,0x03,0x02,0x01] - inst = EXP(en=0xf, target=8, vsrc0=v[4], vsrc1=v[3], vsrc2=v[2], vsrc3=v[3], done=1) - self.assertEqual(inst.to_bytes(), bytes([0x8f,0x08,0x00,0xf8,0x04,0x03,0x02,0x03])) - - def test_exp_partial_mask(self): - # exp mrt0 v0, v1, off, off (en=0x3, only first two components) - # GFX11: encoding: [0x03,0x00,0x00,0xf8,0x00,0x01,0x00,0x00] - inst = EXP(en=0x3, target=0, vsrc0=v[0], vsrc1=v[1], vsrc2=v[0], vsrc3=v[0]) - self.assertEqual(inst.to_bytes(), bytes([0x03,0x00,0x00,0xf8,0x00,0x01,0x00,0x00])) - - def test_exp_row_en(self): - # exp mrtz v4, v3, v2, v1 row_en - # GFX11: encoding: [0x8f,0x20,0x00,0xf8,0x04,0x03,0x02,0x01] - inst = EXP(en=0xf, target=8, vsrc0=v[4], vsrc1=v[3], vsrc2=v[2], vsrc3=v[1], row=1) - self.assertEqual(inst.to_bytes(), bytes([0x8f,0x20,0x00,0xf8,0x04,0x03,0x02,0x01])) - class TestDS(unittest.TestCase): """Test DS (data share / LDS) instructions.""" @@ -380,18 +216,6 @@ class TestDetectFormat(unittest.TestCase): self.assertEqual(detect_format(global_load_b32(vdst=v[0], addr=v[1:2], saddr=NULL).to_bytes()), FLAT) self.assertEqual(detect_format(global_store_b32(addr=v[0:1], data=v[2], saddr=NULL).to_bytes()), FLAT) - def test_detect_mubuf(self): - self.assertEqual(detect_format(buffer_load_b32(v[0], v[1], s[0:3], s[5]).to_bytes()), MUBUF) - - def test_detect_mtbuf(self): - self.assertEqual(detect_format(tbuffer_load_format_x(v[0], v[1], s[0:3], s[5], format=22).to_bytes()), MTBUF) - - def test_detect_mimg(self): - self.assertEqual(detect_format(image_load(v[0:3], v[4:7], s[0:7], dmask=0xf, dim=1).to_bytes()), MIMG) - - def test_detect_exp(self): - self.assertEqual(detect_format(EXP(en=0xf, target=0, vsrc0=v[0], vsrc1=v[1], vsrc2=v[2], vsrc3=v[3]).to_bytes()), EXP) - def test_detect_vopd(self): inst = VOPD(VOPDOp.V_DUAL_MOV_B32, VOPDOp.V_DUAL_MOV_B32, vdstx=v[0], vdsty=v[1], srcx0=0, srcy0=0) self.assertEqual(detect_format(inst.to_bytes()), VOPD) diff --git a/extra/assembly/amd/test/test_handwritten.py b/extra/assembly/amd/test/test_handwritten.py index 6e12679ee4..40916349af 100644 --- a/extra/assembly/amd/test/test_handwritten.py +++ b/extra/assembly/amd/test/test_handwritten.py @@ -4,7 +4,6 @@ import unittest, struct from extra.assembly.amd.autogen.rdna3.ins import * from extra.assembly.amd.dsl import Inst -from extra.assembly.amd.asm import asm from extra.assembly.amd.test.test_roundtrip import compile_asm class TestIntegration(unittest.TestCase): @@ -13,12 +12,9 @@ class TestIntegration(unittest.TestCase): if not hasattr(self, 'inst'): return b = self.inst.to_bytes() st = self.inst.disasm() - reasm = asm(st) - desc = f"{st:25s} {self.inst} {b!r} {reasm}" + # Test that the instruction can be compiled by LLVM and produces the same bytes + desc = f"{st:25s} {self.inst} {b!r}" self.assertEqual(b, compile_asm(st), desc) - # TODO: this compare should work for valid things - #self.assertEqual(self.inst, reasm) - self.assertEqual(repr(self.inst), repr(reasm)) print(desc) def test_wmma(self): diff --git a/extra/assembly/amd/test/test_integration.py b/extra/assembly/amd/test/test_integration.py index b42cccf596..ff0cc44903 100644 --- a/extra/assembly/amd/test/test_integration.py +++ b/extra/assembly/amd/test/test_integration.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 """Integration test: round-trip RDNA3 assembly through AMD toolchain.""" -import unittest, re, io, sys, subprocess +import unittest, io, sys from extra.assembly.amd.autogen.rdna3.ins import * -from extra.assembly.amd.asm import waitcnt, asm -from extra.assembly.amd.test.helpers import get_llvm_mc + +def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int: + return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10) def disassemble(lib: bytes, arch: str = "gfx1100") -> str: """Disassemble ELF binary using tinygrad's compiler, return raw output.""" @@ -40,7 +41,7 @@ def assemble_and_disassemble(instructions: list, arch: str = "gfx1100") -> list[ return parse_disassembly(disassemble(lib, arch)) class TestIntegration(unittest.TestCase): - """Test our assembler output matches LLVM disassembly.""" + """Test our DSL output matches LLVM disassembly.""" def test_simple_sop1(self): """Test SOP1 instructions round-trip.""" @@ -148,79 +149,6 @@ class TestIntegration(unittest.TestCase): return self.fail("Could not find s_mov_b32 in disassembly") -class TestAsm(unittest.TestCase): - """Test asm() string parsing.""" - - def test_asm_basic(self): - """Test basic instruction parsing.""" - inst = asm('s_mov_b32 s0, s1') - self.assertEqual(inst.to_bytes(), s_mov_b32(s[0], s[1]).to_bytes()) - - def test_asm_with_immediates(self): - """Test parsing with immediate values.""" - inst = asm('s_add_u32 s0, s1, 10') - self.assertEqual(inst.to_bytes(), s_add_u32(s[0], s[1], 10).to_bytes()) - - def test_asm_float_const(self): - """Test parsing float constants.""" - inst = asm('v_mul_f32_e32 v0, 1.0, v1') - self.assertEqual(inst.to_bytes(), v_mul_f32_e32(v[0], 1.0, v[1]).to_bytes()) - - def test_asm_hex_immediate(self): - """Test parsing hex immediates.""" - inst = asm('s_waitcnt 0xfc07') - self.assertEqual(inst.to_bytes(), s_waitcnt(simm16=0xfc07).to_bytes()) - - def test_asm_special_regs(self): - """Test parsing special registers.""" - inst = asm('s_mov_b32 s0, vcc_lo') - self.assertEqual(inst.to_bytes(), s_mov_b32(s[0], VCC_LO).to_bytes()) - - def test_asm_register_range(self): - """Test parsing register ranges.""" - inst = asm('s_load_b128 s[4:7], s[0:1], null') - self.assertEqual(inst.to_bytes(), s_load_b128(s[4:7], s[0:1], NULL).to_bytes()) - - def test_asm_matches_llvm(self): - """Test asm() output matches LLVM assembler.""" - from tinygrad.runtime.support.compiler_amd import HIPCompiler - compiler = HIPCompiler('gfx1100') - - def get_llvm_bytes(instr: str) -> bytes: - src = f'.text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n{instr}\n' - lib = compiler.compile(src) - raw = disassemble(lib) - for line in raw.splitlines(): - if instr.split()[0] in line and '//' in line: - hex_str = line.split('//')[1].strip().split(':')[1].strip() - return bytes.fromhex(hex_str)[::-1] - return b'' - - tests = ['s_mov_b32 s0, s1', 's_endpgm', 'v_add_f32_e32 v0, v1, v2'] - for t in tests: - self.assertEqual(asm(t).to_bytes(), get_llvm_bytes(t), f"mismatch for: {t}") - - def test_asm_vop3_modifiers(self): - """Test asm() with VOP3 modifiers (neg, abs, clamp).""" - def get_llvm_encoding(instr: str) -> str: - result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-show-encoding'], - input=instr, capture_output=True, text=True) - if m := re.search(r'encoding:\s*\[(.*?)\]', result.stdout): - return m.group(1).replace('0x','').replace(',','').replace(' ','') - return '' - - tests = [ - 'v_fma_f32 v0, -v1, v2, v3', # neg on src0 - 'v_fma_f32 v0, v1, |v2|, v3', # abs on src1 - 'v_fma_f32 v0, v1, v2, v3 clamp', # clamp - 'v_fma_f32 v0, -v1, |v2|, v3 clamp', # all modifiers - 'v_fma_f32 v0, -|v1|, v2, v3', # neg+abs on same operand - ] - for t in tests: - our_hex = asm(t).to_bytes().hex() - llvm_hex = get_llvm_encoding(t) - self.assertEqual(our_hex, llvm_hex, f"mismatch for: {t}") - class TestTinygradIntegration(unittest.TestCase): """Test that we can parse disassembled tinygrad kernels.""" diff --git a/extra/assembly/amd/test/test_llvm.py b/extra/assembly/amd/test/test_llvm.py index eb314a046d..dd700d66ca 100644 --- a/extra/assembly/amd/test/test_llvm.py +++ b/extra/assembly/amd/test/test_llvm.py @@ -1,37 +1,45 @@ #!/usr/bin/env python3 -"""Test AMD assembler/disassembler against LLVM test vectors.""" +"""Test AMD assembler/disassembler against LLVM test vectors. + +Only compute-relevant instruction formats are tested. Graphics-only formats not supported: + - MUBUF/MTBUF: buffer instructions with resource descriptors (use GLOBAL/FLAT instead) + - MIMG: image/texture instructions + - EXP/VEXPORT: export instructions for pixel/vertex output + - VIMAGE/VSAMPLE: image sampling instructions (RDNA4) + - VBUFFER: buffer instructions (RDNA4) +""" import unittest, re, subprocess, functools from tinygrad.helpers import fetch -from extra.assembly.amd.asm import asm, disasm +from extra.assembly.amd.disasm import disasm from extra.assembly.amd.decode import decode_inst, detect_format from extra.assembly.amd.test.helpers import get_llvm_mc LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/llvmorg-21.1.0/llvm/test/MC/AMDGPU" +# RDNA3 (gfx11) test files for compute instructions +# Excluded: gfx11_asm_mubuf.s, gfx11_asm_mtbuf.s, gfx11_asm_mimg.s, gfx11_asm_mubuf_alias.s, gfx11_asm_mtbuf_alias.s (graphics-only) RDNA_FILES = ['gfx11_asm_sop1.s', 'gfx11_asm_sop2.s', 'gfx11_asm_sopp.s', 'gfx11_asm_sopk.s', 'gfx11_asm_sopc.s', 'gfx11_asm_vop1.s', 'gfx11_asm_vop2.s', 'gfx11_asm_vopc.s', 'gfx11_asm_vop3.s', 'gfx11_asm_vop3p.s', 'gfx11_asm_vinterp.s', 'gfx11_asm_vopd.s', 'gfx11_asm_vopcx.s', 'gfx11_asm_vop3_from_vop1.s', 'gfx11_asm_vop3_from_vop2.s', 'gfx11_asm_vop3_from_vopc.s', - 'gfx11_asm_vop3_from_vopcx.s', 'gfx11_asm_ds.s', 'gfx11_asm_smem.s', 'gfx11_asm_flat.s', 'gfx11_asm_mubuf.s', 'gfx11_asm_mtbuf.s', - 'gfx11_asm_mimg.s', 'gfx11_asm_wmma.s', 'gfx11_asm_vop3_features.s', 'gfx11_asm_vop3p_features.s', 'gfx11_asm_vopd_features.s', + 'gfx11_asm_vop3_from_vopcx.s', 'gfx11_asm_ds.s', 'gfx11_asm_smem.s', 'gfx11_asm_flat.s', + 'gfx11_asm_wmma.s', 'gfx11_asm_vop3_features.s', 'gfx11_asm_vop3p_features.s', 'gfx11_asm_vopd_features.s', 'gfx11_asm_vop3_alias.s', 'gfx11_asm_vop3p_alias.s', 'gfx11_asm_vopc_alias.s', 'gfx11_asm_vopcx_alias.s', 'gfx11_asm_vinterp_alias.s', - 'gfx11_asm_smem_alias.s', 'gfx11_asm_mubuf_alias.s', 'gfx11_asm_mtbuf_alias.s'] -# CDNA test files - includes gfx9 files for shared instructions, plus gfx90a/gfx942 specific files -# gfx90a_ldst_acc.s has MIMG mixed in, filtered via is_mimg check + 'gfx11_asm_smem_alias.s'] +# CDNA (gfx9/gfx90a/gfx942) test files for compute instructions +# Excluded: gfx9_asm_mubuf.s, gfx9_asm_mtbuf.s, gfx90a_ldst_acc.s (has MIMG mixed in) CDNA_FILES = ['gfx9_asm_sop1.s', 'gfx9_asm_sop2.s', 'gfx9_asm_sopp.s', 'gfx9_asm_sopk.s', 'gfx9_asm_sopc.s', 'gfx9_asm_vop1.s', 'gfx9_asm_vop2.s', 'gfx9_asm_vopc.s', 'gfx9_asm_vop3.s', 'gfx9_asm_vop3p.s', - 'gfx9_asm_ds.s', 'gfx9_asm_flat.s', 'gfx9_asm_smem.s', 'gfx9_asm_mubuf.s', 'gfx9_asm_mtbuf.s', - 'gfx90a_ldst_acc.s', 'gfx90a_asm_features.s', 'flat-scratch-gfx942.s', 'gfx942_asm_features.s', + 'gfx9_asm_ds.s', 'gfx9_asm_flat.s', 'gfx9_asm_smem.s', + 'gfx90a_asm_features.s', 'flat-scratch-gfx942.s', 'gfx942_asm_features.s', 'mai-gfx90a.s', 'mai-gfx942.s'] -# RDNA4 (gfx12) test files - excludes alias/err/fake16/dpp files, and vimage/vsample (not supported) -# NOTE: vflat/vdsdir excluded - not implemented; features.s has mixed formats +# RDNA4 (gfx12) test files for compute instructions +# Excluded: gfx12_asm_vbuffer_mubuf.s, gfx12_asm_vbuffer_mtbuf.s, gfx12_asm_exp.s (graphics-only) RDNA4_FILES = ['gfx12_asm_sop1.s', 'gfx12_asm_sop2.s', 'gfx12_asm_sopp.s', 'gfx12_asm_sopk.s', 'gfx12_asm_sopc.s', 'gfx12_asm_vop1.s', 'gfx12_asm_vop2.s', 'gfx12_asm_vopc.s', 'gfx12_asm_vopcx.s', 'gfx12_asm_vop3.s', 'gfx12_asm_vop3c.s', 'gfx12_asm_vop3cx.s', 'gfx12_asm_vop3p.s', 'gfx12_asm_vop3_from_vop1.s', 'gfx12_asm_vop3_from_vop2.s', 'gfx12_asm_vop3p_features.s', 'gfx12_asm_vopd.s', 'gfx12_asm_vopd_features.s', 'gfx12_asm_ds.s', 'gfx12_asm_smem.s', - 'gfx12_asm_vbuffer_mubuf.s', 'gfx12_asm_vbuffer_mtbuf.s', 'gfx12_asm_wmma_w32.s', 'gfx12_asm_exp.s'] - -def _is_mimg(data: bytes) -> bool: return (int.from_bytes(data[:4], 'little') >> 26) & 0x3f == 0b111100 + 'gfx12_asm_wmma_w32.s'] def _parse_llvm_tests(text: str, pattern: str) -> list[tuple[str, bytes]]: tests = [] @@ -53,15 +61,16 @@ def _parse_llvm_tests(text: str, pattern: str) -> list[tuple[str, bytes]]: def _get_tests(f: str, arch: str) -> list[tuple[str, bytes]]: text = fetch(f"{LLVM_BASE}/{f}").read_bytes().decode('utf-8', errors='ignore') if arch == "rdna3": - tests = _parse_llvm_tests(text, r'(?:GFX11|W32|W64)') + # Match GFX11 and W32 only (wavefront32 mode) + tests = _parse_llvm_tests(text, r'(?:GFX11|W32)') elif arch == "rdna4": - # Match GFX12 but not GFX1250 (which has different lit64 encoding) - tests = _parse_llvm_tests(text, r'(?:GFX12(?!50)|W32|W64)') + # Match GFX12 (but not GFX1250) and W32 only (wavefront32 mode) + tests = _parse_llvm_tests(text, r'(?:GFX12(?!50)|W32)') elif 'gfx90a' in f or 'gfx942' in f: tests = _parse_llvm_tests(text, r'(?:GFX90A|GFX942)') else: tests = _parse_llvm_tests(text, r'(?:VI9|GFX9|CHECK)') - return [(a, d) for a, d in tests if not _is_mimg(d)] if arch == "cdna" else tests + return tests def _compile_asm_batch(instrs: list[str], arch: str = "rdna3") -> list[bytes]: if not instrs: return [] @@ -87,14 +96,6 @@ def _make_test(f: str, arch: str, test_type: str): print(f"{name}: {passed} passed, {skipped} skipped") if arch in ("rdna3", "rdna4"): self.assertEqual(skipped, 0, f"{name}: {skipped} tests skipped, expected 0") - elif test_type == "asm": - passed, skipped = 0, 0 - for asm_text, expected in tests: - try: - self.assertEqual(asm(asm_text, arch).to_bytes(), expected) - passed += 1 - except: skipped += 1 - print(f"{name}: {passed} passed, {skipped} skipped") elif test_type == "disasm": to_test = [] for _, data in tests: @@ -114,14 +115,12 @@ class TestLLVM(unittest.TestCase): pass for f in RDNA_FILES: setattr(TestLLVM, f"test_rdna3_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "roundtrip")) - setattr(TestLLVM, f"test_rdna3_asm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "asm")) setattr(TestLLVM, f"test_rdna3_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "disasm")) for f in CDNA_FILES: setattr(TestLLVM, f"test_cdna_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "cdna", "roundtrip")) setattr(TestLLVM, f"test_cdna_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "cdna", "disasm")) for f in RDNA4_FILES: setattr(TestLLVM, f"test_rdna4_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna4", "roundtrip")) - setattr(TestLLVM, f"test_rdna4_asm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna4", "asm")) setattr(TestLLVM, f"test_rdna4_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna4", "disasm")) if __name__ == "__main__": diff --git a/extra/assembly/amd/test/test_roundtrip.py b/extra/assembly/amd/test/test_roundtrip.py index 64bbe03715..d306629377 100644 --- a/extra/assembly/amd/test/test_roundtrip.py +++ b/extra/assembly/amd/test/test_roundtrip.py @@ -2,7 +2,6 @@ """Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match.""" import unittest, io, sys, re, subprocess, os from extra.assembly.amd.dsl import Inst -from extra.assembly.amd.asm import asm from extra.assembly.amd.decode import decode_inst, detect_format from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump @@ -86,8 +85,8 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): def _test_kernel_roundtrip(self, op_fn): """Generate kernel from op_fn, test: 1. decode -> reencode matches original bytes - 2. asm(disasm()) matches LLVM output - 3. our disasm() matches LLVM's disassembly string exactly + 2. disasm() -> LLVM asm -> bytes matches original (validates disasm correctness) + 3. our disasm() matches LLVM's disassembly string (informational) """ arch = self.arch mcpu, mattr = ARCH_CONFIG[arch] @@ -130,19 +129,19 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): offset += size # Collect disasm strings for batched LLVM calls - skip unknown opcodes (op_X) that LLVM can't compile - asm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for asm test + asm_test_instrs: list[tuple[int, str, bytes]] = [] # (idx, our_disasm, orig_bytes) for asm test disasm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for disasm comparison test for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs): if our_disasm is None: continue - # Skip unknown opcodes and malformed instructions for both tests + # Skip unknown opcodes and malformed instructions if our_disasm.startswith('op_') or re.search(r', \d+, \d+, \d+,', our_disasm): continue - asm_test_instrs.append((idx, our_disasm)) + asm_test_instrs.append((idx, our_disasm, orig_bytes)) disasm_test_instrs.append((idx, our_disasm)) - # Batch compile for asm test - asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs], arch) - asm_llvm_map = {idx: result for (idx, _), result in zip(asm_test_instrs, asm_llvm_results)} + # Batch compile for asm test (our disasm -> LLVM asm -> bytes) + asm_llvm_results = compile_asm_batch([d for _, d, _ in asm_test_instrs], arch) + asm_llvm_map = {idx: (result, orig) for (idx, _, orig), result in zip(asm_test_instrs, asm_llvm_results)} # Batch compile+disasm for disasm comparison test disasm_llvm_results = compile_and_disasm_batch([d for _, d in disasm_test_instrs], arch) @@ -166,20 +165,16 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): decode_failed += 1 decode_failures.append(f"K{ki}@{offset}: {our_disasm}: {decode_err}") - # Asm test + # Asm test: our disasm -> LLVM asm -> compare bytes with original if our_disasm is None: asm_skipped += 1 elif idx in asm_llvm_map: - llvm_bytes = asm_llvm_map[idx] - try: - our_bytes = asm(our_disasm).to_bytes() - if our_bytes[:len(llvm_bytes)] == llvm_bytes: - asm_passed += 1 - else: - asm_failed += 1 - asm_failures.append(f"K{ki}@{offset}: '{our_disasm}': ours={our_bytes[:len(llvm_bytes)].hex()} llvm={llvm_bytes.hex()}") - except Exception: - asm_skipped += 1 + llvm_bytes, orig = asm_llvm_map[idx] + if llvm_bytes == orig[:len(llvm_bytes)]: + asm_passed += 1 + else: + asm_failed += 1 + asm_failures.append(f"K{ki}@{offset}: '{our_disasm}': llvm={llvm_bytes.hex()} orig={orig[:len(llvm_bytes)].hex()}") else: asm_skipped += 1 @@ -197,7 +192,7 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): disasm_skipped += 1 print(f"[{arch}] decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped") - print(f"[{arch}] asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped") + print(f"[{arch}] asm via llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped") print(f"[{arch}] disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped") self.assertEqual(decode_failed, 0, f"Decode failures:\n" + "\n".join(decode_failures[:20])) self.assertEqual(asm_failed, 0, f"Asm failures:\n" + "\n".join(asm_failures[:20])) @@ -248,10 +243,10 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): # Fused ops def test_fma(self): self._test_kernel_roundtrip(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0]))) -@unittest.skip("no asm support for RDNA4") +@unittest.skip("RDNA4 decode roundtrip not yet supported") class TestTinygradKernelRoundtripRDNA4(TestTinygradKernelRoundtrip): arch = 'rdna4' -@unittest.skip("no asm support for CDNA") +@unittest.skip("CDNA decode roundtrip not yet supported") class TestTinygradKernelRoundtripCDNA(TestTinygradKernelRoundtrip): arch = 'cdna' if __name__ == "__main__": diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 6c689c89aa..34b448980e 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -425,10 +425,22 @@ def parse_branch(asm:str) -> int|None: return (x - 0x10000 if x & 0x8000 else x)*4 return None +def _op2dsl(op: str) -> str: + """Convert LLVM asm operand (s0, s[0:1], v0) to DSL format (s[0], s[0:1], v[0]).""" + import re + op = op.strip() + lo = op.lower() + SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC', 'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC', + 'scc': 'SCC', 'm0': 'M0', 'null': 'NULL', 'off': 'OFF'} + if lo in SPEC_DSL: return SPEC_DSL[lo] + rp = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'} + if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return f"{rp[m.group(1)]}[{m.group(2)}:{m.group(3)}]" + if m := re.match(r'^([svt](?:tmp)?)(\d+)$', lo): return f"{rp[m.group(1)]}[{m.group(2)}]" + return op + def amdgpu_tokenize(st:str) -> list[str]: try: from extra.assembly.amd.dsl import s, v, Reg, VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF - from extra.assembly.amd.asm import _op2dsl dsl = eval(_op2dsl(st), {'s':s, 'v':v, 'VCC_LO':VCC_LO, 'VCC_HI':VCC_HI, 'VCC':VCC, 'EXEC_LO':EXEC_LO, 'EXEC_HI':EXEC_HI, 'EXEC':EXEC, 'SCC':SCC, 'M0':M0, 'NULL':NULL, 'OFF':OFF}) return [f"{type(dsl).__name__[0].lower()}{dsl.offset + i}" for i in range(dsl.sz)] if isinstance(dsl, Reg) else [st]