display viz rewrites with tabbing if they are subrewrites (#10097)

* display viz rewrites with tabbing if they are subrewrites

* update viz api
This commit is contained in:
qazal
2025-04-29 12:57:21 +03:00
committed by GitHub
parent 73c2f6602f
commit cbf7347cd6
4 changed files with 25 additions and 3 deletions

View File

@@ -1,6 +1,6 @@
import unittest, decimal, json
from tinygrad.dtype import dtypes
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, UPat
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, UPat, Ops
from tinygrad.codegen.symbolic import symbolic
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt, _substitute
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
@@ -14,6 +14,10 @@ inner_rewrite = TrackedPatternMatcher([
(UPat.cvar("x"), lambda x: None if x.dtype == dtypes.float32 else UOp.const(dtypes.float32, x.arg)),
])
l2 = TrackedPatternMatcher([(UPat(Ops.CUSTOM, arg=2, name="x"), lambda x: x.replace(arg=3))])
l1 = TrackedPatternMatcher([(UPat(Ops.CUSTOM, arg=1, name="x"), lambda x: graph_rewrite(x.replace(arg=2), l2))])
l0 = TrackedPatternMatcher([(UPat(Ops.CUSTOM, arg=0, name="x"), lambda x: graph_rewrite(x.replace(arg=1), l1))])
class TestViz(unittest.TestCase):
def setUp(self):
# clear the global context
@@ -170,10 +174,24 @@ class TestViz(unittest.TestCase):
self.assertEqual(len(contexts), 1)
tracked = contexts[0]
self.assertEqual(len(tracked), 3)
self.assertEqual(tracked[0].depth, 0)
self.assertEqual(tracked[1].depth, 1)
self.assertEqual(tracked[2].depth, 1)
# NOTE: this is sorted by the time called, maybe it should be by depth
self.assertEqual([x.name for x in tracked], ["outer", "inner_x", "inner_y"])
self.assertEqual([len(x.matches) for x in tracked], [1, 1, 1])
def test_depth_level(self):
@track_rewrites(named=True)
def fxn(u:UOp): return graph_rewrite(u, l0)
ret = fxn(UOp(Ops.CUSTOM, arg=0))
assert ret is UOp(Ops.CUSTOM, arg=3)
self.assertEqual(len(contexts), 1)
tracked = contexts[0]
self.assertEqual(tracked[0].depth, 0)
self.assertEqual(tracked[1].depth, 1)
self.assertEqual(tracked[2].depth, 2)
def test_shape_label(self):
a = UOp.new_buffer("CPU", 1, dtypes.uint8).expand((4,))
b = UOp.new_buffer("CPU", 1, dtypes.uint8).expand((8,))

View File

@@ -879,6 +879,7 @@ class TrackedGraphRewrite:
bottom_up: bool
matches: list[tuple[UOp, UOp, UPat]] # before+after of all the matches
name: str|None
depth: int
tracked_keys:list[Any] = []
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
_name_cnt:dict[str, int] = {}
@@ -900,7 +901,8 @@ def track_matches(func):
def _track_func(*args, **kwargs):
if tracking:=(TRACK_MATCH_STATS >= 2 and tracked_ctxs):
loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno)
tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0], kwargs.get("bottom_up", False), [], kwargs.get("name", None)))
depth = len(active_rewrites)
tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0], kwargs.get("bottom_up", False),[], kwargs.get("name", None), depth))
active_rewrites.append(ctx)
ret = func(*args, **kwargs)
if tracking: active_rewrites.pop()

View File

@@ -297,6 +297,7 @@ async function main() {
const inner = ul.appendChild(document.createElement("ul"));
if (i === currentKernel && j === currentUOp) inner.className = "active";
inner.innerText = `${u.name ?? u.loc[0].replaceAll("\\", "/").split("/").pop()+':'+u.loc[1]} - ${u.match_count}`;
inner.style.marginLeft = `${8*u.depth}px`;
inner.style.display = i === currentKernel && expandKernel ? "block" : "none";
inner.onclick = (e) => {
e.stopPropagation();

View File

@@ -31,13 +31,14 @@ class GraphRewriteMetadata(TypedDict):
code_line: str # source code calling graph_rewrite
kernel_code: str|None # optionally render the final kernel code
name: str|None # optional name of the rewrite
depth: int # depth if it's a subrewrite
@functools.cache
def render_program(k:Kernel): return k.opts.render(k.uops)
def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata:
return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),
"kernel_code":pcall(render_program, k) if isinstance(k, Kernel) else None, "name":v.name}
"kernel_code":pcall(render_program, k) if isinstance(k, Kernel) else None, "name":v.name, "depth":v.depth}
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
return [(k.name if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)]