renderer/amd: add cdna emulator (#14721)

* renderer/amd: add cdna emulator

* fixes

* no predecode

* no early

* REMU_PATH

* delete that

* round

* Fix cache invalidation check in _compile_smem
This commit is contained in:
George Hotz
2026-02-13 16:06:58 +08:00
committed by GitHub
parent 08a555c875
commit 5289b4e882
9 changed files with 206 additions and 466 deletions

View File

@@ -708,6 +708,8 @@ jobs:
run: SKIP_SLOW_TEST=1 AMD_LLVM=0 pytest -n=auto test/backend/test_ops.py -k "test_sparse_categorical_crossentropy or test_tril or test_nonzero or test_softmax_argmax" --durations 20
- name: Run RDNA4 emulator tests
run: MOCKGPU_ARCH=rdna4 python -m pytest test/test_tiny.py -v --durations 20
- name: Run CDNA4 emulator tests
run: AMD_LLVM=1 MOCKGPU_ARCH=cdna4 python -m pytest test/test_tiny.py -v --durations 20
testnvidia:
strategy:

View File

@@ -1,267 +0,0 @@
#!/usr/bin/env python3
"""Benchmark comparing Python vs Rust RDNA3 emulators on real tinygrad kernels."""
import ctypes, time, os
from pathlib import Path
from tinygrad.renderer.amd.emu import run_asm as python_run_asm, decode_program
from tinygrad.renderer.amd import decode_inst
from tinygrad.runtime.autogen.amd.rdna3.ins import SOPP, SOPPOp
import tinygrad
EXTRA_DIR = Path(tinygrad.__file__).parent.parent / "extra"
REMU_PATH = EXTRA_DIR / "remu/target/release/libremu.so"
if not REMU_PATH.exists():
REMU_PATH = EXTRA_DIR / "remu/target/release/libremu.dylib"
def get_rust_remu():
"""Load the Rust libremu shared library."""
if not REMU_PATH.exists(): return None
remu = ctypes.CDLL(str(REMU_PATH))
remu.run_asm.restype = ctypes.c_int32
remu.run_asm.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32,
ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_void_p]
return remu
def count_instructions(kernel: bytes) -> int:
"""Count instructions in a kernel."""
return len(decode_program(kernel))
def setup_buffers(buf_sizes: list[int], init_data: dict[int, bytes] | None = None):
"""Allocate buffers and return args pointer + valid ranges."""
if init_data is None: init_data = {}
buffers = []
for i, size in enumerate(buf_sizes):
padded = ((size + 15) // 16) * 16 + 16
data = init_data.get(i, b'\x00' * padded)
data_list = list(data) + [0] * (padded - len(data))
buf = (ctypes.c_uint8 * padded)(*data_list[:padded])
buffers.append(buf)
args = (ctypes.c_uint64 * len(buffers))(*[ctypes.addressof(b) for b in buffers])
args_ptr = ctypes.addressof(args)
ranges = {(ctypes.addressof(b), len(b)) for b in buffers}
ranges.add((args_ptr, ctypes.sizeof(args)))
return buffers, args, args_ptr, ranges
def benchmark_emulator(name: str, run_fn, kernel: bytes, global_size, local_size, args_ptr, rsrc2: int, iterations: int = 5):
"""Benchmark an emulator and return average time."""
gx, gy, gz = global_size
lx, ly, lz = local_size
kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
lib_ptr = ctypes.addressof(kernel_buf)
# Warmup
run_fn(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr, rsrc2)
# Timed runs
times = []
for _ in range(iterations):
start = time.perf_counter()
result = run_fn(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr, rsrc2)
end = time.perf_counter()
if result != 0:
print(f" {name} returned error: {result}")
return None
times.append(end - start)
return sum(times) / len(times)
def profile_instructions(kernel: bytes):
"""Profile individual instruction compile times."""
from tinygrad.renderer.amd.emu import _get_runner, _canonical_runner_cache
from tinygrad.helpers import Context
_get_runner.cache_clear()
_canonical_runner_cache.clear()
results = []
i = 0
while i < len(kernel):
inst = decode_inst(kernel[i:])
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_CODE_END: break
inst_bytes = bytes(kernel[i:i + inst.size() + 4])
try: inst_str = repr(inst)
except Exception: inst_str = f"<{type(inst).__name__}>"
# Time the full compile (sink + render + compile)
start = time.perf_counter()
with Context(CCACHE=0):
runner, is_new = _get_runner(inst_bytes)
compile_time = time.perf_counter() - start
results.append({
'inst_str': inst_str + ('' if is_new else ' [CACHED]'),
'compile_ms': compile_time * 1000 if is_new else 0,
})
i += inst.size()
return sorted(results, key=lambda x: x['compile_ms'], reverse=True)
def benchmark_python_split(kernel: bytes, global_size, local_size, args_ptr, rsrc2: int, iterations: int = 5):
"""Benchmark Python emulator with compile and execution times."""
from tinygrad.renderer.amd.emu import _get_runner, _canonical_runner_cache
from tinygrad.helpers import Context
_get_runner.cache_clear()
_canonical_runner_cache.clear()
decode_program.cache_clear()
# Measure compile time (decode_program builds sinks, renders, and compiles)
compile_start = time.perf_counter()
with Context(CCACHE=0):
program = decode_program(kernel)
compile_time = time.perf_counter() - compile_start
n_compiled = len(_canonical_runner_cache)
# Execution time
exec_time = benchmark_emulator("Python", python_run_asm, kernel, global_size, local_size, args_ptr, rsrc2, iterations)
return compile_time, exec_time, len(program), n_compiled
def get_tinygrad_kernel(op_name: str) -> tuple[bytes, tuple, tuple, list[int], dict[int, bytes], int] | None:
"""Get a real tinygrad kernel by operation name. Returns (code, global_size, local_size, buf_sizes, buf_data, rsrc2)."""
try:
from tinygrad import Tensor
from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.autogen import hsa
import numpy as np
np.random.seed(42)
ops = {
"add": lambda: Tensor.empty(1024) + Tensor.empty(1024),
"mul": lambda: Tensor.empty(1024) * Tensor.empty(1024),
"matmul_small": lambda: Tensor.empty(16, 16) @ Tensor.empty(16, 16),
"matmul_medium": lambda: Tensor.empty(64, 64) @ Tensor.empty(64, 64),
"reduce_sum": lambda: Tensor.empty(4096).sum(),
"reduce_max": lambda: Tensor.empty(4096).max(),
"softmax": lambda: Tensor.empty(256).softmax(),
"layernorm": lambda: Tensor.empty(32, 64).layernorm(),
"conv2d": lambda: Tensor.empty(1, 4, 16, 16).conv2d(Tensor.empty(4, 4, 3, 3)),
"gelu": lambda: Tensor.empty(1024).gelu(),
"exp": lambda: Tensor.empty(1024).exp(),
"sin": lambda: Tensor.empty(1024).sin(),
}
if op_name not in ops: return None
out = ops[op_name]()
sched = out.schedule()
for ei in sched:
lowered = ei.lower()
if ei.ast.op.name == 'SINK' and lowered.prg and lowered.prg.p.lib:
lib = bytes(lowered.prg.p.lib)
image = memoryview(bytearray(lib))
_, sections, _ = elf_loader(lib)
rodata_entry = next((sh.header.sh_addr for sh in sections if sh.name == ".rodata"), -1)
for sec in sections:
if sec.name == '.text':
buf_sizes = [b.nbytes for b in lowered.bufs]
# Get initial data from numpy arrays if available
buf_data = {}
for i, buf in enumerate(lowered.bufs):
if hasattr(buf, 'base') and buf.base is not None and hasattr(buf.base, '_buf'):
try: buf_data[i] = bytes(buf.base._buf)
except Exception: pass
# Extract rsrc2 from ELF (same as ops_amd.py)
group_segment_size = image[rodata_entry:rodata_entry+4].cast("I")[0]
lds_size = ((group_segment_size + 511) // 512) & 0x1FF
code = hsa.amd_kernel_code_t.from_buffer_copy(bytes(image[rodata_entry:rodata_entry+256]) + b'\x00'*256)
rsrc2 = code.compute_pgm_rsrc2 | (lds_size << 15)
return (bytes(sec.content), tuple(lowered.prg.p.global_size), tuple(lowered.prg.p.local_size), buf_sizes, buf_data, rsrc2)
return None
except Exception as e:
print(f" Error getting kernel: {e}")
return None
TINYGRAD_TESTS = ["add", "mul", "reduce_sum", "softmax", "exp", "sin", "gelu", "matmul_small"]
def main():
import argparse
parser = argparse.ArgumentParser(description="Benchmark RDNA3 emulators")
parser.add_argument("--iterations", type=int, default=3, help="Number of iterations per benchmark")
parser.add_argument("--profile", type=str, default=None, help="Profile instructions for a specific kernel (e.g. 'sin')")
parser.add_argument("--top", type=int, default=20, help="Number of top instructions to show in profile")
args = parser.parse_args()
# Profile mode: show individual instruction timing
if args.profile:
kernel_info = get_tinygrad_kernel(args.profile)
if kernel_info is None:
print(f"Failed to get kernel for '{args.profile}'")
return
kernel = kernel_info[0]
print(f"Profiling instructions for '{args.profile}' kernel...")
print("=" * 110)
results = profile_instructions(kernel)
print(f"{'Instruction':<90} {'Compile(ms)':>12}")
print("-" * 110)
for r in results[:args.top]:
inst = r['inst_str'][:87] + "..." if len(r['inst_str']) > 90 else r['inst_str']
print(f"{inst:<90} {r['compile_ms']:>12.3f}")
print("-" * 110)
total = sum(r['compile_ms'] for r in results)
print(f"{'TOTAL':<90} {total:>12.3f}")
return
rust_remu = get_rust_remu()
if rust_remu is None:
print("Rust libremu not found. Build with: cargo build --release --manifest-path extra/remu/Cargo.toml")
print("Running Python-only benchmarks...\n")
print("=" * 90)
print("RDNA3 Emulator Benchmark: Python vs Rust")
print("=" * 90)
results = []
print("\n[TINYGRAD KERNELS]")
print("-" * 90)
for op_name in TINYGRAD_TESTS:
print(f"\n{op_name}:", end=" ", flush=True)
kernel_info = get_tinygrad_kernel(op_name)
if kernel_info is None:
print("failed to compile")
continue
kernel, global_size, local_size, buf_sizes, buf_data, rsrc2 = kernel_info
buffers, args_arr, args_ptr, ranges = setup_buffers(buf_sizes, buf_data)
# Benchmark Python emulator (must be first to measure compile time before cache is populated)
py_compile, py_exec, n_insts, n_compiled = benchmark_python_split(kernel, global_size, local_size, args_ptr, rsrc2, args.iterations)
n_workgroups = global_size[0] * global_size[1] * global_size[2]
n_threads = local_size[0] * local_size[1] * local_size[2]
total_work = n_insts * n_workgroups * n_threads
print(f"{n_insts} insts ({n_compiled} unique) × {n_workgroups} WGs × {n_threads} threads = {total_work:,} ops")
rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size,
args_ptr, rsrc2, args.iterations) if rust_remu else None
if py_compile is not None:
py_exec_rate = total_work / py_exec / 1e6
print(f" Compile: {py_compile*1000:8.3f} ms ({n_compiled} unique)")
print(f" Exec: {py_exec*1000:8.3f} ms ({py_exec_rate:7.2f} M ops/s)")
if rust_time:
rust_rate = total_work / rust_time / 1e6
speedup = py_exec / rust_time if py_exec else 0
print(f" Rust: {rust_time*1000:8.3f} ms ({rust_rate:7.2f} M ops/s) [{speedup:.1f}x faster]")
results.append((op_name, n_insts, n_compiled, n_workgroups, py_compile, py_exec, rust_time))
# Summary table
print("\n" + "=" * 110)
print("SUMMARY")
print("=" * 110)
print(f"{'Name':<16} {'Insts':<6} {'Unique':<6} {'WGs':<5} {'Compile (ms)':<14} {'Exec (ms)':<12} {'Rust (ms)':<12} {'Speedup':<10}")
print("-" * 110)
for name, n_insts, n_compiled, n_wgs, py_compile, py_exec, rust_time in results:
compile_ms = f"{py_compile*1000:.3f}" if py_compile else "error"
exec_ms = f"{py_exec*1000:.3f}" if py_exec else "error"
if rust_time:
rust_ms = f"{rust_time*1000:.3f}"
speedup = f"{py_exec/rust_time:.1f}x" if py_exec else "N/A"
else:
rust_ms, speedup = "N/A", "N/A"
print(f"{name:<16} {n_insts:<6} {n_compiled:<6} {n_wgs:<5} {compile_ms:<14} {exec_ms:<12} {rust_ms:<12} {speedup:<10}")
if __name__ == "__main__":
os.environ["AMD"] = "1"
main()

View File

@@ -1,12 +1,15 @@
# Test to compare Python and Rust RDNA3 emulators by running real tinygrad kernels
import unittest, ctypes
from dataclasses import dataclass
from pathlib import Path
from tinygrad import Device
from tinygrad.renderer.amd.emu import WaveState, decode_program, WAVE_SIZE, VCC_LO, EXEC_LO, SCC
from tinygrad.renderer.amd.emu import WaveState, _decode_at, WAVE_SIZE, VCC_LO, EXEC_LO, SCC
from tinygrad.renderer.amd import decode_inst
from test.amd.helpers import KernelInfo
from test.amd.bench_emu import REMU_PATH
import tinygrad
REMU_PATH = Path(tinygrad.__file__).parent.parent / "extra/remu/target/release/libremu.so"
if not REMU_PATH.exists(): REMU_PATH = Path(tinygrad.__file__).parent.parent / "extra/remu/target/release/libremu.dylib"
def set_valid_mem_ranges(ranges): pass # emu2 doesn't need this
@@ -89,7 +92,7 @@ class RustEmulator:
class PythonEmulator:
def __init__(self):
self.state: WaveState | None = None
self.program: dict | None = None
self.program: dict[int, tuple] = {} # lazily populated: pc -> (name, fxn, globals)
self.vmem_buf = None
self.lds_buf = None
self.kernel_buf = None # Keep kernel bytes alive
@@ -99,27 +102,29 @@ class PythonEmulator:
import ctypes
from tinygrad.device import Buffer, BufferSpec
from tinygrad.dtype import dtypes
# Store kernel in a ctypes buffer so generic instructions can read from vmem at actual PC address
# Store kernel in a ctypes buffer so _decode_at can read from memory at actual PC address
self.kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
self.lib_addr = ctypes.addressof(self.kernel_buf)
# Remap program dict to use actual addresses (like run_asm does)
program_raw = decode_program(kernel)
self.program = {self.lib_addr + offset: val for offset, val in program_raw.items()}
self.program = {}
self.state = WaveState(n_lanes)
self.state.pc = self.lib_addr # Set PC to code base address
self.vmem_buf = Buffer('CPU', 1 << 40, dtypes.uint32, options=BufferSpec(external_ptr=0)).ensure_allocated()
self.lds_buf = Buffer('CPU', 65536 // 4, dtypes.uint32).ensure_allocated()
def _ensure_decoded(self, pc: int):
if pc not in self.program:
runner = _decode_at(pc, "rdna3")
self.program[pc] = (runner.p.function_name, runner._prg.fxn, runner.p.globals)
def step(self) -> int:
import ctypes
assert self.program is not None and self.state is not None
assert self.state is not None
pc = self.state.pc
if pc == 0xFFFFFFFFFFFFFFFF or pc not in self.program: return -1
name, fxn, globals_list, _runner = self.program[pc]
if fxn is None: return 1 # unsupported instruction
if pc == 0xFFFFFFFFFFFFFFFF: return -1
self._ensure_decoded(pc)
name, fxn, globals_list = self.program[pc]
buf_addrs = {0: self.state.sgpr_buf._buf.va_addr, 1: self.state.vgpr_buf._buf.va_addr, # type: ignore[union-attr]
2: self.vmem_buf._buf.va_addr, 3: self.lds_buf._buf.va_addr} # type: ignore[union-attr]
# Direct ctypes call - bypasses HCQ overhead
fxn(*[ctypes.c_uint64(buf_addrs[g]) for g in globals_list], ctypes.c_int32(0))
return -1 if self.state.pc == 0xFFFFFFFFFFFFFFFF else 0
@@ -140,7 +145,7 @@ class PythonEmulator:
exec_mask=sgpr[EXEC_LO.offset], sgpr=sgpr, vgpr=vgpr)
def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: tuple[int, int, int],
local_size: tuple[int, int, int], program, max_steps: int, debug: bool, trace_len: int,
local_size: tuple[int, int, int], max_steps: int, debug: bool, trace_len: int,
kernel_idx: int = 0, max_workgroups: int = 8) -> tuple[bool, str, int]:
"""Run a single kernel through both emulators. Returns (success, message, total_steps)."""
gx, gy, gz = global_size
@@ -181,9 +186,9 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
rust_before = rust.get_snapshot()
python_before = python.get_snapshot()
assert python.program is not None
inst_info = python.program.get(python.lib_addr + python_before.pc * 4) # Convert word offset to actual address
inst_hex_name = inst_info[0] if inst_info else f"unknown at PC={python_before.pc}"
pc_addr = python.lib_addr + python_before.pc * 4 # Convert word offset to actual address
python._ensure_decoded(pc_addr)
inst_hex_name = python.program[pc_addr][0]
# Decode the instruction to get mnemonic for sync_after checks
try:
# Format is mnemonic_hexbytes, e.g. v_exp_f32_e32_014b027e -> hex is 014b027e
@@ -310,12 +315,11 @@ def compare_emulators_multi_kernel(kernels: list[KernelInfo], buf_pool: dict[int
kernel_ranges = ranges | {(args_ptr, ctypes.sizeof(args))}
set_valid_mem_ranges(kernel_ranges)
program = decode_program(kernel.code)
n_lanes = kernel.local_size[0] * kernel.local_size[1] * kernel.local_size[2]
ok, msg, steps = run_single_kernel(
kernel.code, min(n_lanes, 32), args_ptr, kernel.global_size,
kernel.local_size, program, max_steps, debug, trace_len, ki
kernel.local_size, max_steps, debug, trace_len, ki
)
total_steps += steps
if not ok:
@@ -341,9 +345,8 @@ def compare_emulators_with_memory(kernel: bytes, n_lanes: int, buf_sizes: list,
ranges.add((args_ptr, ctypes.sizeof(args)))
set_valid_mem_ranges(ranges)
program = decode_program(kernel)
# Legacy wrapper assumes local_size = (n_lanes, 1, 1)
ok, msg, _ = run_single_kernel(kernel, n_lanes, args_ptr, global_size, (n_lanes, 1, 1), program, max_steps, debug, trace_len)
ok, msg, _ = run_single_kernel(kernel, n_lanes, args_ptr, global_size, (n_lanes, 1, 1), max_steps, debug, trace_len)
return ok, msg
def get_kernels_from_tinygrad(op_fn) -> tuple[list[KernelInfo], dict[int, int], dict[int, bytes]]:

View File

@@ -1,96 +0,0 @@
import unittest, ctypes
from tinygrad.runtime.autogen.amd.rdna4 import ins as ir4
from tinygrad.renderer.amd.dsl import v, s
from tinygrad.renderer.amd.emu import WaveState, decode_program
from tinygrad.device import Buffer, BufferSpec
from tinygrad.dtype import dtypes
class TestRDNA4Emu(unittest.TestCase):
def _run(self, insts: list, sgprs: dict[int, int] | None = None, vgprs: dict[tuple[int, int], int] | None = None) -> WaveState:
"""Run instructions and return final WaveState."""
# Add S_ENDPGM if not present
if not any(isinstance(i, ir4.SOPP) and i.op == ir4.SOPPOp.S_ENDPGM for i in insts):
insts = list(insts) + [ir4.SOPP(ir4.SOPPOp.S_ENDPGM, simm=0)]
# Assemble and decode
code = b''.join(i.to_bytes() for i in insts)
code_buf = (ctypes.c_uint8 * len(code)).from_buffer_copy(code)
code_addr = ctypes.addressof(code_buf)
program_raw = decode_program(code, "rdna4")
program = {code_addr + offset: val for offset, val in program_raw.items()}
# Setup wave state
st = WaveState(n_lanes=1)
st.pc = code_addr
for idx, val in (sgprs or {}).items(): st._write_sgpr(idx, val)
for (reg, lane), val in (vgprs or {}).items(): st._write_vgpr(reg, lane, val)
# Setup vmem buffer with external_ptr=0 (maps to address 0, allows any pointer access)
vmem_buf = Buffer('CPU', 1 << 40, dtypes.uint32, options=BufferSpec(external_ptr=0)).ensure_allocated()
# Execute
c_bufs = [ctypes.c_uint64(st.sgpr_buf._buf.va_addr), ctypes.c_uint64(st.vgpr_buf._buf.va_addr),
ctypes.c_uint64(vmem_buf._buf.va_addr), ctypes.c_uint64(0), ctypes.c_uint64(0)]
for _ in range(100):
if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF or pc not in program: break
_, fxn, globals_list, _ = program[pc]
fxn(*[c_bufs[g] for g in globals_list])
return st
def test_vopd_dual_mov(self):
"""Test VOPD with two V_DUAL_MOV_B32 operations: v[1]=s[1], v[2]=s[2]."""
insts = [ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0])]
st = self._run(insts, sgprs={1: 0x40e00000, 2: 0x41100000}) # 7.0f, 9.0f
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
def test_vopd_dual_mov_after_other_vopd(self):
"""Test VOPD reuse: first VOPD(v[3]=0, v[0]=?), then VOPD(v[1]=s[1], v[2]=s[2])."""
# This matches the BEAM kernel sequence that fails
insts = [
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
vdstx=v[3], vdsty=v[0], srcx0=0, srcy0=s[0], vsrcx1=v[0], vsrcy1=v[0]), # v[3]=0, v[0]=s[0]
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0]), # v[1]=s[1], v[2]=s[2]
]
st = self._run(insts, sgprs={0: 0x40a00000, 1: 0x40e00000, 2: 0x41100000}) # 5.0f, 7.0f, 9.0f
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
def test_vopd_with_s_add_f32_sequence(self):
"""Test full BEAM kernel sequence: s_add_f32 then VOPD."""
# This is the exact sequence from the failing BEAM kernel
insts = [
ir4.SOP2(ir4.SOP2Op.S_ADD_F32, sdst=s[0], ssrc0=s[0], ssrc1=s[8]), # s[0] = s[0] + s[8]
ir4.SOP2(ir4.SOP2Op.S_ADD_F32, sdst=s[1], ssrc0=s[1], ssrc1=s[9]), # s[1] = s[1] + s[9]
ir4.SOP2(ir4.SOP2Op.S_ADD_F32, sdst=s[2], ssrc0=s[2], ssrc1=s[10]), # s[2] = s[2] + s[10]
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
vdstx=v[3], vdsty=v[0], srcx0=0, srcy0=s[0], vsrcx1=v[0], vsrcy1=v[0]),
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0]),
]
# Input: s[0:2] = [1,2,3], s[8:10] = [4,5,6]
# After s_add_f32: s[0:2] = [5,7,9]
st = self._run(insts, sgprs={0: 0x3f800000, 1: 0x40000000, 2: 0x40400000, # 1.0, 2.0, 3.0
8: 0x40800000, 9: 0x40a00000, 10: 0x40c00000}) # 4.0, 5.0, 6.0
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
def test_s_mov_b32_then_vopd(self):
"""Test s_mov_b32 followed by VOPD - simulates BEAM kernel sequence."""
# Use s_mov_b32 with SGPR source (copy from pre-initialized SGPRs)
# s[10:12] will have values set by test harness, copy to s[0:2], then VOPD to VGPRs
insts = [
ir4.SOP1(ir4.SOP1Op.S_MOV_B32, sdst=s[0], ssrc0=s[10]), # s[0] = s[10]
ir4.SOP1(ir4.SOP1Op.S_MOV_B32, sdst=s[1], ssrc0=s[11]), # s[1] = s[11]
ir4.SOP1(ir4.SOP1Op.S_MOV_B32, sdst=s[2], ssrc0=s[12]), # s[2] = s[12]
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0]),
]
st = self._run(insts, sgprs={10: 0x40a00000, 11: 0x40e00000, 12: 0x41100000}) # 5.0, 7.0, 9.0
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
if __name__ == '__main__':
unittest.main()

View File

@@ -90,9 +90,9 @@ class AMDDriver(VirtDriver):
def _prepare_gpu(self, gpu_id):
self.doorbells[gpu_id] = memoryview(bytearray(0x2000))
self.gpus[gpu_id] = AMDGPU(gpu_id)
# IP versions: rdna3 = GC 11.0.0, NBIF 4.3.0; rdna4 = GC 12.0.0, NBIF 6.3.1
ip_versions = {"rdna3": {"gc": (11, 0, 0), "sdma": (6, 0, 0), "nbif": (4, 3, 0)},
"rdna4": {"gc": (12, 0, 0), "sdma": (6, 0, 0), "nbif": (6, 3, 1)}}[MOCKGPU_ARCH]
"rdna4": {"gc": (12, 0, 0), "sdma": (6, 0, 0), "nbif": (6, 3, 1)},
"cdna4": {"gc": (9, 5, 0), "sdma": (4, 4, 5), "nbif": (7, 9, 0)}}[MOCKGPU_ARCH]
def ip_discovery_files(hwid, ver, base_addr):
p = f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{hwid}/0'
return [VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{hwid}', functools.partial(DirFileDesc, child_names=['0'])),

View File

@@ -5,7 +5,7 @@ from tinygrad.helpers import getbits, to_mv, getenv
from tinygrad.runtime.support import c
MOCKGPU_ARCH = getenv("MOCKGPU_ARCH", "rdna3")
GFX_TARGET_VERSION = {"rdna3": 110000, "rdna4": 120000}[MOCKGPU_ARCH]
GFX_TARGET_VERSION = {"rdna3": 110000, "rdna4": 120000, "cdna4": 90500}[MOCKGPU_ARCH]
import tinygrad.runtime.autogen.amd_gpu as amd_gpu, tinygrad.runtime.autogen.am.pm4_nv as pm4
SDMA_MAX_COPY_SIZE = 0x400000
@@ -106,8 +106,8 @@ class PM4Executor(AMDQueue):
return (self.rptr[0] - prev_rptr) + executed_in_ib
def _exec_acquire_mem(self, n):
assert n == 6
for _ in range(7): self._next_dword() # TODO: implement
assert n in (5, 6)
for _ in range(n + 1): self._next_dword() # TODO: implement
def _exec_release_mem(self, n):
assert n == 6
@@ -184,6 +184,12 @@ class PM4Executor(AMDQueue):
args_addr = self.gpu.regs[regCOMPUTE_USER_DATA_0] + (self.gpu.regs[regCOMPUTE_USER_DATA_0 + 1] << 32)
lc = [self.gpu.regs[i] for i in range(regCOMPUTE_NUM_THREAD_X, regCOMPUTE_NUM_THREAD_X+3)]
rsrc2 = self.gpu.regs[regCOMPUTE_PGM_RSRC2]
# Read all user data registers (hardware loads these directly into s[0:N])
user_sgpr_count = (rsrc2 >> 1) & 0x1F # USER_SGPR_COUNT is bits 1:5
user_data = []
for i in range(user_sgpr_count):
try: user_data.append(self.gpu.regs[regCOMPUTE_USER_DATA_0 + i])
except KeyError: user_data.append(0)
prg_sz = 0
for st,sz in self.gpu.mapped_ranges:
@@ -197,11 +203,12 @@ class PM4Executor(AMDQueue):
scratch_size = wavesize * 4 # This gives the scratch size per thread (lane)
assert prg_sz > 0, "Invalid prg ptr (not found in mapped ranges)"
# Pass valid memory ranges, rsrc2, scratch_size and arch to Python emulator
# Pass valid memory ranges, rsrc2, scratch_size, arch, and user data registers to Python emulator
if hasattr(remu, 'valid_mem_ranges'): remu.valid_mem_ranges = self.gpu.mapped_ranges
if hasattr(remu, 'rsrc2'): remu.rsrc2 = rsrc2
if hasattr(remu, 'scratch_size'): remu.scratch_size = scratch_size
if hasattr(remu, 'arch'): remu.arch = self.gpu.arch
if hasattr(remu, 'user_data'): remu.user_data = user_data
err = remu.run_asm(prg_addr, prg_sz, *gl, *lc, args_addr)
if err != 0: raise RuntimeError("remu does not support the new instruction introduced in this kernel")
@@ -318,7 +325,7 @@ class AMDGPU(VirtGPU):
self.regs = AMDGPURegisters()
self.mapped_ranges = set()
self.queues = []
self.arch = MOCKGPU_ARCH
self.arch = "cdna" if MOCKGPU_ARCH == "cdna4" else MOCKGPU_ARCH
def map_range(self, vaddr, size): self.mapped_ranges.add((vaddr, size))
def unmap_range(self, vaddr, size): self.mapped_ranges.remove((vaddr, size))
@@ -329,7 +336,7 @@ class AMDGPU(VirtGPU):
self.queues.append(SDMAExecutor(self, base, size, rptr, wptr))
return len(self.queues) - 1
gpu_props = """cpu_cores_count 0
_gpu_props_rdna = """cpu_cores_count 0
simd_count 192
mem_banks_count 1
caches_count 206
@@ -367,3 +374,44 @@ sdma_fw_version 20
unique_id 11673270660693242239
num_xcc 1
max_engine_clk_ccompute 2400"""
_gpu_props_cdna = """cpu_cores_count 0
simd_count 304
mem_banks_count 1
caches_count 206
io_links_count 1
p2p_links_count 5
cpu_core_id_base 0
simd_id_base 2147488032
max_waves_per_simd 16
lds_size_in_kb 128
gds_size_in_kb 0
num_gws 64
wave_front_size 64
array_count 16
simd_arrays_per_engine 4
cu_per_simd_array 19
simd_per_cu 2
max_slots_scratch_cu 32
gfx_target_version {gfx_target_version}
vendor_id 4098
device_id 29772
location_id 34304
domain 0
drm_render_minor {drm_render_minor}
hive_id 0
num_sdma_engines 2
num_sdma_xgmi_engines 0
num_sdma_queues_per_engine 6
num_cp_queues 8
max_engine_clk_fcompute 2100
local_mem_size 0
fw_version 2140
capability 671588992
debug_prop 1495
sdma_fw_version 20
unique_id 11673270660693242239
num_xcc 1
max_engine_clk_ccompute 2100"""
gpu_props = _gpu_props_cdna if MOCKGPU_ARCH == "cdna4" else _gpu_props_rdna

View File

@@ -21,10 +21,11 @@ class PythonRemu:
rsrc2: int = 0x19c # Default: USER_SGPR_COUNT=14, enable X and Y workgroup IDs
scratch_size: int = 0 # private_segment_fixed_size from kernel descriptor
arch: str = "rdna3" # Architecture: rdna3 or rdna4
user_data: list[int] = [] # All COMPUTE_USER_DATA registers (loaded into s[0:N])
def run_asm(self, lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int) -> int:
from tinygrad.renderer.amd.emu import run_asm
return run_asm(lib, lib_sz, gx, gy, gz, lx, ly, lz, args_ptr, self.rsrc2, self.scratch_size, self.arch)
return run_asm(lib, lib_sz, gx, gy, gz, lx, ly, lz, args_ptr, self.rsrc2, self.scratch_size, self.arch, self.user_data)
def _try_dlopen_remu():
# Use Python emulator only if PYTHON_REMU=1

View File

@@ -7,7 +7,7 @@
# arg=4: scratch - per-lane scratch memory
from __future__ import annotations
import ctypes, functools, re, platform, subprocess, tempfile
from typing import Any, Callable
from typing import Callable
# Set/restore DAZ+FTZ (denormals-are-zero + flush-to-zero) to match RDNA3 default float mode
# x86: MXCSR bits DAZ(6)+FTZ(15), ARM64: FPCR bit FZ(24)
@@ -61,8 +61,10 @@ from tinygrad.engine.realize import get_runner
from tinygrad.renderer.amd import decode_inst
from tinygrad.runtime.autogen.amd.rdna3.str_pcode import PCODE as PCODE_RDNA3
from tinygrad.runtime.autogen.amd.rdna4.str_pcode import PCODE as PCODE_RDNA4
from tinygrad.runtime.autogen.amd.cdna.str_pcode import PCODE as PCODE_CDNA
from tinygrad.runtime.autogen.amd.rdna3 import ins as ir3
from tinygrad.runtime.autogen.amd.rdna4 import ins as ir4
from tinygrad.runtime.autogen.amd.cdna import ins as irc
from tinygrad.renderer.amd.dsl import VCC_LO, EXEC_LO, SCC, ttmp
from tinygrad.runtime.autogen.amd.common import Fmt, OpType
from tinygrad.renderer.amd.pcode import parse_block, _FUNCS
@@ -160,7 +162,7 @@ _pcode_fixes = {
def _get_pcode_dict(op) -> dict:
"""Return the PCODE dictionary for the given opcode based on its architecture."""
return PCODE_RDNA4 if 'rdna4' in type(op).__module__ else PCODE_RDNA3
return PCODE_CDNA if 'cdna' in type(op).__module__ else PCODE_RDNA4 if 'rdna4' in type(op).__module__ else PCODE_RDNA3
# Pcode parser
@functools.cache
@@ -465,8 +467,8 @@ class _Ctx:
pcode = get_pcode(op)
vcc_reg = sdst_reg if sdst_reg is not None else VCC_LO.offset
if 'VCC' not in srcs: srcs['VCC'] = self.rsgpr_dyn(_c(vcc_reg))
srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane,
'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0)}) # rounding mode: 0=RNE, RTZ constant
srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane, 'VDST': vdst_reg,
'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0), 'ROUND_NEAREST_EVEN': _c(0)}) # rounding mode constants
_, assigns = parse_pcode(pcode, srcs)
# For integer ops with clamp, compute overflow using wide arithmetic
@@ -543,10 +545,11 @@ class _Ctx:
def _compile_sopp(inst: ir3.SOPP | ir4.SOPP, ctx: _Ctx) -> UOp:
simm16 = ctx.inst_field_signed(type(inst).simm16).cast(dtypes.int16)
if inst.op in (ir3.SOPPOp.S_ENDPGM, ir4.SOPPOp.S_ENDPGM):
if inst.op in (ir3.SOPPOp.S_ENDPGM, ir4.SOPPOp.S_ENDPGM, irc.SOPPOp.S_ENDPGM):
return UOp.sink(ctx.wsgpr_dyn(_c(PC_LO_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)),
ctx.wsgpr_dyn(_c(PC_HI_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)))
if inst.op in (ir3.SOPPOp.S_NOP, ir4.SOPPOp.S_NOP): return UOp.sink(*ctx.inc_pc()) # S_NOP is a no-op
# S_NOP and S_WAITCNT are no-ops in emulator (no pipeline/cache to wait on)
if inst.op in (ir3.SOPPOp.S_NOP, ir4.SOPPOp.S_NOP, irc.SOPPOp.S_NOP, irc.SOPPOp.S_WAITCNT): return UOp.sink(*ctx.inc_pc())
# NOTE: we ignore SOPPs without PCODE
if inst.op in _get_pcode_dict(inst.op):
pcode = get_pcode(inst.op)
@@ -562,10 +565,7 @@ def _compile_sopp(inst: ir3.SOPP | ir4.SOPP, ctx: _Ctx) -> UOp:
def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp:
# Cache invalidation instructions are no-ops in the emulator (we don't model caches)
cache_inv_ops = [ir3.SMEMOp.S_GL1_INV, ir3.SMEMOp.S_DCACHE_INV, ir4.SMEMOp.S_DCACHE_INV]
if hasattr(ir4.SMEMOp, 'S_GL1_INV'): cache_inv_ops.append(ir4.SMEMOp.S_GL1_INV)
if inst.op in cache_inv_ops:
return UOp.sink(*ctx.inc_pc())
if '_INV' in inst.op.name: return UOp.sink(*ctx.inc_pc())
# Dynamic sbase field (bits 5:0) - SGPR pair, field value * 2 = register offset
sbase = ctx.inst_field(type(inst).sbase) * _c(2)
# Dynamic sdata field (bits 12:6) - destination SGPR
@@ -573,34 +573,44 @@ def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp:
# RDNA4 uses 'ioffset', RDNA3 uses 'offset' - use type(inst) to get correct field
offset_field = type(inst).ioffset if hasattr(type(inst), 'ioffset') else type(inst).offset # type: ignore[union-attr]
offset = ctx.inst_field_signed(offset_field) # signed immediate
# Dynamic soffset field - SGPR for additional offset (NULL=124 reads as 0)
soffset = ctx.inst_field(type(inst).soffset)
addr = _u64(ctx.rsgpr_dyn(sbase), ctx.rsgpr_dyn(sbase + _c(1))) + offset.cast(dtypes.uint64) + ctx.rsgpr_dyn(soffset).cast(dtypes.uint64)
# Dynamic soffset field - SGPR for additional offset (NULL=124 reads as 0, CDNA soffset_en=0 means no soffset)
soffset_val = _c(0).cast(dtypes.uint64)
if not (isinstance(inst, irc.SMEM) and not inst.soffset_en):
soffset_val = ctx.inst_field(type(inst).soffset)
soffset_val = ctx.rsgpr_dyn(soffset_val).cast(dtypes.uint64)
addr = _u64(ctx.rsgpr_dyn(sbase), ctx.rsgpr_dyn(sbase + _c(1))) + offset.cast(dtypes.uint64) + soffset_val
_SMEM_NDWORDS = {ir3.SMEMOp.S_LOAD_B32: 1, ir3.SMEMOp.S_LOAD_B64: 2, ir3.SMEMOp.S_LOAD_B128: 4,
ir3.SMEMOp.S_LOAD_B256: 8, ir3.SMEMOp.S_LOAD_B512: 16, ir4.SMEMOp.S_LOAD_B32: 1, ir4.SMEMOp.S_LOAD_B64: 2,
ir4.SMEMOp.S_LOAD_B96: 3, ir4.SMEMOp.S_LOAD_B128: 4, ir4.SMEMOp.S_LOAD_B256: 8, ir4.SMEMOp.S_LOAD_B512: 16}
ir4.SMEMOp.S_LOAD_B96: 3, ir4.SMEMOp.S_LOAD_B128: 4, ir4.SMEMOp.S_LOAD_B256: 8, ir4.SMEMOp.S_LOAD_B512: 16,
irc.SMEMOp.S_LOAD_DWORD: 1, irc.SMEMOp.S_LOAD_DWORDX2: 2, irc.SMEMOp.S_LOAD_DWORDX4: 4,
irc.SMEMOp.S_LOAD_DWORDX8: 8, irc.SMEMOp.S_LOAD_DWORDX16: 16}
ndwords = _SMEM_NDWORDS[inst.op]
stores = [ctx.wsgpr_dyn(sdata_reg + _c(i), ctx.vmem.index((addr + UOp.const(dtypes.uint64, i * 4) >> UOp.const(dtypes.uint64, 2)).cast(dtypes.int)))
for i in range(ndwords)]
return UOp.sink(*stores, *ctx.inc_pc())
def _compile_sop(inst: ir3.SOP1 | ir3.SOP2 | ir3.SOPC | ir3.SOPK | ir4.SOP1 | ir4.SOP2 | ir4.SOPC | ir4.SOPK, ctx: _Ctx) -> UOp:
def _compile_sop(inst: ir3.SOP1|ir3.SOP2|ir3.SOPC|ir3.SOPK|ir4.SOP1|ir4.SOP2|ir4.SOPC|ir4.SOPK|irc.SOP1|irc.SOP2|irc.SOPC|irc.SOPK, ctx: _Ctx) -> UOp:
bits = inst.canonical_op_bits
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
if isinstance(inst, (ir3.SOPK, ir4.SOPK)):
if isinstance(inst, (ir3.SOPK, ir4.SOPK, irc.SOPK)):
sdst_off = ctx.inst_field(type(inst).sdst)
simm16 = ctx.inst_field(type(inst).simm16)
# Sign-extend simm16
simm16_sext = simm16.cast(dtypes.int16).cast(dtypes.int32)
srcs = {'S0': ctx.rsgpr_dyn(sdst_off), 'SIMM16': simm16_sext, 'D0': ctx.rsgpr_dyn(sdst_off)}
# CDNA pcode uses S0 for the immediate in MOVK/MULK/ADDK/CMOVK (where RDNA uses SIMM16),
# but S0 = register for CMPK/SETREG. S1 is always the immediate for CDNA CMPK ops.
op_name = inst.op.name if hasattr(inst.op, 'name') else ''
s0_is_imm = isinstance(inst, irc.SOPK) and 'CMPK' not in op_name and 'SETREG' not in op_name
s0_val = simm16_sext if s0_is_imm else ctx.rsgpr_dyn(sdst_off)
srcs = {'S0': s0_val, 'SIMM16': simm16_sext, 'S1': simm16_sext, 'D0': ctx.rsgpr_dyn(sdst_off)}
dst_off, dst_size = sdst_off, 1
elif isinstance(inst, (ir3.SOP1, ir4.SOP1)):
elif isinstance(inst, (ir3.SOP1, ir4.SOP1, irc.SOP1)):
sdst_off = ctx.inst_field(type(inst).sdst)
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal)}
dst_off, dst_size = sdst_off, bits['d'] // 32
elif isinstance(inst, (ir3.SOP2, ir4.SOP2)):
elif isinstance(inst, (ir3.SOP2, ir4.SOP2, irc.SOP2)):
sdst_off = ctx.inst_field(type(inst).sdst)
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
ssrc1_off = ctx.inst_field(type(inst).ssrc1)
@@ -608,7 +618,7 @@ def _compile_sop(inst: ir3.SOP1 | ir3.SOP2 | ir3.SOPC | ir3.SOPK | ir4.SOP1 | ir
'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)}
if literal is not None: srcs['SIMM32'] = literal
dst_off, dst_size = sdst_off, bits['d'] // 32
elif isinstance(inst, (ir3.SOPC, ir4.SOPC)):
elif isinstance(inst, (ir3.SOPC, ir4.SOPC, irc.SOPC)):
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
ssrc1_off = ctx.inst_field(type(inst).ssrc1)
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal),
@@ -619,7 +629,7 @@ def _compile_sop(inst: ir3.SOP1 | ir3.SOP2 | ir3.SOPC | ir3.SOPK | ir4.SOP1 | ir
return ctx.compile_sop_pcode(inst.op, srcs, dst_off, dst_size)
def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VOP1_SDST | ir4.VOP2, ctx: _Ctx) -> UOp:
def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VOP1_SDST | ir4.VOP2 | irc.VOP1 | irc.VOP2, ctx: _Ctx) -> UOp:
op_name = _op_name(inst)
if op_name in ('V_READFIRSTLANE_B32_E32', 'V_PERMLANE64_B32_E32'): return ctx.compile_lane_pcode(inst.op, inst)
lane, exec_mask, bits = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset)), inst.canonical_op_bits
@@ -628,7 +638,7 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO
write_hi_half = bits['d'] == 16 and (vdst_reg >= _c(128))
if isinstance(write_hi_half, UOp): vdst_reg = write_hi_half.where(vdst_reg - _c(128), vdst_reg)
elif write_hi_half: vdst_reg -= 128
if isinstance(inst, (ir3.VOP1, ir4.VOP1)):
if isinstance(inst, (ir3.VOP1, ir4.VOP1, irc.VOP1)):
# Handle VOP1 hi-half source operand (src0 >= v[128] for 16-bit ops)
src0_off = ctx.inst_field(type(inst).src0)
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal)
@@ -654,12 +664,13 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO
s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0)
srcs = {'S0': s0, 'S1': s1, 'D0': d0}
if inst.op in (ir3.VOP2Op.V_FMAAK_F32_E32, ir3.VOP2Op.V_FMAMK_F32_E32, ir3.VOP2Op.V_FMAAK_F16_E32,
ir3.VOP2Op.V_FMAMK_F16_E32):
ir3.VOP2Op.V_FMAMK_F16_E32, irc.VOP2Op.V_FMAAK_F32_E32, irc.VOP2Op.V_FMAMK_F32_E32):
assert literal is not None
srcs['SIMM32'] = literal
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=write_hi_half)
def _compile_vopc(inst: ir3.VOPC | ir3.VOP3 | ir4.VOPC | ir4.VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp:
def _compile_vopc(inst: ir3.VOPC|ir3.VOP3|ir4.VOPC|ir4.VOP3|irc.VOPC|irc.VOP3, ctx: _Ctx,
opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp:
exec_mask, op_name, bits = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst), inst.canonical_op_bits
is_cmpx, is_vopc = 'CMPX' in op_name, hasattr(inst, 'vsrc1') # is_vopc: e32 vs e64
@@ -707,7 +718,7 @@ def _compile_vopc(inst: ir3.VOPC | ir3.VOP3 | ir4.VOPC | ir4.VOP3, ctx: _Ctx, op
stores = [ctx.wsgpr_dyn(dst_off, new_result)] if not is_vopc else [ctx.wsgpr_dyn(_c(VCC_LO.offset), new_result)]
return UOp.sink(*stores, *ctx.inc_pc())
def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3, ctx: _Ctx) -> UOp:
def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3 | irc.VOP3, ctx: _Ctx) -> UOp:
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
bits = inst.canonical_op_bits
opsel, op_name = getattr(inst, 'opsel', 0) or 0, _op_name(inst)
@@ -741,13 +752,13 @@ def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3, ctx: _Ctx) -> UOp:
src1 = _apply_src_mods(src1, 1, abs_bits, neg_bits, bits['s1'])
src2 = _apply_src_mods(src2, 2, abs_bits, neg_bits, bits['s2'])
srcs = {'S0': src0, 'S1': src1, 'S2': src2}
if inst.op in (ir3.VOP3Op.V_CNDMASK_B32_E64, ir3.VOP3Op.V_CNDMASK_B16) and src2 is not None: srcs['VCC'] = src2
if inst.op in (ir3.VOP3Op.V_CNDMASK_B32_E64, ir3.VOP3Op.V_CNDMASK_B16, irc.VOP3Op.V_CNDMASK_B32_E64) and src2 is not None: srcs['VCC'] = src2
# FMAC instructions need D0 (accumulator) from destination register
if 'FMAC' in op_name: srcs['D0'] = ctx.rvgpr_dyn(vdst_reg, lane)
opsel_dst_hi = bool(opsel & 0b1000) and bits['d'] == 16
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=opsel_dst_hi, clmp=getattr(inst, 'clmp', 0))
def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD, ctx: _Ctx) -> UOp:
def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD | irc.VOP3SD, ctx: _Ctx) -> UOp:
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
bits, pcode, ops = inst.canonical_op_bits, get_pcode(inst.op), inst.canonical_operands
@@ -806,7 +817,7 @@ def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD, ctx: _Ctx) -> UOp:
else:
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, sdst_reg=inst.sdst.offset)
def _compile_wmma(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
def _compile_wmma(inst: ir3.VOP3P | ir4.VOP3P | irc.VOP3P, ctx: _Ctx) -> UOp:
op_name = _op_name(inst)
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
vdst_reg = ctx.inst_field(type(inst).vdst)
@@ -839,14 +850,15 @@ def _compile_wmma(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32), mat_d[i].bitcast(dtypes.uint32), exec_mask) for i in range(256)]
return UOp.sink(*stores, *ctx.inc_pc())
def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P | irc.VOP3P, ctx: _Ctx) -> UOp:
op_name = _op_name(inst)
if 'WMMA' in op_name and ('16X16X16_F16' in op_name or '16X16X16_BF16' in op_name): return _compile_wmma(inst, ctx)
lane = ctx.range()
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
vdst_reg = ctx.inst_field(type(inst).vdst)
do_cast = any(x in op_name for x in ('F16', 'F32', 'BF16')) and 'IU' not in op_name
is_pk_f32 = 'PK' in op_name and 'F32' in op_name and 'MOV' not in op_name # CDNA packed F32 ops
do_cast = any(x in op_name for x in ('F16', 'F32', 'BF16')) and 'IU' not in op_name and not is_pk_f32
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, 16, do_cast=do_cast)
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, 16, do_cast=do_cast)
src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, 16, do_cast=do_cast)
@@ -854,7 +866,30 @@ def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
opsel_hi2 = getattr(inst, 'opsel_hi2', 1) if getattr(inst, 'opsel_hi2', 1) is not None else 1
neg, neg_hi = getattr(inst, 'neg', 0) or 0, getattr(inst, 'neg_hi', 0) or 0
if 'FMA_MIX' in op_name:
if is_pk_f32:
# CDNA packed F32: read 32-bit sources, build 64-bit packed values using opsel.
# For VGPRs: opsel selects between v[reg] (0) and v[reg+1] (1) for each half.
# For SGPR pairs (off < 128): s[N] = lo float32, s[N+1] = hi float32.
# For inline constants (128 <= off < 256): broadcast same value to both halves.
src_offs = [ctx.inst_field(type(inst).src0), ctx.inst_field(type(inst).src1), ctx.inst_field(type(inst).src2)]
def build_pk_f32(src_lo: UOp, src_off: UOp, opsel_lo: int, opsel_hi_bit: int, neg_lo: int, neg_hi_bit: int) -> UOp:
is_vgpr = src_off >= _c(256)
vgpr_lo = ctx.rvgpr_dyn(src_off - _c(256), lane) if lane is not None else _c(0)
vgpr_hi = ctx.rvgpr_dyn(src_off - _c(256) + _c(1), lane) if lane is not None else _c(0)
# For SGPR pairs, opsel selects between s[N] (0) and s[N+1] (1); inline constants always broadcast.
is_sgpr_pair = src_off < _c(128)
sgpr_hi = ctx.rsgpr_dyn(src_off + _c(1), is_sgpr_pair)
scalar_lo_sel = src_lo if not opsel_lo else is_sgpr_pair.where(sgpr_hi, src_lo)
scalar_hi_sel = src_lo if not opsel_hi_bit else is_sgpr_pair.where(sgpr_hi, src_lo)
lo = is_vgpr.where(vgpr_hi if opsel_lo else vgpr_lo, scalar_lo_sel)
hi = is_vgpr.where(vgpr_hi if opsel_hi_bit else vgpr_lo, scalar_hi_sel)
if neg_lo: lo = lo ^ UOp.const(dtypes.uint32, 0x80000000)
if neg_hi_bit: hi = hi ^ UOp.const(dtypes.uint32, 0x80000000)
return _u64(lo, hi)
srcs = {'S0': build_pk_f32(src0, src_offs[0], opsel & 1, opsel_hi & 1, neg & 1, neg_hi & 1),
'S1': build_pk_f32(src1, src_offs[1], opsel & 2, opsel_hi & 2, neg & 2, neg_hi & 2),
'S2': build_pk_f32(src2, src_offs[2], opsel & 4, 1 if opsel_hi2 else 0, neg & 4, neg_hi & 4)}
elif 'FMA_MIX' in op_name:
combined_opsel_hi = (opsel_hi & 0x3) | ((opsel_hi2 & 0x1) << 2)
# For FMA_MIX: neg_hi is ABS (not neg!), neg is actual negation
def apply_abs(v, bit, opsel_hi_bit, opsel_bit):
@@ -924,13 +959,18 @@ def _compile_vopd(inst: ir3.VOPD | ir4.VOPD, ctx: _Ctx) -> UOp:
if dest.startswith('D0'): all_stores.append(ctx.wvgpr_dyn(vdst_reg, lane, _val_to_u32(val), exec_mask, after=srcy1))
return UOp.sink(UOp.group(*all_stores).end(lane), *ctx.inc_pc())
def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS | ir4.VFLAT | ir4.VGLOBAL | ir4.VSCRATCH, ctx: _Ctx) -> UOp:
def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLAT|ir4.VGLOBAL|ir4.VSCRATCH
|irc.DS|irc.FLAT|irc.GLOBAL|irc.SCRATCH, ctx: _Ctx) -> UOp:
"""Unified memory operation compiler for DS, FLAT, GLOBAL, SCRATCH."""
exec_mask, op_name = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst)
pcode = get_pcode(inst.op)
# CDNA pcode uses CalcGlobalAddr/CalcDsAddr to compute address from raw components, but make_addr already handles this.
# Strip the addr computation line and use pre-computed ADDR directly (rename 'addr' -> 'ADDR' in remaining pcode).
if isinstance(inst, (irc.GLOBAL, irc.FLAT, irc.SCRATCH, irc.DS)) and 'Calc' in pcode and 'Addr' in pcode:
pcode = re.sub(r'addr\s*=\s*Calc\w+Addr\([^)]*\)\s*;?\n?', '', pcode).replace('MEM[addr', 'MEM[ADDR')
is_lds = isinstance(inst, (ir3.DS, ir4.DS))
is_scratch = isinstance(inst, (ir3.SCRATCH, ir4.VSCRATCH))
is_lds = isinstance(inst, (ir3.DS, ir4.DS, irc.DS))
is_scratch = isinstance(inst, (ir3.SCRATCH, ir4.VSCRATCH, irc.SCRATCH))
mem = ctx.lds if is_lds else ctx.scratch if is_scratch else ctx.vmem
addr_shift = UOp.const(dtypes.uint32 if is_lds else dtypes.uint64, 2)
@@ -1038,7 +1078,7 @@ def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS
if 'STORE' in op_name and data_bits_mem >= 64:
vdata = vdata | (ctx.rvgpr_dyn(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active,
'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base}
'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base, 'SADDR': saddr_base, 'OFFSET': offset}
for i in range(data_bits_mem // 32):
srcs[f'VDATA{i}'] = ctx.rvgpr_dyn(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0)
return srcs
@@ -1075,7 +1115,7 @@ def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS
return UOp.sink(*ended, *ctx.inc_pc())
# Standard path: single lane range
writes_return_data = '_RTN' in op_name or (is_lds and op_name.startswith('DS_LOAD')) or bool(is_atomic and glc)
writes_return_data = '_RTN' in op_name or (is_lds and (op_name.startswith('DS_LOAD') or op_name.startswith('DS_READ'))) or bool(is_atomic and glc)
lane = ctx.range()
active = _lane_active(exec_mask, lane)
pcode_vars, assigns = parse_pcode(pcode, make_srcs(lane))
@@ -1099,6 +1139,11 @@ _INST_HANDLERS: dict[type, Callable[..., UOp]] = {
ir4.VOP1: _compile_vop12, ir4.VOP1_SDST: _compile_vop12, ir4.VOP2: _compile_vop12, ir4.VOPC: _compile_vopc, ir4.VOP3: _compile_vop3,
ir4.VOP3_SDST: _compile_vop3, ir4.VOP3SD: _compile_vop3sd, ir4.VOP3P: _compile_vop3p, ir4.VOPD: _compile_vopd,
ir4.DS: _compile_mem_op, ir4.VFLAT: _compile_mem_op, ir4.VGLOBAL: _compile_mem_op, ir4.VSCRATCH: _compile_mem_op,
# CDNA instruction classes
irc.SOPP: _compile_sopp, irc.SMEM: _compile_smem, irc.SOP1: _compile_sop, irc.SOP2: _compile_sop, irc.SOPC: _compile_sop, irc.SOPK: _compile_sop,
irc.VOP1: _compile_vop12, irc.VOP2: _compile_vop12, irc.VOPC: _compile_vopc, irc.VOP3: _compile_vop3,
irc.VOP3_SDST: _compile_vop3, irc.VOP3SD: _compile_vop3sd, irc.VOP3P: _compile_vop3p,
irc.DS: _compile_mem_op, irc.FLAT: _compile_mem_op, irc.GLOBAL: _compile_mem_op, irc.SCRATCH: _compile_mem_op,
}
# ═══════════════════════════════════════════════════════════════════════════════
@@ -1116,7 +1161,7 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
# Check if instruction matches any cached canonical pattern
for base, mask, size, runner in _canonical_runner_cache:
if inst_size == size and (inst_int & mask) == base: return runner, False
if inst_size == size and (inst_int & mask) == base: return runner
# Look up handler by type, falling back to base classes for _LIT variants
handler = _INST_HANDLERS.get(type(inst))
@@ -1136,30 +1181,17 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES=""):
runner = get_runner('CPU', sink)
_canonical_runner_cache.append((base, mask, size, runner))
return runner, True
return runner
@functools.cache
def decode_program(data: bytes, arch: str = "rdna3") -> dict[int, tuple[str, Callable, list[int], Any]]:
"""Decode program to {pc: (name, fxn, globals, runner)}."""
result: dict[int, tuple[str, Callable, list[int], Any]] = {}
i = 0
while i < len(data):
inst = decode_inst(data[i:], arch)
if hasattr(inst, 'op') and inst.op in (ir3.SOPPOp.S_CODE_END, ir4.SOPPOp.S_CODE_END): break
try:
runner, is_new = _get_runner(bytes(data[i:i + inst.size() + 4]), arch)
if DEBUG >= 3:
try: inst_str = repr(inst)
except Exception: inst_str = f"<{type(inst).__name__} at PC={i}>"
msg = f"[emu] PC={i}: {inst_str}"
print(colored(msg, 'green') if is_new else msg)
result[i] = (runner.p.function_name, runner._prg.fxn, runner.p.globals, runner)
except Exception as e:
try: inst_str = repr(inst)
except Exception: inst_str = f"<{type(inst).__name__}>"
raise RuntimeError(f"[emu] Failed to compile PC={i} {inst_str}: {type(e).__name__}: {e}") from e
i += inst.size()
return result
def _decode_at(pc: int, arch: str):
"""Decode and compile instruction at absolute address pc. Returns CompiledRunner."""
inst_bytes = bytes((ctypes.c_char * 16).from_address(pc).raw)
inst = decode_inst(inst_bytes, arch)
try: return _get_runner(bytes(inst_bytes[:inst.size() + 4]), arch)
except Exception as e:
try: inst_str = repr(inst)
except Exception: inst_str = f"<{type(inst).__name__}>"
raise RuntimeError(f"[emu] Failed to compile {inst_str}: {type(e).__name__}: {e}") from e
# ═══════════════════════════════════════════════════════════════════════════════
# WAVE STATE
@@ -1206,10 +1238,9 @@ class WaveState:
# ═══════════════════════════════════════════════════════════════════════════════
def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c,
scratch_size: int = 0, arch: str = "rdna3") -> int:
scratch_size: int = 0, arch: str = "rdna3", user_data: list[int]|None = None) -> int:
"""Execute AMD assembly program. scratch_size is private_segment_fixed_size from kernel descriptor (per-lane)."""
program_raw = decode_program(bytes((ctypes.c_char * lib_sz).from_address(lib).raw), arch)
program = {lib + offset: val for offset, val in program_raw.items()} # Remap to actual addresses
program: dict[int, tuple[Callable, list[int]]] = {} # lazily populated: pc -> (fxn, globals) extracted from runner
lds_size = ((rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE_SHIFT) * 512
total_threads = lx * ly * lz
@@ -1226,8 +1257,12 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
for wave_start in range(0, total_threads, WAVE_SIZE):
n_lanes, st = min(WAVE_SIZE, total_threads - wave_start), WaveState(min(WAVE_SIZE, total_threads - wave_start))
st.pc = lib # Set PC to code base address
st._write_sgpr(0, args_ptr & MASK32)
st._write_sgpr(1, (args_ptr >> 32) & MASK32)
# Initialize user SGPRs: hardware loads COMPUTE_USER_DATA registers directly into s[0:N]
if user_data:
for i, val in enumerate(user_data): st._write_sgpr(i, val)
else:
st._write_sgpr(0, args_ptr & MASK32)
st._write_sgpr(1, (args_ptr >> 32) & MASK32)
# Workgroup IDs in SGPRs after user SGPRs
sgpr_idx = (rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT_SHIFT
@@ -1255,13 +1290,16 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
ctypes.c_uint64(vmem_buf._buf.va_addr), ctypes.c_uint64(lds_buf._buf.va_addr),
ctypes.c_uint64(scratch_buf._buf.va_addr if scratch_buf else 0)]
for inst_count in range(1_000_000):
if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF or pc not in program: break
name, fxn, globals_list, _ = program[pc]
assert fxn is not None, f"[emu] No fxn for {name} at PC={pc}"
assert 4 not in globals_list or scratch_buf, f"SCRATCH instruction {name} but scratch_size=0"
if DEBUG >= 6:
inst = decode_inst(bytes((ctypes.c_char * 12).from_address(pc).raw), arch)
print(f"[emu] exec PC={pc:X}: {inst!r}")
if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF: break
if pc not in program:
prev_len = len(_canonical_runner_cache)
runner = _decode_at(pc, arch)
program[pc] = (runner._prg.fxn, runner.p.globals)
if DEBUG >= 3:
inst = decode_inst(bytes((ctypes.c_char * 16).from_address(pc).raw), arch)
msg = f"[emu] PC={pc - lib}: {inst!r}"
print(colored(msg, 'green') if len(_canonical_runner_cache) > prev_len else msg)
fxn, globals_list = program[pc]
fxn(*[c_bufs[g] for g in globals_list])
else: raise RuntimeError("exceeded 1M instructions, likely infinite loop")
return 0

View File

@@ -40,7 +40,10 @@ def _bitreverse(v: UOp, bits: int) -> UOp:
def _extract_bits(val: UOp, hi: int, lo: int) -> UOp:
dt = dtypes.uint64 if val.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
return ((val >> _const(dt, lo)) if lo > 0 else val) & _const(val.dtype, (1 << (hi - lo + 1)) - 1)
result = ((val >> _const(dt, lo)) if lo > 0 else val) & _const(val.dtype, (1 << (hi - lo + 1)) - 1)
# Downcast to uint32 when extracting <=32 bits from a 64-bit value, so .f32 bitcast works correctly
if dt == dtypes.uint64 and (hi - lo + 1) <= 32: result = result.cast(dtypes.uint32)
return result
def _set_bit(old, pos, val):
mask = _u32(1) << pos
@@ -554,7 +557,9 @@ class Parser:
self.eat('LBRACKET')
self.eat_val('laneId', 'IDENT')
self.eat('RBRACKET')
result = (base >> _to_u32(self.vars['laneId'])) & _u32(1)
lane = self.vars['laneId']
shift = lane.cast(base.dtype) if base.dtype != dtypes.uint32 else _to_u32(lane)
result = (base >> shift) & _const(base.dtype, 1)
if self.try_eat('DOT'):
dt_name = self.eat('IDENT').val
return result.cast(DTYPES.get(dt_name, dtypes.uint32))
@@ -806,6 +811,12 @@ def _subst_loop_var(line: str, loop_var: str, val: int) -> str:
def _set_bits(old: UOp, val: UOp, width: int, offset: int) -> UOp:
"""Set bits [offset:offset+width) in old to val, masking and shifting appropriately."""
is64 = old.dtype in (dtypes.uint64, dtypes.int64) or offset + width > 32
if is64:
old = old.cast(dtypes.uint64) if old.dtype != dtypes.uint64 else old
mask = _u64(((1 << width) - 1) << offset)
v = (val.cast(dtypes.uint64) if val.dtype != dtypes.uint64 else val) & _u64((1 << width) - 1)
return (old & (mask ^ _u64(0xFFFFFFFFFFFFFFFF))) | (v << _u64(offset))
mask = _u32(((1 << width) - 1) << offset)
v = (val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val) & _u32((1 << width) - 1)
return (old & (mask ^ _u32(0xFFFFFFFF))) | (v << _u32(offset))