mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
@@ -72,14 +72,10 @@ class AMDComputeQueue(HWQueue):
|
||||
if self.dev.xccs > 1:
|
||||
self._q[prev_len-1] |= (len(self._q) - prev_len)
|
||||
|
||||
def set_grbm_broadcast(self):
|
||||
self.wreg(self.gc.regGRBM_GFX_INDEX, **{f'{f}_broadcast_writes': 1 for f in ['se', 'sh' if self.dev.target[0] == 9 else 'sa', 'instance']})
|
||||
def set_grbm_inst(self, n):
|
||||
self.wreg(self.gc.regGRBM_GFX_INDEX, **{f'{f}_broadcast_writes': 1 for f in ['se', 'sh' if self.dev.target[0] == 9 else 'sa']}, instance_index=n)
|
||||
def set_grbm_se_sh(self, se, sh):
|
||||
self.wreg(self.gc.regGRBM_GFX_INDEX, se_index=se, **{f'{"sh" if self.dev.target[0] == 9 else "sa"}_index':sh}, instance_broadcast_writes=1)
|
||||
def set_grbm_se_sh_wgp(self, se, sh, wgp): self.wreg(self.gc.regGRBM_GFX_INDEX, se_index=se, sa_index=sh, instance_index=wgp << 2)
|
||||
def set_grbm_se(self, se): self.wreg(self.gc.regGRBM_GFX_INDEX, se_index=se, sh_broadcast_writes=1, instance_broadcast_writes=1)
|
||||
def set_grbm(self, instance=None, se=None, sh=None, wgp=None):
|
||||
instance_val = (wgp << 2 | (instance or 0)) if wgp is not None else instance
|
||||
self.wreg(self.gc.regGRBM_GFX_INDEX, **{(f'{key}_broadcast_writes' if val is None else f'{key}_index'): (1 if val is None else val)
|
||||
for key, val in [('instance', instance_val), ('se', se), ('sh' if self.dev.target[0] == 9 else 'sa', sh)]})
|
||||
|
||||
def wait_reg_mem(self, value, mask=0xffffffff, mem=None, reg=None, reg_done=0, op=WAIT_REG_MEM_FUNCTION_GEQ):
|
||||
wrm_info_dw = self.pm4.WAIT_REG_MEM_MEM_SPACE(int(mem is not None)) | self.pm4.WAIT_REG_MEM_OPERATION(int(mem is None and reg_done > 0)) \
|
||||
@@ -140,7 +136,7 @@ class AMDComputeQueue(HWQueue):
|
||||
### PMC ###
|
||||
|
||||
def pmc_reset_counters(self, en=True):
|
||||
self.set_grbm_broadcast()
|
||||
self.set_grbm()
|
||||
self.wreg(self.gc.regCP_PERFMON_CNTL if self.dev.target[0] <= 11 else self.gc.regCP_PERFMON_CNTL_1, perfmon_state=0)
|
||||
if en: self.wreg(self.gc.regCP_PERFMON_CNTL if self.dev.target[0] <= 11 else self.gc.regCP_PERFMON_CNTL_1, perfmon_state=1)
|
||||
return self
|
||||
@@ -171,7 +167,7 @@ class AMDComputeQueue(HWQueue):
|
||||
return self.pmc_reset_counters(en=True)
|
||||
|
||||
def pmc_read(self, buf, sched):
|
||||
self.set_grbm_broadcast()
|
||||
self.set_grbm()
|
||||
self.wreg(self.gc.regCP_PERFMON_CNTL if self.dev.target[0] <= 11 else self.gc.regCP_PERFMON_CNTL_1, perfmon_state=1, perfmon_sample_enable=1)
|
||||
|
||||
for s in sched:
|
||||
@@ -180,9 +176,7 @@ class AMDComputeQueue(HWQueue):
|
||||
for xcc in range(s.xcc):
|
||||
with self.pred_exec(xcc_mask=1 << xcc):
|
||||
for inst, se_idx, sa_idx, wgp_idx in itertools.product(range(s.inst), range(s.se), range(s.sa), range(s.wgp)):
|
||||
if s.inst > 1: self.set_grbm_inst(inst)
|
||||
elif self.dev.target[0] == 9: self.set_grbm_se(se_idx)
|
||||
else: self.set_grbm_se_sh_wgp(se_idx, sa_idx, wgp_idx)
|
||||
self.set_grbm(**({'instance':inst} if s.inst > 1 else ({'se':se_idx}|({'sh':sa_idx, 'wgp':wgp_idx} if self.dev.target[0] != 9 else {}))))
|
||||
|
||||
# Copy counter to memory (src_sel = perf, dst_sel = tc_l2)
|
||||
lo, hi = getattr(self.gc, f'{s.regsample}_LO'), getattr(self.gc, f'{s.regsample}_HI', None)
|
||||
@@ -222,7 +216,7 @@ class AMDComputeQueue(HWQueue):
|
||||
def sqtt_start(self, buf0s:list[HCQBuffer], se_mask:int):
|
||||
self.memory_barrier()
|
||||
if self.dev.target[0] == 9:
|
||||
self.set_grbm_broadcast()
|
||||
self.set_grbm()
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, simd_en=0xf, cu_sel=0, sq_stall_en=1, spi_stall_en=1, reg_stall_en=1, vm_id_mask=0)
|
||||
for se in range(len(buf0s)):
|
||||
mask = (__SQTT_MISC:=1<<0) | (__SQTT_TIME:=1<<1) | (__SQTT_REG:=1<<2) | (__SQTT_WAVE_START:=1<<3) | (__SQTT_WAVE_END:=1<<6) \
|
||||
@@ -230,7 +224,7 @@ class AMDComputeQueue(HWQueue):
|
||||
if (se_mask >> se) & 0b1: mask |= (__SQTTINST:=1<<10) | (__SQTT_INST_PC:=1<<11) | (__SQTT_ISSUE:=1<<13)
|
||||
|
||||
with self.pred_exec(xcc_mask=1<<(se // (ses_per_xcc:=(self.dev.se_cnt // self.dev.xccs)))):
|
||||
self.set_grbm_se_sh(se % ses_per_xcc, 0)
|
||||
self.set_grbm(se=se % ses_per_xcc, sh=0)
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_TOKEN_MASK, reg_mask=0xf, token_mask=mask)
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_TOKEN_MASK2, inst_mask=0xffffffff)
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_BASE, addr=lo32(buf0s[se].va_addr >> 12))
|
||||
@@ -242,7 +236,7 @@ class AMDComputeQueue(HWQueue):
|
||||
self.spi_config(tracing=True)
|
||||
# One buffer for one SE, mesa does it with a single buffer and ac_sqtt_get_data_offset, but this is simpler and should work just as well
|
||||
for se in range(len(buf0s)):
|
||||
self.set_grbm_se_sh(se, 0)
|
||||
self.set_grbm(se=se, sh=0)
|
||||
|
||||
buf0_lo, buf0_hi = data64_le(buf0s[se].va_addr >> 12)
|
||||
if self.dev.target >= (12,0,0):
|
||||
@@ -275,7 +269,7 @@ class AMDComputeQueue(HWQueue):
|
||||
**({} if self.dev.target < (12,0,0) else {'exclude_barrier_wait': 1}))
|
||||
self.sqtt_config(tracing=True)
|
||||
|
||||
self.set_grbm_broadcast()
|
||||
self.set_grbm()
|
||||
if self.dev.target[0] > 9: self.wreg(self.gc.regCOMPUTE_THREAD_TRACE_ENABLE, 1)
|
||||
self.memory_barrier()
|
||||
return self
|
||||
@@ -283,7 +277,7 @@ class AMDComputeQueue(HWQueue):
|
||||
# Magic values from src/amd/common/ac_sqtt.c:ac_sqtt_emit_stop and src/amd/common/ac_sqtt.c:ac_sqtt_emit_wait
|
||||
def sqtt_stop(self, ses:int, wptrs:HCQBuffer):
|
||||
self.memory_barrier()
|
||||
self.set_grbm_broadcast()
|
||||
self.set_grbm()
|
||||
|
||||
# Start shutting everything down
|
||||
if self.dev.target[0] == 9: self.wreg(self.gc.regSQ_THREAD_TRACE_MODE, mask_cs=1, autoflush_en=1, mode=0)
|
||||
@@ -293,7 +287,7 @@ class AMDComputeQueue(HWQueue):
|
||||
|
||||
# For each SE wait for finish to complete and copy regSQ_THREAD_TRACE_WPTR to know where in the buffer trace data ends
|
||||
for se in range(ses):
|
||||
self.set_grbm_se_sh(se, 0)
|
||||
self.set_grbm(se=se, sh=0)
|
||||
|
||||
status_reg = self.gc.regSQ_THREAD_TRACE_STATUS.addr[0] - (self.pm4.PACKET3_SET_UCONFIG_REG_START if self.dev.target[0] == 9 else 0)
|
||||
if self.dev.target >= (10, 0, 0):
|
||||
@@ -305,7 +299,7 @@ class AMDComputeQueue(HWQueue):
|
||||
# Copy WPTR to memory (src_sel = perf, dst_sel = tc_l2, wr_confirm = True)
|
||||
self.pkt3(self.pm4.PACKET3_COPY_DATA, 1 << 20 | 2 << 8 | 4, self.gc.regSQ_THREAD_TRACE_WPTR.addr[0], 0, *data64_le(wptrs.va_addr+(se*4)))
|
||||
|
||||
self.set_grbm_broadcast()
|
||||
self.set_grbm()
|
||||
if self.dev.target[0] > 9: self.spi_config(tracing=False)
|
||||
self.memory_barrier()
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user