From fe260ac4d7aa5ea864d23d8f201aa2b6e111eb35 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 17 Feb 2025 09:59:41 +0200 Subject: [PATCH] viz/server cleanups [pr] (#9127) * viz/server cleanups [pr] * space --- tinygrad/viz/serve.py | 50 +++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index eb2aa4e7a1..291107258c 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -18,24 +18,36 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB" # VIZ API -class GraphRewriteMetadata(TypedDict): - loc: tuple[str, int] # [path, lineno] calling graph_rewrite - match_count: int # total match count in this context - code_line: str # source code calling graph_rewrite - kernel_code: str|None # optionally render the final kernel code - -class GraphRewriteDetails(TypedDict): - graph: dict # JSON serialized UOp for this rewrite step - uop: str # strigified UOp for this rewrite step - diff: list[str]|None # string diff of the single UOp that changed - changed_nodes: list[int]|None # the changed UOp id + all its parents ids - upat: tuple[tuple[str, int], str]|None - # NOTE: if any extra rendering in VIZ fails, we don't crash def pcall(fxn:Callable[..., str], *args, **kwargs) -> str: try: return fxn(*args, **kwargs) except Exception as e: return f"ERROR: {e}" +# ** Metadata for a track_rewrites scope + +class GraphRewriteMetadata(TypedDict): + loc: tuple[str, int] # [path, lineno] calling graph_rewrite + match_count: int # total match count in this context + code_line: str # source code calling graph_rewrite + kernel_code: str|None # optionally render the final kernel code + +@functools.lru_cache(None) +def _prg(k:Kernel): return k.to_program().src +def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata: + return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(), + "kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None} +def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]: + return [(to_function_name(k.name) if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)] + +# ** Complete rewrite details for a graph_rewrite call + +class GraphRewriteDetails(TypedDict): + graph: dict # JSON serialized UOp for this rewrite step + uop: str # strigified UOp for this rewrite step + diff: list[str]|None # string diff of the single UOp that changed + changed_nodes: list[int]|None # the changed UOp id + all its parents ids + upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat + def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]: assert isinstance(x, UOp) # NOTE: this is [id, [label, src_ids, color]] @@ -61,14 +73,6 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]: graph[id(u)] = (label, [id(x) for x in u.src if x not in excluded], uops_colors.get(u.op, "#ffffff")) return graph -@functools.lru_cache(None) -def _prg(k:Kernel): return k.to_program().src -def to_metadata(k:Any,v:TrackedGraphRewrite) -> GraphRewriteMetadata: - return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(), - "kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None} -def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]: - return [(to_function_name(k.name) if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)] - def get_details(k:Any, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]: yield {"graph": (sink_json:=uop_to_json(sink:=ctx.sink)), "uop":str(sink), "changed_nodes":None, "diff":None, "upat":None} replaces: dict[UOp, UOp] = {} @@ -181,10 +185,6 @@ if __name__ == "__main__": # NOTE: this context is a tuple of list[keys] and list[values] kernels = get_metadata(*contexts) if contexts is not None else [] - if getenv("FUZZ_VIZ"): - ret = [get_details(contexts[0][i], contexts[1][i][j]) for i,v in tqdm(enumerate(kernels)) for j,args in enumerate(v[1])] - print(f"fuzzed {len(ret)} rewrite details") - perfetto_profile = to_perfetto(profile) if profile is not None else None server = HTTPServer(('', PORT), Handler)