mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
@@ -18,24 +18,36 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB"
|
||||
|
||||
# VIZ API
|
||||
|
||||
class GraphRewriteMetadata(TypedDict):
|
||||
loc: tuple[str, int] # [path, lineno] calling graph_rewrite
|
||||
match_count: int # total match count in this context
|
||||
code_line: str # source code calling graph_rewrite
|
||||
kernel_code: str|None # optionally render the final kernel code
|
||||
|
||||
class GraphRewriteDetails(TypedDict):
|
||||
graph: dict # JSON serialized UOp for this rewrite step
|
||||
uop: str # strigified UOp for this rewrite step
|
||||
diff: list[str]|None # string diff of the single UOp that changed
|
||||
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
|
||||
upat: tuple[tuple[str, int], str]|None
|
||||
|
||||
# NOTE: if any extra rendering in VIZ fails, we don't crash
|
||||
def pcall(fxn:Callable[..., str], *args, **kwargs) -> str:
|
||||
try: return fxn(*args, **kwargs)
|
||||
except Exception as e: return f"ERROR: {e}"
|
||||
|
||||
# ** Metadata for a track_rewrites scope
|
||||
|
||||
class GraphRewriteMetadata(TypedDict):
|
||||
loc: tuple[str, int] # [path, lineno] calling graph_rewrite
|
||||
match_count: int # total match count in this context
|
||||
code_line: str # source code calling graph_rewrite
|
||||
kernel_code: str|None # optionally render the final kernel code
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _prg(k:Kernel): return k.to_program().src
|
||||
def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata:
|
||||
return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),
|
||||
"kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None}
|
||||
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
|
||||
return [(to_function_name(k.name) if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)]
|
||||
|
||||
# ** Complete rewrite details for a graph_rewrite call
|
||||
|
||||
class GraphRewriteDetails(TypedDict):
|
||||
graph: dict # JSON serialized UOp for this rewrite step
|
||||
uop: str # strigified UOp for this rewrite step
|
||||
diff: list[str]|None # string diff of the single UOp that changed
|
||||
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
|
||||
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
|
||||
|
||||
def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]:
|
||||
assert isinstance(x, UOp)
|
||||
# NOTE: this is [id, [label, src_ids, color]]
|
||||
@@ -61,14 +73,6 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]:
|
||||
graph[id(u)] = (label, [id(x) for x in u.src if x not in excluded], uops_colors.get(u.op, "#ffffff"))
|
||||
return graph
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _prg(k:Kernel): return k.to_program().src
|
||||
def to_metadata(k:Any,v:TrackedGraphRewrite) -> GraphRewriteMetadata:
|
||||
return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),
|
||||
"kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None}
|
||||
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
|
||||
return [(to_function_name(k.name) if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)]
|
||||
|
||||
def get_details(k:Any, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
|
||||
yield {"graph": (sink_json:=uop_to_json(sink:=ctx.sink)), "uop":str(sink), "changed_nodes":None, "diff":None, "upat":None}
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
@@ -181,10 +185,6 @@ if __name__ == "__main__":
|
||||
# NOTE: this context is a tuple of list[keys] and list[values]
|
||||
kernels = get_metadata(*contexts) if contexts is not None else []
|
||||
|
||||
if getenv("FUZZ_VIZ"):
|
||||
ret = [get_details(contexts[0][i], contexts[1][i][j]) for i,v in tqdm(enumerate(kernels)) for j,args in enumerate(v[1])]
|
||||
print(f"fuzzed {len(ret)} rewrite details")
|
||||
|
||||
perfetto_profile = to_perfetto(profile) if profile is not None else None
|
||||
|
||||
server = HTTPServer(('', PORT), Handler)
|
||||
|
||||
Reference in New Issue
Block a user