diff --git a/viz/serve.py b/viz/serve.py index 71bd7881a8..052566c137 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass from typing import Dict, List, Tuple import pickle, re, os, sys, time, threading, webbrowser, json, difflib from tinygrad.helpers import getenv -from tinygrad.ops import TrackedRewriteContext, UOp +from tinygrad.ops import TrackedRewriteContext, UOp, UOps from tinygrad.engine.graph import uops_colors, word_wrap from http.server import HTTPServer, BaseHTTPRequestHandler @@ -43,7 +43,10 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet: extra: List[List[str]] = [[str(ctx.sink)]] for (first, rewritten, pattern) in ctx.rewrites: diffs.append((pattern, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines())))) - uops.append(new_sink:=replace_uop(uops[-1], first, rewritten, {})) + # if the sink was replaced, we have to replace the entire graph, otherwise just replace the parent + new_sink = rewritten if first.op is UOps.SINK else replace_uop(uops[-1], first, rewritten, {}) + assert new_sink.op is UOps.SINK + uops.append(new_sink) extra.append([str(new_sink)]) return UOpRet(ctx.loc, list(map(uop_to_json, uops)), diffs, extra)