From 5a30a32af88521f8eb99dc5a980cd24256300371 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:37:53 +0800 Subject: [PATCH] small viz fixups from the swizzle pads branch [run_process_replay] (#6557) * small viz fixups from the swizzle pads branch [run_process_replay] * handle indexed ones --- tinygrad/ops.py | 3 ++- viz/serve.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 122399a2f4..c76165a1a2 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -632,7 +632,8 @@ def get_location() -> Tuple[str, int]: while (frm.f_code.co_filename.split('/')[-1] in {"ops.py", ''}) and frm.f_back is not None: frm = frm.f_back return frm.f_code.co_filename, frm.f_lineno @functools.lru_cache(None) -def lines(fn) -> List[str]: return open(fn).readlines() +def lines(fn) -> List[str]: + with open(fn) as f: return f.readlines() class UPat(MathTrait): __slots__ = ["op", "dtype", "arg", "name", "src"] diff --git a/viz/serve.py b/viz/serve.py index 052566c137..6b8deba387 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 from dataclasses import asdict, dataclass from typing import Dict, List, Tuple -import pickle, re, os, sys, time, threading, webbrowser, json, difflib +import pickle, re, os, sys, time, threading, webbrowser, json, difflib, contextlib from tinygrad.helpers import getenv from tinygrad.ops import TrackedRewriteContext, UOp, UOps from tinygrad.engine.graph import uops_colors, word_wrap @@ -21,6 +21,9 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: graph: Dict[int, Tuple[str, str, List[int], str, str]] = {} for u in x.sparents: label = f"{str(u.op)[5:]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}" + if getenv("WITH_SHAPE"): + with contextlib.suppress(Exception): # if the UOp is indexed already it's fine + if u.st is not None: label += f"\n{u.st.shape}" graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src], str(u.arg), uops_colors.get(u.op, "#ffffff")) return graph