mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
renderer/amd: add cdna emulator (#14721)
* renderer/amd: add cdna emulator * fixes * no predecode * no early * REMU_PATH * delete that * round * Fix cache invalidation check in _compile_smem
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -708,6 +708,8 @@ jobs:
|
||||
run: SKIP_SLOW_TEST=1 AMD_LLVM=0 pytest -n=auto test/backend/test_ops.py -k "test_sparse_categorical_crossentropy or test_tril or test_nonzero or test_softmax_argmax" --durations 20
|
||||
- name: Run RDNA4 emulator tests
|
||||
run: MOCKGPU_ARCH=rdna4 python -m pytest test/test_tiny.py -v --durations 20
|
||||
- name: Run CDNA4 emulator tests
|
||||
run: AMD_LLVM=1 MOCKGPU_ARCH=cdna4 python -m pytest test/test_tiny.py -v --durations 20
|
||||
|
||||
testnvidia:
|
||||
strategy:
|
||||
|
||||
@@ -1,267 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Benchmark comparing Python vs Rust RDNA3 emulators on real tinygrad kernels."""
|
||||
import ctypes, time, os
|
||||
from pathlib import Path
|
||||
|
||||
from tinygrad.renderer.amd.emu import run_asm as python_run_asm, decode_program
|
||||
from tinygrad.renderer.amd import decode_inst
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import SOPP, SOPPOp
|
||||
|
||||
import tinygrad
|
||||
EXTRA_DIR = Path(tinygrad.__file__).parent.parent / "extra"
|
||||
REMU_PATH = EXTRA_DIR / "remu/target/release/libremu.so"
|
||||
if not REMU_PATH.exists():
|
||||
REMU_PATH = EXTRA_DIR / "remu/target/release/libremu.dylib"
|
||||
|
||||
def get_rust_remu():
|
||||
"""Load the Rust libremu shared library."""
|
||||
if not REMU_PATH.exists(): return None
|
||||
remu = ctypes.CDLL(str(REMU_PATH))
|
||||
remu.run_asm.restype = ctypes.c_int32
|
||||
remu.run_asm.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32,
|
||||
ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_void_p]
|
||||
return remu
|
||||
|
||||
def count_instructions(kernel: bytes) -> int:
|
||||
"""Count instructions in a kernel."""
|
||||
return len(decode_program(kernel))
|
||||
|
||||
def setup_buffers(buf_sizes: list[int], init_data: dict[int, bytes] | None = None):
|
||||
"""Allocate buffers and return args pointer + valid ranges."""
|
||||
if init_data is None: init_data = {}
|
||||
buffers = []
|
||||
for i, size in enumerate(buf_sizes):
|
||||
padded = ((size + 15) // 16) * 16 + 16
|
||||
data = init_data.get(i, b'\x00' * padded)
|
||||
data_list = list(data) + [0] * (padded - len(data))
|
||||
buf = (ctypes.c_uint8 * padded)(*data_list[:padded])
|
||||
buffers.append(buf)
|
||||
args = (ctypes.c_uint64 * len(buffers))(*[ctypes.addressof(b) for b in buffers])
|
||||
args_ptr = ctypes.addressof(args)
|
||||
ranges = {(ctypes.addressof(b), len(b)) for b in buffers}
|
||||
ranges.add((args_ptr, ctypes.sizeof(args)))
|
||||
return buffers, args, args_ptr, ranges
|
||||
|
||||
def benchmark_emulator(name: str, run_fn, kernel: bytes, global_size, local_size, args_ptr, rsrc2: int, iterations: int = 5):
|
||||
"""Benchmark an emulator and return average time."""
|
||||
gx, gy, gz = global_size
|
||||
lx, ly, lz = local_size
|
||||
kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
|
||||
lib_ptr = ctypes.addressof(kernel_buf)
|
||||
|
||||
# Warmup
|
||||
run_fn(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr, rsrc2)
|
||||
|
||||
# Timed runs
|
||||
times = []
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
result = run_fn(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr, rsrc2)
|
||||
end = time.perf_counter()
|
||||
if result != 0:
|
||||
print(f" {name} returned error: {result}")
|
||||
return None
|
||||
times.append(end - start)
|
||||
|
||||
return sum(times) / len(times)
|
||||
|
||||
def profile_instructions(kernel: bytes):
|
||||
"""Profile individual instruction compile times."""
|
||||
from tinygrad.renderer.amd.emu import _get_runner, _canonical_runner_cache
|
||||
from tinygrad.helpers import Context
|
||||
_get_runner.cache_clear()
|
||||
_canonical_runner_cache.clear()
|
||||
|
||||
results = []
|
||||
i = 0
|
||||
while i < len(kernel):
|
||||
inst = decode_inst(kernel[i:])
|
||||
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_CODE_END: break
|
||||
inst_bytes = bytes(kernel[i:i + inst.size() + 4])
|
||||
try: inst_str = repr(inst)
|
||||
except Exception: inst_str = f"<{type(inst).__name__}>"
|
||||
|
||||
# Time the full compile (sink + render + compile)
|
||||
start = time.perf_counter()
|
||||
with Context(CCACHE=0):
|
||||
runner, is_new = _get_runner(inst_bytes)
|
||||
compile_time = time.perf_counter() - start
|
||||
|
||||
results.append({
|
||||
'inst_str': inst_str + ('' if is_new else ' [CACHED]'),
|
||||
'compile_ms': compile_time * 1000 if is_new else 0,
|
||||
})
|
||||
i += inst.size()
|
||||
|
||||
return sorted(results, key=lambda x: x['compile_ms'], reverse=True)
|
||||
|
||||
def benchmark_python_split(kernel: bytes, global_size, local_size, args_ptr, rsrc2: int, iterations: int = 5):
|
||||
"""Benchmark Python emulator with compile and execution times."""
|
||||
from tinygrad.renderer.amd.emu import _get_runner, _canonical_runner_cache
|
||||
from tinygrad.helpers import Context
|
||||
_get_runner.cache_clear()
|
||||
_canonical_runner_cache.clear()
|
||||
decode_program.cache_clear()
|
||||
|
||||
# Measure compile time (decode_program builds sinks, renders, and compiles)
|
||||
compile_start = time.perf_counter()
|
||||
with Context(CCACHE=0):
|
||||
program = decode_program(kernel)
|
||||
compile_time = time.perf_counter() - compile_start
|
||||
n_compiled = len(_canonical_runner_cache)
|
||||
|
||||
# Execution time
|
||||
exec_time = benchmark_emulator("Python", python_run_asm, kernel, global_size, local_size, args_ptr, rsrc2, iterations)
|
||||
return compile_time, exec_time, len(program), n_compiled
|
||||
|
||||
def get_tinygrad_kernel(op_name: str) -> tuple[bytes, tuple, tuple, list[int], dict[int, bytes], int] | None:
|
||||
"""Get a real tinygrad kernel by operation name. Returns (code, global_size, local_size, buf_sizes, buf_data, rsrc2)."""
|
||||
try:
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
from tinygrad.runtime.autogen import hsa
|
||||
import numpy as np
|
||||
np.random.seed(42)
|
||||
|
||||
ops = {
|
||||
"add": lambda: Tensor.empty(1024) + Tensor.empty(1024),
|
||||
"mul": lambda: Tensor.empty(1024) * Tensor.empty(1024),
|
||||
"matmul_small": lambda: Tensor.empty(16, 16) @ Tensor.empty(16, 16),
|
||||
"matmul_medium": lambda: Tensor.empty(64, 64) @ Tensor.empty(64, 64),
|
||||
"reduce_sum": lambda: Tensor.empty(4096).sum(),
|
||||
"reduce_max": lambda: Tensor.empty(4096).max(),
|
||||
"softmax": lambda: Tensor.empty(256).softmax(),
|
||||
"layernorm": lambda: Tensor.empty(32, 64).layernorm(),
|
||||
"conv2d": lambda: Tensor.empty(1, 4, 16, 16).conv2d(Tensor.empty(4, 4, 3, 3)),
|
||||
"gelu": lambda: Tensor.empty(1024).gelu(),
|
||||
"exp": lambda: Tensor.empty(1024).exp(),
|
||||
"sin": lambda: Tensor.empty(1024).sin(),
|
||||
}
|
||||
|
||||
if op_name not in ops: return None
|
||||
out = ops[op_name]()
|
||||
sched = out.schedule()
|
||||
|
||||
for ei in sched:
|
||||
lowered = ei.lower()
|
||||
if ei.ast.op.name == 'SINK' and lowered.prg and lowered.prg.p.lib:
|
||||
lib = bytes(lowered.prg.p.lib)
|
||||
image = memoryview(bytearray(lib))
|
||||
_, sections, _ = elf_loader(lib)
|
||||
rodata_entry = next((sh.header.sh_addr for sh in sections if sh.name == ".rodata"), -1)
|
||||
for sec in sections:
|
||||
if sec.name == '.text':
|
||||
buf_sizes = [b.nbytes for b in lowered.bufs]
|
||||
# Get initial data from numpy arrays if available
|
||||
buf_data = {}
|
||||
for i, buf in enumerate(lowered.bufs):
|
||||
if hasattr(buf, 'base') and buf.base is not None and hasattr(buf.base, '_buf'):
|
||||
try: buf_data[i] = bytes(buf.base._buf)
|
||||
except Exception: pass
|
||||
# Extract rsrc2 from ELF (same as ops_amd.py)
|
||||
group_segment_size = image[rodata_entry:rodata_entry+4].cast("I")[0]
|
||||
lds_size = ((group_segment_size + 511) // 512) & 0x1FF
|
||||
code = hsa.amd_kernel_code_t.from_buffer_copy(bytes(image[rodata_entry:rodata_entry+256]) + b'\x00'*256)
|
||||
rsrc2 = code.compute_pgm_rsrc2 | (lds_size << 15)
|
||||
return (bytes(sec.content), tuple(lowered.prg.p.global_size), tuple(lowered.prg.p.local_size), buf_sizes, buf_data, rsrc2)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f" Error getting kernel: {e}")
|
||||
return None
|
||||
|
||||
TINYGRAD_TESTS = ["add", "mul", "reduce_sum", "softmax", "exp", "sin", "gelu", "matmul_small"]
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Benchmark RDNA3 emulators")
|
||||
parser.add_argument("--iterations", type=int, default=3, help="Number of iterations per benchmark")
|
||||
parser.add_argument("--profile", type=str, default=None, help="Profile instructions for a specific kernel (e.g. 'sin')")
|
||||
parser.add_argument("--top", type=int, default=20, help="Number of top instructions to show in profile")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Profile mode: show individual instruction timing
|
||||
if args.profile:
|
||||
kernel_info = get_tinygrad_kernel(args.profile)
|
||||
if kernel_info is None:
|
||||
print(f"Failed to get kernel for '{args.profile}'")
|
||||
return
|
||||
kernel = kernel_info[0]
|
||||
print(f"Profiling instructions for '{args.profile}' kernel...")
|
||||
print("=" * 110)
|
||||
results = profile_instructions(kernel)
|
||||
print(f"{'Instruction':<90} {'Compile(ms)':>12}")
|
||||
print("-" * 110)
|
||||
for r in results[:args.top]:
|
||||
inst = r['inst_str'][:87] + "..." if len(r['inst_str']) > 90 else r['inst_str']
|
||||
print(f"{inst:<90} {r['compile_ms']:>12.3f}")
|
||||
print("-" * 110)
|
||||
total = sum(r['compile_ms'] for r in results)
|
||||
print(f"{'TOTAL':<90} {total:>12.3f}")
|
||||
return
|
||||
|
||||
rust_remu = get_rust_remu()
|
||||
if rust_remu is None:
|
||||
print("Rust libremu not found. Build with: cargo build --release --manifest-path extra/remu/Cargo.toml")
|
||||
print("Running Python-only benchmarks...\n")
|
||||
|
||||
print("=" * 90)
|
||||
print("RDNA3 Emulator Benchmark: Python vs Rust")
|
||||
print("=" * 90)
|
||||
|
||||
results = []
|
||||
|
||||
print("\n[TINYGRAD KERNELS]")
|
||||
print("-" * 90)
|
||||
|
||||
for op_name in TINYGRAD_TESTS:
|
||||
print(f"\n{op_name}:", end=" ", flush=True)
|
||||
kernel_info = get_tinygrad_kernel(op_name)
|
||||
if kernel_info is None:
|
||||
print("failed to compile")
|
||||
continue
|
||||
|
||||
kernel, global_size, local_size, buf_sizes, buf_data, rsrc2 = kernel_info
|
||||
buffers, args_arr, args_ptr, ranges = setup_buffers(buf_sizes, buf_data)
|
||||
|
||||
# Benchmark Python emulator (must be first to measure compile time before cache is populated)
|
||||
py_compile, py_exec, n_insts, n_compiled = benchmark_python_split(kernel, global_size, local_size, args_ptr, rsrc2, args.iterations)
|
||||
|
||||
n_workgroups = global_size[0] * global_size[1] * global_size[2]
|
||||
n_threads = local_size[0] * local_size[1] * local_size[2]
|
||||
total_work = n_insts * n_workgroups * n_threads
|
||||
|
||||
print(f"{n_insts} insts ({n_compiled} unique) × {n_workgroups} WGs × {n_threads} threads = {total_work:,} ops")
|
||||
rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size,
|
||||
args_ptr, rsrc2, args.iterations) if rust_remu else None
|
||||
|
||||
if py_compile is not None:
|
||||
py_exec_rate = total_work / py_exec / 1e6
|
||||
print(f" Compile: {py_compile*1000:8.3f} ms ({n_compiled} unique)")
|
||||
print(f" Exec: {py_exec*1000:8.3f} ms ({py_exec_rate:7.2f} M ops/s)")
|
||||
if rust_time:
|
||||
rust_rate = total_work / rust_time / 1e6
|
||||
speedup = py_exec / rust_time if py_exec else 0
|
||||
print(f" Rust: {rust_time*1000:8.3f} ms ({rust_rate:7.2f} M ops/s) [{speedup:.1f}x faster]")
|
||||
|
||||
results.append((op_name, n_insts, n_compiled, n_workgroups, py_compile, py_exec, rust_time))
|
||||
|
||||
# Summary table
|
||||
print("\n" + "=" * 110)
|
||||
print("SUMMARY")
|
||||
print("=" * 110)
|
||||
print(f"{'Name':<16} {'Insts':<6} {'Unique':<6} {'WGs':<5} {'Compile (ms)':<14} {'Exec (ms)':<12} {'Rust (ms)':<12} {'Speedup':<10}")
|
||||
print("-" * 110)
|
||||
|
||||
for name, n_insts, n_compiled, n_wgs, py_compile, py_exec, rust_time in results:
|
||||
compile_ms = f"{py_compile*1000:.3f}" if py_compile else "error"
|
||||
exec_ms = f"{py_exec*1000:.3f}" if py_exec else "error"
|
||||
if rust_time:
|
||||
rust_ms = f"{rust_time*1000:.3f}"
|
||||
speedup = f"{py_exec/rust_time:.1f}x" if py_exec else "N/A"
|
||||
else:
|
||||
rust_ms, speedup = "N/A", "N/A"
|
||||
print(f"{name:<16} {n_insts:<6} {n_compiled:<6} {n_wgs:<5} {compile_ms:<14} {exec_ms:<12} {rust_ms:<12} {speedup:<10}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["AMD"] = "1"
|
||||
main()
|
||||
@@ -1,12 +1,15 @@
|
||||
# Test to compare Python and Rust RDNA3 emulators by running real tinygrad kernels
|
||||
import unittest, ctypes
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from tinygrad import Device
|
||||
|
||||
from tinygrad.renderer.amd.emu import WaveState, decode_program, WAVE_SIZE, VCC_LO, EXEC_LO, SCC
|
||||
from tinygrad.renderer.amd.emu import WaveState, _decode_at, WAVE_SIZE, VCC_LO, EXEC_LO, SCC
|
||||
from tinygrad.renderer.amd import decode_inst
|
||||
from test.amd.helpers import KernelInfo
|
||||
from test.amd.bench_emu import REMU_PATH
|
||||
import tinygrad
|
||||
REMU_PATH = Path(tinygrad.__file__).parent.parent / "extra/remu/target/release/libremu.so"
|
||||
if not REMU_PATH.exists(): REMU_PATH = Path(tinygrad.__file__).parent.parent / "extra/remu/target/release/libremu.dylib"
|
||||
|
||||
def set_valid_mem_ranges(ranges): pass # emu2 doesn't need this
|
||||
|
||||
@@ -89,7 +92,7 @@ class RustEmulator:
|
||||
class PythonEmulator:
|
||||
def __init__(self):
|
||||
self.state: WaveState | None = None
|
||||
self.program: dict | None = None
|
||||
self.program: dict[int, tuple] = {} # lazily populated: pc -> (name, fxn, globals)
|
||||
self.vmem_buf = None
|
||||
self.lds_buf = None
|
||||
self.kernel_buf = None # Keep kernel bytes alive
|
||||
@@ -99,27 +102,29 @@ class PythonEmulator:
|
||||
import ctypes
|
||||
from tinygrad.device import Buffer, BufferSpec
|
||||
from tinygrad.dtype import dtypes
|
||||
# Store kernel in a ctypes buffer so generic instructions can read from vmem at actual PC address
|
||||
# Store kernel in a ctypes buffer so _decode_at can read from memory at actual PC address
|
||||
self.kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
|
||||
self.lib_addr = ctypes.addressof(self.kernel_buf)
|
||||
# Remap program dict to use actual addresses (like run_asm does)
|
||||
program_raw = decode_program(kernel)
|
||||
self.program = {self.lib_addr + offset: val for offset, val in program_raw.items()}
|
||||
self.program = {}
|
||||
self.state = WaveState(n_lanes)
|
||||
self.state.pc = self.lib_addr # Set PC to code base address
|
||||
self.vmem_buf = Buffer('CPU', 1 << 40, dtypes.uint32, options=BufferSpec(external_ptr=0)).ensure_allocated()
|
||||
self.lds_buf = Buffer('CPU', 65536 // 4, dtypes.uint32).ensure_allocated()
|
||||
|
||||
def _ensure_decoded(self, pc: int):
|
||||
if pc not in self.program:
|
||||
runner = _decode_at(pc, "rdna3")
|
||||
self.program[pc] = (runner.p.function_name, runner._prg.fxn, runner.p.globals)
|
||||
|
||||
def step(self) -> int:
|
||||
import ctypes
|
||||
assert self.program is not None and self.state is not None
|
||||
assert self.state is not None
|
||||
pc = self.state.pc
|
||||
if pc == 0xFFFFFFFFFFFFFFFF or pc not in self.program: return -1
|
||||
name, fxn, globals_list, _runner = self.program[pc]
|
||||
if fxn is None: return 1 # unsupported instruction
|
||||
if pc == 0xFFFFFFFFFFFFFFFF: return -1
|
||||
self._ensure_decoded(pc)
|
||||
name, fxn, globals_list = self.program[pc]
|
||||
buf_addrs = {0: self.state.sgpr_buf._buf.va_addr, 1: self.state.vgpr_buf._buf.va_addr, # type: ignore[union-attr]
|
||||
2: self.vmem_buf._buf.va_addr, 3: self.lds_buf._buf.va_addr} # type: ignore[union-attr]
|
||||
# Direct ctypes call - bypasses HCQ overhead
|
||||
fxn(*[ctypes.c_uint64(buf_addrs[g]) for g in globals_list], ctypes.c_int32(0))
|
||||
return -1 if self.state.pc == 0xFFFFFFFFFFFFFFFF else 0
|
||||
|
||||
@@ -140,7 +145,7 @@ class PythonEmulator:
|
||||
exec_mask=sgpr[EXEC_LO.offset], sgpr=sgpr, vgpr=vgpr)
|
||||
|
||||
def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: tuple[int, int, int],
|
||||
local_size: tuple[int, int, int], program, max_steps: int, debug: bool, trace_len: int,
|
||||
local_size: tuple[int, int, int], max_steps: int, debug: bool, trace_len: int,
|
||||
kernel_idx: int = 0, max_workgroups: int = 8) -> tuple[bool, str, int]:
|
||||
"""Run a single kernel through both emulators. Returns (success, message, total_steps)."""
|
||||
gx, gy, gz = global_size
|
||||
@@ -181,9 +186,9 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
|
||||
rust_before = rust.get_snapshot()
|
||||
python_before = python.get_snapshot()
|
||||
|
||||
assert python.program is not None
|
||||
inst_info = python.program.get(python.lib_addr + python_before.pc * 4) # Convert word offset to actual address
|
||||
inst_hex_name = inst_info[0] if inst_info else f"unknown at PC={python_before.pc}"
|
||||
pc_addr = python.lib_addr + python_before.pc * 4 # Convert word offset to actual address
|
||||
python._ensure_decoded(pc_addr)
|
||||
inst_hex_name = python.program[pc_addr][0]
|
||||
# Decode the instruction to get mnemonic for sync_after checks
|
||||
try:
|
||||
# Format is mnemonic_hexbytes, e.g. v_exp_f32_e32_014b027e -> hex is 014b027e
|
||||
@@ -310,12 +315,11 @@ def compare_emulators_multi_kernel(kernels: list[KernelInfo], buf_pool: dict[int
|
||||
kernel_ranges = ranges | {(args_ptr, ctypes.sizeof(args))}
|
||||
set_valid_mem_ranges(kernel_ranges)
|
||||
|
||||
program = decode_program(kernel.code)
|
||||
n_lanes = kernel.local_size[0] * kernel.local_size[1] * kernel.local_size[2]
|
||||
|
||||
ok, msg, steps = run_single_kernel(
|
||||
kernel.code, min(n_lanes, 32), args_ptr, kernel.global_size,
|
||||
kernel.local_size, program, max_steps, debug, trace_len, ki
|
||||
kernel.local_size, max_steps, debug, trace_len, ki
|
||||
)
|
||||
total_steps += steps
|
||||
if not ok:
|
||||
@@ -341,9 +345,8 @@ def compare_emulators_with_memory(kernel: bytes, n_lanes: int, buf_sizes: list,
|
||||
ranges.add((args_ptr, ctypes.sizeof(args)))
|
||||
set_valid_mem_ranges(ranges)
|
||||
|
||||
program = decode_program(kernel)
|
||||
# Legacy wrapper assumes local_size = (n_lanes, 1, 1)
|
||||
ok, msg, _ = run_single_kernel(kernel, n_lanes, args_ptr, global_size, (n_lanes, 1, 1), program, max_steps, debug, trace_len)
|
||||
ok, msg, _ = run_single_kernel(kernel, n_lanes, args_ptr, global_size, (n_lanes, 1, 1), max_steps, debug, trace_len)
|
||||
return ok, msg
|
||||
|
||||
def get_kernels_from_tinygrad(op_fn) -> tuple[list[KernelInfo], dict[int, int], dict[int, bytes]]:
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
import unittest, ctypes
|
||||
from tinygrad.runtime.autogen.amd.rdna4 import ins as ir4
|
||||
from tinygrad.renderer.amd.dsl import v, s
|
||||
from tinygrad.renderer.amd.emu import WaveState, decode_program
|
||||
from tinygrad.device import Buffer, BufferSpec
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
class TestRDNA4Emu(unittest.TestCase):
|
||||
def _run(self, insts: list, sgprs: dict[int, int] | None = None, vgprs: dict[tuple[int, int], int] | None = None) -> WaveState:
|
||||
"""Run instructions and return final WaveState."""
|
||||
# Add S_ENDPGM if not present
|
||||
if not any(isinstance(i, ir4.SOPP) and i.op == ir4.SOPPOp.S_ENDPGM for i in insts):
|
||||
insts = list(insts) + [ir4.SOPP(ir4.SOPPOp.S_ENDPGM, simm=0)]
|
||||
|
||||
# Assemble and decode
|
||||
code = b''.join(i.to_bytes() for i in insts)
|
||||
code_buf = (ctypes.c_uint8 * len(code)).from_buffer_copy(code)
|
||||
code_addr = ctypes.addressof(code_buf)
|
||||
program_raw = decode_program(code, "rdna4")
|
||||
program = {code_addr + offset: val for offset, val in program_raw.items()}
|
||||
|
||||
# Setup wave state
|
||||
st = WaveState(n_lanes=1)
|
||||
st.pc = code_addr
|
||||
for idx, val in (sgprs or {}).items(): st._write_sgpr(idx, val)
|
||||
for (reg, lane), val in (vgprs or {}).items(): st._write_vgpr(reg, lane, val)
|
||||
|
||||
# Setup vmem buffer with external_ptr=0 (maps to address 0, allows any pointer access)
|
||||
vmem_buf = Buffer('CPU', 1 << 40, dtypes.uint32, options=BufferSpec(external_ptr=0)).ensure_allocated()
|
||||
|
||||
# Execute
|
||||
c_bufs = [ctypes.c_uint64(st.sgpr_buf._buf.va_addr), ctypes.c_uint64(st.vgpr_buf._buf.va_addr),
|
||||
ctypes.c_uint64(vmem_buf._buf.va_addr), ctypes.c_uint64(0), ctypes.c_uint64(0)]
|
||||
for _ in range(100):
|
||||
if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF or pc not in program: break
|
||||
_, fxn, globals_list, _ = program[pc]
|
||||
fxn(*[c_bufs[g] for g in globals_list])
|
||||
return st
|
||||
|
||||
def test_vopd_dual_mov(self):
|
||||
"""Test VOPD with two V_DUAL_MOV_B32 operations: v[1]=s[1], v[2]=s[2]."""
|
||||
insts = [ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
|
||||
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0])]
|
||||
st = self._run(insts, sgprs={1: 0x40e00000, 2: 0x41100000}) # 7.0f, 9.0f
|
||||
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
|
||||
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
|
||||
|
||||
def test_vopd_dual_mov_after_other_vopd(self):
|
||||
"""Test VOPD reuse: first VOPD(v[3]=0, v[0]=?), then VOPD(v[1]=s[1], v[2]=s[2])."""
|
||||
# This matches the BEAM kernel sequence that fails
|
||||
insts = [
|
||||
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
|
||||
vdstx=v[3], vdsty=v[0], srcx0=0, srcy0=s[0], vsrcx1=v[0], vsrcy1=v[0]), # v[3]=0, v[0]=s[0]
|
||||
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
|
||||
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0]), # v[1]=s[1], v[2]=s[2]
|
||||
]
|
||||
st = self._run(insts, sgprs={0: 0x40a00000, 1: 0x40e00000, 2: 0x41100000}) # 5.0f, 7.0f, 9.0f
|
||||
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
|
||||
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
|
||||
|
||||
def test_vopd_with_s_add_f32_sequence(self):
|
||||
"""Test full BEAM kernel sequence: s_add_f32 then VOPD."""
|
||||
# This is the exact sequence from the failing BEAM kernel
|
||||
insts = [
|
||||
ir4.SOP2(ir4.SOP2Op.S_ADD_F32, sdst=s[0], ssrc0=s[0], ssrc1=s[8]), # s[0] = s[0] + s[8]
|
||||
ir4.SOP2(ir4.SOP2Op.S_ADD_F32, sdst=s[1], ssrc0=s[1], ssrc1=s[9]), # s[1] = s[1] + s[9]
|
||||
ir4.SOP2(ir4.SOP2Op.S_ADD_F32, sdst=s[2], ssrc0=s[2], ssrc1=s[10]), # s[2] = s[2] + s[10]
|
||||
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
|
||||
vdstx=v[3], vdsty=v[0], srcx0=0, srcy0=s[0], vsrcx1=v[0], vsrcy1=v[0]),
|
||||
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
|
||||
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0]),
|
||||
]
|
||||
# Input: s[0:2] = [1,2,3], s[8:10] = [4,5,6]
|
||||
# After s_add_f32: s[0:2] = [5,7,9]
|
||||
st = self._run(insts, sgprs={0: 0x3f800000, 1: 0x40000000, 2: 0x40400000, # 1.0, 2.0, 3.0
|
||||
8: 0x40800000, 9: 0x40a00000, 10: 0x40c00000}) # 4.0, 5.0, 6.0
|
||||
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
|
||||
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
|
||||
|
||||
def test_s_mov_b32_then_vopd(self):
|
||||
"""Test s_mov_b32 followed by VOPD - simulates BEAM kernel sequence."""
|
||||
# Use s_mov_b32 with SGPR source (copy from pre-initialized SGPRs)
|
||||
# s[10:12] will have values set by test harness, copy to s[0:2], then VOPD to VGPRs
|
||||
insts = [
|
||||
ir4.SOP1(ir4.SOP1Op.S_MOV_B32, sdst=s[0], ssrc0=s[10]), # s[0] = s[10]
|
||||
ir4.SOP1(ir4.SOP1Op.S_MOV_B32, sdst=s[1], ssrc0=s[11]), # s[1] = s[11]
|
||||
ir4.SOP1(ir4.SOP1Op.S_MOV_B32, sdst=s[2], ssrc0=s[12]), # s[2] = s[12]
|
||||
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
|
||||
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0]),
|
||||
]
|
||||
st = self._run(insts, sgprs={10: 0x40a00000, 11: 0x40e00000, 12: 0x41100000}) # 5.0, 7.0, 9.0
|
||||
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
|
||||
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -90,9 +90,9 @@ class AMDDriver(VirtDriver):
|
||||
def _prepare_gpu(self, gpu_id):
|
||||
self.doorbells[gpu_id] = memoryview(bytearray(0x2000))
|
||||
self.gpus[gpu_id] = AMDGPU(gpu_id)
|
||||
# IP versions: rdna3 = GC 11.0.0, NBIF 4.3.0; rdna4 = GC 12.0.0, NBIF 6.3.1
|
||||
ip_versions = {"rdna3": {"gc": (11, 0, 0), "sdma": (6, 0, 0), "nbif": (4, 3, 0)},
|
||||
"rdna4": {"gc": (12, 0, 0), "sdma": (6, 0, 0), "nbif": (6, 3, 1)}}[MOCKGPU_ARCH]
|
||||
"rdna4": {"gc": (12, 0, 0), "sdma": (6, 0, 0), "nbif": (6, 3, 1)},
|
||||
"cdna4": {"gc": (9, 5, 0), "sdma": (4, 4, 5), "nbif": (7, 9, 0)}}[MOCKGPU_ARCH]
|
||||
def ip_discovery_files(hwid, ver, base_addr):
|
||||
p = f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{hwid}/0'
|
||||
return [VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{hwid}', functools.partial(DirFileDesc, child_names=['0'])),
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.helpers import getbits, to_mv, getenv
|
||||
from tinygrad.runtime.support import c
|
||||
|
||||
MOCKGPU_ARCH = getenv("MOCKGPU_ARCH", "rdna3")
|
||||
GFX_TARGET_VERSION = {"rdna3": 110000, "rdna4": 120000}[MOCKGPU_ARCH]
|
||||
GFX_TARGET_VERSION = {"rdna3": 110000, "rdna4": 120000, "cdna4": 90500}[MOCKGPU_ARCH]
|
||||
import tinygrad.runtime.autogen.amd_gpu as amd_gpu, tinygrad.runtime.autogen.am.pm4_nv as pm4
|
||||
|
||||
SDMA_MAX_COPY_SIZE = 0x400000
|
||||
@@ -106,8 +106,8 @@ class PM4Executor(AMDQueue):
|
||||
return (self.rptr[0] - prev_rptr) + executed_in_ib
|
||||
|
||||
def _exec_acquire_mem(self, n):
|
||||
assert n == 6
|
||||
for _ in range(7): self._next_dword() # TODO: implement
|
||||
assert n in (5, 6)
|
||||
for _ in range(n + 1): self._next_dword() # TODO: implement
|
||||
|
||||
def _exec_release_mem(self, n):
|
||||
assert n == 6
|
||||
@@ -184,6 +184,12 @@ class PM4Executor(AMDQueue):
|
||||
args_addr = self.gpu.regs[regCOMPUTE_USER_DATA_0] + (self.gpu.regs[regCOMPUTE_USER_DATA_0 + 1] << 32)
|
||||
lc = [self.gpu.regs[i] for i in range(regCOMPUTE_NUM_THREAD_X, regCOMPUTE_NUM_THREAD_X+3)]
|
||||
rsrc2 = self.gpu.regs[regCOMPUTE_PGM_RSRC2]
|
||||
# Read all user data registers (hardware loads these directly into s[0:N])
|
||||
user_sgpr_count = (rsrc2 >> 1) & 0x1F # USER_SGPR_COUNT is bits 1:5
|
||||
user_data = []
|
||||
for i in range(user_sgpr_count):
|
||||
try: user_data.append(self.gpu.regs[regCOMPUTE_USER_DATA_0 + i])
|
||||
except KeyError: user_data.append(0)
|
||||
|
||||
prg_sz = 0
|
||||
for st,sz in self.gpu.mapped_ranges:
|
||||
@@ -197,11 +203,12 @@ class PM4Executor(AMDQueue):
|
||||
scratch_size = wavesize * 4 # This gives the scratch size per thread (lane)
|
||||
|
||||
assert prg_sz > 0, "Invalid prg ptr (not found in mapped ranges)"
|
||||
# Pass valid memory ranges, rsrc2, scratch_size and arch to Python emulator
|
||||
# Pass valid memory ranges, rsrc2, scratch_size, arch, and user data registers to Python emulator
|
||||
if hasattr(remu, 'valid_mem_ranges'): remu.valid_mem_ranges = self.gpu.mapped_ranges
|
||||
if hasattr(remu, 'rsrc2'): remu.rsrc2 = rsrc2
|
||||
if hasattr(remu, 'scratch_size'): remu.scratch_size = scratch_size
|
||||
if hasattr(remu, 'arch'): remu.arch = self.gpu.arch
|
||||
if hasattr(remu, 'user_data'): remu.user_data = user_data
|
||||
err = remu.run_asm(prg_addr, prg_sz, *gl, *lc, args_addr)
|
||||
if err != 0: raise RuntimeError("remu does not support the new instruction introduced in this kernel")
|
||||
|
||||
@@ -318,7 +325,7 @@ class AMDGPU(VirtGPU):
|
||||
self.regs = AMDGPURegisters()
|
||||
self.mapped_ranges = set()
|
||||
self.queues = []
|
||||
self.arch = MOCKGPU_ARCH
|
||||
self.arch = "cdna" if MOCKGPU_ARCH == "cdna4" else MOCKGPU_ARCH
|
||||
|
||||
def map_range(self, vaddr, size): self.mapped_ranges.add((vaddr, size))
|
||||
def unmap_range(self, vaddr, size): self.mapped_ranges.remove((vaddr, size))
|
||||
@@ -329,7 +336,7 @@ class AMDGPU(VirtGPU):
|
||||
self.queues.append(SDMAExecutor(self, base, size, rptr, wptr))
|
||||
return len(self.queues) - 1
|
||||
|
||||
gpu_props = """cpu_cores_count 0
|
||||
_gpu_props_rdna = """cpu_cores_count 0
|
||||
simd_count 192
|
||||
mem_banks_count 1
|
||||
caches_count 206
|
||||
@@ -367,3 +374,44 @@ sdma_fw_version 20
|
||||
unique_id 11673270660693242239
|
||||
num_xcc 1
|
||||
max_engine_clk_ccompute 2400"""
|
||||
|
||||
_gpu_props_cdna = """cpu_cores_count 0
|
||||
simd_count 304
|
||||
mem_banks_count 1
|
||||
caches_count 206
|
||||
io_links_count 1
|
||||
p2p_links_count 5
|
||||
cpu_core_id_base 0
|
||||
simd_id_base 2147488032
|
||||
max_waves_per_simd 16
|
||||
lds_size_in_kb 128
|
||||
gds_size_in_kb 0
|
||||
num_gws 64
|
||||
wave_front_size 64
|
||||
array_count 16
|
||||
simd_arrays_per_engine 4
|
||||
cu_per_simd_array 19
|
||||
simd_per_cu 2
|
||||
max_slots_scratch_cu 32
|
||||
gfx_target_version {gfx_target_version}
|
||||
vendor_id 4098
|
||||
device_id 29772
|
||||
location_id 34304
|
||||
domain 0
|
||||
drm_render_minor {drm_render_minor}
|
||||
hive_id 0
|
||||
num_sdma_engines 2
|
||||
num_sdma_xgmi_engines 0
|
||||
num_sdma_queues_per_engine 6
|
||||
num_cp_queues 8
|
||||
max_engine_clk_fcompute 2100
|
||||
local_mem_size 0
|
||||
fw_version 2140
|
||||
capability 671588992
|
||||
debug_prop 1495
|
||||
sdma_fw_version 20
|
||||
unique_id 11673270660693242239
|
||||
num_xcc 1
|
||||
max_engine_clk_ccompute 2100"""
|
||||
|
||||
gpu_props = _gpu_props_cdna if MOCKGPU_ARCH == "cdna4" else _gpu_props_rdna
|
||||
|
||||
@@ -21,10 +21,11 @@ class PythonRemu:
|
||||
rsrc2: int = 0x19c # Default: USER_SGPR_COUNT=14, enable X and Y workgroup IDs
|
||||
scratch_size: int = 0 # private_segment_fixed_size from kernel descriptor
|
||||
arch: str = "rdna3" # Architecture: rdna3 or rdna4
|
||||
user_data: list[int] = [] # All COMPUTE_USER_DATA registers (loaded into s[0:N])
|
||||
|
||||
def run_asm(self, lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int) -> int:
|
||||
from tinygrad.renderer.amd.emu import run_asm
|
||||
return run_asm(lib, lib_sz, gx, gy, gz, lx, ly, lz, args_ptr, self.rsrc2, self.scratch_size, self.arch)
|
||||
return run_asm(lib, lib_sz, gx, gy, gz, lx, ly, lz, args_ptr, self.rsrc2, self.scratch_size, self.arch, self.user_data)
|
||||
|
||||
def _try_dlopen_remu():
|
||||
# Use Python emulator only if PYTHON_REMU=1
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
# arg=4: scratch - per-lane scratch memory
|
||||
from __future__ import annotations
|
||||
import ctypes, functools, re, platform, subprocess, tempfile
|
||||
from typing import Any, Callable
|
||||
from typing import Callable
|
||||
|
||||
# Set/restore DAZ+FTZ (denormals-are-zero + flush-to-zero) to match RDNA3 default float mode
|
||||
# x86: MXCSR bits DAZ(6)+FTZ(15), ARM64: FPCR bit FZ(24)
|
||||
@@ -61,8 +61,10 @@ from tinygrad.engine.realize import get_runner
|
||||
from tinygrad.renderer.amd import decode_inst
|
||||
from tinygrad.runtime.autogen.amd.rdna3.str_pcode import PCODE as PCODE_RDNA3
|
||||
from tinygrad.runtime.autogen.amd.rdna4.str_pcode import PCODE as PCODE_RDNA4
|
||||
from tinygrad.runtime.autogen.amd.cdna.str_pcode import PCODE as PCODE_CDNA
|
||||
from tinygrad.runtime.autogen.amd.rdna3 import ins as ir3
|
||||
from tinygrad.runtime.autogen.amd.rdna4 import ins as ir4
|
||||
from tinygrad.runtime.autogen.amd.cdna import ins as irc
|
||||
from tinygrad.renderer.amd.dsl import VCC_LO, EXEC_LO, SCC, ttmp
|
||||
from tinygrad.runtime.autogen.amd.common import Fmt, OpType
|
||||
from tinygrad.renderer.amd.pcode import parse_block, _FUNCS
|
||||
@@ -160,7 +162,7 @@ _pcode_fixes = {
|
||||
|
||||
def _get_pcode_dict(op) -> dict:
|
||||
"""Return the PCODE dictionary for the given opcode based on its architecture."""
|
||||
return PCODE_RDNA4 if 'rdna4' in type(op).__module__ else PCODE_RDNA3
|
||||
return PCODE_CDNA if 'cdna' in type(op).__module__ else PCODE_RDNA4 if 'rdna4' in type(op).__module__ else PCODE_RDNA3
|
||||
|
||||
# Pcode parser
|
||||
@functools.cache
|
||||
@@ -465,8 +467,8 @@ class _Ctx:
|
||||
pcode = get_pcode(op)
|
||||
vcc_reg = sdst_reg if sdst_reg is not None else VCC_LO.offset
|
||||
if 'VCC' not in srcs: srcs['VCC'] = self.rsgpr_dyn(_c(vcc_reg))
|
||||
srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane,
|
||||
'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0)}) # rounding mode: 0=RNE, RTZ constant
|
||||
srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane, 'VDST': vdst_reg,
|
||||
'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0), 'ROUND_NEAREST_EVEN': _c(0)}) # rounding mode constants
|
||||
_, assigns = parse_pcode(pcode, srcs)
|
||||
|
||||
# For integer ops with clamp, compute overflow using wide arithmetic
|
||||
@@ -543,10 +545,11 @@ class _Ctx:
|
||||
|
||||
def _compile_sopp(inst: ir3.SOPP | ir4.SOPP, ctx: _Ctx) -> UOp:
|
||||
simm16 = ctx.inst_field_signed(type(inst).simm16).cast(dtypes.int16)
|
||||
if inst.op in (ir3.SOPPOp.S_ENDPGM, ir4.SOPPOp.S_ENDPGM):
|
||||
if inst.op in (ir3.SOPPOp.S_ENDPGM, ir4.SOPPOp.S_ENDPGM, irc.SOPPOp.S_ENDPGM):
|
||||
return UOp.sink(ctx.wsgpr_dyn(_c(PC_LO_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)),
|
||||
ctx.wsgpr_dyn(_c(PC_HI_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)))
|
||||
if inst.op in (ir3.SOPPOp.S_NOP, ir4.SOPPOp.S_NOP): return UOp.sink(*ctx.inc_pc()) # S_NOP is a no-op
|
||||
# S_NOP and S_WAITCNT are no-ops in emulator (no pipeline/cache to wait on)
|
||||
if inst.op in (ir3.SOPPOp.S_NOP, ir4.SOPPOp.S_NOP, irc.SOPPOp.S_NOP, irc.SOPPOp.S_WAITCNT): return UOp.sink(*ctx.inc_pc())
|
||||
# NOTE: we ignore SOPPs without PCODE
|
||||
if inst.op in _get_pcode_dict(inst.op):
|
||||
pcode = get_pcode(inst.op)
|
||||
@@ -562,10 +565,7 @@ def _compile_sopp(inst: ir3.SOPP | ir4.SOPP, ctx: _Ctx) -> UOp:
|
||||
|
||||
def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp:
|
||||
# Cache invalidation instructions are no-ops in the emulator (we don't model caches)
|
||||
cache_inv_ops = [ir3.SMEMOp.S_GL1_INV, ir3.SMEMOp.S_DCACHE_INV, ir4.SMEMOp.S_DCACHE_INV]
|
||||
if hasattr(ir4.SMEMOp, 'S_GL1_INV'): cache_inv_ops.append(ir4.SMEMOp.S_GL1_INV)
|
||||
if inst.op in cache_inv_ops:
|
||||
return UOp.sink(*ctx.inc_pc())
|
||||
if '_INV' in inst.op.name: return UOp.sink(*ctx.inc_pc())
|
||||
# Dynamic sbase field (bits 5:0) - SGPR pair, field value * 2 = register offset
|
||||
sbase = ctx.inst_field(type(inst).sbase) * _c(2)
|
||||
# Dynamic sdata field (bits 12:6) - destination SGPR
|
||||
@@ -573,34 +573,44 @@ def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp:
|
||||
# RDNA4 uses 'ioffset', RDNA3 uses 'offset' - use type(inst) to get correct field
|
||||
offset_field = type(inst).ioffset if hasattr(type(inst), 'ioffset') else type(inst).offset # type: ignore[union-attr]
|
||||
offset = ctx.inst_field_signed(offset_field) # signed immediate
|
||||
# Dynamic soffset field - SGPR for additional offset (NULL=124 reads as 0)
|
||||
soffset = ctx.inst_field(type(inst).soffset)
|
||||
addr = _u64(ctx.rsgpr_dyn(sbase), ctx.rsgpr_dyn(sbase + _c(1))) + offset.cast(dtypes.uint64) + ctx.rsgpr_dyn(soffset).cast(dtypes.uint64)
|
||||
# Dynamic soffset field - SGPR for additional offset (NULL=124 reads as 0, CDNA soffset_en=0 means no soffset)
|
||||
soffset_val = _c(0).cast(dtypes.uint64)
|
||||
if not (isinstance(inst, irc.SMEM) and not inst.soffset_en):
|
||||
soffset_val = ctx.inst_field(type(inst).soffset)
|
||||
soffset_val = ctx.rsgpr_dyn(soffset_val).cast(dtypes.uint64)
|
||||
addr = _u64(ctx.rsgpr_dyn(sbase), ctx.rsgpr_dyn(sbase + _c(1))) + offset.cast(dtypes.uint64) + soffset_val
|
||||
_SMEM_NDWORDS = {ir3.SMEMOp.S_LOAD_B32: 1, ir3.SMEMOp.S_LOAD_B64: 2, ir3.SMEMOp.S_LOAD_B128: 4,
|
||||
ir3.SMEMOp.S_LOAD_B256: 8, ir3.SMEMOp.S_LOAD_B512: 16, ir4.SMEMOp.S_LOAD_B32: 1, ir4.SMEMOp.S_LOAD_B64: 2,
|
||||
ir4.SMEMOp.S_LOAD_B96: 3, ir4.SMEMOp.S_LOAD_B128: 4, ir4.SMEMOp.S_LOAD_B256: 8, ir4.SMEMOp.S_LOAD_B512: 16}
|
||||
ir4.SMEMOp.S_LOAD_B96: 3, ir4.SMEMOp.S_LOAD_B128: 4, ir4.SMEMOp.S_LOAD_B256: 8, ir4.SMEMOp.S_LOAD_B512: 16,
|
||||
irc.SMEMOp.S_LOAD_DWORD: 1, irc.SMEMOp.S_LOAD_DWORDX2: 2, irc.SMEMOp.S_LOAD_DWORDX4: 4,
|
||||
irc.SMEMOp.S_LOAD_DWORDX8: 8, irc.SMEMOp.S_LOAD_DWORDX16: 16}
|
||||
ndwords = _SMEM_NDWORDS[inst.op]
|
||||
stores = [ctx.wsgpr_dyn(sdata_reg + _c(i), ctx.vmem.index((addr + UOp.const(dtypes.uint64, i * 4) >> UOp.const(dtypes.uint64, 2)).cast(dtypes.int)))
|
||||
for i in range(ndwords)]
|
||||
return UOp.sink(*stores, *ctx.inc_pc())
|
||||
|
||||
def _compile_sop(inst: ir3.SOP1 | ir3.SOP2 | ir3.SOPC | ir3.SOPK | ir4.SOP1 | ir4.SOP2 | ir4.SOPC | ir4.SOPK, ctx: _Ctx) -> UOp:
|
||||
def _compile_sop(inst: ir3.SOP1|ir3.SOP2|ir3.SOPC|ir3.SOPK|ir4.SOP1|ir4.SOP2|ir4.SOPC|ir4.SOPK|irc.SOP1|irc.SOP2|irc.SOPC|irc.SOPK, ctx: _Ctx) -> UOp:
|
||||
bits = inst.canonical_op_bits
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
|
||||
|
||||
if isinstance(inst, (ir3.SOPK, ir4.SOPK)):
|
||||
if isinstance(inst, (ir3.SOPK, ir4.SOPK, irc.SOPK)):
|
||||
sdst_off = ctx.inst_field(type(inst).sdst)
|
||||
simm16 = ctx.inst_field(type(inst).simm16)
|
||||
# Sign-extend simm16
|
||||
simm16_sext = simm16.cast(dtypes.int16).cast(dtypes.int32)
|
||||
srcs = {'S0': ctx.rsgpr_dyn(sdst_off), 'SIMM16': simm16_sext, 'D0': ctx.rsgpr_dyn(sdst_off)}
|
||||
# CDNA pcode uses S0 for the immediate in MOVK/MULK/ADDK/CMOVK (where RDNA uses SIMM16),
|
||||
# but S0 = register for CMPK/SETREG. S1 is always the immediate for CDNA CMPK ops.
|
||||
op_name = inst.op.name if hasattr(inst.op, 'name') else ''
|
||||
s0_is_imm = isinstance(inst, irc.SOPK) and 'CMPK' not in op_name and 'SETREG' not in op_name
|
||||
s0_val = simm16_sext if s0_is_imm else ctx.rsgpr_dyn(sdst_off)
|
||||
srcs = {'S0': s0_val, 'SIMM16': simm16_sext, 'S1': simm16_sext, 'D0': ctx.rsgpr_dyn(sdst_off)}
|
||||
dst_off, dst_size = sdst_off, 1
|
||||
elif isinstance(inst, (ir3.SOP1, ir4.SOP1)):
|
||||
elif isinstance(inst, (ir3.SOP1, ir4.SOP1, irc.SOP1)):
|
||||
sdst_off = ctx.inst_field(type(inst).sdst)
|
||||
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
|
||||
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal)}
|
||||
dst_off, dst_size = sdst_off, bits['d'] // 32
|
||||
elif isinstance(inst, (ir3.SOP2, ir4.SOP2)):
|
||||
elif isinstance(inst, (ir3.SOP2, ir4.SOP2, irc.SOP2)):
|
||||
sdst_off = ctx.inst_field(type(inst).sdst)
|
||||
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
|
||||
ssrc1_off = ctx.inst_field(type(inst).ssrc1)
|
||||
@@ -608,7 +618,7 @@ def _compile_sop(inst: ir3.SOP1 | ir3.SOP2 | ir3.SOPC | ir3.SOPK | ir4.SOP1 | ir
|
||||
'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)}
|
||||
if literal is not None: srcs['SIMM32'] = literal
|
||||
dst_off, dst_size = sdst_off, bits['d'] // 32
|
||||
elif isinstance(inst, (ir3.SOPC, ir4.SOPC)):
|
||||
elif isinstance(inst, (ir3.SOPC, ir4.SOPC, irc.SOPC)):
|
||||
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
|
||||
ssrc1_off = ctx.inst_field(type(inst).ssrc1)
|
||||
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal),
|
||||
@@ -619,7 +629,7 @@ def _compile_sop(inst: ir3.SOP1 | ir3.SOP2 | ir3.SOPC | ir3.SOPK | ir4.SOP1 | ir
|
||||
|
||||
return ctx.compile_sop_pcode(inst.op, srcs, dst_off, dst_size)
|
||||
|
||||
def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VOP1_SDST | ir4.VOP2, ctx: _Ctx) -> UOp:
|
||||
def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VOP1_SDST | ir4.VOP2 | irc.VOP1 | irc.VOP2, ctx: _Ctx) -> UOp:
|
||||
op_name = _op_name(inst)
|
||||
if op_name in ('V_READFIRSTLANE_B32_E32', 'V_PERMLANE64_B32_E32'): return ctx.compile_lane_pcode(inst.op, inst)
|
||||
lane, exec_mask, bits = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset)), inst.canonical_op_bits
|
||||
@@ -628,7 +638,7 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO
|
||||
write_hi_half = bits['d'] == 16 and (vdst_reg >= _c(128))
|
||||
if isinstance(write_hi_half, UOp): vdst_reg = write_hi_half.where(vdst_reg - _c(128), vdst_reg)
|
||||
elif write_hi_half: vdst_reg -= 128
|
||||
if isinstance(inst, (ir3.VOP1, ir4.VOP1)):
|
||||
if isinstance(inst, (ir3.VOP1, ir4.VOP1, irc.VOP1)):
|
||||
# Handle VOP1 hi-half source operand (src0 >= v[128] for 16-bit ops)
|
||||
src0_off = ctx.inst_field(type(inst).src0)
|
||||
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal)
|
||||
@@ -654,12 +664,13 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO
|
||||
s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0)
|
||||
srcs = {'S0': s0, 'S1': s1, 'D0': d0}
|
||||
if inst.op in (ir3.VOP2Op.V_FMAAK_F32_E32, ir3.VOP2Op.V_FMAMK_F32_E32, ir3.VOP2Op.V_FMAAK_F16_E32,
|
||||
ir3.VOP2Op.V_FMAMK_F16_E32):
|
||||
ir3.VOP2Op.V_FMAMK_F16_E32, irc.VOP2Op.V_FMAAK_F32_E32, irc.VOP2Op.V_FMAMK_F32_E32):
|
||||
assert literal is not None
|
||||
srcs['SIMM32'] = literal
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=write_hi_half)
|
||||
|
||||
def _compile_vopc(inst: ir3.VOPC | ir3.VOP3 | ir4.VOPC | ir4.VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp:
|
||||
def _compile_vopc(inst: ir3.VOPC|ir3.VOP3|ir4.VOPC|ir4.VOP3|irc.VOPC|irc.VOP3, ctx: _Ctx,
|
||||
opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp:
|
||||
exec_mask, op_name, bits = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst), inst.canonical_op_bits
|
||||
is_cmpx, is_vopc = 'CMPX' in op_name, hasattr(inst, 'vsrc1') # is_vopc: e32 vs e64
|
||||
|
||||
@@ -707,7 +718,7 @@ def _compile_vopc(inst: ir3.VOPC | ir3.VOP3 | ir4.VOPC | ir4.VOP3, ctx: _Ctx, op
|
||||
stores = [ctx.wsgpr_dyn(dst_off, new_result)] if not is_vopc else [ctx.wsgpr_dyn(_c(VCC_LO.offset), new_result)]
|
||||
return UOp.sink(*stores, *ctx.inc_pc())
|
||||
|
||||
def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3, ctx: _Ctx) -> UOp:
|
||||
def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3 | irc.VOP3, ctx: _Ctx) -> UOp:
|
||||
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
bits = inst.canonical_op_bits
|
||||
opsel, op_name = getattr(inst, 'opsel', 0) or 0, _op_name(inst)
|
||||
@@ -741,13 +752,13 @@ def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3, ctx: _Ctx) -> UOp:
|
||||
src1 = _apply_src_mods(src1, 1, abs_bits, neg_bits, bits['s1'])
|
||||
src2 = _apply_src_mods(src2, 2, abs_bits, neg_bits, bits['s2'])
|
||||
srcs = {'S0': src0, 'S1': src1, 'S2': src2}
|
||||
if inst.op in (ir3.VOP3Op.V_CNDMASK_B32_E64, ir3.VOP3Op.V_CNDMASK_B16) and src2 is not None: srcs['VCC'] = src2
|
||||
if inst.op in (ir3.VOP3Op.V_CNDMASK_B32_E64, ir3.VOP3Op.V_CNDMASK_B16, irc.VOP3Op.V_CNDMASK_B32_E64) and src2 is not None: srcs['VCC'] = src2
|
||||
# FMAC instructions need D0 (accumulator) from destination register
|
||||
if 'FMAC' in op_name: srcs['D0'] = ctx.rvgpr_dyn(vdst_reg, lane)
|
||||
opsel_dst_hi = bool(opsel & 0b1000) and bits['d'] == 16
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=opsel_dst_hi, clmp=getattr(inst, 'clmp', 0))
|
||||
|
||||
def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD, ctx: _Ctx) -> UOp:
|
||||
def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD | irc.VOP3SD, ctx: _Ctx) -> UOp:
|
||||
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
bits, pcode, ops = inst.canonical_op_bits, get_pcode(inst.op), inst.canonical_operands
|
||||
|
||||
@@ -806,7 +817,7 @@ def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD, ctx: _Ctx) -> UOp:
|
||||
else:
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, sdst_reg=inst.sdst.offset)
|
||||
|
||||
def _compile_wmma(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
|
||||
def _compile_wmma(inst: ir3.VOP3P | ir4.VOP3P | irc.VOP3P, ctx: _Ctx) -> UOp:
|
||||
op_name = _op_name(inst)
|
||||
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
@@ -839,14 +850,15 @@ def _compile_wmma(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
|
||||
stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32), mat_d[i].bitcast(dtypes.uint32), exec_mask) for i in range(256)]
|
||||
return UOp.sink(*stores, *ctx.inc_pc())
|
||||
|
||||
def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
|
||||
def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P | irc.VOP3P, ctx: _Ctx) -> UOp:
|
||||
op_name = _op_name(inst)
|
||||
if 'WMMA' in op_name and ('16X16X16_F16' in op_name or '16X16X16_BF16' in op_name): return _compile_wmma(inst, ctx)
|
||||
|
||||
lane = ctx.range()
|
||||
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
do_cast = any(x in op_name for x in ('F16', 'F32', 'BF16')) and 'IU' not in op_name
|
||||
is_pk_f32 = 'PK' in op_name and 'F32' in op_name and 'MOV' not in op_name # CDNA packed F32 ops
|
||||
do_cast = any(x in op_name for x in ('F16', 'F32', 'BF16')) and 'IU' not in op_name and not is_pk_f32
|
||||
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, 16, do_cast=do_cast)
|
||||
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, 16, do_cast=do_cast)
|
||||
src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, 16, do_cast=do_cast)
|
||||
@@ -854,7 +866,30 @@ def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
|
||||
opsel_hi2 = getattr(inst, 'opsel_hi2', 1) if getattr(inst, 'opsel_hi2', 1) is not None else 1
|
||||
neg, neg_hi = getattr(inst, 'neg', 0) or 0, getattr(inst, 'neg_hi', 0) or 0
|
||||
|
||||
if 'FMA_MIX' in op_name:
|
||||
if is_pk_f32:
|
||||
# CDNA packed F32: read 32-bit sources, build 64-bit packed values using opsel.
|
||||
# For VGPRs: opsel selects between v[reg] (0) and v[reg+1] (1) for each half.
|
||||
# For SGPR pairs (off < 128): s[N] = lo float32, s[N+1] = hi float32.
|
||||
# For inline constants (128 <= off < 256): broadcast same value to both halves.
|
||||
src_offs = [ctx.inst_field(type(inst).src0), ctx.inst_field(type(inst).src1), ctx.inst_field(type(inst).src2)]
|
||||
def build_pk_f32(src_lo: UOp, src_off: UOp, opsel_lo: int, opsel_hi_bit: int, neg_lo: int, neg_hi_bit: int) -> UOp:
|
||||
is_vgpr = src_off >= _c(256)
|
||||
vgpr_lo = ctx.rvgpr_dyn(src_off - _c(256), lane) if lane is not None else _c(0)
|
||||
vgpr_hi = ctx.rvgpr_dyn(src_off - _c(256) + _c(1), lane) if lane is not None else _c(0)
|
||||
# For SGPR pairs, opsel selects between s[N] (0) and s[N+1] (1); inline constants always broadcast.
|
||||
is_sgpr_pair = src_off < _c(128)
|
||||
sgpr_hi = ctx.rsgpr_dyn(src_off + _c(1), is_sgpr_pair)
|
||||
scalar_lo_sel = src_lo if not opsel_lo else is_sgpr_pair.where(sgpr_hi, src_lo)
|
||||
scalar_hi_sel = src_lo if not opsel_hi_bit else is_sgpr_pair.where(sgpr_hi, src_lo)
|
||||
lo = is_vgpr.where(vgpr_hi if opsel_lo else vgpr_lo, scalar_lo_sel)
|
||||
hi = is_vgpr.where(vgpr_hi if opsel_hi_bit else vgpr_lo, scalar_hi_sel)
|
||||
if neg_lo: lo = lo ^ UOp.const(dtypes.uint32, 0x80000000)
|
||||
if neg_hi_bit: hi = hi ^ UOp.const(dtypes.uint32, 0x80000000)
|
||||
return _u64(lo, hi)
|
||||
srcs = {'S0': build_pk_f32(src0, src_offs[0], opsel & 1, opsel_hi & 1, neg & 1, neg_hi & 1),
|
||||
'S1': build_pk_f32(src1, src_offs[1], opsel & 2, opsel_hi & 2, neg & 2, neg_hi & 2),
|
||||
'S2': build_pk_f32(src2, src_offs[2], opsel & 4, 1 if opsel_hi2 else 0, neg & 4, neg_hi & 4)}
|
||||
elif 'FMA_MIX' in op_name:
|
||||
combined_opsel_hi = (opsel_hi & 0x3) | ((opsel_hi2 & 0x1) << 2)
|
||||
# For FMA_MIX: neg_hi is ABS (not neg!), neg is actual negation
|
||||
def apply_abs(v, bit, opsel_hi_bit, opsel_bit):
|
||||
@@ -924,13 +959,18 @@ def _compile_vopd(inst: ir3.VOPD | ir4.VOPD, ctx: _Ctx) -> UOp:
|
||||
if dest.startswith('D0'): all_stores.append(ctx.wvgpr_dyn(vdst_reg, lane, _val_to_u32(val), exec_mask, after=srcy1))
|
||||
return UOp.sink(UOp.group(*all_stores).end(lane), *ctx.inc_pc())
|
||||
|
||||
def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS | ir4.VFLAT | ir4.VGLOBAL | ir4.VSCRATCH, ctx: _Ctx) -> UOp:
|
||||
def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLAT|ir4.VGLOBAL|ir4.VSCRATCH
|
||||
|irc.DS|irc.FLAT|irc.GLOBAL|irc.SCRATCH, ctx: _Ctx) -> UOp:
|
||||
"""Unified memory operation compiler for DS, FLAT, GLOBAL, SCRATCH."""
|
||||
exec_mask, op_name = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst)
|
||||
pcode = get_pcode(inst.op)
|
||||
# CDNA pcode uses CalcGlobalAddr/CalcDsAddr to compute address from raw components, but make_addr already handles this.
|
||||
# Strip the addr computation line and use pre-computed ADDR directly (rename 'addr' -> 'ADDR' in remaining pcode).
|
||||
if isinstance(inst, (irc.GLOBAL, irc.FLAT, irc.SCRATCH, irc.DS)) and 'Calc' in pcode and 'Addr' in pcode:
|
||||
pcode = re.sub(r'addr\s*=\s*Calc\w+Addr\([^)]*\)\s*;?\n?', '', pcode).replace('MEM[addr', 'MEM[ADDR')
|
||||
|
||||
is_lds = isinstance(inst, (ir3.DS, ir4.DS))
|
||||
is_scratch = isinstance(inst, (ir3.SCRATCH, ir4.VSCRATCH))
|
||||
is_lds = isinstance(inst, (ir3.DS, ir4.DS, irc.DS))
|
||||
is_scratch = isinstance(inst, (ir3.SCRATCH, ir4.VSCRATCH, irc.SCRATCH))
|
||||
mem = ctx.lds if is_lds else ctx.scratch if is_scratch else ctx.vmem
|
||||
addr_shift = UOp.const(dtypes.uint32 if is_lds else dtypes.uint64, 2)
|
||||
|
||||
@@ -1038,7 +1078,7 @@ def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS
|
||||
if 'STORE' in op_name and data_bits_mem >= 64:
|
||||
vdata = vdata | (ctx.rvgpr_dyn(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
|
||||
srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active,
|
||||
'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base}
|
||||
'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base, 'SADDR': saddr_base, 'OFFSET': offset}
|
||||
for i in range(data_bits_mem // 32):
|
||||
srcs[f'VDATA{i}'] = ctx.rvgpr_dyn(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0)
|
||||
return srcs
|
||||
@@ -1075,7 +1115,7 @@ def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS
|
||||
return UOp.sink(*ended, *ctx.inc_pc())
|
||||
|
||||
# Standard path: single lane range
|
||||
writes_return_data = '_RTN' in op_name or (is_lds and op_name.startswith('DS_LOAD')) or bool(is_atomic and glc)
|
||||
writes_return_data = '_RTN' in op_name or (is_lds and (op_name.startswith('DS_LOAD') or op_name.startswith('DS_READ'))) or bool(is_atomic and glc)
|
||||
lane = ctx.range()
|
||||
active = _lane_active(exec_mask, lane)
|
||||
pcode_vars, assigns = parse_pcode(pcode, make_srcs(lane))
|
||||
@@ -1099,6 +1139,11 @@ _INST_HANDLERS: dict[type, Callable[..., UOp]] = {
|
||||
ir4.VOP1: _compile_vop12, ir4.VOP1_SDST: _compile_vop12, ir4.VOP2: _compile_vop12, ir4.VOPC: _compile_vopc, ir4.VOP3: _compile_vop3,
|
||||
ir4.VOP3_SDST: _compile_vop3, ir4.VOP3SD: _compile_vop3sd, ir4.VOP3P: _compile_vop3p, ir4.VOPD: _compile_vopd,
|
||||
ir4.DS: _compile_mem_op, ir4.VFLAT: _compile_mem_op, ir4.VGLOBAL: _compile_mem_op, ir4.VSCRATCH: _compile_mem_op,
|
||||
# CDNA instruction classes
|
||||
irc.SOPP: _compile_sopp, irc.SMEM: _compile_smem, irc.SOP1: _compile_sop, irc.SOP2: _compile_sop, irc.SOPC: _compile_sop, irc.SOPK: _compile_sop,
|
||||
irc.VOP1: _compile_vop12, irc.VOP2: _compile_vop12, irc.VOPC: _compile_vopc, irc.VOP3: _compile_vop3,
|
||||
irc.VOP3_SDST: _compile_vop3, irc.VOP3SD: _compile_vop3sd, irc.VOP3P: _compile_vop3p,
|
||||
irc.DS: _compile_mem_op, irc.FLAT: _compile_mem_op, irc.GLOBAL: _compile_mem_op, irc.SCRATCH: _compile_mem_op,
|
||||
}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
@@ -1116,7 +1161,7 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
|
||||
|
||||
# Check if instruction matches any cached canonical pattern
|
||||
for base, mask, size, runner in _canonical_runner_cache:
|
||||
if inst_size == size and (inst_int & mask) == base: return runner, False
|
||||
if inst_size == size and (inst_int & mask) == base: return runner
|
||||
|
||||
# Look up handler by type, falling back to base classes for _LIT variants
|
||||
handler = _INST_HANDLERS.get(type(inst))
|
||||
@@ -1136,30 +1181,17 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
|
||||
with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES=""):
|
||||
runner = get_runner('CPU', sink)
|
||||
_canonical_runner_cache.append((base, mask, size, runner))
|
||||
return runner, True
|
||||
return runner
|
||||
|
||||
@functools.cache
|
||||
def decode_program(data: bytes, arch: str = "rdna3") -> dict[int, tuple[str, Callable, list[int], Any]]:
|
||||
"""Decode program to {pc: (name, fxn, globals, runner)}."""
|
||||
result: dict[int, tuple[str, Callable, list[int], Any]] = {}
|
||||
i = 0
|
||||
while i < len(data):
|
||||
inst = decode_inst(data[i:], arch)
|
||||
if hasattr(inst, 'op') and inst.op in (ir3.SOPPOp.S_CODE_END, ir4.SOPPOp.S_CODE_END): break
|
||||
try:
|
||||
runner, is_new = _get_runner(bytes(data[i:i + inst.size() + 4]), arch)
|
||||
if DEBUG >= 3:
|
||||
try: inst_str = repr(inst)
|
||||
except Exception: inst_str = f"<{type(inst).__name__} at PC={i}>"
|
||||
msg = f"[emu] PC={i}: {inst_str}"
|
||||
print(colored(msg, 'green') if is_new else msg)
|
||||
result[i] = (runner.p.function_name, runner._prg.fxn, runner.p.globals, runner)
|
||||
except Exception as e:
|
||||
try: inst_str = repr(inst)
|
||||
except Exception: inst_str = f"<{type(inst).__name__}>"
|
||||
raise RuntimeError(f"[emu] Failed to compile PC={i} {inst_str}: {type(e).__name__}: {e}") from e
|
||||
i += inst.size()
|
||||
return result
|
||||
def _decode_at(pc: int, arch: str):
|
||||
"""Decode and compile instruction at absolute address pc. Returns CompiledRunner."""
|
||||
inst_bytes = bytes((ctypes.c_char * 16).from_address(pc).raw)
|
||||
inst = decode_inst(inst_bytes, arch)
|
||||
try: return _get_runner(bytes(inst_bytes[:inst.size() + 4]), arch)
|
||||
except Exception as e:
|
||||
try: inst_str = repr(inst)
|
||||
except Exception: inst_str = f"<{type(inst).__name__}>"
|
||||
raise RuntimeError(f"[emu] Failed to compile {inst_str}: {type(e).__name__}: {e}") from e
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# WAVE STATE
|
||||
@@ -1206,10 +1238,9 @@ class WaveState:
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c,
|
||||
scratch_size: int = 0, arch: str = "rdna3") -> int:
|
||||
scratch_size: int = 0, arch: str = "rdna3", user_data: list[int]|None = None) -> int:
|
||||
"""Execute AMD assembly program. scratch_size is private_segment_fixed_size from kernel descriptor (per-lane)."""
|
||||
program_raw = decode_program(bytes((ctypes.c_char * lib_sz).from_address(lib).raw), arch)
|
||||
program = {lib + offset: val for offset, val in program_raw.items()} # Remap to actual addresses
|
||||
program: dict[int, tuple[Callable, list[int]]] = {} # lazily populated: pc -> (fxn, globals) extracted from runner
|
||||
lds_size = ((rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE_SHIFT) * 512
|
||||
total_threads = lx * ly * lz
|
||||
|
||||
@@ -1226,8 +1257,12 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
|
||||
for wave_start in range(0, total_threads, WAVE_SIZE):
|
||||
n_lanes, st = min(WAVE_SIZE, total_threads - wave_start), WaveState(min(WAVE_SIZE, total_threads - wave_start))
|
||||
st.pc = lib # Set PC to code base address
|
||||
st._write_sgpr(0, args_ptr & MASK32)
|
||||
st._write_sgpr(1, (args_ptr >> 32) & MASK32)
|
||||
# Initialize user SGPRs: hardware loads COMPUTE_USER_DATA registers directly into s[0:N]
|
||||
if user_data:
|
||||
for i, val in enumerate(user_data): st._write_sgpr(i, val)
|
||||
else:
|
||||
st._write_sgpr(0, args_ptr & MASK32)
|
||||
st._write_sgpr(1, (args_ptr >> 32) & MASK32)
|
||||
|
||||
# Workgroup IDs in SGPRs after user SGPRs
|
||||
sgpr_idx = (rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT_SHIFT
|
||||
@@ -1255,13 +1290,16 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
|
||||
ctypes.c_uint64(vmem_buf._buf.va_addr), ctypes.c_uint64(lds_buf._buf.va_addr),
|
||||
ctypes.c_uint64(scratch_buf._buf.va_addr if scratch_buf else 0)]
|
||||
for inst_count in range(1_000_000):
|
||||
if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF or pc not in program: break
|
||||
name, fxn, globals_list, _ = program[pc]
|
||||
assert fxn is not None, f"[emu] No fxn for {name} at PC={pc}"
|
||||
assert 4 not in globals_list or scratch_buf, f"SCRATCH instruction {name} but scratch_size=0"
|
||||
if DEBUG >= 6:
|
||||
inst = decode_inst(bytes((ctypes.c_char * 12).from_address(pc).raw), arch)
|
||||
print(f"[emu] exec PC={pc:X}: {inst!r}")
|
||||
if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF: break
|
||||
if pc not in program:
|
||||
prev_len = len(_canonical_runner_cache)
|
||||
runner = _decode_at(pc, arch)
|
||||
program[pc] = (runner._prg.fxn, runner.p.globals)
|
||||
if DEBUG >= 3:
|
||||
inst = decode_inst(bytes((ctypes.c_char * 16).from_address(pc).raw), arch)
|
||||
msg = f"[emu] PC={pc - lib}: {inst!r}"
|
||||
print(colored(msg, 'green') if len(_canonical_runner_cache) > prev_len else msg)
|
||||
fxn, globals_list = program[pc]
|
||||
fxn(*[c_bufs[g] for g in globals_list])
|
||||
else: raise RuntimeError("exceeded 1M instructions, likely infinite loop")
|
||||
return 0
|
||||
|
||||
@@ -40,7 +40,10 @@ def _bitreverse(v: UOp, bits: int) -> UOp:
|
||||
|
||||
def _extract_bits(val: UOp, hi: int, lo: int) -> UOp:
|
||||
dt = dtypes.uint64 if val.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
|
||||
return ((val >> _const(dt, lo)) if lo > 0 else val) & _const(val.dtype, (1 << (hi - lo + 1)) - 1)
|
||||
result = ((val >> _const(dt, lo)) if lo > 0 else val) & _const(val.dtype, (1 << (hi - lo + 1)) - 1)
|
||||
# Downcast to uint32 when extracting <=32 bits from a 64-bit value, so .f32 bitcast works correctly
|
||||
if dt == dtypes.uint64 and (hi - lo + 1) <= 32: result = result.cast(dtypes.uint32)
|
||||
return result
|
||||
|
||||
def _set_bit(old, pos, val):
|
||||
mask = _u32(1) << pos
|
||||
@@ -554,7 +557,9 @@ class Parser:
|
||||
self.eat('LBRACKET')
|
||||
self.eat_val('laneId', 'IDENT')
|
||||
self.eat('RBRACKET')
|
||||
result = (base >> _to_u32(self.vars['laneId'])) & _u32(1)
|
||||
lane = self.vars['laneId']
|
||||
shift = lane.cast(base.dtype) if base.dtype != dtypes.uint32 else _to_u32(lane)
|
||||
result = (base >> shift) & _const(base.dtype, 1)
|
||||
if self.try_eat('DOT'):
|
||||
dt_name = self.eat('IDENT').val
|
||||
return result.cast(DTYPES.get(dt_name, dtypes.uint32))
|
||||
@@ -806,6 +811,12 @@ def _subst_loop_var(line: str, loop_var: str, val: int) -> str:
|
||||
|
||||
def _set_bits(old: UOp, val: UOp, width: int, offset: int) -> UOp:
|
||||
"""Set bits [offset:offset+width) in old to val, masking and shifting appropriately."""
|
||||
is64 = old.dtype in (dtypes.uint64, dtypes.int64) or offset + width > 32
|
||||
if is64:
|
||||
old = old.cast(dtypes.uint64) if old.dtype != dtypes.uint64 else old
|
||||
mask = _u64(((1 << width) - 1) << offset)
|
||||
v = (val.cast(dtypes.uint64) if val.dtype != dtypes.uint64 else val) & _u64((1 << width) - 1)
|
||||
return (old & (mask ^ _u64(0xFFFFFFFFFFFFFFFF))) | (v << _u64(offset))
|
||||
mask = _u32(((1 << width) - 1) << offset)
|
||||
v = (val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val) & _u32((1 << width) - 1)
|
||||
return (old & (mask ^ _u32(0xFFFFFFFF))) | (v << _u32(offset))
|
||||
|
||||
Reference in New Issue
Block a user