From be0028d3ceda95c18249a90ea46cadac6bb32ca1 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Mon, 3 Nov 2025 03:35:55 +0800 Subject: [PATCH] amd: universal set_grbm (#13062) * amd: universal set_grbm * fix --- tinygrad/runtime/ops_amd.py | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index bf9f95a9e5..5ffa85d60a 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -72,14 +72,10 @@ class AMDComputeQueue(HWQueue): if self.dev.xccs > 1: self._q[prev_len-1] |= (len(self._q) - prev_len) - def set_grbm_broadcast(self): - self.wreg(self.gc.regGRBM_GFX_INDEX, **{f'{f}_broadcast_writes': 1 for f in ['se', 'sh' if self.dev.target[0] == 9 else 'sa', 'instance']}) - def set_grbm_inst(self, n): - self.wreg(self.gc.regGRBM_GFX_INDEX, **{f'{f}_broadcast_writes': 1 for f in ['se', 'sh' if self.dev.target[0] == 9 else 'sa']}, instance_index=n) - def set_grbm_se_sh(self, se, sh): - self.wreg(self.gc.regGRBM_GFX_INDEX, se_index=se, **{f'{"sh" if self.dev.target[0] == 9 else "sa"}_index':sh}, instance_broadcast_writes=1) - def set_grbm_se_sh_wgp(self, se, sh, wgp): self.wreg(self.gc.regGRBM_GFX_INDEX, se_index=se, sa_index=sh, instance_index=wgp << 2) - def set_grbm_se(self, se): self.wreg(self.gc.regGRBM_GFX_INDEX, se_index=se, sh_broadcast_writes=1, instance_broadcast_writes=1) + def set_grbm(self, instance=None, se=None, sh=None, wgp=None): + instance_val = (wgp << 2 | (instance or 0)) if wgp is not None else instance + self.wreg(self.gc.regGRBM_GFX_INDEX, **{(f'{key}_broadcast_writes' if val is None else f'{key}_index'): (1 if val is None else val) + for key, val in [('instance', instance_val), ('se', se), ('sh' if self.dev.target[0] == 9 else 'sa', sh)]}) def wait_reg_mem(self, value, mask=0xffffffff, mem=None, reg=None, reg_done=0, op=WAIT_REG_MEM_FUNCTION_GEQ): wrm_info_dw = self.pm4.WAIT_REG_MEM_MEM_SPACE(int(mem is not None)) | self.pm4.WAIT_REG_MEM_OPERATION(int(mem is None and reg_done > 0)) \ @@ -140,7 +136,7 @@ class AMDComputeQueue(HWQueue): ### PMC ### def pmc_reset_counters(self, en=True): - self.set_grbm_broadcast() + self.set_grbm() self.wreg(self.gc.regCP_PERFMON_CNTL if self.dev.target[0] <= 11 else self.gc.regCP_PERFMON_CNTL_1, perfmon_state=0) if en: self.wreg(self.gc.regCP_PERFMON_CNTL if self.dev.target[0] <= 11 else self.gc.regCP_PERFMON_CNTL_1, perfmon_state=1) return self @@ -171,7 +167,7 @@ class AMDComputeQueue(HWQueue): return self.pmc_reset_counters(en=True) def pmc_read(self, buf, sched): - self.set_grbm_broadcast() + self.set_grbm() self.wreg(self.gc.regCP_PERFMON_CNTL if self.dev.target[0] <= 11 else self.gc.regCP_PERFMON_CNTL_1, perfmon_state=1, perfmon_sample_enable=1) for s in sched: @@ -180,9 +176,7 @@ class AMDComputeQueue(HWQueue): for xcc in range(s.xcc): with self.pred_exec(xcc_mask=1 << xcc): for inst, se_idx, sa_idx, wgp_idx in itertools.product(range(s.inst), range(s.se), range(s.sa), range(s.wgp)): - if s.inst > 1: self.set_grbm_inst(inst) - elif self.dev.target[0] == 9: self.set_grbm_se(se_idx) - else: self.set_grbm_se_sh_wgp(se_idx, sa_idx, wgp_idx) + self.set_grbm(**({'instance':inst} if s.inst > 1 else ({'se':se_idx}|({'sh':sa_idx, 'wgp':wgp_idx} if self.dev.target[0] != 9 else {})))) # Copy counter to memory (src_sel = perf, dst_sel = tc_l2) lo, hi = getattr(self.gc, f'{s.regsample}_LO'), getattr(self.gc, f'{s.regsample}_HI', None) @@ -222,7 +216,7 @@ class AMDComputeQueue(HWQueue): def sqtt_start(self, buf0s:list[HCQBuffer], se_mask:int): self.memory_barrier() if self.dev.target[0] == 9: - self.set_grbm_broadcast() + self.set_grbm() self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, simd_en=0xf, cu_sel=0, sq_stall_en=1, spi_stall_en=1, reg_stall_en=1, vm_id_mask=0) for se in range(len(buf0s)): mask = (__SQTT_MISC:=1<<0) | (__SQTT_TIME:=1<<1) | (__SQTT_REG:=1<<2) | (__SQTT_WAVE_START:=1<<3) | (__SQTT_WAVE_END:=1<<6) \ @@ -230,7 +224,7 @@ class AMDComputeQueue(HWQueue): if (se_mask >> se) & 0b1: mask |= (__SQTTINST:=1<<10) | (__SQTT_INST_PC:=1<<11) | (__SQTT_ISSUE:=1<<13) with self.pred_exec(xcc_mask=1<<(se // (ses_per_xcc:=(self.dev.se_cnt // self.dev.xccs)))): - self.set_grbm_se_sh(se % ses_per_xcc, 0) + self.set_grbm(se=se % ses_per_xcc, sh=0) self.wreg(self.gc.regSQ_THREAD_TRACE_TOKEN_MASK, reg_mask=0xf, token_mask=mask) self.wreg(self.gc.regSQ_THREAD_TRACE_TOKEN_MASK2, inst_mask=0xffffffff) self.wreg(self.gc.regSQ_THREAD_TRACE_BASE, addr=lo32(buf0s[se].va_addr >> 12)) @@ -242,7 +236,7 @@ class AMDComputeQueue(HWQueue): self.spi_config(tracing=True) # One buffer for one SE, mesa does it with a single buffer and ac_sqtt_get_data_offset, but this is simpler and should work just as well for se in range(len(buf0s)): - self.set_grbm_se_sh(se, 0) + self.set_grbm(se=se, sh=0) buf0_lo, buf0_hi = data64_le(buf0s[se].va_addr >> 12) if self.dev.target >= (12,0,0): @@ -275,7 +269,7 @@ class AMDComputeQueue(HWQueue): **({} if self.dev.target < (12,0,0) else {'exclude_barrier_wait': 1})) self.sqtt_config(tracing=True) - self.set_grbm_broadcast() + self.set_grbm() if self.dev.target[0] > 9: self.wreg(self.gc.regCOMPUTE_THREAD_TRACE_ENABLE, 1) self.memory_barrier() return self @@ -283,7 +277,7 @@ class AMDComputeQueue(HWQueue): # Magic values from src/amd/common/ac_sqtt.c:ac_sqtt_emit_stop and src/amd/common/ac_sqtt.c:ac_sqtt_emit_wait def sqtt_stop(self, ses:int, wptrs:HCQBuffer): self.memory_barrier() - self.set_grbm_broadcast() + self.set_grbm() # Start shutting everything down if self.dev.target[0] == 9: self.wreg(self.gc.regSQ_THREAD_TRACE_MODE, mask_cs=1, autoflush_en=1, mode=0) @@ -293,7 +287,7 @@ class AMDComputeQueue(HWQueue): # For each SE wait for finish to complete and copy regSQ_THREAD_TRACE_WPTR to know where in the buffer trace data ends for se in range(ses): - self.set_grbm_se_sh(se, 0) + self.set_grbm(se=se, sh=0) status_reg = self.gc.regSQ_THREAD_TRACE_STATUS.addr[0] - (self.pm4.PACKET3_SET_UCONFIG_REG_START if self.dev.target[0] == 9 else 0) if self.dev.target >= (10, 0, 0): @@ -305,7 +299,7 @@ class AMDComputeQueue(HWQueue): # Copy WPTR to memory (src_sel = perf, dst_sel = tc_l2, wr_confirm = True) self.pkt3(self.pm4.PACKET3_COPY_DATA, 1 << 20 | 2 << 8 | 4, self.gc.regSQ_THREAD_TRACE_WPTR.addr[0], 0, *data64_le(wptrs.va_addr+(se*4))) - self.set_grbm_broadcast() + self.set_grbm() if self.dev.target[0] > 9: self.spi_config(tracing=False) self.memory_barrier() return self