mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
fix amd
This commit is contained in:
@@ -375,7 +375,7 @@ def _mem_store(mem: UOp, addr: UOp, val: UOp, active: UOp, addr_bits: int = 32,
|
||||
"""Conditional memory store with sub-word support. Returns list of store UOps."""
|
||||
adt = dtypes.uint64 if addr_bits == 64 else dtypes.uint32
|
||||
word_addr = addr >> UOp.const(adt, 2)
|
||||
idx = mem.index(word_addr.cast(dtypes.int), active)
|
||||
idx = mem.index(word_addr.cast(dtypes.int).valid(active))
|
||||
if data_bits == 32: return [idx.store(active.where(_to_u32(val), idx))]
|
||||
# Sub-word store: read-modify-write with mask
|
||||
byte_pos = addr.cast(dtypes.uint32) & _c(3)
|
||||
@@ -388,7 +388,7 @@ def _mem_store(mem: UOp, addr: UOp, val: UOp, active: UOp, addr_bits: int = 32,
|
||||
is_cross = byte_pos.eq(_c(3))
|
||||
cross_word0 = (idx & _c(0x00FFFFFF)) | ((val_u32 & _c(0xFF)) << _c(24))
|
||||
store0 = idx.store(active.where(is_cross.where(cross_word0, new_word), idx))
|
||||
next_idx = mem.index((word_addr + UOp.const(adt, 1)).cast(dtypes.int), active & is_cross)
|
||||
next_idx = mem.index((word_addr + UOp.const(adt, 1)).cast(dtypes.int).valid(active & is_cross))
|
||||
cross_word1 = (next_idx & _c(0xFFFFFF00)) | ((val_u32 >> _c(8)) & _c(0xFF))
|
||||
return [store0, next_idx.store((active & is_cross).where(cross_word1, next_idx))]
|
||||
|
||||
@@ -398,7 +398,7 @@ def _mem_store_bytes(mem: UOp, addr: UOp, val: UOp, active: UOp, data_bits: int
|
||||
val_u32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val
|
||||
for i in range(data_bits // 8):
|
||||
byte_val = (val_u32 >> UOp.const(dtypes.uint32, i * 8)) & UOp.const(dtypes.uint32, 0xFF)
|
||||
stores.append(mem.index((addr + UOp.const(dtypes.uint64, i)).cast(dtypes.int), active).store(byte_val.cast(dtypes.uint8)))
|
||||
stores.append(mem.index((addr + UOp.const(dtypes.uint64, i)).cast(dtypes.int).valid(active)).store(byte_val.cast(dtypes.uint8)))
|
||||
return stores
|
||||
|
||||
def _collect_data_slices(assigns: list[tuple[str, UOp]], data_prefix: str, pcode_vars: dict | None = None, op_name: str = "") -> dict[int, UOp]:
|
||||
@@ -516,14 +516,14 @@ class _Ctx:
|
||||
# Dynamic register access (takes UOp index instead of int)
|
||||
def rsgpr_dyn(self, reg: UOp, valid: UOp | None = None) -> UOp:
|
||||
"""Read SGPR with dynamic register index."""
|
||||
if valid is not None: return self.sgpr.index(reg.cast(dtypes.int), valid, ptr=True).load()
|
||||
if valid is not None: return self.sgpr.index(reg.cast(dtypes.int).valid(valid), ptr=True).load()
|
||||
return self.sgpr.index(reg.cast(dtypes.int), ptr=True).load()
|
||||
|
||||
def wsgpr_dyn(self, reg: UOp, val: UOp) -> UOp:
|
||||
"""Write SGPR with dynamic register index. On RDNA, index 124 = NULL (writes discarded). On CDNA, index 124 = M0 (read/write)."""
|
||||
# RDNA: NULL (124) discards writes. CDNA: M0 (124) is writable.
|
||||
valid = None if self.wave_size == 64 else reg.ne(_c(124))
|
||||
return self.sgpr.index(reg.cast(dtypes.int), valid).store(val.cast(dtypes.uint32))
|
||||
return self.sgpr.index(reg.cast(dtypes.int).valid(valid) if valid is not None else reg.cast(dtypes.int)).store(val.cast(dtypes.uint32))
|
||||
|
||||
def wmask(self, reg: UOp, val: UOp) -> list[UOp]:
|
||||
"""Write a lane mask (VCC/EXEC). Splits into lo/hi for wave64."""
|
||||
@@ -540,24 +540,24 @@ class _Ctx:
|
||||
def rvgpr_dyn(self, reg: UOp, lane: UOp, valid: UOp | None = None) -> UOp:
|
||||
"""Read VGPR with dynamic register index."""
|
||||
idx = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
|
||||
return self.vgpr.index(idx, valid, ptr=True).load() if valid is not None else self.vgpr.index(idx, ptr=True).load()
|
||||
return self.vgpr.index(idx.valid(valid), ptr=True).load() if valid is not None else self.vgpr.index(idx, ptr=True).load()
|
||||
|
||||
def wvgpr_dyn(self, reg: UOp, lane: UOp, val: UOp, exec_mask: UOp, after: UOp | None = None) -> UOp:
|
||||
"""Write VGPR with dynamic register index."""
|
||||
buf = self.vgpr.after(after) if after is not None else self.vgpr
|
||||
offset = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
|
||||
return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32))
|
||||
return buf.index(offset.valid(_lane_active(exec_mask, lane))).store(val.cast(dtypes.uint32))
|
||||
|
||||
def raccvgpr_dyn(self, reg: UOp, lane: UOp, valid: UOp | None = None) -> UOp:
|
||||
"""Read ACCVGPR with dynamic register index (CDNA only)."""
|
||||
idx = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
|
||||
return self.accvgpr.index(idx, valid, ptr=True).load() if valid is not None else self.accvgpr.index(idx, ptr=True).load()
|
||||
return self.accvgpr.index(idx.valid(valid), ptr=True).load() if valid is not None else self.accvgpr.index(idx, ptr=True).load()
|
||||
|
||||
def waccvgpr_dyn(self, reg: UOp, lane: UOp, val: UOp, exec_mask: UOp, after: UOp | None = None) -> UOp:
|
||||
"""Write ACCVGPR with dynamic register index (CDNA only)."""
|
||||
buf = self.accvgpr.after(after) if after is not None else self.accvgpr
|
||||
offset = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
|
||||
return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32))
|
||||
return buf.index(offset.valid(_lane_active(exec_mask, lane))).store(val.cast(dtypes.uint32))
|
||||
|
||||
def rsrc_dyn(self, off: UOp, lane: UOp | None, bits: int = 32, literal: UOp | None = None, is_f64: bool = False, do_cast: bool = True) -> UOp:
|
||||
"""Read source operand with dynamic offset. Handles SGPR/inline constants (<256), VGPR (>=256).
|
||||
@@ -713,7 +713,7 @@ class _Ctx:
|
||||
old = self.vgpr.index(val[0].cast(dtypes.int), ptr=True).load()
|
||||
new_val = _set_bits(old, _val_to_bits(val[1]), width, lo_bit).cast(dtypes.uint32)
|
||||
active = _lane_active(exec_mask, lane)
|
||||
raw_stores.append(('vgpr_direct', self.vgpr.index(val[0].cast(dtypes.int), active).store(new_val)))
|
||||
raw_stores.append(('vgpr_direct', self.vgpr.index(val[0].cast(dtypes.int).valid(active)).store(new_val)))
|
||||
continue
|
||||
if 'D0' in dest and '[laneId]' in dest:
|
||||
old_vcc = self.rmask(_c(VCC_LO.offset))
|
||||
@@ -1847,7 +1847,7 @@ def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLA
|
||||
if data_bits < 32:
|
||||
# Sub-dword LDS write: read-modify-write within the uint32 slot
|
||||
word_addr = (addr >> addr_shift).cast(dtypes.int)
|
||||
idx = mem.index(word_addr, active)
|
||||
idx = mem.index(word_addr.valid(active))
|
||||
byte_pos = addr.cast(dtypes.uint32) & _c(3)
|
||||
byte_shift = byte_pos * _c(8)
|
||||
size_mask = _c(0xFF if data_bits == 8 else 0xFFFF)
|
||||
@@ -2005,17 +2005,18 @@ def _compile_mubuf(inst: irc.MUBUF, ctx: _Ctx) -> UOp:
|
||||
word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2)
|
||||
val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64), ptr=True).load(), _c(0))
|
||||
lds_idx = ((lds_addr + _c(i * 4)) >> _c(2)).cast(dtypes.int)
|
||||
stores.append(ctx.lds.index(lds_idx, active).store(active.where(val, ctx.lds.index(lds_idx, active))))
|
||||
lds_slot = ctx.lds.index(lds_idx.valid(active))
|
||||
stores.append(lds_slot.store(active.where(val, lds_slot)))
|
||||
elif is_store:
|
||||
for i in range(n_dwords):
|
||||
word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2)
|
||||
idx = mem.index(word_addr.cast(dtypes.int64), in_bounds)
|
||||
idx = mem.index(word_addr.cast(dtypes.int64).valid(in_bounds))
|
||||
val = (ctx.raccvgpr_dyn if use_acc else ctx.rvgpr_dyn)(vdata + _c(i), lane)
|
||||
stores.append(idx.store(in_bounds.where(_to_u32(val), idx)))
|
||||
else:
|
||||
for i in range(n_dwords):
|
||||
word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2)
|
||||
val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64), in_bounds, ptr=True).load(), _c(0))
|
||||
val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64).valid(in_bounds), ptr=True).load(), _c(0))
|
||||
stores.append((ctx.waccvgpr_dyn if use_acc else ctx.wvgpr_dyn)(vdata + _c(i), lane, val, exec_mask))
|
||||
return UOp.sink(UOp.group(*stores).end(lane), *ctx.inc_pc())
|
||||
|
||||
|
||||
@@ -828,28 +828,28 @@ class Parser:
|
||||
assert mem is not None, "memory load requires _vmem or _lds"
|
||||
adt = dtypes.uint64 if addr.dtype == dtypes.uint64 else dtypes.uint32
|
||||
active = self.vars.get('_active')
|
||||
gate = (active,) if active is not None else ()
|
||||
def mindex(idx:UOp, ptr=False): return mem.index(idx.valid(active) if active is not None else idx, ptr=ptr)
|
||||
byte_mem = mem.dtype.base == dtypes.uint8
|
||||
if byte_mem:
|
||||
idx = addr.cast(dtypes.int)
|
||||
if dt in (dtypes.uint64, dtypes.int64, dtypes.float64):
|
||||
val = _u32(0).cast(dtypes.uint64)
|
||||
for i in range(8): val = val | (mem.index(idx + _const(dtypes.int, i), *gate, ptr=True).load().cast(dtypes.uint64) << _u64(i * 8))
|
||||
for i in range(8): val = val | (mindex(idx + _const(dtypes.int, i), ptr=True).load().cast(dtypes.uint64) << _u64(i * 8))
|
||||
elif dt in (dtypes.uint8, dtypes.int8):
|
||||
val = mem.index(idx, *gate, ptr=True).load().cast(dt)
|
||||
val = mindex(idx, ptr=True).load().cast(dt)
|
||||
elif dt in (dtypes.uint16, dtypes.int16, dtypes.short):
|
||||
lo = mem.index(idx, *gate, ptr=True).load().cast(dtypes.uint32)
|
||||
hi = mem.index(idx + _const(dtypes.int, 1), *gate, ptr=True).load().cast(dtypes.uint32)
|
||||
lo = mindex(idx, ptr=True).load().cast(dtypes.uint32)
|
||||
hi = mindex(idx + _const(dtypes.int, 1), ptr=True).load().cast(dtypes.uint32)
|
||||
val = (lo | (hi << _u32(8))).cast(dt)
|
||||
else:
|
||||
val = _u32(0)
|
||||
for i in range(4): val = val | (mem.index(idx + _const(dtypes.int, i), *gate, ptr=True).load().cast(dtypes.uint32) << _u32(i * 8))
|
||||
for i in range(4): val = val | (mindex(idx + _const(dtypes.int, i), ptr=True).load().cast(dtypes.uint32) << _u32(i * 8))
|
||||
else:
|
||||
idx = (addr >> _const(addr.dtype, 2)).cast(dtypes.int)
|
||||
val = mem.index(idx, *gate)
|
||||
val = mindex(idx)
|
||||
if dt in (dtypes.uint64, dtypes.int64, dtypes.float64):
|
||||
idx2 = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int)
|
||||
val = val.cast(dtypes.uint64) | (mem.index(idx2, *gate).cast(dtypes.uint64) << _u64(32))
|
||||
val = val.cast(dtypes.uint64) | (mindex(idx2).cast(dtypes.uint64) << _u64(32))
|
||||
elif dt in (dtypes.uint8, dtypes.int8): val = (val >> ((addr & _const(adt, 3)).cast(dtypes.uint32) * _u32(8))) & _u32(0xFF)
|
||||
elif dt in (dtypes.uint16, dtypes.int16):
|
||||
val = (val >> (((addr >> _const(adt, 1)) & _const(adt, 1)).cast(dtypes.uint32) * _u32(16))) & _u32(0xFFFF)
|
||||
@@ -862,7 +862,7 @@ class Parser:
|
||||
idx_native = (addr >> _const(adt, 2)).cast(dtypes.int64)
|
||||
idx_hi_native = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int64)
|
||||
safe_idx_hi = is_unaligned.where(idx_hi_native, idx_native)
|
||||
hi = mem.index(safe_idx_hi, *gate)
|
||||
hi = mindex(safe_idx_hi)
|
||||
combined = val.cast(dtypes.uint64) | (hi.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
|
||||
val = is_unaligned.where((combined >> (byte_off.cast(dtypes.uint64) * UOp.const(dtypes.uint64, 8))).cast(dtypes.uint32), val)
|
||||
return _cast_to(val, dt)
|
||||
|
||||
Reference in New Issue
Block a user