diff --git a/test/mockgpu/amd/emu.py b/test/mockgpu/amd/emu.py index 6f2e3825fd..d8cbc20bfb 100644 --- a/test/mockgpu/amd/emu.py +++ b/test/mockgpu/amd/emu.py @@ -375,7 +375,7 @@ def _mem_store(mem: UOp, addr: UOp, val: UOp, active: UOp, addr_bits: int = 32, """Conditional memory store with sub-word support. Returns list of store UOps.""" adt = dtypes.uint64 if addr_bits == 64 else dtypes.uint32 word_addr = addr >> UOp.const(adt, 2) - idx = mem.index(word_addr.cast(dtypes.int), active) + idx = mem.index(word_addr.cast(dtypes.int).valid(active)) if data_bits == 32: return [idx.store(active.where(_to_u32(val), idx))] # Sub-word store: read-modify-write with mask byte_pos = addr.cast(dtypes.uint32) & _c(3) @@ -388,7 +388,7 @@ def _mem_store(mem: UOp, addr: UOp, val: UOp, active: UOp, addr_bits: int = 32, is_cross = byte_pos.eq(_c(3)) cross_word0 = (idx & _c(0x00FFFFFF)) | ((val_u32 & _c(0xFF)) << _c(24)) store0 = idx.store(active.where(is_cross.where(cross_word0, new_word), idx)) - next_idx = mem.index((word_addr + UOp.const(adt, 1)).cast(dtypes.int), active & is_cross) + next_idx = mem.index((word_addr + UOp.const(adt, 1)).cast(dtypes.int).valid(active & is_cross)) cross_word1 = (next_idx & _c(0xFFFFFF00)) | ((val_u32 >> _c(8)) & _c(0xFF)) return [store0, next_idx.store((active & is_cross).where(cross_word1, next_idx))] @@ -398,7 +398,7 @@ def _mem_store_bytes(mem: UOp, addr: UOp, val: UOp, active: UOp, data_bits: int val_u32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val for i in range(data_bits // 8): byte_val = (val_u32 >> UOp.const(dtypes.uint32, i * 8)) & UOp.const(dtypes.uint32, 0xFF) - stores.append(mem.index((addr + UOp.const(dtypes.uint64, i)).cast(dtypes.int), active).store(byte_val.cast(dtypes.uint8))) + stores.append(mem.index((addr + UOp.const(dtypes.uint64, i)).cast(dtypes.int).valid(active)).store(byte_val.cast(dtypes.uint8))) return stores def _collect_data_slices(assigns: list[tuple[str, UOp]], data_prefix: str, pcode_vars: dict | None = None, op_name: str = "") -> dict[int, UOp]: @@ -516,14 +516,14 @@ class _Ctx: # Dynamic register access (takes UOp index instead of int) def rsgpr_dyn(self, reg: UOp, valid: UOp | None = None) -> UOp: """Read SGPR with dynamic register index.""" - if valid is not None: return self.sgpr.index(reg.cast(dtypes.int), valid, ptr=True).load() + if valid is not None: return self.sgpr.index(reg.cast(dtypes.int).valid(valid), ptr=True).load() return self.sgpr.index(reg.cast(dtypes.int), ptr=True).load() def wsgpr_dyn(self, reg: UOp, val: UOp) -> UOp: """Write SGPR with dynamic register index. On RDNA, index 124 = NULL (writes discarded). On CDNA, index 124 = M0 (read/write).""" # RDNA: NULL (124) discards writes. CDNA: M0 (124) is writable. valid = None if self.wave_size == 64 else reg.ne(_c(124)) - return self.sgpr.index(reg.cast(dtypes.int), valid).store(val.cast(dtypes.uint32)) + return self.sgpr.index(reg.cast(dtypes.int).valid(valid) if valid is not None else reg.cast(dtypes.int)).store(val.cast(dtypes.uint32)) def wmask(self, reg: UOp, val: UOp) -> list[UOp]: """Write a lane mask (VCC/EXEC). Splits into lo/hi for wave64.""" @@ -540,24 +540,24 @@ class _Ctx: def rvgpr_dyn(self, reg: UOp, lane: UOp, valid: UOp | None = None) -> UOp: """Read VGPR with dynamic register index.""" idx = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int) - return self.vgpr.index(idx, valid, ptr=True).load() if valid is not None else self.vgpr.index(idx, ptr=True).load() + return self.vgpr.index(idx.valid(valid), ptr=True).load() if valid is not None else self.vgpr.index(idx, ptr=True).load() def wvgpr_dyn(self, reg: UOp, lane: UOp, val: UOp, exec_mask: UOp, after: UOp | None = None) -> UOp: """Write VGPR with dynamic register index.""" buf = self.vgpr.after(after) if after is not None else self.vgpr offset = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int) - return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32)) + return buf.index(offset.valid(_lane_active(exec_mask, lane))).store(val.cast(dtypes.uint32)) def raccvgpr_dyn(self, reg: UOp, lane: UOp, valid: UOp | None = None) -> UOp: """Read ACCVGPR with dynamic register index (CDNA only).""" idx = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int) - return self.accvgpr.index(idx, valid, ptr=True).load() if valid is not None else self.accvgpr.index(idx, ptr=True).load() + return self.accvgpr.index(idx.valid(valid), ptr=True).load() if valid is not None else self.accvgpr.index(idx, ptr=True).load() def waccvgpr_dyn(self, reg: UOp, lane: UOp, val: UOp, exec_mask: UOp, after: UOp | None = None) -> UOp: """Write ACCVGPR with dynamic register index (CDNA only).""" buf = self.accvgpr.after(after) if after is not None else self.accvgpr offset = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int) - return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32)) + return buf.index(offset.valid(_lane_active(exec_mask, lane))).store(val.cast(dtypes.uint32)) def rsrc_dyn(self, off: UOp, lane: UOp | None, bits: int = 32, literal: UOp | None = None, is_f64: bool = False, do_cast: bool = True) -> UOp: """Read source operand with dynamic offset. Handles SGPR/inline constants (<256), VGPR (>=256). @@ -713,7 +713,7 @@ class _Ctx: old = self.vgpr.index(val[0].cast(dtypes.int), ptr=True).load() new_val = _set_bits(old, _val_to_bits(val[1]), width, lo_bit).cast(dtypes.uint32) active = _lane_active(exec_mask, lane) - raw_stores.append(('vgpr_direct', self.vgpr.index(val[0].cast(dtypes.int), active).store(new_val))) + raw_stores.append(('vgpr_direct', self.vgpr.index(val[0].cast(dtypes.int).valid(active)).store(new_val))) continue if 'D0' in dest and '[laneId]' in dest: old_vcc = self.rmask(_c(VCC_LO.offset)) @@ -1847,7 +1847,7 @@ def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLA if data_bits < 32: # Sub-dword LDS write: read-modify-write within the uint32 slot word_addr = (addr >> addr_shift).cast(dtypes.int) - idx = mem.index(word_addr, active) + idx = mem.index(word_addr.valid(active)) byte_pos = addr.cast(dtypes.uint32) & _c(3) byte_shift = byte_pos * _c(8) size_mask = _c(0xFF if data_bits == 8 else 0xFFFF) @@ -2005,17 +2005,18 @@ def _compile_mubuf(inst: irc.MUBUF, ctx: _Ctx) -> UOp: word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2) val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64), ptr=True).load(), _c(0)) lds_idx = ((lds_addr + _c(i * 4)) >> _c(2)).cast(dtypes.int) - stores.append(ctx.lds.index(lds_idx, active).store(active.where(val, ctx.lds.index(lds_idx, active)))) + lds_slot = ctx.lds.index(lds_idx.valid(active)) + stores.append(lds_slot.store(active.where(val, lds_slot))) elif is_store: for i in range(n_dwords): word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2) - idx = mem.index(word_addr.cast(dtypes.int64), in_bounds) + idx = mem.index(word_addr.cast(dtypes.int64).valid(in_bounds)) val = (ctx.raccvgpr_dyn if use_acc else ctx.rvgpr_dyn)(vdata + _c(i), lane) stores.append(idx.store(in_bounds.where(_to_u32(val), idx))) else: for i in range(n_dwords): word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2) - val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64), in_bounds, ptr=True).load(), _c(0)) + val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64).valid(in_bounds), ptr=True).load(), _c(0)) stores.append((ctx.waccvgpr_dyn if use_acc else ctx.wvgpr_dyn)(vdata + _c(i), lane, val, exec_mask)) return UOp.sink(UOp.group(*stores).end(lane), *ctx.inc_pc()) diff --git a/test/mockgpu/amd/pcode.py b/test/mockgpu/amd/pcode.py index c752754c09..3a97f6885f 100644 --- a/test/mockgpu/amd/pcode.py +++ b/test/mockgpu/amd/pcode.py @@ -828,28 +828,28 @@ class Parser: assert mem is not None, "memory load requires _vmem or _lds" adt = dtypes.uint64 if addr.dtype == dtypes.uint64 else dtypes.uint32 active = self.vars.get('_active') - gate = (active,) if active is not None else () + def mindex(idx:UOp, ptr=False): return mem.index(idx.valid(active) if active is not None else idx, ptr=ptr) byte_mem = mem.dtype.base == dtypes.uint8 if byte_mem: idx = addr.cast(dtypes.int) if dt in (dtypes.uint64, dtypes.int64, dtypes.float64): val = _u32(0).cast(dtypes.uint64) - for i in range(8): val = val | (mem.index(idx + _const(dtypes.int, i), *gate, ptr=True).load().cast(dtypes.uint64) << _u64(i * 8)) + for i in range(8): val = val | (mindex(idx + _const(dtypes.int, i), ptr=True).load().cast(dtypes.uint64) << _u64(i * 8)) elif dt in (dtypes.uint8, dtypes.int8): - val = mem.index(idx, *gate, ptr=True).load().cast(dt) + val = mindex(idx, ptr=True).load().cast(dt) elif dt in (dtypes.uint16, dtypes.int16, dtypes.short): - lo = mem.index(idx, *gate, ptr=True).load().cast(dtypes.uint32) - hi = mem.index(idx + _const(dtypes.int, 1), *gate, ptr=True).load().cast(dtypes.uint32) + lo = mindex(idx, ptr=True).load().cast(dtypes.uint32) + hi = mindex(idx + _const(dtypes.int, 1), ptr=True).load().cast(dtypes.uint32) val = (lo | (hi << _u32(8))).cast(dt) else: val = _u32(0) - for i in range(4): val = val | (mem.index(idx + _const(dtypes.int, i), *gate, ptr=True).load().cast(dtypes.uint32) << _u32(i * 8)) + for i in range(4): val = val | (mindex(idx + _const(dtypes.int, i), ptr=True).load().cast(dtypes.uint32) << _u32(i * 8)) else: idx = (addr >> _const(addr.dtype, 2)).cast(dtypes.int) - val = mem.index(idx, *gate) + val = mindex(idx) if dt in (dtypes.uint64, dtypes.int64, dtypes.float64): idx2 = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int) - val = val.cast(dtypes.uint64) | (mem.index(idx2, *gate).cast(dtypes.uint64) << _u64(32)) + val = val.cast(dtypes.uint64) | (mindex(idx2).cast(dtypes.uint64) << _u64(32)) elif dt in (dtypes.uint8, dtypes.int8): val = (val >> ((addr & _const(adt, 3)).cast(dtypes.uint32) * _u32(8))) & _u32(0xFF) elif dt in (dtypes.uint16, dtypes.int16): val = (val >> (((addr >> _const(adt, 1)) & _const(adt, 1)).cast(dtypes.uint32) * _u32(16))) & _u32(0xFFFF) @@ -862,7 +862,7 @@ class Parser: idx_native = (addr >> _const(adt, 2)).cast(dtypes.int64) idx_hi_native = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int64) safe_idx_hi = is_unaligned.where(idx_hi_native, idx_native) - hi = mem.index(safe_idx_hi, *gate) + hi = mindex(safe_idx_hi) combined = val.cast(dtypes.uint64) | (hi.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32)) val = is_unaligned.where((combined >> (byte_off.cast(dtypes.uint64) * UOp.const(dtypes.uint64, 8))).cast(dtypes.uint32), val) return _cast_to(val, dt)