diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ede5a345cb..9149f22f85 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -708,6 +708,8 @@ jobs: run: SKIP_SLOW_TEST=1 AMD_LLVM=0 pytest -n=auto test/backend/test_ops.py -k "test_sparse_categorical_crossentropy or test_tril or test_nonzero or test_softmax_argmax" --durations 20 - name: Run RDNA4 emulator tests run: MOCKGPU_ARCH=rdna4 python -m pytest test/test_tiny.py -v --durations 20 + - name: Run CDNA4 emulator tests + run: AMD_LLVM=1 MOCKGPU_ARCH=cdna4 python -m pytest test/test_tiny.py -v --durations 20 testnvidia: strategy: diff --git a/test/amd/bench_emu.py b/test/amd/bench_emu.py deleted file mode 100644 index 6c1a00c9eb..0000000000 --- a/test/amd/bench_emu.py +++ /dev/null @@ -1,267 +0,0 @@ -#!/usr/bin/env python3 -"""Benchmark comparing Python vs Rust RDNA3 emulators on real tinygrad kernels.""" -import ctypes, time, os -from pathlib import Path - -from tinygrad.renderer.amd.emu import run_asm as python_run_asm, decode_program -from tinygrad.renderer.amd import decode_inst -from tinygrad.runtime.autogen.amd.rdna3.ins import SOPP, SOPPOp - -import tinygrad -EXTRA_DIR = Path(tinygrad.__file__).parent.parent / "extra" -REMU_PATH = EXTRA_DIR / "remu/target/release/libremu.so" -if not REMU_PATH.exists(): - REMU_PATH = EXTRA_DIR / "remu/target/release/libremu.dylib" - -def get_rust_remu(): - """Load the Rust libremu shared library.""" - if not REMU_PATH.exists(): return None - remu = ctypes.CDLL(str(REMU_PATH)) - remu.run_asm.restype = ctypes.c_int32 - remu.run_asm.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, - ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_void_p] - return remu - -def count_instructions(kernel: bytes) -> int: - """Count instructions in a kernel.""" - return len(decode_program(kernel)) - -def setup_buffers(buf_sizes: list[int], init_data: dict[int, bytes] | None = None): - """Allocate buffers and return args pointer + valid ranges.""" - if init_data is None: init_data = {} - buffers = [] - for i, size in enumerate(buf_sizes): - padded = ((size + 15) // 16) * 16 + 16 - data = init_data.get(i, b'\x00' * padded) - data_list = list(data) + [0] * (padded - len(data)) - buf = (ctypes.c_uint8 * padded)(*data_list[:padded]) - buffers.append(buf) - args = (ctypes.c_uint64 * len(buffers))(*[ctypes.addressof(b) for b in buffers]) - args_ptr = ctypes.addressof(args) - ranges = {(ctypes.addressof(b), len(b)) for b in buffers} - ranges.add((args_ptr, ctypes.sizeof(args))) - return buffers, args, args_ptr, ranges - -def benchmark_emulator(name: str, run_fn, kernel: bytes, global_size, local_size, args_ptr, rsrc2: int, iterations: int = 5): - """Benchmark an emulator and return average time.""" - gx, gy, gz = global_size - lx, ly, lz = local_size - kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel) - lib_ptr = ctypes.addressof(kernel_buf) - - # Warmup - run_fn(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr, rsrc2) - - # Timed runs - times = [] - for _ in range(iterations): - start = time.perf_counter() - result = run_fn(lib_ptr, len(kernel), gx, gy, gz, lx, ly, lz, args_ptr, rsrc2) - end = time.perf_counter() - if result != 0: - print(f" {name} returned error: {result}") - return None - times.append(end - start) - - return sum(times) / len(times) - -def profile_instructions(kernel: bytes): - """Profile individual instruction compile times.""" - from tinygrad.renderer.amd.emu import _get_runner, _canonical_runner_cache - from tinygrad.helpers import Context - _get_runner.cache_clear() - _canonical_runner_cache.clear() - - results = [] - i = 0 - while i < len(kernel): - inst = decode_inst(kernel[i:]) - if isinstance(inst, SOPP) and inst.op == SOPPOp.S_CODE_END: break - inst_bytes = bytes(kernel[i:i + inst.size() + 4]) - try: inst_str = repr(inst) - except Exception: inst_str = f"<{type(inst).__name__}>" - - # Time the full compile (sink + render + compile) - start = time.perf_counter() - with Context(CCACHE=0): - runner, is_new = _get_runner(inst_bytes) - compile_time = time.perf_counter() - start - - results.append({ - 'inst_str': inst_str + ('' if is_new else ' [CACHED]'), - 'compile_ms': compile_time * 1000 if is_new else 0, - }) - i += inst.size() - - return sorted(results, key=lambda x: x['compile_ms'], reverse=True) - -def benchmark_python_split(kernel: bytes, global_size, local_size, args_ptr, rsrc2: int, iterations: int = 5): - """Benchmark Python emulator with compile and execution times.""" - from tinygrad.renderer.amd.emu import _get_runner, _canonical_runner_cache - from tinygrad.helpers import Context - _get_runner.cache_clear() - _canonical_runner_cache.clear() - decode_program.cache_clear() - - # Measure compile time (decode_program builds sinks, renders, and compiles) - compile_start = time.perf_counter() - with Context(CCACHE=0): - program = decode_program(kernel) - compile_time = time.perf_counter() - compile_start - n_compiled = len(_canonical_runner_cache) - - # Execution time - exec_time = benchmark_emulator("Python", python_run_asm, kernel, global_size, local_size, args_ptr, rsrc2, iterations) - return compile_time, exec_time, len(program), n_compiled - -def get_tinygrad_kernel(op_name: str) -> tuple[bytes, tuple, tuple, list[int], dict[int, bytes], int] | None: - """Get a real tinygrad kernel by operation name. Returns (code, global_size, local_size, buf_sizes, buf_data, rsrc2).""" - try: - from tinygrad import Tensor - from tinygrad.runtime.support.elf import elf_loader - from tinygrad.runtime.autogen import hsa - import numpy as np - np.random.seed(42) - - ops = { - "add": lambda: Tensor.empty(1024) + Tensor.empty(1024), - "mul": lambda: Tensor.empty(1024) * Tensor.empty(1024), - "matmul_small": lambda: Tensor.empty(16, 16) @ Tensor.empty(16, 16), - "matmul_medium": lambda: Tensor.empty(64, 64) @ Tensor.empty(64, 64), - "reduce_sum": lambda: Tensor.empty(4096).sum(), - "reduce_max": lambda: Tensor.empty(4096).max(), - "softmax": lambda: Tensor.empty(256).softmax(), - "layernorm": lambda: Tensor.empty(32, 64).layernorm(), - "conv2d": lambda: Tensor.empty(1, 4, 16, 16).conv2d(Tensor.empty(4, 4, 3, 3)), - "gelu": lambda: Tensor.empty(1024).gelu(), - "exp": lambda: Tensor.empty(1024).exp(), - "sin": lambda: Tensor.empty(1024).sin(), - } - - if op_name not in ops: return None - out = ops[op_name]() - sched = out.schedule() - - for ei in sched: - lowered = ei.lower() - if ei.ast.op.name == 'SINK' and lowered.prg and lowered.prg.p.lib: - lib = bytes(lowered.prg.p.lib) - image = memoryview(bytearray(lib)) - _, sections, _ = elf_loader(lib) - rodata_entry = next((sh.header.sh_addr for sh in sections if sh.name == ".rodata"), -1) - for sec in sections: - if sec.name == '.text': - buf_sizes = [b.nbytes for b in lowered.bufs] - # Get initial data from numpy arrays if available - buf_data = {} - for i, buf in enumerate(lowered.bufs): - if hasattr(buf, 'base') and buf.base is not None and hasattr(buf.base, '_buf'): - try: buf_data[i] = bytes(buf.base._buf) - except Exception: pass - # Extract rsrc2 from ELF (same as ops_amd.py) - group_segment_size = image[rodata_entry:rodata_entry+4].cast("I")[0] - lds_size = ((group_segment_size + 511) // 512) & 0x1FF - code = hsa.amd_kernel_code_t.from_buffer_copy(bytes(image[rodata_entry:rodata_entry+256]) + b'\x00'*256) - rsrc2 = code.compute_pgm_rsrc2 | (lds_size << 15) - return (bytes(sec.content), tuple(lowered.prg.p.global_size), tuple(lowered.prg.p.local_size), buf_sizes, buf_data, rsrc2) - return None - except Exception as e: - print(f" Error getting kernel: {e}") - return None - -TINYGRAD_TESTS = ["add", "mul", "reduce_sum", "softmax", "exp", "sin", "gelu", "matmul_small"] - -def main(): - import argparse - parser = argparse.ArgumentParser(description="Benchmark RDNA3 emulators") - parser.add_argument("--iterations", type=int, default=3, help="Number of iterations per benchmark") - parser.add_argument("--profile", type=str, default=None, help="Profile instructions for a specific kernel (e.g. 'sin')") - parser.add_argument("--top", type=int, default=20, help="Number of top instructions to show in profile") - args = parser.parse_args() - - # Profile mode: show individual instruction timing - if args.profile: - kernel_info = get_tinygrad_kernel(args.profile) - if kernel_info is None: - print(f"Failed to get kernel for '{args.profile}'") - return - kernel = kernel_info[0] - print(f"Profiling instructions for '{args.profile}' kernel...") - print("=" * 110) - results = profile_instructions(kernel) - print(f"{'Instruction':<90} {'Compile(ms)':>12}") - print("-" * 110) - for r in results[:args.top]: - inst = r['inst_str'][:87] + "..." if len(r['inst_str']) > 90 else r['inst_str'] - print(f"{inst:<90} {r['compile_ms']:>12.3f}") - print("-" * 110) - total = sum(r['compile_ms'] for r in results) - print(f"{'TOTAL':<90} {total:>12.3f}") - return - - rust_remu = get_rust_remu() - if rust_remu is None: - print("Rust libremu not found. Build with: cargo build --release --manifest-path extra/remu/Cargo.toml") - print("Running Python-only benchmarks...\n") - - print("=" * 90) - print("RDNA3 Emulator Benchmark: Python vs Rust") - print("=" * 90) - - results = [] - - print("\n[TINYGRAD KERNELS]") - print("-" * 90) - - for op_name in TINYGRAD_TESTS: - print(f"\n{op_name}:", end=" ", flush=True) - kernel_info = get_tinygrad_kernel(op_name) - if kernel_info is None: - print("failed to compile") - continue - - kernel, global_size, local_size, buf_sizes, buf_data, rsrc2 = kernel_info - buffers, args_arr, args_ptr, ranges = setup_buffers(buf_sizes, buf_data) - - # Benchmark Python emulator (must be first to measure compile time before cache is populated) - py_compile, py_exec, n_insts, n_compiled = benchmark_python_split(kernel, global_size, local_size, args_ptr, rsrc2, args.iterations) - - n_workgroups = global_size[0] * global_size[1] * global_size[2] - n_threads = local_size[0] * local_size[1] * local_size[2] - total_work = n_insts * n_workgroups * n_threads - - print(f"{n_insts} insts ({n_compiled} unique) × {n_workgroups} WGs × {n_threads} threads = {total_work:,} ops") - rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size, - args_ptr, rsrc2, args.iterations) if rust_remu else None - - if py_compile is not None: - py_exec_rate = total_work / py_exec / 1e6 - print(f" Compile: {py_compile*1000:8.3f} ms ({n_compiled} unique)") - print(f" Exec: {py_exec*1000:8.3f} ms ({py_exec_rate:7.2f} M ops/s)") - if rust_time: - rust_rate = total_work / rust_time / 1e6 - speedup = py_exec / rust_time if py_exec else 0 - print(f" Rust: {rust_time*1000:8.3f} ms ({rust_rate:7.2f} M ops/s) [{speedup:.1f}x faster]") - - results.append((op_name, n_insts, n_compiled, n_workgroups, py_compile, py_exec, rust_time)) - - # Summary table - print("\n" + "=" * 110) - print("SUMMARY") - print("=" * 110) - print(f"{'Name':<16} {'Insts':<6} {'Unique':<6} {'WGs':<5} {'Compile (ms)':<14} {'Exec (ms)':<12} {'Rust (ms)':<12} {'Speedup':<10}") - print("-" * 110) - - for name, n_insts, n_compiled, n_wgs, py_compile, py_exec, rust_time in results: - compile_ms = f"{py_compile*1000:.3f}" if py_compile else "error" - exec_ms = f"{py_exec*1000:.3f}" if py_exec else "error" - if rust_time: - rust_ms = f"{rust_time*1000:.3f}" - speedup = f"{py_exec/rust_time:.1f}x" if py_exec else "N/A" - else: - rust_ms, speedup = "N/A", "N/A" - print(f"{name:<16} {n_insts:<6} {n_compiled:<6} {n_wgs:<5} {compile_ms:<14} {exec_ms:<12} {rust_ms:<12} {speedup:<10}") - -if __name__ == "__main__": - os.environ["AMD"] = "1" - main() diff --git a/test/amd/test_compare_emulators.py b/test/amd/test_compare_emulators.py index 78b6e77213..3d81cfb4c9 100644 --- a/test/amd/test_compare_emulators.py +++ b/test/amd/test_compare_emulators.py @@ -1,12 +1,15 @@ # Test to compare Python and Rust RDNA3 emulators by running real tinygrad kernels import unittest, ctypes from dataclasses import dataclass +from pathlib import Path from tinygrad import Device -from tinygrad.renderer.amd.emu import WaveState, decode_program, WAVE_SIZE, VCC_LO, EXEC_LO, SCC +from tinygrad.renderer.amd.emu import WaveState, _decode_at, WAVE_SIZE, VCC_LO, EXEC_LO, SCC from tinygrad.renderer.amd import decode_inst from test.amd.helpers import KernelInfo -from test.amd.bench_emu import REMU_PATH +import tinygrad +REMU_PATH = Path(tinygrad.__file__).parent.parent / "extra/remu/target/release/libremu.so" +if not REMU_PATH.exists(): REMU_PATH = Path(tinygrad.__file__).parent.parent / "extra/remu/target/release/libremu.dylib" def set_valid_mem_ranges(ranges): pass # emu2 doesn't need this @@ -89,7 +92,7 @@ class RustEmulator: class PythonEmulator: def __init__(self): self.state: WaveState | None = None - self.program: dict | None = None + self.program: dict[int, tuple] = {} # lazily populated: pc -> (name, fxn, globals) self.vmem_buf = None self.lds_buf = None self.kernel_buf = None # Keep kernel bytes alive @@ -99,27 +102,29 @@ class PythonEmulator: import ctypes from tinygrad.device import Buffer, BufferSpec from tinygrad.dtype import dtypes - # Store kernel in a ctypes buffer so generic instructions can read from vmem at actual PC address + # Store kernel in a ctypes buffer so _decode_at can read from memory at actual PC address self.kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel) self.lib_addr = ctypes.addressof(self.kernel_buf) - # Remap program dict to use actual addresses (like run_asm does) - program_raw = decode_program(kernel) - self.program = {self.lib_addr + offset: val for offset, val in program_raw.items()} + self.program = {} self.state = WaveState(n_lanes) self.state.pc = self.lib_addr # Set PC to code base address self.vmem_buf = Buffer('CPU', 1 << 40, dtypes.uint32, options=BufferSpec(external_ptr=0)).ensure_allocated() self.lds_buf = Buffer('CPU', 65536 // 4, dtypes.uint32).ensure_allocated() + def _ensure_decoded(self, pc: int): + if pc not in self.program: + runner = _decode_at(pc, "rdna3") + self.program[pc] = (runner.p.function_name, runner._prg.fxn, runner.p.globals) + def step(self) -> int: import ctypes - assert self.program is not None and self.state is not None + assert self.state is not None pc = self.state.pc - if pc == 0xFFFFFFFFFFFFFFFF or pc not in self.program: return -1 - name, fxn, globals_list, _runner = self.program[pc] - if fxn is None: return 1 # unsupported instruction + if pc == 0xFFFFFFFFFFFFFFFF: return -1 + self._ensure_decoded(pc) + name, fxn, globals_list = self.program[pc] buf_addrs = {0: self.state.sgpr_buf._buf.va_addr, 1: self.state.vgpr_buf._buf.va_addr, # type: ignore[union-attr] 2: self.vmem_buf._buf.va_addr, 3: self.lds_buf._buf.va_addr} # type: ignore[union-attr] - # Direct ctypes call - bypasses HCQ overhead fxn(*[ctypes.c_uint64(buf_addrs[g]) for g in globals_list], ctypes.c_int32(0)) return -1 if self.state.pc == 0xFFFFFFFFFFFFFFFF else 0 @@ -140,7 +145,7 @@ class PythonEmulator: exec_mask=sgpr[EXEC_LO.offset], sgpr=sgpr, vgpr=vgpr) def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: tuple[int, int, int], - local_size: tuple[int, int, int], program, max_steps: int, debug: bool, trace_len: int, + local_size: tuple[int, int, int], max_steps: int, debug: bool, trace_len: int, kernel_idx: int = 0, max_workgroups: int = 8) -> tuple[bool, str, int]: """Run a single kernel through both emulators. Returns (success, message, total_steps).""" gx, gy, gz = global_size @@ -181,9 +186,9 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t rust_before = rust.get_snapshot() python_before = python.get_snapshot() - assert python.program is not None - inst_info = python.program.get(python.lib_addr + python_before.pc * 4) # Convert word offset to actual address - inst_hex_name = inst_info[0] if inst_info else f"unknown at PC={python_before.pc}" + pc_addr = python.lib_addr + python_before.pc * 4 # Convert word offset to actual address + python._ensure_decoded(pc_addr) + inst_hex_name = python.program[pc_addr][0] # Decode the instruction to get mnemonic for sync_after checks try: # Format is mnemonic_hexbytes, e.g. v_exp_f32_e32_014b027e -> hex is 014b027e @@ -310,12 +315,11 @@ def compare_emulators_multi_kernel(kernels: list[KernelInfo], buf_pool: dict[int kernel_ranges = ranges | {(args_ptr, ctypes.sizeof(args))} set_valid_mem_ranges(kernel_ranges) - program = decode_program(kernel.code) n_lanes = kernel.local_size[0] * kernel.local_size[1] * kernel.local_size[2] ok, msg, steps = run_single_kernel( kernel.code, min(n_lanes, 32), args_ptr, kernel.global_size, - kernel.local_size, program, max_steps, debug, trace_len, ki + kernel.local_size, max_steps, debug, trace_len, ki ) total_steps += steps if not ok: @@ -341,9 +345,8 @@ def compare_emulators_with_memory(kernel: bytes, n_lanes: int, buf_sizes: list, ranges.add((args_ptr, ctypes.sizeof(args))) set_valid_mem_ranges(ranges) - program = decode_program(kernel) # Legacy wrapper assumes local_size = (n_lanes, 1, 1) - ok, msg, _ = run_single_kernel(kernel, n_lanes, args_ptr, global_size, (n_lanes, 1, 1), program, max_steps, debug, trace_len) + ok, msg, _ = run_single_kernel(kernel, n_lanes, args_ptr, global_size, (n_lanes, 1, 1), max_steps, debug, trace_len) return ok, msg def get_kernels_from_tinygrad(op_fn) -> tuple[list[KernelInfo], dict[int, int], dict[int, bytes]]: diff --git a/test/amd/test_rdna4_emu.py b/test/amd/test_rdna4_emu.py deleted file mode 100644 index 7ef5666dab..0000000000 --- a/test/amd/test_rdna4_emu.py +++ /dev/null @@ -1,96 +0,0 @@ -import unittest, ctypes -from tinygrad.runtime.autogen.amd.rdna4 import ins as ir4 -from tinygrad.renderer.amd.dsl import v, s -from tinygrad.renderer.amd.emu import WaveState, decode_program -from tinygrad.device import Buffer, BufferSpec -from tinygrad.dtype import dtypes - -class TestRDNA4Emu(unittest.TestCase): - def _run(self, insts: list, sgprs: dict[int, int] | None = None, vgprs: dict[tuple[int, int], int] | None = None) -> WaveState: - """Run instructions and return final WaveState.""" - # Add S_ENDPGM if not present - if not any(isinstance(i, ir4.SOPP) and i.op == ir4.SOPPOp.S_ENDPGM for i in insts): - insts = list(insts) + [ir4.SOPP(ir4.SOPPOp.S_ENDPGM, simm=0)] - - # Assemble and decode - code = b''.join(i.to_bytes() for i in insts) - code_buf = (ctypes.c_uint8 * len(code)).from_buffer_copy(code) - code_addr = ctypes.addressof(code_buf) - program_raw = decode_program(code, "rdna4") - program = {code_addr + offset: val for offset, val in program_raw.items()} - - # Setup wave state - st = WaveState(n_lanes=1) - st.pc = code_addr - for idx, val in (sgprs or {}).items(): st._write_sgpr(idx, val) - for (reg, lane), val in (vgprs or {}).items(): st._write_vgpr(reg, lane, val) - - # Setup vmem buffer with external_ptr=0 (maps to address 0, allows any pointer access) - vmem_buf = Buffer('CPU', 1 << 40, dtypes.uint32, options=BufferSpec(external_ptr=0)).ensure_allocated() - - # Execute - c_bufs = [ctypes.c_uint64(st.sgpr_buf._buf.va_addr), ctypes.c_uint64(st.vgpr_buf._buf.va_addr), - ctypes.c_uint64(vmem_buf._buf.va_addr), ctypes.c_uint64(0), ctypes.c_uint64(0)] - for _ in range(100): - if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF or pc not in program: break - _, fxn, globals_list, _ = program[pc] - fxn(*[c_bufs[g] for g in globals_list]) - return st - - def test_vopd_dual_mov(self): - """Test VOPD with two V_DUAL_MOV_B32 operations: v[1]=s[1], v[2]=s[2].""" - insts = [ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32, - vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0])] - st = self._run(insts, sgprs={1: 0x40e00000, 2: 0x41100000}) # 7.0f, 9.0f - self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0 - self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0 - - def test_vopd_dual_mov_after_other_vopd(self): - """Test VOPD reuse: first VOPD(v[3]=0, v[0]=?), then VOPD(v[1]=s[1], v[2]=s[2]).""" - # This matches the BEAM kernel sequence that fails - insts = [ - ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32, - vdstx=v[3], vdsty=v[0], srcx0=0, srcy0=s[0], vsrcx1=v[0], vsrcy1=v[0]), # v[3]=0, v[0]=s[0] - ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32, - vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0]), # v[1]=s[1], v[2]=s[2] - ] - st = self._run(insts, sgprs={0: 0x40a00000, 1: 0x40e00000, 2: 0x41100000}) # 5.0f, 7.0f, 9.0f - self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0 - self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0 - - def test_vopd_with_s_add_f32_sequence(self): - """Test full BEAM kernel sequence: s_add_f32 then VOPD.""" - # This is the exact sequence from the failing BEAM kernel - insts = [ - ir4.SOP2(ir4.SOP2Op.S_ADD_F32, sdst=s[0], ssrc0=s[0], ssrc1=s[8]), # s[0] = s[0] + s[8] - ir4.SOP2(ir4.SOP2Op.S_ADD_F32, sdst=s[1], ssrc0=s[1], ssrc1=s[9]), # s[1] = s[1] + s[9] - ir4.SOP2(ir4.SOP2Op.S_ADD_F32, sdst=s[2], ssrc0=s[2], ssrc1=s[10]), # s[2] = s[2] + s[10] - ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32, - vdstx=v[3], vdsty=v[0], srcx0=0, srcy0=s[0], vsrcx1=v[0], vsrcy1=v[0]), - ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32, - vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0]), - ] - # Input: s[0:2] = [1,2,3], s[8:10] = [4,5,6] - # After s_add_f32: s[0:2] = [5,7,9] - st = self._run(insts, sgprs={0: 0x3f800000, 1: 0x40000000, 2: 0x40400000, # 1.0, 2.0, 3.0 - 8: 0x40800000, 9: 0x40a00000, 10: 0x40c00000}) # 4.0, 5.0, 6.0 - self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0 - self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0 - - def test_s_mov_b32_then_vopd(self): - """Test s_mov_b32 followed by VOPD - simulates BEAM kernel sequence.""" - # Use s_mov_b32 with SGPR source (copy from pre-initialized SGPRs) - # s[10:12] will have values set by test harness, copy to s[0:2], then VOPD to VGPRs - insts = [ - ir4.SOP1(ir4.SOP1Op.S_MOV_B32, sdst=s[0], ssrc0=s[10]), # s[0] = s[10] - ir4.SOP1(ir4.SOP1Op.S_MOV_B32, sdst=s[1], ssrc0=s[11]), # s[1] = s[11] - ir4.SOP1(ir4.SOP1Op.S_MOV_B32, sdst=s[2], ssrc0=s[12]), # s[2] = s[12] - ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32, - vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0]), - ] - st = self._run(insts, sgprs={10: 0x40a00000, 11: 0x40e00000, 12: 0x41100000}) # 5.0, 7.0, 9.0 - self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0 - self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0 - -if __name__ == '__main__': - unittest.main() diff --git a/test/mockgpu/amd/amddriver.py b/test/mockgpu/amd/amddriver.py index b5801f823c..d58a9f4a3e 100644 --- a/test/mockgpu/amd/amddriver.py +++ b/test/mockgpu/amd/amddriver.py @@ -90,9 +90,9 @@ class AMDDriver(VirtDriver): def _prepare_gpu(self, gpu_id): self.doorbells[gpu_id] = memoryview(bytearray(0x2000)) self.gpus[gpu_id] = AMDGPU(gpu_id) - # IP versions: rdna3 = GC 11.0.0, NBIF 4.3.0; rdna4 = GC 12.0.0, NBIF 6.3.1 ip_versions = {"rdna3": {"gc": (11, 0, 0), "sdma": (6, 0, 0), "nbif": (4, 3, 0)}, - "rdna4": {"gc": (12, 0, 0), "sdma": (6, 0, 0), "nbif": (6, 3, 1)}}[MOCKGPU_ARCH] + "rdna4": {"gc": (12, 0, 0), "sdma": (6, 0, 0), "nbif": (6, 3, 1)}, + "cdna4": {"gc": (9, 5, 0), "sdma": (4, 4, 5), "nbif": (7, 9, 0)}}[MOCKGPU_ARCH] def ip_discovery_files(hwid, ver, base_addr): p = f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{hwid}/0' return [VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{hwid}', functools.partial(DirFileDesc, child_names=['0'])), diff --git a/test/mockgpu/amd/amdgpu.py b/test/mockgpu/amd/amdgpu.py index f1c752c06c..6a15392a72 100644 --- a/test/mockgpu/amd/amdgpu.py +++ b/test/mockgpu/amd/amdgpu.py @@ -5,7 +5,7 @@ from tinygrad.helpers import getbits, to_mv, getenv from tinygrad.runtime.support import c MOCKGPU_ARCH = getenv("MOCKGPU_ARCH", "rdna3") -GFX_TARGET_VERSION = {"rdna3": 110000, "rdna4": 120000}[MOCKGPU_ARCH] +GFX_TARGET_VERSION = {"rdna3": 110000, "rdna4": 120000, "cdna4": 90500}[MOCKGPU_ARCH] import tinygrad.runtime.autogen.amd_gpu as amd_gpu, tinygrad.runtime.autogen.am.pm4_nv as pm4 SDMA_MAX_COPY_SIZE = 0x400000 @@ -106,8 +106,8 @@ class PM4Executor(AMDQueue): return (self.rptr[0] - prev_rptr) + executed_in_ib def _exec_acquire_mem(self, n): - assert n == 6 - for _ in range(7): self._next_dword() # TODO: implement + assert n in (5, 6) + for _ in range(n + 1): self._next_dword() # TODO: implement def _exec_release_mem(self, n): assert n == 6 @@ -184,6 +184,12 @@ class PM4Executor(AMDQueue): args_addr = self.gpu.regs[regCOMPUTE_USER_DATA_0] + (self.gpu.regs[regCOMPUTE_USER_DATA_0 + 1] << 32) lc = [self.gpu.regs[i] for i in range(regCOMPUTE_NUM_THREAD_X, regCOMPUTE_NUM_THREAD_X+3)] rsrc2 = self.gpu.regs[regCOMPUTE_PGM_RSRC2] + # Read all user data registers (hardware loads these directly into s[0:N]) + user_sgpr_count = (rsrc2 >> 1) & 0x1F # USER_SGPR_COUNT is bits 1:5 + user_data = [] + for i in range(user_sgpr_count): + try: user_data.append(self.gpu.regs[regCOMPUTE_USER_DATA_0 + i]) + except KeyError: user_data.append(0) prg_sz = 0 for st,sz in self.gpu.mapped_ranges: @@ -197,11 +203,12 @@ class PM4Executor(AMDQueue): scratch_size = wavesize * 4 # This gives the scratch size per thread (lane) assert prg_sz > 0, "Invalid prg ptr (not found in mapped ranges)" - # Pass valid memory ranges, rsrc2, scratch_size and arch to Python emulator + # Pass valid memory ranges, rsrc2, scratch_size, arch, and user data registers to Python emulator if hasattr(remu, 'valid_mem_ranges'): remu.valid_mem_ranges = self.gpu.mapped_ranges if hasattr(remu, 'rsrc2'): remu.rsrc2 = rsrc2 if hasattr(remu, 'scratch_size'): remu.scratch_size = scratch_size if hasattr(remu, 'arch'): remu.arch = self.gpu.arch + if hasattr(remu, 'user_data'): remu.user_data = user_data err = remu.run_asm(prg_addr, prg_sz, *gl, *lc, args_addr) if err != 0: raise RuntimeError("remu does not support the new instruction introduced in this kernel") @@ -318,7 +325,7 @@ class AMDGPU(VirtGPU): self.regs = AMDGPURegisters() self.mapped_ranges = set() self.queues = [] - self.arch = MOCKGPU_ARCH + self.arch = "cdna" if MOCKGPU_ARCH == "cdna4" else MOCKGPU_ARCH def map_range(self, vaddr, size): self.mapped_ranges.add((vaddr, size)) def unmap_range(self, vaddr, size): self.mapped_ranges.remove((vaddr, size)) @@ -329,7 +336,7 @@ class AMDGPU(VirtGPU): self.queues.append(SDMAExecutor(self, base, size, rptr, wptr)) return len(self.queues) - 1 -gpu_props = """cpu_cores_count 0 +_gpu_props_rdna = """cpu_cores_count 0 simd_count 192 mem_banks_count 1 caches_count 206 @@ -367,3 +374,44 @@ sdma_fw_version 20 unique_id 11673270660693242239 num_xcc 1 max_engine_clk_ccompute 2400""" + +_gpu_props_cdna = """cpu_cores_count 0 +simd_count 304 +mem_banks_count 1 +caches_count 206 +io_links_count 1 +p2p_links_count 5 +cpu_core_id_base 0 +simd_id_base 2147488032 +max_waves_per_simd 16 +lds_size_in_kb 128 +gds_size_in_kb 0 +num_gws 64 +wave_front_size 64 +array_count 16 +simd_arrays_per_engine 4 +cu_per_simd_array 19 +simd_per_cu 2 +max_slots_scratch_cu 32 +gfx_target_version {gfx_target_version} +vendor_id 4098 +device_id 29772 +location_id 34304 +domain 0 +drm_render_minor {drm_render_minor} +hive_id 0 +num_sdma_engines 2 +num_sdma_xgmi_engines 0 +num_sdma_queues_per_engine 6 +num_cp_queues 8 +max_engine_clk_fcompute 2100 +local_mem_size 0 +fw_version 2140 +capability 671588992 +debug_prop 1495 +sdma_fw_version 20 +unique_id 11673270660693242239 +num_xcc 1 +max_engine_clk_ccompute 2100""" + +gpu_props = _gpu_props_cdna if MOCKGPU_ARCH == "cdna4" else _gpu_props_rdna diff --git a/test/mockgpu/helpers.py b/test/mockgpu/helpers.py index dbad50f0d4..d135a92522 100644 --- a/test/mockgpu/helpers.py +++ b/test/mockgpu/helpers.py @@ -21,10 +21,11 @@ class PythonRemu: rsrc2: int = 0x19c # Default: USER_SGPR_COUNT=14, enable X and Y workgroup IDs scratch_size: int = 0 # private_segment_fixed_size from kernel descriptor arch: str = "rdna3" # Architecture: rdna3 or rdna4 + user_data: list[int] = [] # All COMPUTE_USER_DATA registers (loaded into s[0:N]) def run_asm(self, lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int) -> int: from tinygrad.renderer.amd.emu import run_asm - return run_asm(lib, lib_sz, gx, gy, gz, lx, ly, lz, args_ptr, self.rsrc2, self.scratch_size, self.arch) + return run_asm(lib, lib_sz, gx, gy, gz, lx, ly, lz, args_ptr, self.rsrc2, self.scratch_size, self.arch, self.user_data) def _try_dlopen_remu(): # Use Python emulator only if PYTHON_REMU=1 diff --git a/tinygrad/renderer/amd/emu.py b/tinygrad/renderer/amd/emu.py index 08637cf578..efd6c14fc1 100644 --- a/tinygrad/renderer/amd/emu.py +++ b/tinygrad/renderer/amd/emu.py @@ -7,7 +7,7 @@ # arg=4: scratch - per-lane scratch memory from __future__ import annotations import ctypes, functools, re, platform, subprocess, tempfile -from typing import Any, Callable +from typing import Callable # Set/restore DAZ+FTZ (denormals-are-zero + flush-to-zero) to match RDNA3 default float mode # x86: MXCSR bits DAZ(6)+FTZ(15), ARM64: FPCR bit FZ(24) @@ -61,8 +61,10 @@ from tinygrad.engine.realize import get_runner from tinygrad.renderer.amd import decode_inst from tinygrad.runtime.autogen.amd.rdna3.str_pcode import PCODE as PCODE_RDNA3 from tinygrad.runtime.autogen.amd.rdna4.str_pcode import PCODE as PCODE_RDNA4 +from tinygrad.runtime.autogen.amd.cdna.str_pcode import PCODE as PCODE_CDNA from tinygrad.runtime.autogen.amd.rdna3 import ins as ir3 from tinygrad.runtime.autogen.amd.rdna4 import ins as ir4 +from tinygrad.runtime.autogen.amd.cdna import ins as irc from tinygrad.renderer.amd.dsl import VCC_LO, EXEC_LO, SCC, ttmp from tinygrad.runtime.autogen.amd.common import Fmt, OpType from tinygrad.renderer.amd.pcode import parse_block, _FUNCS @@ -160,7 +162,7 @@ _pcode_fixes = { def _get_pcode_dict(op) -> dict: """Return the PCODE dictionary for the given opcode based on its architecture.""" - return PCODE_RDNA4 if 'rdna4' in type(op).__module__ else PCODE_RDNA3 + return PCODE_CDNA if 'cdna' in type(op).__module__ else PCODE_RDNA4 if 'rdna4' in type(op).__module__ else PCODE_RDNA3 # Pcode parser @functools.cache @@ -465,8 +467,8 @@ class _Ctx: pcode = get_pcode(op) vcc_reg = sdst_reg if sdst_reg is not None else VCC_LO.offset if 'VCC' not in srcs: srcs['VCC'] = self.rsgpr_dyn(_c(vcc_reg)) - srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane, - 'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0)}) # rounding mode: 0=RNE, RTZ constant + srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane, 'VDST': vdst_reg, + 'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0), 'ROUND_NEAREST_EVEN': _c(0)}) # rounding mode constants _, assigns = parse_pcode(pcode, srcs) # For integer ops with clamp, compute overflow using wide arithmetic @@ -543,10 +545,11 @@ class _Ctx: def _compile_sopp(inst: ir3.SOPP | ir4.SOPP, ctx: _Ctx) -> UOp: simm16 = ctx.inst_field_signed(type(inst).simm16).cast(dtypes.int16) - if inst.op in (ir3.SOPPOp.S_ENDPGM, ir4.SOPPOp.S_ENDPGM): + if inst.op in (ir3.SOPPOp.S_ENDPGM, ir4.SOPPOp.S_ENDPGM, irc.SOPPOp.S_ENDPGM): return UOp.sink(ctx.wsgpr_dyn(_c(PC_LO_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)), ctx.wsgpr_dyn(_c(PC_HI_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF))) - if inst.op in (ir3.SOPPOp.S_NOP, ir4.SOPPOp.S_NOP): return UOp.sink(*ctx.inc_pc()) # S_NOP is a no-op + # S_NOP and S_WAITCNT are no-ops in emulator (no pipeline/cache to wait on) + if inst.op in (ir3.SOPPOp.S_NOP, ir4.SOPPOp.S_NOP, irc.SOPPOp.S_NOP, irc.SOPPOp.S_WAITCNT): return UOp.sink(*ctx.inc_pc()) # NOTE: we ignore SOPPs without PCODE if inst.op in _get_pcode_dict(inst.op): pcode = get_pcode(inst.op) @@ -562,10 +565,7 @@ def _compile_sopp(inst: ir3.SOPP | ir4.SOPP, ctx: _Ctx) -> UOp: def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp: # Cache invalidation instructions are no-ops in the emulator (we don't model caches) - cache_inv_ops = [ir3.SMEMOp.S_GL1_INV, ir3.SMEMOp.S_DCACHE_INV, ir4.SMEMOp.S_DCACHE_INV] - if hasattr(ir4.SMEMOp, 'S_GL1_INV'): cache_inv_ops.append(ir4.SMEMOp.S_GL1_INV) - if inst.op in cache_inv_ops: - return UOp.sink(*ctx.inc_pc()) + if '_INV' in inst.op.name: return UOp.sink(*ctx.inc_pc()) # Dynamic sbase field (bits 5:0) - SGPR pair, field value * 2 = register offset sbase = ctx.inst_field(type(inst).sbase) * _c(2) # Dynamic sdata field (bits 12:6) - destination SGPR @@ -573,34 +573,44 @@ def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp: # RDNA4 uses 'ioffset', RDNA3 uses 'offset' - use type(inst) to get correct field offset_field = type(inst).ioffset if hasattr(type(inst), 'ioffset') else type(inst).offset # type: ignore[union-attr] offset = ctx.inst_field_signed(offset_field) # signed immediate - # Dynamic soffset field - SGPR for additional offset (NULL=124 reads as 0) - soffset = ctx.inst_field(type(inst).soffset) - addr = _u64(ctx.rsgpr_dyn(sbase), ctx.rsgpr_dyn(sbase + _c(1))) + offset.cast(dtypes.uint64) + ctx.rsgpr_dyn(soffset).cast(dtypes.uint64) + # Dynamic soffset field - SGPR for additional offset (NULL=124 reads as 0, CDNA soffset_en=0 means no soffset) + soffset_val = _c(0).cast(dtypes.uint64) + if not (isinstance(inst, irc.SMEM) and not inst.soffset_en): + soffset_val = ctx.inst_field(type(inst).soffset) + soffset_val = ctx.rsgpr_dyn(soffset_val).cast(dtypes.uint64) + addr = _u64(ctx.rsgpr_dyn(sbase), ctx.rsgpr_dyn(sbase + _c(1))) + offset.cast(dtypes.uint64) + soffset_val _SMEM_NDWORDS = {ir3.SMEMOp.S_LOAD_B32: 1, ir3.SMEMOp.S_LOAD_B64: 2, ir3.SMEMOp.S_LOAD_B128: 4, ir3.SMEMOp.S_LOAD_B256: 8, ir3.SMEMOp.S_LOAD_B512: 16, ir4.SMEMOp.S_LOAD_B32: 1, ir4.SMEMOp.S_LOAD_B64: 2, - ir4.SMEMOp.S_LOAD_B96: 3, ir4.SMEMOp.S_LOAD_B128: 4, ir4.SMEMOp.S_LOAD_B256: 8, ir4.SMEMOp.S_LOAD_B512: 16} + ir4.SMEMOp.S_LOAD_B96: 3, ir4.SMEMOp.S_LOAD_B128: 4, ir4.SMEMOp.S_LOAD_B256: 8, ir4.SMEMOp.S_LOAD_B512: 16, + irc.SMEMOp.S_LOAD_DWORD: 1, irc.SMEMOp.S_LOAD_DWORDX2: 2, irc.SMEMOp.S_LOAD_DWORDX4: 4, + irc.SMEMOp.S_LOAD_DWORDX8: 8, irc.SMEMOp.S_LOAD_DWORDX16: 16} ndwords = _SMEM_NDWORDS[inst.op] stores = [ctx.wsgpr_dyn(sdata_reg + _c(i), ctx.vmem.index((addr + UOp.const(dtypes.uint64, i * 4) >> UOp.const(dtypes.uint64, 2)).cast(dtypes.int))) for i in range(ndwords)] return UOp.sink(*stores, *ctx.inc_pc()) -def _compile_sop(inst: ir3.SOP1 | ir3.SOP2 | ir3.SOPC | ir3.SOPK | ir4.SOP1 | ir4.SOP2 | ir4.SOPC | ir4.SOPK, ctx: _Ctx) -> UOp: +def _compile_sop(inst: ir3.SOP1|ir3.SOP2|ir3.SOPC|ir3.SOPK|ir4.SOP1|ir4.SOP2|ir4.SOPC|ir4.SOPK|irc.SOP1|irc.SOP2|irc.SOPC|irc.SOPK, ctx: _Ctx) -> UOp: bits = inst.canonical_op_bits literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr] - if isinstance(inst, (ir3.SOPK, ir4.SOPK)): + if isinstance(inst, (ir3.SOPK, ir4.SOPK, irc.SOPK)): sdst_off = ctx.inst_field(type(inst).sdst) simm16 = ctx.inst_field(type(inst).simm16) # Sign-extend simm16 simm16_sext = simm16.cast(dtypes.int16).cast(dtypes.int32) - srcs = {'S0': ctx.rsgpr_dyn(sdst_off), 'SIMM16': simm16_sext, 'D0': ctx.rsgpr_dyn(sdst_off)} + # CDNA pcode uses S0 for the immediate in MOVK/MULK/ADDK/CMOVK (where RDNA uses SIMM16), + # but S0 = register for CMPK/SETREG. S1 is always the immediate for CDNA CMPK ops. + op_name = inst.op.name if hasattr(inst.op, 'name') else '' + s0_is_imm = isinstance(inst, irc.SOPK) and 'CMPK' not in op_name and 'SETREG' not in op_name + s0_val = simm16_sext if s0_is_imm else ctx.rsgpr_dyn(sdst_off) + srcs = {'S0': s0_val, 'SIMM16': simm16_sext, 'S1': simm16_sext, 'D0': ctx.rsgpr_dyn(sdst_off)} dst_off, dst_size = sdst_off, 1 - elif isinstance(inst, (ir3.SOP1, ir4.SOP1)): + elif isinstance(inst, (ir3.SOP1, ir4.SOP1, irc.SOP1)): sdst_off = ctx.inst_field(type(inst).sdst) ssrc0_off = ctx.inst_field(type(inst).ssrc0) srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal)} dst_off, dst_size = sdst_off, bits['d'] // 32 - elif isinstance(inst, (ir3.SOP2, ir4.SOP2)): + elif isinstance(inst, (ir3.SOP2, ir4.SOP2, irc.SOP2)): sdst_off = ctx.inst_field(type(inst).sdst) ssrc0_off = ctx.inst_field(type(inst).ssrc0) ssrc1_off = ctx.inst_field(type(inst).ssrc1) @@ -608,7 +618,7 @@ def _compile_sop(inst: ir3.SOP1 | ir3.SOP2 | ir3.SOPC | ir3.SOPK | ir4.SOP1 | ir 'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)} if literal is not None: srcs['SIMM32'] = literal dst_off, dst_size = sdst_off, bits['d'] // 32 - elif isinstance(inst, (ir3.SOPC, ir4.SOPC)): + elif isinstance(inst, (ir3.SOPC, ir4.SOPC, irc.SOPC)): ssrc0_off = ctx.inst_field(type(inst).ssrc0) ssrc1_off = ctx.inst_field(type(inst).ssrc1) srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal), @@ -619,7 +629,7 @@ def _compile_sop(inst: ir3.SOP1 | ir3.SOP2 | ir3.SOPC | ir3.SOPK | ir4.SOP1 | ir return ctx.compile_sop_pcode(inst.op, srcs, dst_off, dst_size) -def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VOP1_SDST | ir4.VOP2, ctx: _Ctx) -> UOp: +def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VOP1_SDST | ir4.VOP2 | irc.VOP1 | irc.VOP2, ctx: _Ctx) -> UOp: op_name = _op_name(inst) if op_name in ('V_READFIRSTLANE_B32_E32', 'V_PERMLANE64_B32_E32'): return ctx.compile_lane_pcode(inst.op, inst) lane, exec_mask, bits = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset)), inst.canonical_op_bits @@ -628,7 +638,7 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO write_hi_half = bits['d'] == 16 and (vdst_reg >= _c(128)) if isinstance(write_hi_half, UOp): vdst_reg = write_hi_half.where(vdst_reg - _c(128), vdst_reg) elif write_hi_half: vdst_reg -= 128 - if isinstance(inst, (ir3.VOP1, ir4.VOP1)): + if isinstance(inst, (ir3.VOP1, ir4.VOP1, irc.VOP1)): # Handle VOP1 hi-half source operand (src0 >= v[128] for 16-bit ops) src0_off = ctx.inst_field(type(inst).src0) s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal) @@ -654,12 +664,13 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0) srcs = {'S0': s0, 'S1': s1, 'D0': d0} if inst.op in (ir3.VOP2Op.V_FMAAK_F32_E32, ir3.VOP2Op.V_FMAMK_F32_E32, ir3.VOP2Op.V_FMAAK_F16_E32, - ir3.VOP2Op.V_FMAMK_F16_E32): + ir3.VOP2Op.V_FMAMK_F16_E32, irc.VOP2Op.V_FMAAK_F32_E32, irc.VOP2Op.V_FMAMK_F32_E32): assert literal is not None srcs['SIMM32'] = literal return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=write_hi_half) -def _compile_vopc(inst: ir3.VOPC | ir3.VOP3 | ir4.VOPC | ir4.VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp: +def _compile_vopc(inst: ir3.VOPC|ir3.VOP3|ir4.VOPC|ir4.VOP3|irc.VOPC|irc.VOP3, ctx: _Ctx, + opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp: exec_mask, op_name, bits = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst), inst.canonical_op_bits is_cmpx, is_vopc = 'CMPX' in op_name, hasattr(inst, 'vsrc1') # is_vopc: e32 vs e64 @@ -707,7 +718,7 @@ def _compile_vopc(inst: ir3.VOPC | ir3.VOP3 | ir4.VOPC | ir4.VOP3, ctx: _Ctx, op stores = [ctx.wsgpr_dyn(dst_off, new_result)] if not is_vopc else [ctx.wsgpr_dyn(_c(VCC_LO.offset), new_result)] return UOp.sink(*stores, *ctx.inc_pc()) -def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3, ctx: _Ctx) -> UOp: +def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3 | irc.VOP3, ctx: _Ctx) -> UOp: exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset)) bits = inst.canonical_op_bits opsel, op_name = getattr(inst, 'opsel', 0) or 0, _op_name(inst) @@ -741,13 +752,13 @@ def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3, ctx: _Ctx) -> UOp: src1 = _apply_src_mods(src1, 1, abs_bits, neg_bits, bits['s1']) src2 = _apply_src_mods(src2, 2, abs_bits, neg_bits, bits['s2']) srcs = {'S0': src0, 'S1': src1, 'S2': src2} - if inst.op in (ir3.VOP3Op.V_CNDMASK_B32_E64, ir3.VOP3Op.V_CNDMASK_B16) and src2 is not None: srcs['VCC'] = src2 + if inst.op in (ir3.VOP3Op.V_CNDMASK_B32_E64, ir3.VOP3Op.V_CNDMASK_B16, irc.VOP3Op.V_CNDMASK_B32_E64) and src2 is not None: srcs['VCC'] = src2 # FMAC instructions need D0 (accumulator) from destination register if 'FMAC' in op_name: srcs['D0'] = ctx.rvgpr_dyn(vdst_reg, lane) opsel_dst_hi = bool(opsel & 0b1000) and bits['d'] == 16 return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=opsel_dst_hi, clmp=getattr(inst, 'clmp', 0)) -def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD, ctx: _Ctx) -> UOp: +def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD | irc.VOP3SD, ctx: _Ctx) -> UOp: exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset)) bits, pcode, ops = inst.canonical_op_bits, get_pcode(inst.op), inst.canonical_operands @@ -806,7 +817,7 @@ def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD, ctx: _Ctx) -> UOp: else: return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, sdst_reg=inst.sdst.offset) -def _compile_wmma(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp: +def _compile_wmma(inst: ir3.VOP3P | ir4.VOP3P | irc.VOP3P, ctx: _Ctx) -> UOp: op_name = _op_name(inst) exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset)) vdst_reg = ctx.inst_field(type(inst).vdst) @@ -839,14 +850,15 @@ def _compile_wmma(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp: stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32), mat_d[i].bitcast(dtypes.uint32), exec_mask) for i in range(256)] return UOp.sink(*stores, *ctx.inc_pc()) -def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp: +def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P | irc.VOP3P, ctx: _Ctx) -> UOp: op_name = _op_name(inst) if 'WMMA' in op_name and ('16X16X16_F16' in op_name or '16X16X16_BF16' in op_name): return _compile_wmma(inst, ctx) lane = ctx.range() exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset)) vdst_reg = ctx.inst_field(type(inst).vdst) - do_cast = any(x in op_name for x in ('F16', 'F32', 'BF16')) and 'IU' not in op_name + is_pk_f32 = 'PK' in op_name and 'F32' in op_name and 'MOV' not in op_name # CDNA packed F32 ops + do_cast = any(x in op_name for x in ('F16', 'F32', 'BF16')) and 'IU' not in op_name and not is_pk_f32 src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, 16, do_cast=do_cast) src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, 16, do_cast=do_cast) src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, 16, do_cast=do_cast) @@ -854,7 +866,30 @@ def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp: opsel_hi2 = getattr(inst, 'opsel_hi2', 1) if getattr(inst, 'opsel_hi2', 1) is not None else 1 neg, neg_hi = getattr(inst, 'neg', 0) or 0, getattr(inst, 'neg_hi', 0) or 0 - if 'FMA_MIX' in op_name: + if is_pk_f32: + # CDNA packed F32: read 32-bit sources, build 64-bit packed values using opsel. + # For VGPRs: opsel selects between v[reg] (0) and v[reg+1] (1) for each half. + # For SGPR pairs (off < 128): s[N] = lo float32, s[N+1] = hi float32. + # For inline constants (128 <= off < 256): broadcast same value to both halves. + src_offs = [ctx.inst_field(type(inst).src0), ctx.inst_field(type(inst).src1), ctx.inst_field(type(inst).src2)] + def build_pk_f32(src_lo: UOp, src_off: UOp, opsel_lo: int, opsel_hi_bit: int, neg_lo: int, neg_hi_bit: int) -> UOp: + is_vgpr = src_off >= _c(256) + vgpr_lo = ctx.rvgpr_dyn(src_off - _c(256), lane) if lane is not None else _c(0) + vgpr_hi = ctx.rvgpr_dyn(src_off - _c(256) + _c(1), lane) if lane is not None else _c(0) + # For SGPR pairs, opsel selects between s[N] (0) and s[N+1] (1); inline constants always broadcast. + is_sgpr_pair = src_off < _c(128) + sgpr_hi = ctx.rsgpr_dyn(src_off + _c(1), is_sgpr_pair) + scalar_lo_sel = src_lo if not opsel_lo else is_sgpr_pair.where(sgpr_hi, src_lo) + scalar_hi_sel = src_lo if not opsel_hi_bit else is_sgpr_pair.where(sgpr_hi, src_lo) + lo = is_vgpr.where(vgpr_hi if opsel_lo else vgpr_lo, scalar_lo_sel) + hi = is_vgpr.where(vgpr_hi if opsel_hi_bit else vgpr_lo, scalar_hi_sel) + if neg_lo: lo = lo ^ UOp.const(dtypes.uint32, 0x80000000) + if neg_hi_bit: hi = hi ^ UOp.const(dtypes.uint32, 0x80000000) + return _u64(lo, hi) + srcs = {'S0': build_pk_f32(src0, src_offs[0], opsel & 1, opsel_hi & 1, neg & 1, neg_hi & 1), + 'S1': build_pk_f32(src1, src_offs[1], opsel & 2, opsel_hi & 2, neg & 2, neg_hi & 2), + 'S2': build_pk_f32(src2, src_offs[2], opsel & 4, 1 if opsel_hi2 else 0, neg & 4, neg_hi & 4)} + elif 'FMA_MIX' in op_name: combined_opsel_hi = (opsel_hi & 0x3) | ((opsel_hi2 & 0x1) << 2) # For FMA_MIX: neg_hi is ABS (not neg!), neg is actual negation def apply_abs(v, bit, opsel_hi_bit, opsel_bit): @@ -924,13 +959,18 @@ def _compile_vopd(inst: ir3.VOPD | ir4.VOPD, ctx: _Ctx) -> UOp: if dest.startswith('D0'): all_stores.append(ctx.wvgpr_dyn(vdst_reg, lane, _val_to_u32(val), exec_mask, after=srcy1)) return UOp.sink(UOp.group(*all_stores).end(lane), *ctx.inc_pc()) -def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS | ir4.VFLAT | ir4.VGLOBAL | ir4.VSCRATCH, ctx: _Ctx) -> UOp: +def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLAT|ir4.VGLOBAL|ir4.VSCRATCH + |irc.DS|irc.FLAT|irc.GLOBAL|irc.SCRATCH, ctx: _Ctx) -> UOp: """Unified memory operation compiler for DS, FLAT, GLOBAL, SCRATCH.""" exec_mask, op_name = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst) pcode = get_pcode(inst.op) + # CDNA pcode uses CalcGlobalAddr/CalcDsAddr to compute address from raw components, but make_addr already handles this. + # Strip the addr computation line and use pre-computed ADDR directly (rename 'addr' -> 'ADDR' in remaining pcode). + if isinstance(inst, (irc.GLOBAL, irc.FLAT, irc.SCRATCH, irc.DS)) and 'Calc' in pcode and 'Addr' in pcode: + pcode = re.sub(r'addr\s*=\s*Calc\w+Addr\([^)]*\)\s*;?\n?', '', pcode).replace('MEM[addr', 'MEM[ADDR') - is_lds = isinstance(inst, (ir3.DS, ir4.DS)) - is_scratch = isinstance(inst, (ir3.SCRATCH, ir4.VSCRATCH)) + is_lds = isinstance(inst, (ir3.DS, ir4.DS, irc.DS)) + is_scratch = isinstance(inst, (ir3.SCRATCH, ir4.VSCRATCH, irc.SCRATCH)) mem = ctx.lds if is_lds else ctx.scratch if is_scratch else ctx.vmem addr_shift = UOp.const(dtypes.uint32 if is_lds else dtypes.uint64, 2) @@ -1038,7 +1078,7 @@ def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS if 'STORE' in op_name and data_bits_mem >= 64: vdata = vdata | (ctx.rvgpr_dyn(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32)) srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active, - 'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base} + 'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base, 'SADDR': saddr_base, 'OFFSET': offset} for i in range(data_bits_mem // 32): srcs[f'VDATA{i}'] = ctx.rvgpr_dyn(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0) return srcs @@ -1075,7 +1115,7 @@ def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS return UOp.sink(*ended, *ctx.inc_pc()) # Standard path: single lane range - writes_return_data = '_RTN' in op_name or (is_lds and op_name.startswith('DS_LOAD')) or bool(is_atomic and glc) + writes_return_data = '_RTN' in op_name or (is_lds and (op_name.startswith('DS_LOAD') or op_name.startswith('DS_READ'))) or bool(is_atomic and glc) lane = ctx.range() active = _lane_active(exec_mask, lane) pcode_vars, assigns = parse_pcode(pcode, make_srcs(lane)) @@ -1099,6 +1139,11 @@ _INST_HANDLERS: dict[type, Callable[..., UOp]] = { ir4.VOP1: _compile_vop12, ir4.VOP1_SDST: _compile_vop12, ir4.VOP2: _compile_vop12, ir4.VOPC: _compile_vopc, ir4.VOP3: _compile_vop3, ir4.VOP3_SDST: _compile_vop3, ir4.VOP3SD: _compile_vop3sd, ir4.VOP3P: _compile_vop3p, ir4.VOPD: _compile_vopd, ir4.DS: _compile_mem_op, ir4.VFLAT: _compile_mem_op, ir4.VGLOBAL: _compile_mem_op, ir4.VSCRATCH: _compile_mem_op, + # CDNA instruction classes + irc.SOPP: _compile_sopp, irc.SMEM: _compile_smem, irc.SOP1: _compile_sop, irc.SOP2: _compile_sop, irc.SOPC: _compile_sop, irc.SOPK: _compile_sop, + irc.VOP1: _compile_vop12, irc.VOP2: _compile_vop12, irc.VOPC: _compile_vopc, irc.VOP3: _compile_vop3, + irc.VOP3_SDST: _compile_vop3, irc.VOP3SD: _compile_vop3sd, irc.VOP3P: _compile_vop3p, + irc.DS: _compile_mem_op, irc.FLAT: _compile_mem_op, irc.GLOBAL: _compile_mem_op, irc.SCRATCH: _compile_mem_op, } # ═══════════════════════════════════════════════════════════════════════════════ @@ -1116,7 +1161,7 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"): # Check if instruction matches any cached canonical pattern for base, mask, size, runner in _canonical_runner_cache: - if inst_size == size and (inst_int & mask) == base: return runner, False + if inst_size == size and (inst_int & mask) == base: return runner # Look up handler by type, falling back to base classes for _LIT variants handler = _INST_HANDLERS.get(type(inst)) @@ -1136,30 +1181,17 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"): with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES=""): runner = get_runner('CPU', sink) _canonical_runner_cache.append((base, mask, size, runner)) - return runner, True + return runner -@functools.cache -def decode_program(data: bytes, arch: str = "rdna3") -> dict[int, tuple[str, Callable, list[int], Any]]: - """Decode program to {pc: (name, fxn, globals, runner)}.""" - result: dict[int, tuple[str, Callable, list[int], Any]] = {} - i = 0 - while i < len(data): - inst = decode_inst(data[i:], arch) - if hasattr(inst, 'op') and inst.op in (ir3.SOPPOp.S_CODE_END, ir4.SOPPOp.S_CODE_END): break - try: - runner, is_new = _get_runner(bytes(data[i:i + inst.size() + 4]), arch) - if DEBUG >= 3: - try: inst_str = repr(inst) - except Exception: inst_str = f"<{type(inst).__name__} at PC={i}>" - msg = f"[emu] PC={i}: {inst_str}" - print(colored(msg, 'green') if is_new else msg) - result[i] = (runner.p.function_name, runner._prg.fxn, runner.p.globals, runner) - except Exception as e: - try: inst_str = repr(inst) - except Exception: inst_str = f"<{type(inst).__name__}>" - raise RuntimeError(f"[emu] Failed to compile PC={i} {inst_str}: {type(e).__name__}: {e}") from e - i += inst.size() - return result +def _decode_at(pc: int, arch: str): + """Decode and compile instruction at absolute address pc. Returns CompiledRunner.""" + inst_bytes = bytes((ctypes.c_char * 16).from_address(pc).raw) + inst = decode_inst(inst_bytes, arch) + try: return _get_runner(bytes(inst_bytes[:inst.size() + 4]), arch) + except Exception as e: + try: inst_str = repr(inst) + except Exception: inst_str = f"<{type(inst).__name__}>" + raise RuntimeError(f"[emu] Failed to compile {inst_str}: {type(e).__name__}: {e}") from e # ═══════════════════════════════════════════════════════════════════════════════ # WAVE STATE @@ -1206,10 +1238,9 @@ class WaveState: # ═══════════════════════════════════════════════════════════════════════════════ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c, - scratch_size: int = 0, arch: str = "rdna3") -> int: + scratch_size: int = 0, arch: str = "rdna3", user_data: list[int]|None = None) -> int: """Execute AMD assembly program. scratch_size is private_segment_fixed_size from kernel descriptor (per-lane).""" - program_raw = decode_program(bytes((ctypes.c_char * lib_sz).from_address(lib).raw), arch) - program = {lib + offset: val for offset, val in program_raw.items()} # Remap to actual addresses + program: dict[int, tuple[Callable, list[int]]] = {} # lazily populated: pc -> (fxn, globals) extracted from runner lds_size = ((rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE_SHIFT) * 512 total_threads = lx * ly * lz @@ -1226,8 +1257,12 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, for wave_start in range(0, total_threads, WAVE_SIZE): n_lanes, st = min(WAVE_SIZE, total_threads - wave_start), WaveState(min(WAVE_SIZE, total_threads - wave_start)) st.pc = lib # Set PC to code base address - st._write_sgpr(0, args_ptr & MASK32) - st._write_sgpr(1, (args_ptr >> 32) & MASK32) + # Initialize user SGPRs: hardware loads COMPUTE_USER_DATA registers directly into s[0:N] + if user_data: + for i, val in enumerate(user_data): st._write_sgpr(i, val) + else: + st._write_sgpr(0, args_ptr & MASK32) + st._write_sgpr(1, (args_ptr >> 32) & MASK32) # Workgroup IDs in SGPRs after user SGPRs sgpr_idx = (rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT_SHIFT @@ -1255,13 +1290,16 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, ctypes.c_uint64(vmem_buf._buf.va_addr), ctypes.c_uint64(lds_buf._buf.va_addr), ctypes.c_uint64(scratch_buf._buf.va_addr if scratch_buf else 0)] for inst_count in range(1_000_000): - if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF or pc not in program: break - name, fxn, globals_list, _ = program[pc] - assert fxn is not None, f"[emu] No fxn for {name} at PC={pc}" - assert 4 not in globals_list or scratch_buf, f"SCRATCH instruction {name} but scratch_size=0" - if DEBUG >= 6: - inst = decode_inst(bytes((ctypes.c_char * 12).from_address(pc).raw), arch) - print(f"[emu] exec PC={pc:X}: {inst!r}") + if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF: break + if pc not in program: + prev_len = len(_canonical_runner_cache) + runner = _decode_at(pc, arch) + program[pc] = (runner._prg.fxn, runner.p.globals) + if DEBUG >= 3: + inst = decode_inst(bytes((ctypes.c_char * 16).from_address(pc).raw), arch) + msg = f"[emu] PC={pc - lib}: {inst!r}" + print(colored(msg, 'green') if len(_canonical_runner_cache) > prev_len else msg) + fxn, globals_list = program[pc] fxn(*[c_bufs[g] for g in globals_list]) else: raise RuntimeError("exceeded 1M instructions, likely infinite loop") return 0 diff --git a/tinygrad/renderer/amd/pcode.py b/tinygrad/renderer/amd/pcode.py index dba5d082a4..52f58acdce 100644 --- a/tinygrad/renderer/amd/pcode.py +++ b/tinygrad/renderer/amd/pcode.py @@ -40,7 +40,10 @@ def _bitreverse(v: UOp, bits: int) -> UOp: def _extract_bits(val: UOp, hi: int, lo: int) -> UOp: dt = dtypes.uint64 if val.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32 - return ((val >> _const(dt, lo)) if lo > 0 else val) & _const(val.dtype, (1 << (hi - lo + 1)) - 1) + result = ((val >> _const(dt, lo)) if lo > 0 else val) & _const(val.dtype, (1 << (hi - lo + 1)) - 1) + # Downcast to uint32 when extracting <=32 bits from a 64-bit value, so .f32 bitcast works correctly + if dt == dtypes.uint64 and (hi - lo + 1) <= 32: result = result.cast(dtypes.uint32) + return result def _set_bit(old, pos, val): mask = _u32(1) << pos @@ -554,7 +557,9 @@ class Parser: self.eat('LBRACKET') self.eat_val('laneId', 'IDENT') self.eat('RBRACKET') - result = (base >> _to_u32(self.vars['laneId'])) & _u32(1) + lane = self.vars['laneId'] + shift = lane.cast(base.dtype) if base.dtype != dtypes.uint32 else _to_u32(lane) + result = (base >> shift) & _const(base.dtype, 1) if self.try_eat('DOT'): dt_name = self.eat('IDENT').val return result.cast(DTYPES.get(dt_name, dtypes.uint32)) @@ -806,6 +811,12 @@ def _subst_loop_var(line: str, loop_var: str, val: int) -> str: def _set_bits(old: UOp, val: UOp, width: int, offset: int) -> UOp: """Set bits [offset:offset+width) in old to val, masking and shifting appropriately.""" + is64 = old.dtype in (dtypes.uint64, dtypes.int64) or offset + width > 32 + if is64: + old = old.cast(dtypes.uint64) if old.dtype != dtypes.uint64 else old + mask = _u64(((1 << width) - 1) << offset) + v = (val.cast(dtypes.uint64) if val.dtype != dtypes.uint64 else val) & _u64((1 << width) - 1) + return (old & (mask ^ _u64(0xFFFFFFFFFFFFFFFF))) | (v << _u64(offset)) mask = _u32(((1 << width) - 1) << offset) v = (val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val) & _u32((1 << width) - 1) return (old & (mask ^ _u32(0xFFFFFFFF))) | (v << _u32(offset))