diff --git a/extra/viz/README b/extra/viz/README index bd3408ab77..1041da76d6 100644 --- a/extra/viz/README +++ b/extra/viz/README @@ -38,8 +38,8 @@ You can select a specific trace with --source, Example workflow: VIZ=-2 python extra/gemm/amd_asm_matmul.py # View barriers -extra/viz/cli.py --profile -s "Exec kernel SQTT PKTS SE:0" | rg BARRIER | head -10 +extra/viz/cli.py --profile -s "SQTT kernel PKTS SE:0" | rg BARRIER | head -10 # Find the EXEC corresponding to a DISPATCH at cycle 410 -extra/viz/cli.py --profile -s "Exec kernel SQTT PKTS SE:0" | awk '/EXEC/ && $1 - $5 == 410' +extra/viz/cli.py --profile -s "SQTT kernel PKTS SE:0" | awk '/EXEC/ && $1 - $5 == 410' ``` diff --git a/extra/viz/cli.py b/extra/viz/cli.py index 47f1267b35..06a67ee0d0 100755 --- a/extra/viz/cli.py +++ b/extra/viz/cli.py @@ -60,7 +60,7 @@ def main(args) -> None: if (profile_bytes:=viz.get_profile(events)) is None: raise RuntimeError(f"empty profile in {args.profile_path}") profile = decode_profile(profile_bytes) viz.load_amd_counters(viz.ctxs, events) - profile["layout"].update([(f'{c["name"]} SQTT {s["name"]}', s["data"]) for c in viz.ctxs if c["name"].startswith("Exec") for s in c["steps"] + profile["layout"].update([(f'{c["name"]} {s["name"]}', s["data"]) for c in viz.ctxs if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].startswith("PKTS")]) if args.source is None: for k in profile["layout"]: diff --git a/test/amd/test_sqtt_profiler.py b/test/amd/test_sqtt_profiler.py index 62a15a94b5..65bc97edc4 100644 --- a/test/amd/test_sqtt_profiler.py +++ b/test/amd/test_sqtt_profiler.py @@ -9,7 +9,7 @@ def save_sqtt(): Device[Device.DEFAULT].synchronize() Device[Device.DEFAULT]._at_profile_finalize() load_amd_counters(ret, Compiled.profile_events) - ret[:] = [r for r in ret if r["name"].startswith("Exec")] + ret[:] = [r for r in ret if r["name"].startswith("SQTT")] @unittest.skipUnless(Device.DEFAULT == "AMD", "only runs on AMD") class TestSQTTProfiler(unittest.TestCase): @@ -28,7 +28,7 @@ class TestSQTTProfiler(unittest.TestCase): ei = t.schedule()[0].lower() ei.run() self.assertEqual(len(sqtt), 1) - self.assertEqual(sqtt[0]["name"], f"Exec {ei.prg.p.function_name}") + self.assertEqual(sqtt[0]["name"], f"SQTT {ei.prg.p.function_name}") def test_multiple_runs(self): t = Tensor.empty(1) + 1 @@ -38,7 +38,7 @@ class TestSQTTProfiler(unittest.TestCase): ei.run() self.assertEqual(len(sqtt), N) for i in range(1, N): - self.assertEqual(sqtt[i]["name"], f"Exec {ei.prg.p.function_name} n{i+1}") + self.assertEqual(sqtt[i]["name"], f"SQTT {ei.prg.p.function_name} n{i+1}") def test_multiple_kernels(self): t = ((Tensor.empty(1) + 1).contiguous() + 2) @@ -47,7 +47,7 @@ class TestSQTTProfiler(unittest.TestCase): for si in sched: si.lower().run() self.assertEqual(len(sqtt), len(sched)) for i,k in enumerate(sched): - self.assertEqual(sqtt[i]["name"], f"Exec {k.lower().prg.p.function_name}") + self.assertEqual(sqtt[i]["name"], f"SQTT {k.lower().prg.p.function_name}") def test_multiple_kernels_lower(self): t = ((Tensor.empty(1) + 1).contiguous() + 2) @@ -57,7 +57,7 @@ class TestSQTTProfiler(unittest.TestCase): for p in prgs: p.run() self.assertEqual(len(sqtt), len(sched)) for i,ei in enumerate(prgs): - self.assertEqual(sqtt[i]["name"], f"Exec {ei.prg.p.function_name}") + self.assertEqual(sqtt[i]["name"], f"SQTT {ei.prg.p.function_name}") def test_jit(self): @TinyJit diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index dd89139bec..a95cd1d3ec 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -336,7 +336,7 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None: for e in sqtt: if e.itrace: steps.append(create_step(f"PKTS SE:{e.se}", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)), data=(e.blob, prg_events[k].lib,arch))) steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch))) - ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps}) + ctxs.append({"name":f"SQTT {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps}) wave_colors = {"WMMA": "#1F7857", **{x:"#ffffc0" for x in ["VALU", "VINTERP"]}, "SALU": "#cef263", "SMEM": "#ffc0c0", "STORE": "#4fa3cc", **{x:"#b2b7c9" for x in ["VMEM", "SGMEM"]}, "LDS": "#9fb4a6", "IMMEDIATE": "#f3b44a", "BARRIER": "#d00000",