From d7d32d82eebbc0e4209b66ea9f4894358aef33e2 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 6 May 2026 21:39:34 +0300 Subject: [PATCH] viz/cli: print first uop with DEBUG=6 (#16065) * viz/cli: print first uop with DEBUG=6 * rename fmt to emit * define inst --- tinygrad/viz/cli.py | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/tinygrad/viz/cli.py b/tinygrad/viz/cli.py index a7421eca24..334798431c 100755 --- a/tinygrad/viz/cli.py +++ b/tinygrad/viz/cli.py @@ -64,21 +64,22 @@ def get(data:dict, key:str): def main(args) -> None: viz.load_rewrites(viz_data:=viz.VizData(viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {})))) - def fmt(val, to_str=str) -> str: return json.dumps(val if isinstance(val, dict) else {"value":val}) if args.json else to_str(val) + def emit(val, to_str=str) -> str: return json.dumps(val if isinstance(val, dict) else {"value":val}) if args.json else to_str(val) - def print_step(step:dict) -> None: + def print_step(step:dict, reconstruct_matches=False) -> None: data = viz.get_render(viz_data, step["query"]) if isinstance(data.get("value"), Iterator): for m in data["value"]: - if m.get("uop"): print(fmt(m["uop"])) + if m.get("uop"): print(emit(m["uop"])) + if not reconstruct_matches: return None if m.get("diff"): loc = pathlib.Path(m["upat"][0][0]) - print(fmt(f"{loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}")) - for line in m["diff"]: print(fmt(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))) - if data.get("src") is not None: print(fmt(data["src"])) + print(emit(f"{loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}")) + for line in m["diff"]: print(emit(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))) + if data.get("src") is not None: print(emit(data["src"])) - events:list = viz.load_pickle(args.profile_path, default=[]) - if (profile_bytes:=viz.get_profile(viz_data, events)) is None: raise RuntimeError(f"empty profile in {args.profile_path}") + profile_bytes = viz.get_profile(viz_data, viz.load_pickle(args.profile_path, default=[])) + if profile_bytes is None: raise RuntimeError(f"empty profile in {args.profile_path}") profile = decode_profile(profile_bytes) profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz_data.ctxs if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))]) @@ -103,7 +104,7 @@ def main(args) -> None: op_name, ret, info = e.name.display_name, json.loads(e.name.ret[4:]) if e.name.ret else {}, "" color = next((v for k,v in viz.wave_colors.items() if k in op_name), None) op_str = hex_colored(op_name, color) if color and not NO_COLOR else op_name - phase, delay = None, 0 + inst, phase, delay = None, None, 0 idx = next(pkt_idxs.setdefault(e.device, itertools.count())) if e.device.startswith("WAVE"): inst = f"0x{pc:05x} {pc_map[pc]}" if (pc:=ret.get("pc")) is not None else f"{'':7} {op_name}" @@ -115,7 +116,7 @@ def main(args) -> None: if inst and phase: info = f"{phase:<8} {inst}" unit = e.device.replace(" ", "-") row = {"clk":int(e.st)-inst_st, "cycle":int(e.st), "unit":unit, "op":op_name, "dur":int(unwrap(e.en)-e.st), "delay":delay or "", "info":info} - print(fmt(row, lambda _: f"{row['clk']:<12} {unit:<20} {op_str}{' '*(22-ansilen(op_str))} {row['dur']:<4} {str(row['delay']):<4} {info}")) + print(emit(row, lambda _: f"{row['clk']:<12} {unit:<20} {op_str}{' '*(22-ansilen(op_str))} {row['dur']:<4} {str(row['delay']):<4} {info}")) # ** PMC printer elif "PMC" in args.src: @@ -123,15 +124,15 @@ def main(args) -> None: pmc_fmt:list[str] = [] for name,val,*detail in pmc["rows"]: pmc_fmt += [f"{name} {val}"]+([" ".join(f"{k}={v}" for k,v in zip(detail[0]["cols"], r)) for r in detail[0]["rows"]] if detail else []) - print(fmt(pmc, lambda _: "\n".join(pmc_fmt))) + print(emit(pmc, lambda _: "\n".join(pmc_fmt))) # ** Memory printer elif data is not None and data["event_type"] == 1: - print(fmt({"peak":data["peak"]}, lambda _: f"Peak: {data['peak']}"+"\n"+f"{'TS':<10} {'Event':<6} {'Key':>8} Info")) + print(emit({"peak":data["peak"]}, lambda _: f"Peak: {data['peak']}"+"\n"+f"{'TS':<10} {'Event':<6} {'Key':>8} Info")) for e in data["events"]: info = str(arg:=e.pop("arg", {})) if e["event"] == "free": info = ', '.join([f"{fmt_colored(k)} {['read','write','write+read'][m]}@data{n}" for _,k,n,m in arg["users"]]) - print(fmt({**e, "info":info}, lambda _: f"{e['ts']:<10} {e['event']:<6} {e.get('key', ''):>8} {info}")) + print(emit({**e, "info":info}, lambda _: f"{e['ts']:<10} {e['event']:<6} {e.get('key', ''):>8} {info}")) # ** Profiler printer else: @@ -174,8 +175,8 @@ def main(args) -> None: line = f"{file.split('/')[-1]}:{lineno} {fxn}" if fmt: ext.append(f"{line} {code}") elif not file.startswith("<") and not fxn.startswith("<"): fmt["loc"] = line - yield {"device":dev, "name":fmt_colored(e["name"]), "dur_ms":e["dur"]*1e-3, - "st_ms":e["st"]*1e-3, "fmt":fmt, "ref":e["ref"], "ext":"\n".join(ext)} + yield {"device":dev, "name":fmt_colored(e["name"]), "dur_ms":e["dur"]*1e-3, "st_ms":e["st"]*1e-3, "fmt":fmt, "ref":e["ref"], + "ext":"\n".join(ext)} def fmt_top(k:dict) -> str: return f"{fmt_colored(k['name'])}{' ' * max(0, 38-ansilen(k['name']))} {time_to_str(k['dur_ms']*1e-3, w=9)} {k['count']:7d} {k['pct']:6.2f}%"+\ (" "*4+fmt_data(k['fmt']) if k['fmt'] else "") @@ -187,15 +188,16 @@ def main(args) -> None: fmt_row = fmt_top if args.top else fmt_all seen_refs:set[int] = set() def render_event(k:dict, ls=args.list) -> None: - print(fmt(k, to_str=fmt_row)) + print(emit(k, to_str=fmt_row)) if k["ref"] is not None and k["ref"] not in seen_refs: seen_refs.add(k["ref"]) for s in viz_data.ctxs[k["ref"]]["steps"]: if DEBUG >= 3 and s["name"] == "View Base AST": print_step(s) if DEBUG >= 4 and s["name"] == "View Source": print_step(s) - if DEBUG >= 5 or ls: print(fmt(" "*s["depth"]+s["name"]+(f" - {s['match_count']}" if s.get('match_count', 0) else ''))) - if DEBUG >= 6 or args.item and len(args.item) > 1 and s["name"] == args.item[1]: print_step(s) - elif DEBUG >= 3 and k.get("ext"): print(fmt(k["ext"])) + if DEBUG >= 5 or ls: print(emit(" "*s["depth"]+s["name"]+(f" - {s['match_count']}" if s.get('match_count', 0) else ''))) + if DEBUG >= 6: print_step(s) + if DEBUG >= 7 or (args.item and len(args.item) > 1 and s["name"] == args.item[1]): print_step(s, reconstruct_matches=True) + elif DEBUG >= 3 and k.get("ext"): print(emit(k["ext"])) produce = produce_top_kernels if args.top else produce_all_kernels if args.item: if len(args.item) > 2: raise RuntimeError(f"-i takes at most 2 names (got {args.item})")