mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
clean up wave_size access
This commit is contained in:
@@ -4,7 +4,8 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from tinygrad import Device
|
||||
|
||||
from test.mockgpu.amd.emu import WaveState, _decode_at, WAVE_SIZE, VCC_LO, EXEC_LO, SCC
|
||||
from test.mockgpu.amd.emu import WaveState, _decode_at, VCC_LO, EXEC_LO, SCC
|
||||
WAVE_SIZE = 32
|
||||
from tinygrad.renderer.amd import decode_inst
|
||||
import tinygrad
|
||||
REMU_PATH = Path(tinygrad.__file__).parent.parent / "extra/remu/target/release/libremu.so"
|
||||
|
||||
@@ -231,7 +231,6 @@ VOPD_TO_VOP2 = {
|
||||
ir4.VOPDOp.V_DUAL_FMAAK_F32: ir3.VOP2Op.V_FMAAK_F32_E32, ir4.VOPDOp.V_DUAL_FMAMK_F32: ir3.VOP2Op.V_FMAMK_F32_E32,
|
||||
}
|
||||
def _wave_size(arch: str) -> int: return 64 if arch.startswith("cdna") else 32
|
||||
WAVE_SIZE = 32 # default wave size for RDNA (exported for test_compare_emulators)
|
||||
# Special registers stored after inline constants (256-259)
|
||||
PC_LO_IDX, PC_HI_IDX, SCRATCH_STRIDE_IDX = 256, 257, 259
|
||||
# SGPR buffer: 0-127 = SGPRs, 128-255 = inline constants, 256-259 = special registers
|
||||
|
||||
@@ -554,7 +554,7 @@ class Parser:
|
||||
self.eat('RBRACKET')
|
||||
vgpr = self.vars.get('_vgpr')
|
||||
if vgpr is None: return _u32(0)
|
||||
ws = self.vars.get('_wave_size', 32)
|
||||
ws = self.vars['_wave_size']
|
||||
return vgpr.index(_to_u32(reg) * _u32(ws) + _to_u32(lane), ptr=True).load()
|
||||
if self.try_eat('LPAREN'):
|
||||
args = self._parse_args()
|
||||
@@ -567,8 +567,8 @@ class Parser:
|
||||
if name == 'OVERFLOW_F32': return _const(dtypes.uint32, 0x7F7FFFFF).bitcast(dtypes.float32)
|
||||
if name == 'UNDERFLOW_F64': return _const(dtypes.uint64, 1).bitcast(dtypes.float64)
|
||||
if name == 'OVERFLOW_F64': return _const(dtypes.uint64, 0x7FEFFFFFFFFFFFFF).bitcast(dtypes.float64)
|
||||
if name == 'WAVE32': return _const(dtypes.bool, self.vars.get('_wave_size', 32) <= 32)
|
||||
if name == 'WAVE64': return _const(dtypes.bool, self.vars.get('_wave_size', 32) > 32)
|
||||
if name == 'WAVE32': return _const(dtypes.bool, self.vars['_wave_size'] <= 32)
|
||||
if name == 'WAVE64': return _const(dtypes.bool, self.vars['_wave_size'] > 32)
|
||||
if name == 'WAVE_MODE' and self.try_eat('DOT') and self.try_eat_val('IEEE', 'IDENT'): return _u32(1)
|
||||
if self.try_eat('LBRACE'):
|
||||
idx = self.eat('NUM').val
|
||||
@@ -580,7 +580,7 @@ class Parser:
|
||||
self.eat('RBRACKET')
|
||||
vgpr = self.vars.get('_vgpr')
|
||||
if vgpr is None: return _u32(0)
|
||||
ws = self.vars.get('_wave_size', 32)
|
||||
ws = self.vars['_wave_size']
|
||||
return vgpr.index(_to_u32(reg) * _u32(ws) + _u32(int(idx)), ptr=True).load()
|
||||
elem = self.vars.get(f'{name}@{idx}', self.vars.get(f'{name}{idx}'))
|
||||
if elem is None:
|
||||
@@ -1062,7 +1062,7 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic
|
||||
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
|
||||
ln = parse_tokens(lane_toks, env, funcs)
|
||||
rg, val = parse_tokens(reg_toks, env, funcs), parse_tokens(toks[j:], env, funcs)
|
||||
ws = env.get('_wave_size', 32)
|
||||
ws = env['_wave_size']
|
||||
vgpr_idx = _to_u32(rg) * _u32(ws) + _to_u32(ln)
|
||||
if assigns is not None:
|
||||
assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}][{hi_val}:{lo_val}]', (vgpr_idx, val, hi_val, lo_val)))
|
||||
@@ -1073,7 +1073,7 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic
|
||||
ln = parse_tokens(lane_toks, env, funcs)
|
||||
rg, val = parse_tokens(reg_toks, env, funcs), parse_tokens(toks[j:], env, funcs)
|
||||
if assigns is not None:
|
||||
ws = env.get('_wave_size', 32)
|
||||
ws = env['_wave_size']
|
||||
assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(ws) + _to_u32(ln), val)))
|
||||
i += 1
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user