viz/server cleanups [pr] (#9127)

* viz/server cleanups [pr]

* space
This commit is contained in:
qazal
2025-02-17 09:59:41 +02:00
committed by GitHub
parent a38b47e026
commit fe260ac4d7

View File

@@ -18,24 +18,36 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB"
# VIZ API
class GraphRewriteMetadata(TypedDict):
loc: tuple[str, int] # [path, lineno] calling graph_rewrite
match_count: int # total match count in this context
code_line: str # source code calling graph_rewrite
kernel_code: str|None # optionally render the final kernel code
class GraphRewriteDetails(TypedDict):
graph: dict # JSON serialized UOp for this rewrite step
uop: str # strigified UOp for this rewrite step
diff: list[str]|None # string diff of the single UOp that changed
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
upat: tuple[tuple[str, int], str]|None
# NOTE: if any extra rendering in VIZ fails, we don't crash
def pcall(fxn:Callable[..., str], *args, **kwargs) -> str:
try: return fxn(*args, **kwargs)
except Exception as e: return f"ERROR: {e}"
# ** Metadata for a track_rewrites scope
class GraphRewriteMetadata(TypedDict):
loc: tuple[str, int] # [path, lineno] calling graph_rewrite
match_count: int # total match count in this context
code_line: str # source code calling graph_rewrite
kernel_code: str|None # optionally render the final kernel code
@functools.lru_cache(None)
def _prg(k:Kernel): return k.to_program().src
def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata:
return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),
"kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None}
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
return [(to_function_name(k.name) if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)]
# ** Complete rewrite details for a graph_rewrite call
class GraphRewriteDetails(TypedDict):
graph: dict # JSON serialized UOp for this rewrite step
uop: str # strigified UOp for this rewrite step
diff: list[str]|None # string diff of the single UOp that changed
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]:
assert isinstance(x, UOp)
# NOTE: this is [id, [label, src_ids, color]]
@@ -61,14 +73,6 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]:
graph[id(u)] = (label, [id(x) for x in u.src if x not in excluded], uops_colors.get(u.op, "#ffffff"))
return graph
@functools.lru_cache(None)
def _prg(k:Kernel): return k.to_program().src
def to_metadata(k:Any,v:TrackedGraphRewrite) -> GraphRewriteMetadata:
return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),
"kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None}
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
return [(to_function_name(k.name) if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)]
def get_details(k:Any, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
yield {"graph": (sink_json:=uop_to_json(sink:=ctx.sink)), "uop":str(sink), "changed_nodes":None, "diff":None, "upat":None}
replaces: dict[UOp, UOp] = {}
@@ -181,10 +185,6 @@ if __name__ == "__main__":
# NOTE: this context is a tuple of list[keys] and list[values]
kernels = get_metadata(*contexts) if contexts is not None else []
if getenv("FUZZ_VIZ"):
ret = [get_details(contexts[0][i], contexts[1][i][j]) for i,v in tqdm(enumerate(kernels)) for j,args in enumerate(v[1])]
print(f"fuzzed {len(ret)} rewrite details")
perfetto_profile = to_perfetto(profile) if profile is not None else None
server = HTTPServer(('', PORT), Handler)