amd: refactor sqtt into sep functions (#9816)

* amd: refactor sqtt into sep functions

* fix
This commit is contained in:
nimlgen
2025-04-10 00:39:45 +03:00
committed by GitHub
parent 0ca98b9f20
commit d7330ea6ad

View File

@@ -59,11 +59,6 @@ class AMDComputeQueue(HWQueue):
if self.dev.xccs > 1:
self._q[prev_len-1] |= (len(self._q) - prev_len)
def sqtt_userdata(self, data, *extra_dwords):
data_ints = [x[0] for x in struct.iter_unpack('<I', bytes(data))] + list(extra_dwords)
for i in range(0, len(data_ints), 2):
self.wreg(self.gc.regSQ_THREAD_TRACE_USERDATA_2, *data_ints[i:i+2])
def wait_reg_mem(self, value, mask=0xffffffff, mem=None, reg_req=None, reg_done=None):
wrm_info_dw = self.pm4.WAIT_REG_MEM_MEM_SPACE(int(mem is not None)) | self.pm4.WAIT_REG_MEM_OPERATION(int(mem is None)) \
| self.pm4.WAIT_REG_MEM_FUNCTION(WAIT_REG_MEM_FUNCTION_GEQ) | self.pm4.WAIT_REG_MEM_ENGINE(0)
@@ -138,12 +133,19 @@ class AMDComputeQueue(HWQueue):
self.wreg(self.gc.regSPI_CONFIG_CNTL, ps_pkr_priority_cntl=3, exp_priority_order=3, gpr_write_priority=0x2c688,
enable_sqg_bop_events=int(tracing), enable_sqg_top_events=int(tracing))
### SQTT ###
def sqtt_userdata(self, data, *extra_dwords):
data_ints = [x[0] for x in struct.iter_unpack('<I', bytes(data))] + list(extra_dwords)
for i in range(0, len(data_ints), 2):
self.wreg(self.gc.regSQ_THREAD_TRACE_USERDATA_2, *data_ints[i:i+2])
def sqtt_config(self, tracing:bool):
self.wreg(self.gc.regSQ_THREAD_TRACE_CTRL, draw_event_en=1, spi_stall_en=1, sq_stall_en=1, reg_at_hwm=2, hiwater=1,
rt_freq=self.soc.SQ_TT_RT_FREQ_4096_CLK, util_timer=self.soc.SQ_TT_UTIL_TIMER_250_CLK, mode=int(tracing))
# Magic values from mesa/src/amd/vulkan/radv_sqtt.c:radv_emit_spi_config_cntl and src/amd/common/ac_sqtt.c:ac_sqtt_emit_start
def start_trace(self, buf0s:list[HCQBuffer], se_mask:int):
def sqtt_start(self, buf0s:list[HCQBuffer], se_mask:int):
self.memory_barrier()
self.spi_config(tracing=True)
# One buffer for one SE, mesa does it with a single buffer and ac_sqtt_get_data_offset, but this is simpler and should work just as well
@@ -176,7 +178,7 @@ class AMDComputeQueue(HWQueue):
return self
# Magic values from src/amd/common/ac_sqtt.c:ac_sqtt_emit_stop and src/amd/common/ac_sqtt.c:ac_sqtt_emit_wait
def stop_trace(self, ses: int, wptrs: HCQBuffer):
def sqtt_stop(self, ses: int, wptrs: HCQBuffer):
self.memory_barrier()
# Start shutting everything down
self.wreg(self.gc.regCOMPUTE_THREAD_TRACE_ENABLE, 0)
@@ -203,18 +205,33 @@ class AMDComputeQueue(HWQueue):
self.memory_barrier()
return self
def sqtt_prg_marker(self, prg:AMDProgram, global_size:tuple[sint, ...]):
BIND_POINT_COMPUTE = 1
self.sqtt_userdata(sqtt.struct_rgp_sqtt_marker_pipeline_bind(
_0=sqtt.union_rgp_sqtt_marker_pipeline_bind_0(_0=sqtt.struct_rgp_sqtt_marker_pipeline_bind_0_0(
identifier=sqtt.RGP_SQTT_MARKER_IDENTIFIER_BIND_PIPELINE, bind_point=BIND_POINT_COMPUTE)),
_1=sqtt.union_rgp_sqtt_marker_pipeline_bind_1(api_pso_hash=data64_le(prg.libhash[0]))))
self.sqtt_userdata(sqtt.struct_rgp_sqtt_marker_event(
_0=sqtt.union_rgp_sqtt_marker_event_0(_0=sqtt.struct_rgp_sqtt_marker_event_0_0(has_thread_dims=1)),
_2=sqtt.union_rgp_sqtt_marker_event_2(cmd_id=prg.dev.cmd_id)), *global_size)
prg.dev.cmd_id += 1
def exec(self, prg:AMDProgram, args_state:CLikeArgsState, global_size:tuple[sint, ...], local_size:tuple[sint, ...]):
self.bind_args_state(args_state)
self.acquire_mem(gli=0, gl2=0)
user_regs = []
if prg.enable_private_segment_sgpr:
assert self.dev.xccs == 1, "Only architected flat scratch is suppored on multi-xcc"
scratch_hilo = data64_le(prg.dev.scratch.va_addr)
# sgpr word1 bit31 enables swizzle
# sgpr word3 = 0x14 << 12 | 2 << 28 | 2 << 21 | 1 << 23
user_regs = [scratch_hilo[0], scratch_hilo[1] | 1 << 31, 0xffffffff, 0x20c14000] if prg.enable_private_segment_sgpr else []
else: user_regs = []
user_regs = [scratch_hilo[0], scratch_hilo[1] | 1 << 31, 0xffffffff, 0x20c14000]
if prg.enable_dispatch_ptr:
dp = hsa.hsa_kernel_dispatch_packet_t.from_address(dp_addr:=args_state.ptr + prg.kernargs_segment_size)
@@ -225,45 +242,38 @@ class AMDComputeQueue(HWQueue):
user_regs += [*data64_le(args_state.ptr)]
if prg.dev.sqtt_enabled:
self.sqtt_userdata(sqtt.struct_rgp_sqtt_marker_pipeline_bind(
_0=sqtt.union_rgp_sqtt_marker_pipeline_bind_0(_0=sqtt.struct_rgp_sqtt_marker_pipeline_bind_0_0(
identifier=sqtt.RGP_SQTT_MARKER_IDENTIFIER_BIND_PIPELINE,
bind_point=1, # compute
)),
_1=sqtt.union_rgp_sqtt_marker_pipeline_bind_1(api_pso_hash=data64_le(prg.libhash[0])),
))
self.sqtt_userdata(sqtt.struct_rgp_sqtt_marker_event(
_0=sqtt.union_rgp_sqtt_marker_event_0(_0=sqtt.struct_rgp_sqtt_marker_event_0_0(has_thread_dims=1)),
_2=sqtt.union_rgp_sqtt_marker_event_2(cmd_id=prg.dev.cmd_id),
), *global_size)
prg.dev.cmd_id += 1
if prg.dev.sqtt_enabled: self.sqtt_prg_marker(prg, global_size)
self.wreg(self.gc.regCOMPUTE_PGM_LO, *data64_le(prg.prog_addr >> 8))
self.wreg(self.gc.regCOMPUTE_PGM_RSRC1, prg.rsrc1, prg.rsrc2)
self.wreg(self.gc.regCOMPUTE_PGM_RSRC3, prg.rsrc3)
self.wreg(self.gc.regCOMPUTE_TMPRING_SIZE, prg.dev.tmpring_size)
if prg.dev.has_scratch_base_registers:
for xcc_id in range(self.dev.xccs):
with self.pred_exec(xcc_mask=1<<xcc_id):
scratch_base = prg.dev.scratch.va_addr + (prg.dev.scratch.size // self.dev.xccs * xcc_id)
self.wreg(self.gc.regCOMPUTE_DISPATCH_SCRATCH_BASE_LO, *data64_le(scratch_base >> 8))
if (10,0,0) <= prg.dev.target < (11,0,0): self.wreg(self.gc.mmCP_COHER_START_DELAY, 0x20)
self.wreg(self.gc.regCOMPUTE_RESTART_X, 0, 0, 0)
self.wreg(self.gc.regCOMPUTE_STATIC_THREAD_MGMT_SE0, 0xFFFFFFFF, 0xFFFFFFFF)
self.wreg(self.gc.regCOMPUTE_STATIC_THREAD_MGMT_SE2, 0xFFFFFFFF, 0xFFFFFFFF)
if prg.dev.target >= (11,0,0):
self.wreg(self.gc.regCOMPUTE_STATIC_THREAD_MGMT_SE4, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF)
if prg.dev.target >= (11,0,0): self.wreg(self.gc.regCOMPUTE_STATIC_THREAD_MGMT_SE4, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF)
self.wreg(self.gc.regCOMPUTE_USER_DATA_0, *user_regs)
self.wreg(self.gc.regCOMPUTE_RESOURCE_LIMITS, 0)
self.wreg(self.gc.regCOMPUTE_START_X, 0, 0, 0, *local_size, 0, 0)
self.wreg(self.gc.regCOMPUTE_RESOURCE_LIMITS, 0)
gfx10p = {'cs_w32_en': int(prg.wave32)} if prg.dev.target >= (10,0,0) else {}
DISPATCH_INITIATOR = self.gc.regCOMPUTE_DISPATCH_INITIATOR.encode(**gfx10p, force_start_at_000=1, compute_shader_en=1)
self.pkt3(self.pm4.PACKET3_DISPATCH_DIRECT, *global_size, DISPATCH_INITIATOR)
if prg.dev.sqtt_enabled: self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.THREAD_TRACE_MARKER) | self.pm4.EVENT_INDEX(0))
self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.CS_PARTIAL_FLUSH) | self.pm4.EVENT_INDEX(EVENT_INDEX_PARTIAL_FLUSH))
if self.dev.xccs > 1:
self.release_mem(cache_flush=True)
self.acquire_mem(gli=0)
@@ -850,7 +860,7 @@ class AMDDevice(HCQCompiled):
self.sqtt_buffers = [self.allocator.alloc(SQTT_BUFFER_SIZE*1024*1024, BufferSpec(cpu_access=True, nolru=True)) for _ in range(SQTT_NUM)]
self.sqtt_itrace_se_mask = getenv("SQTT_ITRACE_SE_MASK", 2) # -1 enable all, 0 disable all, >0 bitmask for where to enable instruction tracing
self.cmd_id = 0
AMDComputeQueue(self).start_trace(self.sqtt_buffers, self.sqtt_itrace_se_mask).submit(self)
AMDComputeQueue(self).sqtt_start(self.sqtt_buffers, self.sqtt_itrace_se_mask).submit(self)
def create_queue(self, queue_type, ring_size, ctx_save_restore_size=0, eop_buffer_size=0, ctl_stack_size=0, debug_memory_size=0):
ring = self.dev_iface.alloc(ring_size, uncached=True, cpu_access=True)
@@ -888,7 +898,7 @@ class AMDDevice(HCQCompiled):
if self.sqtt_enabled:
wptrs_buf = self.allocator.alloc(round_up(len(self.sqtt_buffers), 0x1000), BufferSpec(cpu_access=True, nolru=True))
wptrs = to_mv(wptrs_buf.va_addr, wptrs_buf.size)
AMDComputeQueue(self).stop_trace(len(self.sqtt_buffers), wptrs_buf).signal(self.timeline_signal, self.next_timeline()).submit(self)
AMDComputeQueue(self).sqtt_stop(len(self.sqtt_buffers), wptrs_buf).signal(self.timeline_signal, self.next_timeline()).submit(self)
self.synchronize()
if DEBUG>=2: print('Saving SQTT in profile...')
for i,buf0 in enumerate(self.sqtt_buffers):