From dae36150087f23145db0cb62ae09086fba3aeba8 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 16 Sep 2024 16:00:27 +0800 Subject: [PATCH] replace viz graph when it's sink (#6541) --- viz/serve.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/viz/serve.py b/viz/serve.py index 71bd7881a8..052566c137 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass from typing import Dict, List, Tuple import pickle, re, os, sys, time, threading, webbrowser, json, difflib from tinygrad.helpers import getenv -from tinygrad.ops import TrackedRewriteContext, UOp +from tinygrad.ops import TrackedRewriteContext, UOp, UOps from tinygrad.engine.graph import uops_colors, word_wrap from http.server import HTTPServer, BaseHTTPRequestHandler @@ -43,7 +43,10 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet: extra: List[List[str]] = [[str(ctx.sink)]] for (first, rewritten, pattern) in ctx.rewrites: diffs.append((pattern, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines())))) - uops.append(new_sink:=replace_uop(uops[-1], first, rewritten, {})) + # if the sink was replaced, we have to replace the entire graph, otherwise just replace the parent + new_sink = rewritten if first.op is UOps.SINK else replace_uop(uops[-1], first, rewritten, {}) + assert new_sink.op is UOps.SINK + uops.append(new_sink) extra.append([str(new_sink)]) return UOpRet(ctx.loc, list(map(uop_to_json, uops)), diffs, extra)