replace viz graph when it's sink (#6541)

This commit is contained in:
qazal
2024-09-16 16:00:27 +08:00
committed by GitHub
parent 2a5a53c3db
commit dae3615008

View File

@@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass
from typing import Dict, List, Tuple
import pickle, re, os, sys, time, threading, webbrowser, json, difflib
from tinygrad.helpers import getenv
from tinygrad.ops import TrackedRewriteContext, UOp
from tinygrad.ops import TrackedRewriteContext, UOp, UOps
from tinygrad.engine.graph import uops_colors, word_wrap
from http.server import HTTPServer, BaseHTTPRequestHandler
@@ -43,7 +43,10 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
extra: List[List[str]] = [[str(ctx.sink)]]
for (first, rewritten, pattern) in ctx.rewrites:
diffs.append((pattern, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
uops.append(new_sink:=replace_uop(uops[-1], first, rewritten, {}))
# if the sink was replaced, we have to replace the entire graph, otherwise just replace the parent
new_sink = rewritten if first.op is UOps.SINK else replace_uop(uops[-1], first, rewritten, {})
assert new_sink.op is UOps.SINK
uops.append(new_sink)
extra.append([str(new_sink)])
return UOpRet(ctx.loc, list(map(uop_to_json, uops)), diffs, extra)