From 1202ff5787332ecfa033d0a20e2686d4116a2ec3 Mon Sep 17 00:00:00 2001 From: qazal Date: Tue, 17 Mar 2026 19:18:27 +0000 Subject: [PATCH] clean up wave_size access --- test/amd/test_compare_emulators.py | 3 ++- test/mockgpu/amd/emu.py | 1 - test/mockgpu/amd/pcode.py | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/amd/test_compare_emulators.py b/test/amd/test_compare_emulators.py index 079545d095..e0fedadc3b 100644 --- a/test/amd/test_compare_emulators.py +++ b/test/amd/test_compare_emulators.py @@ -4,7 +4,8 @@ from dataclasses import dataclass from pathlib import Path from tinygrad import Device -from test.mockgpu.amd.emu import WaveState, _decode_at, WAVE_SIZE, VCC_LO, EXEC_LO, SCC +from test.mockgpu.amd.emu import WaveState, _decode_at, VCC_LO, EXEC_LO, SCC +WAVE_SIZE = 32 from tinygrad.renderer.amd import decode_inst import tinygrad REMU_PATH = Path(tinygrad.__file__).parent.parent / "extra/remu/target/release/libremu.so" diff --git a/test/mockgpu/amd/emu.py b/test/mockgpu/amd/emu.py index f68b111c9a..2ba667e68f 100644 --- a/test/mockgpu/amd/emu.py +++ b/test/mockgpu/amd/emu.py @@ -231,7 +231,6 @@ VOPD_TO_VOP2 = { ir4.VOPDOp.V_DUAL_FMAAK_F32: ir3.VOP2Op.V_FMAAK_F32_E32, ir4.VOPDOp.V_DUAL_FMAMK_F32: ir3.VOP2Op.V_FMAMK_F32_E32, } def _wave_size(arch: str) -> int: return 64 if arch.startswith("cdna") else 32 -WAVE_SIZE = 32 # default wave size for RDNA (exported for test_compare_emulators) # Special registers stored after inline constants (256-259) PC_LO_IDX, PC_HI_IDX, SCRATCH_STRIDE_IDX = 256, 257, 259 # SGPR buffer: 0-127 = SGPRs, 128-255 = inline constants, 256-259 = special registers diff --git a/test/mockgpu/amd/pcode.py b/test/mockgpu/amd/pcode.py index 0fccebbfbc..84bda157a0 100644 --- a/test/mockgpu/amd/pcode.py +++ b/test/mockgpu/amd/pcode.py @@ -554,7 +554,7 @@ class Parser: self.eat('RBRACKET') vgpr = self.vars.get('_vgpr') if vgpr is None: return _u32(0) - ws = self.vars.get('_wave_size', 32) + ws = self.vars['_wave_size'] return vgpr.index(_to_u32(reg) * _u32(ws) + _to_u32(lane), ptr=True).load() if self.try_eat('LPAREN'): args = self._parse_args() @@ -567,8 +567,8 @@ class Parser: if name == 'OVERFLOW_F32': return _const(dtypes.uint32, 0x7F7FFFFF).bitcast(dtypes.float32) if name == 'UNDERFLOW_F64': return _const(dtypes.uint64, 1).bitcast(dtypes.float64) if name == 'OVERFLOW_F64': return _const(dtypes.uint64, 0x7FEFFFFFFFFFFFFF).bitcast(dtypes.float64) - if name == 'WAVE32': return _const(dtypes.bool, self.vars.get('_wave_size', 32) <= 32) - if name == 'WAVE64': return _const(dtypes.bool, self.vars.get('_wave_size', 32) > 32) + if name == 'WAVE32': return _const(dtypes.bool, self.vars['_wave_size'] <= 32) + if name == 'WAVE64': return _const(dtypes.bool, self.vars['_wave_size'] > 32) if name == 'WAVE_MODE' and self.try_eat('DOT') and self.try_eat_val('IEEE', 'IDENT'): return _u32(1) if self.try_eat('LBRACE'): idx = self.eat('NUM').val @@ -580,7 +580,7 @@ class Parser: self.eat('RBRACKET') vgpr = self.vars.get('_vgpr') if vgpr is None: return _u32(0) - ws = self.vars.get('_wave_size', 32) + ws = self.vars['_wave_size'] return vgpr.index(_to_u32(reg) * _u32(ws) + _u32(int(idx)), ptr=True).load() elem = self.vars.get(f'{name}@{idx}', self.vars.get(f'{name}{idx}')) if elem is None: @@ -1062,7 +1062,7 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic if j < len(toks) and toks[j].type == 'EQUALS': j += 1 ln = parse_tokens(lane_toks, env, funcs) rg, val = parse_tokens(reg_toks, env, funcs), parse_tokens(toks[j:], env, funcs) - ws = env.get('_wave_size', 32) + ws = env['_wave_size'] vgpr_idx = _to_u32(rg) * _u32(ws) + _to_u32(ln) if assigns is not None: assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}][{hi_val}:{lo_val}]', (vgpr_idx, val, hi_val, lo_val))) @@ -1073,7 +1073,7 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic ln = parse_tokens(lane_toks, env, funcs) rg, val = parse_tokens(reg_toks, env, funcs), parse_tokens(toks[j:], env, funcs) if assigns is not None: - ws = env.get('_wave_size', 32) + ws = env['_wave_size'] assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(ws) + _to_u32(ln), val))) i += 1 continue