mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
display viz rewrites with tabbing if they are subrewrites (#10097)
* display viz rewrites with tabbing if they are subrewrites * update viz api
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import unittest, decimal, json
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, UPat
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, UPat, Ops
|
||||
from tinygrad.codegen.symbolic import symbolic
|
||||
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt, _substitute
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
|
||||
@@ -14,6 +14,10 @@ inner_rewrite = TrackedPatternMatcher([
|
||||
(UPat.cvar("x"), lambda x: None if x.dtype == dtypes.float32 else UOp.const(dtypes.float32, x.arg)),
|
||||
])
|
||||
|
||||
l2 = TrackedPatternMatcher([(UPat(Ops.CUSTOM, arg=2, name="x"), lambda x: x.replace(arg=3))])
|
||||
l1 = TrackedPatternMatcher([(UPat(Ops.CUSTOM, arg=1, name="x"), lambda x: graph_rewrite(x.replace(arg=2), l2))])
|
||||
l0 = TrackedPatternMatcher([(UPat(Ops.CUSTOM, arg=0, name="x"), lambda x: graph_rewrite(x.replace(arg=1), l1))])
|
||||
|
||||
class TestViz(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clear the global context
|
||||
@@ -170,10 +174,24 @@ class TestViz(unittest.TestCase):
|
||||
self.assertEqual(len(contexts), 1)
|
||||
tracked = contexts[0]
|
||||
self.assertEqual(len(tracked), 3)
|
||||
self.assertEqual(tracked[0].depth, 0)
|
||||
self.assertEqual(tracked[1].depth, 1)
|
||||
self.assertEqual(tracked[2].depth, 1)
|
||||
# NOTE: this is sorted by the time called, maybe it should be by depth
|
||||
self.assertEqual([x.name for x in tracked], ["outer", "inner_x", "inner_y"])
|
||||
self.assertEqual([len(x.matches) for x in tracked], [1, 1, 1])
|
||||
|
||||
def test_depth_level(self):
|
||||
@track_rewrites(named=True)
|
||||
def fxn(u:UOp): return graph_rewrite(u, l0)
|
||||
ret = fxn(UOp(Ops.CUSTOM, arg=0))
|
||||
assert ret is UOp(Ops.CUSTOM, arg=3)
|
||||
self.assertEqual(len(contexts), 1)
|
||||
tracked = contexts[0]
|
||||
self.assertEqual(tracked[0].depth, 0)
|
||||
self.assertEqual(tracked[1].depth, 1)
|
||||
self.assertEqual(tracked[2].depth, 2)
|
||||
|
||||
def test_shape_label(self):
|
||||
a = UOp.new_buffer("CPU", 1, dtypes.uint8).expand((4,))
|
||||
b = UOp.new_buffer("CPU", 1, dtypes.uint8).expand((8,))
|
||||
|
||||
@@ -879,6 +879,7 @@ class TrackedGraphRewrite:
|
||||
bottom_up: bool
|
||||
matches: list[tuple[UOp, UOp, UPat]] # before+after of all the matches
|
||||
name: str|None
|
||||
depth: int
|
||||
tracked_keys:list[Any] = []
|
||||
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
|
||||
_name_cnt:dict[str, int] = {}
|
||||
@@ -900,7 +901,8 @@ def track_matches(func):
|
||||
def _track_func(*args, **kwargs):
|
||||
if tracking:=(TRACK_MATCH_STATS >= 2 and tracked_ctxs):
|
||||
loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno)
|
||||
tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0], kwargs.get("bottom_up", False), [], kwargs.get("name", None)))
|
||||
depth = len(active_rewrites)
|
||||
tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0], kwargs.get("bottom_up", False),[], kwargs.get("name", None), depth))
|
||||
active_rewrites.append(ctx)
|
||||
ret = func(*args, **kwargs)
|
||||
if tracking: active_rewrites.pop()
|
||||
|
||||
@@ -297,6 +297,7 @@ async function main() {
|
||||
const inner = ul.appendChild(document.createElement("ul"));
|
||||
if (i === currentKernel && j === currentUOp) inner.className = "active";
|
||||
inner.innerText = `${u.name ?? u.loc[0].replaceAll("\\", "/").split("/").pop()+':'+u.loc[1]} - ${u.match_count}`;
|
||||
inner.style.marginLeft = `${8*u.depth}px`;
|
||||
inner.style.display = i === currentKernel && expandKernel ? "block" : "none";
|
||||
inner.onclick = (e) => {
|
||||
e.stopPropagation();
|
||||
|
||||
@@ -31,13 +31,14 @@ class GraphRewriteMetadata(TypedDict):
|
||||
code_line: str # source code calling graph_rewrite
|
||||
kernel_code: str|None # optionally render the final kernel code
|
||||
name: str|None # optional name of the rewrite
|
||||
depth: int # depth if it's a subrewrite
|
||||
|
||||
@functools.cache
|
||||
def render_program(k:Kernel): return k.opts.render(k.uops)
|
||||
|
||||
def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata:
|
||||
return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),
|
||||
"kernel_code":pcall(render_program, k) if isinstance(k, Kernel) else None, "name":v.name}
|
||||
"kernel_code":pcall(render_program, k) if isinstance(k, Kernel) else None, "name":v.name, "depth":v.depth}
|
||||
|
||||
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
|
||||
return [(k.name if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)]
|
||||
|
||||
Reference in New Issue
Block a user