clean up wave_size access

This commit is contained in:
qazal
2026-03-17 19:18:27 +00:00
parent f53e4e72f1
commit 1202ff5787
3 changed files with 8 additions and 8 deletions

View File

@@ -4,7 +4,8 @@ from dataclasses import dataclass
from pathlib import Path
from tinygrad import Device
from test.mockgpu.amd.emu import WaveState, _decode_at, WAVE_SIZE, VCC_LO, EXEC_LO, SCC
from test.mockgpu.amd.emu import WaveState, _decode_at, VCC_LO, EXEC_LO, SCC
WAVE_SIZE = 32
from tinygrad.renderer.amd import decode_inst
import tinygrad
REMU_PATH = Path(tinygrad.__file__).parent.parent / "extra/remu/target/release/libremu.so"

View File

@@ -231,7 +231,6 @@ VOPD_TO_VOP2 = {
ir4.VOPDOp.V_DUAL_FMAAK_F32: ir3.VOP2Op.V_FMAAK_F32_E32, ir4.VOPDOp.V_DUAL_FMAMK_F32: ir3.VOP2Op.V_FMAMK_F32_E32,
}
def _wave_size(arch: str) -> int: return 64 if arch.startswith("cdna") else 32
WAVE_SIZE = 32 # default wave size for RDNA (exported for test_compare_emulators)
# Special registers stored after inline constants (256-259)
PC_LO_IDX, PC_HI_IDX, SCRATCH_STRIDE_IDX = 256, 257, 259
# SGPR buffer: 0-127 = SGPRs, 128-255 = inline constants, 256-259 = special registers

View File

@@ -554,7 +554,7 @@ class Parser:
self.eat('RBRACKET')
vgpr = self.vars.get('_vgpr')
if vgpr is None: return _u32(0)
ws = self.vars.get('_wave_size', 32)
ws = self.vars['_wave_size']
return vgpr.index(_to_u32(reg) * _u32(ws) + _to_u32(lane), ptr=True).load()
if self.try_eat('LPAREN'):
args = self._parse_args()
@@ -567,8 +567,8 @@ class Parser:
if name == 'OVERFLOW_F32': return _const(dtypes.uint32, 0x7F7FFFFF).bitcast(dtypes.float32)
if name == 'UNDERFLOW_F64': return _const(dtypes.uint64, 1).bitcast(dtypes.float64)
if name == 'OVERFLOW_F64': return _const(dtypes.uint64, 0x7FEFFFFFFFFFFFFF).bitcast(dtypes.float64)
if name == 'WAVE32': return _const(dtypes.bool, self.vars.get('_wave_size', 32) <= 32)
if name == 'WAVE64': return _const(dtypes.bool, self.vars.get('_wave_size', 32) > 32)
if name == 'WAVE32': return _const(dtypes.bool, self.vars['_wave_size'] <= 32)
if name == 'WAVE64': return _const(dtypes.bool, self.vars['_wave_size'] > 32)
if name == 'WAVE_MODE' and self.try_eat('DOT') and self.try_eat_val('IEEE', 'IDENT'): return _u32(1)
if self.try_eat('LBRACE'):
idx = self.eat('NUM').val
@@ -580,7 +580,7 @@ class Parser:
self.eat('RBRACKET')
vgpr = self.vars.get('_vgpr')
if vgpr is None: return _u32(0)
ws = self.vars.get('_wave_size', 32)
ws = self.vars['_wave_size']
return vgpr.index(_to_u32(reg) * _u32(ws) + _u32(int(idx)), ptr=True).load()
elem = self.vars.get(f'{name}@{idx}', self.vars.get(f'{name}{idx}'))
if elem is None:
@@ -1062,7 +1062,7 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
ln = parse_tokens(lane_toks, env, funcs)
rg, val = parse_tokens(reg_toks, env, funcs), parse_tokens(toks[j:], env, funcs)
ws = env.get('_wave_size', 32)
ws = env['_wave_size']
vgpr_idx = _to_u32(rg) * _u32(ws) + _to_u32(ln)
if assigns is not None:
assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}][{hi_val}:{lo_val}]', (vgpr_idx, val, hi_val, lo_val)))
@@ -1073,7 +1073,7 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic
ln = parse_tokens(lane_toks, env, funcs)
rg, val = parse_tokens(reg_toks, env, funcs), parse_tokens(toks[j:], env, funcs)
if assigns is not None:
ws = env.get('_wave_size', 32)
ws = env['_wave_size']
assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(ws) + _to_u32(ln), val)))
i += 1
continue