mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
@@ -38,8 +38,8 @@ You can select a specific trace with --source, Example workflow:
|
||||
VIZ=-2 python extra/gemm/amd_asm_matmul.py
|
||||
|
||||
# View barriers
|
||||
extra/viz/cli.py --profile -s "Exec kernel SQTT PKTS SE:0" | rg BARRIER | head -10
|
||||
extra/viz/cli.py --profile -s "SQTT kernel PKTS SE:0" | rg BARRIER | head -10
|
||||
|
||||
# Find the EXEC corresponding to a DISPATCH at cycle 410
|
||||
extra/viz/cli.py --profile -s "Exec kernel SQTT PKTS SE:0" | awk '/EXEC/ && $1 - $5 == 410'
|
||||
extra/viz/cli.py --profile -s "SQTT kernel PKTS SE:0" | awk '/EXEC/ && $1 - $5 == 410'
|
||||
```
|
||||
|
||||
@@ -60,7 +60,7 @@ def main(args) -> None:
|
||||
if (profile_bytes:=viz.get_profile(events)) is None: raise RuntimeError(f"empty profile in {args.profile_path}")
|
||||
profile = decode_profile(profile_bytes)
|
||||
viz.load_amd_counters(viz.ctxs, events)
|
||||
profile["layout"].update([(f'{c["name"]} SQTT {s["name"]}', s["data"]) for c in viz.ctxs if c["name"].startswith("Exec") for s in c["steps"]
|
||||
profile["layout"].update([(f'{c["name"]} {s["name"]}', s["data"]) for c in viz.ctxs if c["name"].startswith("SQTT") for s in c["steps"]
|
||||
if s["name"].startswith("PKTS")])
|
||||
if args.source is None:
|
||||
for k in profile["layout"]:
|
||||
|
||||
@@ -9,7 +9,7 @@ def save_sqtt():
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
Device[Device.DEFAULT]._at_profile_finalize()
|
||||
load_amd_counters(ret, Compiled.profile_events)
|
||||
ret[:] = [r for r in ret if r["name"].startswith("Exec")]
|
||||
ret[:] = [r for r in ret if r["name"].startswith("SQTT")]
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "AMD", "only runs on AMD")
|
||||
class TestSQTTProfiler(unittest.TestCase):
|
||||
@@ -28,7 +28,7 @@ class TestSQTTProfiler(unittest.TestCase):
|
||||
ei = t.schedule()[0].lower()
|
||||
ei.run()
|
||||
self.assertEqual(len(sqtt), 1)
|
||||
self.assertEqual(sqtt[0]["name"], f"Exec {ei.prg.p.function_name}")
|
||||
self.assertEqual(sqtt[0]["name"], f"SQTT {ei.prg.p.function_name}")
|
||||
|
||||
def test_multiple_runs(self):
|
||||
t = Tensor.empty(1) + 1
|
||||
@@ -38,7 +38,7 @@ class TestSQTTProfiler(unittest.TestCase):
|
||||
ei.run()
|
||||
self.assertEqual(len(sqtt), N)
|
||||
for i in range(1, N):
|
||||
self.assertEqual(sqtt[i]["name"], f"Exec {ei.prg.p.function_name} n{i+1}")
|
||||
self.assertEqual(sqtt[i]["name"], f"SQTT {ei.prg.p.function_name} n{i+1}")
|
||||
|
||||
def test_multiple_kernels(self):
|
||||
t = ((Tensor.empty(1) + 1).contiguous() + 2)
|
||||
@@ -47,7 +47,7 @@ class TestSQTTProfiler(unittest.TestCase):
|
||||
for si in sched: si.lower().run()
|
||||
self.assertEqual(len(sqtt), len(sched))
|
||||
for i,k in enumerate(sched):
|
||||
self.assertEqual(sqtt[i]["name"], f"Exec {k.lower().prg.p.function_name}")
|
||||
self.assertEqual(sqtt[i]["name"], f"SQTT {k.lower().prg.p.function_name}")
|
||||
|
||||
def test_multiple_kernels_lower(self):
|
||||
t = ((Tensor.empty(1) + 1).contiguous() + 2)
|
||||
@@ -57,7 +57,7 @@ class TestSQTTProfiler(unittest.TestCase):
|
||||
for p in prgs: p.run()
|
||||
self.assertEqual(len(sqtt), len(sched))
|
||||
for i,ei in enumerate(prgs):
|
||||
self.assertEqual(sqtt[i]["name"], f"Exec {ei.prg.p.function_name}")
|
||||
self.assertEqual(sqtt[i]["name"], f"SQTT {ei.prg.p.function_name}")
|
||||
|
||||
def test_jit(self):
|
||||
@TinyJit
|
||||
|
||||
@@ -336,7 +336,7 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None:
|
||||
for e in sqtt:
|
||||
if e.itrace: steps.append(create_step(f"PKTS SE:{e.se}", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)), data=(e.blob, prg_events[k].lib,arch)))
|
||||
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
|
||||
ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
|
||||
ctxs.append({"name":f"SQTT {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
|
||||
|
||||
wave_colors = {"WMMA": "#1F7857", **{x:"#ffffc0" for x in ["VALU", "VINTERP"]}, "SALU": "#cef263", "SMEM": "#ffc0c0", "STORE": "#4fa3cc",
|
||||
**{x:"#b2b7c9" for x in ["VMEM", "SGMEM"]}, "LDS": "#9fb4a6", "IMMEDIATE": "#f3b44a", "BARRIER": "#d00000",
|
||||
|
||||
Reference in New Issue
Block a user