mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
amd: refactor sqtt into sep functions (#9816)
* amd: refactor sqtt into sep functions * fix
This commit is contained in:
@@ -59,11 +59,6 @@ class AMDComputeQueue(HWQueue):
|
||||
if self.dev.xccs > 1:
|
||||
self._q[prev_len-1] |= (len(self._q) - prev_len)
|
||||
|
||||
def sqtt_userdata(self, data, *extra_dwords):
|
||||
data_ints = [x[0] for x in struct.iter_unpack('<I', bytes(data))] + list(extra_dwords)
|
||||
for i in range(0, len(data_ints), 2):
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_USERDATA_2, *data_ints[i:i+2])
|
||||
|
||||
def wait_reg_mem(self, value, mask=0xffffffff, mem=None, reg_req=None, reg_done=None):
|
||||
wrm_info_dw = self.pm4.WAIT_REG_MEM_MEM_SPACE(int(mem is not None)) | self.pm4.WAIT_REG_MEM_OPERATION(int(mem is None)) \
|
||||
| self.pm4.WAIT_REG_MEM_FUNCTION(WAIT_REG_MEM_FUNCTION_GEQ) | self.pm4.WAIT_REG_MEM_ENGINE(0)
|
||||
@@ -138,12 +133,19 @@ class AMDComputeQueue(HWQueue):
|
||||
self.wreg(self.gc.regSPI_CONFIG_CNTL, ps_pkr_priority_cntl=3, exp_priority_order=3, gpr_write_priority=0x2c688,
|
||||
enable_sqg_bop_events=int(tracing), enable_sqg_top_events=int(tracing))
|
||||
|
||||
### SQTT ###
|
||||
|
||||
def sqtt_userdata(self, data, *extra_dwords):
|
||||
data_ints = [x[0] for x in struct.iter_unpack('<I', bytes(data))] + list(extra_dwords)
|
||||
for i in range(0, len(data_ints), 2):
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_USERDATA_2, *data_ints[i:i+2])
|
||||
|
||||
def sqtt_config(self, tracing:bool):
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_CTRL, draw_event_en=1, spi_stall_en=1, sq_stall_en=1, reg_at_hwm=2, hiwater=1,
|
||||
rt_freq=self.soc.SQ_TT_RT_FREQ_4096_CLK, util_timer=self.soc.SQ_TT_UTIL_TIMER_250_CLK, mode=int(tracing))
|
||||
|
||||
# Magic values from mesa/src/amd/vulkan/radv_sqtt.c:radv_emit_spi_config_cntl and src/amd/common/ac_sqtt.c:ac_sqtt_emit_start
|
||||
def start_trace(self, buf0s:list[HCQBuffer], se_mask:int):
|
||||
def sqtt_start(self, buf0s:list[HCQBuffer], se_mask:int):
|
||||
self.memory_barrier()
|
||||
self.spi_config(tracing=True)
|
||||
# One buffer for one SE, mesa does it with a single buffer and ac_sqtt_get_data_offset, but this is simpler and should work just as well
|
||||
@@ -176,7 +178,7 @@ class AMDComputeQueue(HWQueue):
|
||||
return self
|
||||
|
||||
# Magic values from src/amd/common/ac_sqtt.c:ac_sqtt_emit_stop and src/amd/common/ac_sqtt.c:ac_sqtt_emit_wait
|
||||
def stop_trace(self, ses: int, wptrs: HCQBuffer):
|
||||
def sqtt_stop(self, ses: int, wptrs: HCQBuffer):
|
||||
self.memory_barrier()
|
||||
# Start shutting everything down
|
||||
self.wreg(self.gc.regCOMPUTE_THREAD_TRACE_ENABLE, 0)
|
||||
@@ -203,18 +205,33 @@ class AMDComputeQueue(HWQueue):
|
||||
self.memory_barrier()
|
||||
return self
|
||||
|
||||
def sqtt_prg_marker(self, prg:AMDProgram, global_size:tuple[sint, ...]):
|
||||
BIND_POINT_COMPUTE = 1
|
||||
|
||||
self.sqtt_userdata(sqtt.struct_rgp_sqtt_marker_pipeline_bind(
|
||||
_0=sqtt.union_rgp_sqtt_marker_pipeline_bind_0(_0=sqtt.struct_rgp_sqtt_marker_pipeline_bind_0_0(
|
||||
identifier=sqtt.RGP_SQTT_MARKER_IDENTIFIER_BIND_PIPELINE, bind_point=BIND_POINT_COMPUTE)),
|
||||
_1=sqtt.union_rgp_sqtt_marker_pipeline_bind_1(api_pso_hash=data64_le(prg.libhash[0]))))
|
||||
|
||||
self.sqtt_userdata(sqtt.struct_rgp_sqtt_marker_event(
|
||||
_0=sqtt.union_rgp_sqtt_marker_event_0(_0=sqtt.struct_rgp_sqtt_marker_event_0_0(has_thread_dims=1)),
|
||||
_2=sqtt.union_rgp_sqtt_marker_event_2(cmd_id=prg.dev.cmd_id)), *global_size)
|
||||
|
||||
prg.dev.cmd_id += 1
|
||||
|
||||
def exec(self, prg:AMDProgram, args_state:CLikeArgsState, global_size:tuple[sint, ...], local_size:tuple[sint, ...]):
|
||||
self.bind_args_state(args_state)
|
||||
|
||||
self.acquire_mem(gli=0, gl2=0)
|
||||
|
||||
user_regs = []
|
||||
if prg.enable_private_segment_sgpr:
|
||||
assert self.dev.xccs == 1, "Only architected flat scratch is suppored on multi-xcc"
|
||||
scratch_hilo = data64_le(prg.dev.scratch.va_addr)
|
||||
# sgpr word1 bit31 enables swizzle
|
||||
# sgpr word3 = 0x14 << 12 | 2 << 28 | 2 << 21 | 1 << 23
|
||||
user_regs = [scratch_hilo[0], scratch_hilo[1] | 1 << 31, 0xffffffff, 0x20c14000] if prg.enable_private_segment_sgpr else []
|
||||
else: user_regs = []
|
||||
user_regs = [scratch_hilo[0], scratch_hilo[1] | 1 << 31, 0xffffffff, 0x20c14000]
|
||||
|
||||
if prg.enable_dispatch_ptr:
|
||||
dp = hsa.hsa_kernel_dispatch_packet_t.from_address(dp_addr:=args_state.ptr + prg.kernargs_segment_size)
|
||||
|
||||
@@ -225,45 +242,38 @@ class AMDComputeQueue(HWQueue):
|
||||
|
||||
user_regs += [*data64_le(args_state.ptr)]
|
||||
|
||||
if prg.dev.sqtt_enabled:
|
||||
self.sqtt_userdata(sqtt.struct_rgp_sqtt_marker_pipeline_bind(
|
||||
_0=sqtt.union_rgp_sqtt_marker_pipeline_bind_0(_0=sqtt.struct_rgp_sqtt_marker_pipeline_bind_0_0(
|
||||
identifier=sqtt.RGP_SQTT_MARKER_IDENTIFIER_BIND_PIPELINE,
|
||||
bind_point=1, # compute
|
||||
)),
|
||||
_1=sqtt.union_rgp_sqtt_marker_pipeline_bind_1(api_pso_hash=data64_le(prg.libhash[0])),
|
||||
))
|
||||
self.sqtt_userdata(sqtt.struct_rgp_sqtt_marker_event(
|
||||
_0=sqtt.union_rgp_sqtt_marker_event_0(_0=sqtt.struct_rgp_sqtt_marker_event_0_0(has_thread_dims=1)),
|
||||
_2=sqtt.union_rgp_sqtt_marker_event_2(cmd_id=prg.dev.cmd_id),
|
||||
), *global_size)
|
||||
prg.dev.cmd_id += 1
|
||||
if prg.dev.sqtt_enabled: self.sqtt_prg_marker(prg, global_size)
|
||||
|
||||
self.wreg(self.gc.regCOMPUTE_PGM_LO, *data64_le(prg.prog_addr >> 8))
|
||||
self.wreg(self.gc.regCOMPUTE_PGM_RSRC1, prg.rsrc1, prg.rsrc2)
|
||||
self.wreg(self.gc.regCOMPUTE_PGM_RSRC3, prg.rsrc3)
|
||||
self.wreg(self.gc.regCOMPUTE_TMPRING_SIZE, prg.dev.tmpring_size)
|
||||
|
||||
if prg.dev.has_scratch_base_registers:
|
||||
for xcc_id in range(self.dev.xccs):
|
||||
with self.pred_exec(xcc_mask=1<<xcc_id):
|
||||
scratch_base = prg.dev.scratch.va_addr + (prg.dev.scratch.size // self.dev.xccs * xcc_id)
|
||||
self.wreg(self.gc.regCOMPUTE_DISPATCH_SCRATCH_BASE_LO, *data64_le(scratch_base >> 8))
|
||||
|
||||
if (10,0,0) <= prg.dev.target < (11,0,0): self.wreg(self.gc.mmCP_COHER_START_DELAY, 0x20)
|
||||
|
||||
self.wreg(self.gc.regCOMPUTE_RESTART_X, 0, 0, 0)
|
||||
self.wreg(self.gc.regCOMPUTE_STATIC_THREAD_MGMT_SE0, 0xFFFFFFFF, 0xFFFFFFFF)
|
||||
self.wreg(self.gc.regCOMPUTE_STATIC_THREAD_MGMT_SE2, 0xFFFFFFFF, 0xFFFFFFFF)
|
||||
if prg.dev.target >= (11,0,0):
|
||||
self.wreg(self.gc.regCOMPUTE_STATIC_THREAD_MGMT_SE4, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF)
|
||||
if prg.dev.target >= (11,0,0): self.wreg(self.gc.regCOMPUTE_STATIC_THREAD_MGMT_SE4, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF)
|
||||
|
||||
self.wreg(self.gc.regCOMPUTE_USER_DATA_0, *user_regs)
|
||||
self.wreg(self.gc.regCOMPUTE_RESOURCE_LIMITS, 0)
|
||||
|
||||
self.wreg(self.gc.regCOMPUTE_START_X, 0, 0, 0, *local_size, 0, 0)
|
||||
self.wreg(self.gc.regCOMPUTE_RESOURCE_LIMITS, 0)
|
||||
|
||||
gfx10p = {'cs_w32_en': int(prg.wave32)} if prg.dev.target >= (10,0,0) else {}
|
||||
DISPATCH_INITIATOR = self.gc.regCOMPUTE_DISPATCH_INITIATOR.encode(**gfx10p, force_start_at_000=1, compute_shader_en=1)
|
||||
self.pkt3(self.pm4.PACKET3_DISPATCH_DIRECT, *global_size, DISPATCH_INITIATOR)
|
||||
|
||||
if prg.dev.sqtt_enabled: self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.THREAD_TRACE_MARKER) | self.pm4.EVENT_INDEX(0))
|
||||
self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.CS_PARTIAL_FLUSH) | self.pm4.EVENT_INDEX(EVENT_INDEX_PARTIAL_FLUSH))
|
||||
|
||||
if self.dev.xccs > 1:
|
||||
self.release_mem(cache_flush=True)
|
||||
self.acquire_mem(gli=0)
|
||||
@@ -850,7 +860,7 @@ class AMDDevice(HCQCompiled):
|
||||
self.sqtt_buffers = [self.allocator.alloc(SQTT_BUFFER_SIZE*1024*1024, BufferSpec(cpu_access=True, nolru=True)) for _ in range(SQTT_NUM)]
|
||||
self.sqtt_itrace_se_mask = getenv("SQTT_ITRACE_SE_MASK", 2) # -1 enable all, 0 disable all, >0 bitmask for where to enable instruction tracing
|
||||
self.cmd_id = 0
|
||||
AMDComputeQueue(self).start_trace(self.sqtt_buffers, self.sqtt_itrace_se_mask).submit(self)
|
||||
AMDComputeQueue(self).sqtt_start(self.sqtt_buffers, self.sqtt_itrace_se_mask).submit(self)
|
||||
|
||||
def create_queue(self, queue_type, ring_size, ctx_save_restore_size=0, eop_buffer_size=0, ctl_stack_size=0, debug_memory_size=0):
|
||||
ring = self.dev_iface.alloc(ring_size, uncached=True, cpu_access=True)
|
||||
@@ -888,7 +898,7 @@ class AMDDevice(HCQCompiled):
|
||||
if self.sqtt_enabled:
|
||||
wptrs_buf = self.allocator.alloc(round_up(len(self.sqtt_buffers), 0x1000), BufferSpec(cpu_access=True, nolru=True))
|
||||
wptrs = to_mv(wptrs_buf.va_addr, wptrs_buf.size)
|
||||
AMDComputeQueue(self).stop_trace(len(self.sqtt_buffers), wptrs_buf).signal(self.timeline_signal, self.next_timeline()).submit(self)
|
||||
AMDComputeQueue(self).sqtt_stop(len(self.sqtt_buffers), wptrs_buf).signal(self.timeline_signal, self.next_timeline()).submit(self)
|
||||
self.synchronize()
|
||||
if DEBUG>=2: print('Saving SQTT in profile...')
|
||||
for i,buf0 in enumerate(self.sqtt_buffers):
|
||||
|
||||
Reference in New Issue
Block a user