mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
replace viz graph when it's sink (#6541)
This commit is contained in:
@@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
import pickle, re, os, sys, time, threading, webbrowser, json, difflib
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps
|
||||
from tinygrad.engine.graph import uops_colors, word_wrap
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
|
||||
@@ -43,7 +43,10 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
|
||||
extra: List[List[str]] = [[str(ctx.sink)]]
|
||||
for (first, rewritten, pattern) in ctx.rewrites:
|
||||
diffs.append((pattern, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
|
||||
uops.append(new_sink:=replace_uop(uops[-1], first, rewritten, {}))
|
||||
# if the sink was replaced, we have to replace the entire graph, otherwise just replace the parent
|
||||
new_sink = rewritten if first.op is UOps.SINK else replace_uop(uops[-1], first, rewritten, {})
|
||||
assert new_sink.op is UOps.SINK
|
||||
uops.append(new_sink)
|
||||
extra.append([str(new_sink)])
|
||||
return UOpRet(ctx.loc, list(map(uop_to_json, uops)), diffs, extra)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user