Files
tinygrad/test/amd/test_sqttmap.py
qazal afc3904e58 viz/cli: unit tests in CI (#15788)
* simple failing test

* test stdout

* cleanup sqttmap
2026-04-17 22:34:44 +09:00

145 lines
7.0 KiB
Python

# test to compare every packet with the rocprof decoder
import unittest, pickle
from typing import Iterator
from pathlib import Path
from tinygrad.helpers import DEBUG, getenv, temp, ansistrip
from tinygrad.renderer.amd.sqtt import print_packets, map_insts
from tinygrad.runtime.autogen.amd.rdna3.ins import s_endpgm
from tinygrad.viz.serve import sqtt_timeline
from test.amd.disasm import disasm
from test.null.test_viz import run_cli
import tinygrad
EXAMPLES_DIR = Path(tinygrad.__file__).parent.parent / "extra/sqtt/examples"
def rocprof_inst_traces_match(sqtt, prg, target):
from tinygrad.viz.serve import amd_decode
from extra.sqtt.roc import decode as roc_decode, InstExec
addr_table = amd_decode(prg.lib, target)
disasm_map = {addr+prg.base:inst for addr,inst in addr_table.items()}
rctx = roc_decode([sqtt], {prg.tag:disasm_map})
rwaves = rctx.inst_execs.get((sqtt.kern, sqtt.exec_tag), [])
rwaves_iter:dict[int, list[Iterator[InstExec]]] = {} # wave unit (0-15) -> list of inst trace iterators for all executions on that unit
for w in rwaves: rwaves_iter.setdefault(w.wave_id, []).append(w.unpack_insts())
if not rwaves: return 0, 0, 0
passed_insts = 0
for pkt, info in map_insts(sqtt.blob, prg.lib, target):
if DEBUG >= 2: print_packets([(pkt, info)])
if info is None: continue
if DEBUG >= 2: print(f"{' '*29}{disasm(info.inst)}")
rocprof_inst = next(rwaves_iter[info.wave][0])
ref_pc = rocprof_inst.pc-prg.base
# always check pc matches
assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc]} != {info.pc}:{disasm(info.inst)}"
# special handling for s_endpgm, it marks the wave completion.
if info.inst == s_endpgm():
completed_wave = list(rwaves_iter[info.wave].pop(0))
assert len(completed_wave) == 0, f"incomplete instructions in wave {info.wave}"
# otherwise the packet timestamp is time + "stall"
else:
assert pkt._time == rocprof_inst.time+rocprof_inst.stall
passed_insts += 1
for k,v in rwaves_iter.items():
assert len(v) == 0, f"incomplete wave {k}"
return passed_insts, len(rwaves), len(rwaves_iter)
class TestSQTTMapBase(unittest.TestCase):
target: str
examples: dict
@classmethod
def setUpClass(cls):
if cls is TestSQTTMapBase: raise unittest.SkipTest("base class")
cls.examples = {}
for pkl_path in ([Path(temp("profile.pkl", append_user=True))] if getenv("LOAD_PROFILE") else sorted((EXAMPLES_DIR/cls.target).glob("*.pkl"))):
with open(pkl_path, "rb") as f:
data = pickle.load(f)
sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"]
kern_events = {e.tag:e for e in data if type(e).__name__ == "ProfileProgramEvent"}
if sqtt_events and kern_events:
cls.examples[pkl_path.stem] = (sqtt_events, kern_events, cls.target)
def test_rocprof_inst_traces_match(self):
for name, (events, kern_events, target) in self.examples.items():
if "sync" in name and self.target.startswith("gfx12"):
self.skipTest("our timestamps are off by a few cycles because rocprof patches timestamps for rdna4 barriers")
for event in events:
if not event.itrace: continue
if event.kern not in kern_events: continue
with self.subTest(example=name, kern=event.kern):
passed_insts, n_waves, n_units = rocprof_inst_traces_match(event, kern_events[event.kern], target)
if n_waves: print(f"{name}: passed for {passed_insts} instructions across {n_waves} waves scheduled on {n_units} wave units")
def test_sqtt_timeline(self):
for name, (events, kern_events, target) in self.examples.items():
for event in events:
if (p:=kern_events.get(event.kern)) is None: continue
with self.subTest(example=name, kern=event.kern):
# skip if there's no SQTT frequency data
if not (timeline:=list(sqtt_timeline(event.blob, p.lib, target))): continue
if not (frequency:=[e.key for e in timeline if type(e).__name__ == "ProfilePointEvent" and e.name == "freq_hz"]): continue
mean = sum(frequency) / len(frequency)
variance = sum((v - mean) ** 2 for v in frequency) / len(frequency)
self.assertGreater(mean, 0)
if DEBUG >= 2: print(f"{name:20s} SE:{event.se} {mean/1e9:.2f} GHz mean, {variance/1e18:.2f} GHz^2 variance")
events = [e for e in timeline if type(e).__name__ == "ProfileRangeEvent"]
insts, execs = 0, 0
for e in events:
if "EXEC" in e.device:
if "ALT" not in e.name.display_name: execs += 1
elif "WAVE" in e.device:
# sopk/immediates don't get ALU/MEM EXEC
if e.name.display_name not in {"IMMEDIATE", "IMMEDIATE_MASK", "JUMP", "JUMP_NO", "MESSAGE", "BARRIER", "BARRIER_SIGNAL",
"WAVEEND", "WAVERDY"} and not e.name.display_name.startswith("OTHER_"): insts += 1
else: raise Exception(f"timeline row must be INST or EXEC, got {e.device}")
self.assertEqual(execs, insts)
def test_wave_sync(self):
for name, (events, kern_events, target) in self.examples.items():
for event in events:
wave_barriers = {}
for e in sqtt_timeline(event.blob, kern_events[event.kern].lib, target):
if type(e).__name__ == "ProfileRangeEvent" and e.name.display_name == "BARRIER": wave_barriers.setdefault(e.device, []).append(e)
if not wave_barriers: continue
for row, events in wave_barriers.items():
for e in events:
assert e.en-e.st > 1, f"all barriers must have a duration greater than 1, got {e}"
def test_sqtt_cli(self):
for pkl_path in sorted((EXAMPLES_DIR/self.target).glob("*.pkl")):
out = run_cli("--profile", "--profile-path", str(pkl_path))
sqtt_traces = [l.strip() for l in out.split("\n") if "SQTT" in l]
for name in sqtt_traces:
out = run_cli("--profile", "--profile-path", str(pkl_path), "-s", ansistrip(name))
lines = out.split("\n")
self.assertIn("Clk", lines[0])
for r in lines[2:]:
parts = r.split()
self.assertTrue(parts[0].isdigit(), f"expected clock timestamp, got {parts[0]}")
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"
class TestSQTTMapRDNA4(TestSQTTMapBase):
target = "gfx1200"
@unittest.expectedFailure
def test_rdna4_wmma(self):
events, kernels, target = self.examples["profile_handwritten_run_0"]
row_ends = {}
for e in sqtt_timeline(events[0].blob, list(kernels.values())[0].lib, target):
if type(e).__name__ != "ProfileRangeEvent" or e.device != "ALUEXEC:0 WMMA": continue
if (et:=row_ends.get(e.device)) is not None and e.st < et:
raise RuntimeError(f"WMMA exec overlaps in {e.device}: {e.st} {et}.")
row_ends[e.device] = e.en
class TestSQTTMapCDNA(TestSQTTMapBase):
target = "gfx950"
def test_rocprof_inst_traces_match(self): self.skipTest("requires timestamp patching to match rocprof, currently it's off by a few cycles")
if __name__ == "__main__":
unittest.main()