diff --git a/extra/sqtt/roc.py b/extra/sqtt/roc.py index 5fec24c2a8..bab27a0a19 100644 --- a/extra/sqtt/roc.py +++ b/extra/sqtt/roc.py @@ -1,4 +1,5 @@ import ctypes, pathlib, argparse, pickle, re, functools, dataclasses, itertools, threading +from typing import Generator from tinygrad.helpers import temp, unwrap, DEBUG from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileProgramEvent from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent @@ -31,7 +32,7 @@ def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]: @dataclasses.dataclass(frozen=True) class InstExec: typ:str - inst:str + pc:int stall:int dur:int time:int @@ -44,7 +45,13 @@ class WaveExec: se:int begin_time:int end_time:int - insts:list[InstExec] + insts:bytearray + def unpack_insts(self) -> Generator[InstExec, None, None]: + sz = ctypes.sizeof(struct:=rocprof.rocprofiler_thread_trace_decoder_inst_t) + insts_array = (struct*(len(self.insts)//sz)).from_buffer(self.insts) + for inst in insts_array: + inst_typ = rocprof.enum_rocprofiler_thread_trace_decoder_inst_category_t.get(inst.category) + yield InstExec(inst_typ, inst.pc.address, inst.stall, inst.duration, inst.time) class _ROCParseCtx: def __init__(self, dev_evs:dict[str, ProfileDeviceEvent], sqtt_evs:list[ProfileSQTTEvent], prog_evs:list[ProfileProgramEvent]): @@ -70,18 +77,12 @@ class _ROCParseCtx: def on_wave_ev(self, ev:rocprof.rocprofiler_thread_trace_decoder_wave_t): if DEBUG >= 5: print("WAVE", ev.wave_id, self.active_se, ev.cu, ev.simd, ev.contexts, ev.begin_time, ev.end_time) - inst_execs:list[InstExec] = [] - disasm = self.disasms[unwrap(self.active_kern)] - for j in range(ev.instructions_size): - inst_ev = ev.instructions_array[j] - inst_typ = rocprof.enum_rocprofiler_thread_trace_decoder_inst_category_t.get(inst_ev.category) - inst_disasm = disasm[unwrap(inst_ev.pc.address)][0] - inst_execs.append(InstExec(inst_typ, inst_disasm, inst_ev.stall, inst_ev.duration, inst_ev.time)) - if DEBUG >= 8: print(inst_execs[-1]) + insts_blob = bytearray(sz:=ev.instructions_size * ctypes.sizeof(rocprof.rocprofiler_thread_trace_decoder_inst_t)) + ctypes.memmove((ctypes.c_char * sz).from_buffer(insts_blob), ev.instructions_array, sz) if ev.instructions_size > 0: self.inst_execs.setdefault(unwrap(self.active_kern), []).append(WaveExec(ev.wave_id, ev.cu, ev.simd, unwrap(self.active_se), ev.begin_time, - ev.end_time, inst_execs)) + ev.end_time, insts_blob)) def decode(profile:list[ProfileEvent]) -> _ROCParseCtx: dev_events:dict[str, ProfileDeviceEvent] = {} diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 6caaba7e99..329946fc80 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -225,13 +225,7 @@ def load_sqtt(profile:list[ProfileEvent]) -> None: if not rctx.inst_execs: return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded") steps:list[dict] = [] for name,waves in rctx.inst_execs.items(): - # Idle: The total time gap between the completion of previous instruction and the beginning of the current instruction. - # The idle time can be caused by: - # * Arbiter loss - # * Source or destination register dependency - # * Instruction cache miss - # Stall: The total number of cycles the hardware pipe couldn't issue an instruction. - # Duration: Total latency in cycles, defined as "Stall time + Issue time" for gfx9 or "Stall time + Execute time" for gfx10+. + disasm = rctx.disasms[name] units:dict[str, int] = {} events:list[ProfileEvent] = [] wave_execs:dict[str, dict] = {} @@ -239,19 +233,13 @@ def load_sqtt(profile:list[ProfileEvent]) -> None: if (row:=f"SE:{w.se} CU:{w.cu} SIMD:{w.simd} WAVE:{w.wave_id}") not in units: units[row] = 0 units[row] += 1 events.append(ProfileRangeEvent(row, f"N:{units[row]}", Decimal(w.begin_time), Decimal(w.end_time))) - rows, prev_instr = [], w.begin_time - for i,e in enumerate(w.insts): - rows.append((e.inst, e.time, max(0, e.time-prev_instr), e.dur, e.stall, str(e.typ).split("_")[-1])) - prev_instr = max(prev_instr, e.time + e.dur) - summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu}, - {"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":units[row]}] - wave_execs[f"{row} N:{units[row]}"] = {"rows":rows, "cols":["Instruction", "Clk", "Idle", "Duration", "Stall", "Type"], "summary":summary} + wave_execs[f"{row} N:{units[row]}"] = {"wave":w, "disasm":disasm, "run_number":units[row]} # gather and sort all wave execs of this kernel events = [ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+events kernel = trace.keys[r].ret if (r:=ref_map.get(name)) else None steps.append(create_step(kernel.name if kernel is not None else name, ("/counters", len(ctxs), len(steps)), {"value":get_profile(events, sort_fn=row_tuple), "content_type":"application/octet-stream"}, depth=1)) - for k in sorted(wave_execs, key=row_tuple): steps.append(create_step(k, ("/counters", len(ctxs), len(steps)), wave_execs[k], depth=2)) + for k in sorted(wave_execs, key=row_tuple): steps.append(create_step(k, ("/sqtt-insts", len(ctxs), len(steps)), wave_execs[k], depth=2)) ctxs.append({"name":"Counters", "steps":steps}) def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]|None=None) -> bytes|None: @@ -330,6 +318,24 @@ def get_render(i:int, j:int, fmt:str) -> dict: return get_llvm_mca(disasm_str, ctypes.string_at(llvm.LLVMGetTargetMachineTriple(tm:=compiler.target_machine)).decode(), ctypes.string_at(llvm.LLVMGetTargetMachineCPU(tm)).decode()) return {"src":disasm_str, "lang":"x86asm"} + if fmt == "sqtt-insts": + columns = ["Instruction", "Clk", "Idle", "Duration", "Stall", "Type"] + # Idle: The total time gap between the completion of previous instruction and the beginning of the current instruction. + # The idle time can be caused by: + # * Arbiter loss + # * Source or destination register dependency + # * Instruction cache miss + # Stall: The total number of cycles the hardware pipe couldn't issue an instruction. + # Duration: Total latency in cycles, defined as "Stall time + Issue time" for gfx9 or "Stall time + Execute time" for gfx10+. + prev_instr = (w:=data["wave"]).begin_time + pc_to_inst = data["disasm"] + rows:list[tuple] = [] + for e in w.unpack_insts(): + rows.append((pc_to_inst[e.pc][0], e.time, max(0, e.time-prev_instr), e.dur, e.stall, str(e.typ).split("_")[-1])) + prev_instr = max(prev_instr, e.time + e.dur) + summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu}, + {"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}] + return {"rows":rows, "cols":columns, "summary":summary} return data # ** HTTP server