From 7238df7a9413f8a5dbe418afb7d77de8a796468a Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 26 Nov 2025 04:10:10 +0800 Subject: [PATCH] viz: cleanup sort_fn (#13454) --- test/unit/test_viz.py | 14 +++++++++++--- tinygrad/viz/serve.py | 17 +++++++++-------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index 36d69d1a44..cb71de868f 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -6,7 +6,7 @@ from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatch from tinygrad.uop.symbolic import sym from tinygrad.dtype import dtypes from tinygrad.helpers import PROFILE, colored, ansistrip, flatten, TracingKey, ProfileRangeEvent, ProfileEvent, Context, cpu_events, profile_marker -from tinygrad.helpers import VIZ +from tinygrad.helpers import VIZ, cpu_profile from tinygrad.device import Buffer @track_rewrites(name=True) @@ -415,8 +415,8 @@ class TestVizProfiler(BaseTestViz): tracks = list(j['layout']) self.assertEqual(tracks[0], 'NV') - self.assertEqual(tracks[1], 'NV:1') - self.assertEqual(tracks[2], 'NV Graph') + self.assertEqual(tracks[1], 'NV Graph') + self.assertEqual(tracks[2], 'NV:1') nv_events = j['layout']['NV']['events'] self.assertEqual(nv_events[0]['name'], 'E_25_4n2') @@ -470,6 +470,14 @@ class TestVizProfiler(BaseTestViz): assert kernels[0]["st"] <= markers[0]["ts"] <= kernels[1]["st"] assert markers[1]["ts"] >= kernels[1]["st"]+kernels[1]["dur"] + def test_layout_order(self): + def fn(): return + for dname in ["TINY", "USER", "TEST:1 N1", "TEST:2 N1", "TEST:1 N2"]: + with cpu_profile("fn", dname): fn() + layout = list(load_profile(cpu_events)["layout"]) + self.assertListEqual(layout[:2], ["USER","TINY"]) + self.assertListEqual(layout[2:], ["TEST:1 N1","TEST:1 N2", "TEST:2 N1"]) + def _alloc(b:int): a = Tensor.empty(b, device="NULL", dtype=dtypes.char) a.uop.buffer.allocate() diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 77954833fa..279751c7ca 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -261,10 +261,11 @@ def load_sqtt(profile:list[ProfileEvent]) -> None: steps.append(create_step(k.replace(cu, ""), ("/sqtt-insts", len(ctxs), len(steps)), wave_insts[cu][k], depth=3)) ctxs.append({"name":"Counters", "steps":steps}) -def device_sort_fn(k): - order = {"USER": 0, "TINY": 1} - dname, *rest = k.split() - return order.get(dname, len(order)+len(rest)) +def device_sort_fn(k:str) -> tuple[int, str, int]: + order = {"USER": 0, "TINY": 1, "DISK": 999} + dname = k.split()[0] + dev_rank = next((v for k,v in order.items() if dname.startswith(k)), len(order)) + return (dev_rank, dname, len(k)) def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_sort_fn) -> bytes|None: # start by getting the time diffs @@ -292,12 +293,12 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_ scache:dict[str, int] = {} peaks:list[int] = [] dtype_size:dict[str, int] = {} - for k in sorted(dev_events, key=sort_fn): - (v:=dev_events[k]).sort(key=lambda e:e[0]) + for k,v in dev_events.items(): + v.sort(key=lambda e:e[0]) layout[k] = timeline_layout(v, start_ts, scache) layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtype_size, scache) - groups = layout.items() if sort_fn is not None else sorted(layout.items(), key=lambda x: '' if len(ss:=x[0].split(" ")) == 1 else ss[1]) - ret = [b"".join([struct.pack("