# test to compare every packet with the rocprof decoder import unittest, pickle, contextlib, io from typing import Iterator from pathlib import Path from tinygrad.helpers import DEBUG, getenv, temp, ansistrip from tinygrad.renderer.amd.sqtt import print_packets, map_insts from tinygrad.runtime.autogen.amd.rdna3.ins import s_endpgm from tinygrad.viz.serve import sqtt_timeline from test.amd.disasm import disasm import tinygrad EXAMPLES_DIR = Path(tinygrad.__file__).parent.parent / "extra/sqtt/examples" def run_cli(*cli_args) -> str: from extra.viz.cli import main, get_arg_parser args = get_arg_parser().parse_args(cli_args) with contextlib.redirect_stdout(buf:=io.StringIO()): main(args) return buf.getvalue().strip() def rocprof_inst_traces_match(sqtt, prg, target): from tinygrad.viz.serve import amd_decode from extra.sqtt.roc import decode as roc_decode, InstExec addr_table = amd_decode(prg.lib, target) disasm_map = {addr+prg.base:inst for addr,inst in addr_table.items()} rctx = roc_decode([sqtt], {prg.tag:disasm_map}) rwaves = rctx.inst_execs.get((sqtt.kern, sqtt.exec_tag), []) rwaves_iter:dict[int, list[Iterator[InstExec]]] = {} # wave unit (0-15) -> list of inst trace iterators for all executions on that unit for w in rwaves: rwaves_iter.setdefault(w.wave_id, []).append(w.unpack_insts()) if not rwaves: return 0, 0, 0 passed_insts = 0 for pkt, info in map_insts(sqtt.blob, prg.lib, target): if DEBUG >= 2: print_packets([(pkt, info)]) if info is None: continue if DEBUG >= 2: print(f"{' '*29}{disasm(info.inst)}") rocprof_inst = next(rwaves_iter[info.wave][0]) ref_pc = rocprof_inst.pc-prg.base # always check pc matches assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc]} != {info.pc}:{disasm(info.inst)}" # special handling for s_endpgm, it marks the wave completion. if info.inst == s_endpgm(): completed_wave = list(rwaves_iter[info.wave].pop(0)) assert len(completed_wave) == 0, f"incomplete instructions in wave {info.wave}" # otherwise the packet timestamp is time + "stall" else: assert pkt._time == rocprof_inst.time+rocprof_inst.stall passed_insts += 1 for k,v in rwaves_iter.items(): assert len(v) == 0, f"incomplete wave {k}" return passed_insts, len(rwaves), len(rwaves_iter) class TestSQTTMapBase(unittest.TestCase): target: str examples: dict @classmethod def setUpClass(cls): if cls is TestSQTTMapBase: raise unittest.SkipTest("base class") cls.examples = {} for pkl_path in ([Path(temp("profile.pkl", append_user=True))] if getenv("LOAD_PROFILE") else sorted((EXAMPLES_DIR/cls.target).glob("*.pkl"))): with open(pkl_path, "rb") as f: data = pickle.load(f) sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"] kern_events = {e.tag:e for e in data if type(e).__name__ == "ProfileProgramEvent"} if sqtt_events and kern_events: cls.examples[pkl_path.stem] = (sqtt_events, kern_events, cls.target) def test_rocprof_inst_traces_match(self): for name, (events, kern_events, target) in self.examples.items(): if "sync" in name and self.target.startswith("gfx12"): self.skipTest("our timestamps are off by a few cycles because rocprof patches timestamps for rdna4 barriers") for event in events: if not event.itrace: continue if event.kern not in kern_events: continue with self.subTest(example=name, kern=event.kern): passed_insts, n_waves, n_units = rocprof_inst_traces_match(event, kern_events[event.kern], target) if n_waves: print(f"{name}: passed for {passed_insts} instructions across {n_waves} waves scheduled on {n_units} wave units") def test_sqtt_timeline(self): for name, (events, kern_events, target) in self.examples.items(): for event in events: if (p:=kern_events.get(event.kern)) is None: continue with self.subTest(example=name, kern=event.kern): # skip if there's no SQTT frequency data if not (timeline:=list(sqtt_timeline(event.blob, p.lib, target))): continue if not (frequency:=[e.key for e in timeline if type(e).__name__ == "ProfilePointEvent" and e.name == "freq_hz"]): continue mean = sum(frequency) / len(frequency) variance = sum((v - mean) ** 2 for v in frequency) / len(frequency) self.assertGreater(mean, 0) if DEBUG >= 2: print(f"{name:20s} SE:{event.se} {mean/1e9:.2f} GHz mean, {variance/1e18:.2f} GHz^2 variance") events = [e for e in timeline if type(e).__name__ == "ProfileRangeEvent"] insts, execs = 0, 0 for e in events: if "EXEC" in e.device: if "ALT" not in e.name.display_name: execs += 1 elif "WAVE" in e.device: # sopk/immediates don't get ALU/MEM EXEC if e.name.display_name not in {"IMMEDIATE", "IMMEDIATE_MASK", "JUMP", "JUMP_NO", "MESSAGE", "BARRIER", "BARRIER_SIGNAL", "WAVEEND", "WAVERDY"} and not e.name.display_name.startswith("OTHER_"): insts += 1 else: raise Exception(f"timeline row must be INST or EXEC, got {e.device}") self.assertEqual(execs, insts) def test_wave_sync(self): for name, (events, kern_events, target) in self.examples.items(): for event in events: wave_barriers = {} for e in sqtt_timeline(event.blob, kern_events[event.kern].lib, target): if type(e).__name__ == "ProfileRangeEvent" and e.name.display_name == "BARRIER": wave_barriers.setdefault(e.device, []).append(e) if not wave_barriers: continue for row, events in wave_barriers.items(): for e in events: assert e.en-e.st > 1, f"all barriers must have a duration greater than 1, got {e}" def test_sqtt_cli(self): for pkl_path in sorted((EXAMPLES_DIR/self.target).glob("*.pkl")): out = run_cli("--profile", "--profile-path", str(pkl_path)) sqtt_traces = [l.strip() for l in out.split("\n") if "SQTT" in l] for name in sqtt_traces: out = run_cli("--profile", "--profile-path", str(pkl_path), "-s", ansistrip(name)) lines = out.split("\n") self.assertIn("Clk", lines[0]) for r in lines[2:]: parts = r.split() self.assertTrue(parts[0].isdigit(), f"expected clock timestamp, got {parts[0]}") class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100" class TestSQTTMapRDNA4(TestSQTTMapBase): target = "gfx1200" @unittest.expectedFailure def test_rdna4_wmma(self): events, kernels, target = self.examples["profile_handwritten_run_0"] row_ends = {} for e in sqtt_timeline(events[0].blob, list(kernels.values())[0].lib, target): if type(e).__name__ != "ProfileRangeEvent" or e.device != "ALUEXEC:0 WMMA": continue if (et:=row_ends.get(e.device)) is not None and e.st < et: raise RuntimeError(f"WMMA exec overlaps in {e.device}: {e.st} {et}.") row_ends[e.device] = e.en class TestSQTTMapCDNA(TestSQTTMapBase): target = "gfx950" def test_rocprof_inst_traces_match(self): self.skipTest("requires timestamp patching to match rocprof, currently it's off by a few cycles") if __name__ == "__main__": unittest.main()