more generic naming in VIZ [pr] (#10695)

* note

* rename kernel to ctx

* rename uop things to currentStep + expandSteps

* already destructured

* some things that were called ctx are steps

* still a kernel
This commit is contained in:
qazal
2025-06-08 15:37:39 +03:00
committed by GitHub
parent c70486908e
commit 1ad8062591
4 changed files with 61 additions and 56 deletions

View File

@@ -82,10 +82,10 @@
position: relative;
height: 100%;
}
.metadata > * + *, .rewrite-container > * + *, .kernel-list > * + * {
.metadata > * + *, .rewrite-container > * + *, .ctx-list > * + * {
margin-top: 12px;
}
.kernel-list > ul > * + * {
.ctx-list > ul > * + * {
margin-top: 4px;
}
.graph {
@@ -93,12 +93,12 @@
inset: 0;
z-index: 1;
}
.kernel-list-parent {
.ctx-list-parent {
width: 15%;
padding-top: 50px;
border-right: 1px solid #4a4b56;
}
.kernel-list {
.ctx-list {
width: 100%;
height: 100%;
overflow-y: auto;
@@ -193,7 +193,7 @@
</svg>
</button>
</div>
<div class="container kernel-list-parent"><div class="kernel-list"></div></div>
<div class="container ctx-list-parent"><div class="ctx-list"></div></div>
<div class="graph">
<div class="progress-message">Rendering new layout...</div>
<svg id="graph-svg" preserveAspectRatio="xMidYMid meet">

View File

@@ -28,7 +28,7 @@ async function renderDag(graph, additions, recenter=false) {
if (timeout != null) clearTimeout(timeout);
const progressMessage = document.querySelector(".progress-message");
timeout = setTimeout(() => {progressMessage.style.display = "block"}, 2000);
worker.postMessage({graph, additions, kernels});
worker.postMessage({graph, additions, ctxs});
worker.onmessage = (e) => {
progressMessage.style.display = "none";
clearTimeout(timeout);
@@ -38,7 +38,7 @@ async function renderDag(graph, additions, recenter=false) {
const STROKE_WIDTH = 1.4;
const nodes = d3.select("#nodes").selectAll("g").data(g.nodes().map(id => g.node(id)), d => d).join("g")
.attr("transform", d => `translate(${d.x},${d.y})`).classed("clickable", d => d.ref != null)
.on("click", (_,d) => d.ref != null && setState({ expandKernel: true, currentKernel:d.ref, currentUOp:0, currentRewrite:0 }));
.on("click", (_,d) => d.ref != null && setState({ expandSteps: true, currentCtx:d.ref, currentStep:0, currentRewrite:0 }));
nodes.selectAll("rect").data(d => [d]).join("rect").attr("width", d => d.width).attr("height", d => d.height).attr("fill", d => d.color)
.attr("x", d => -d.width/2).attr("y", d => -d.height/2).attr("style", d => d.style ?? `stroke:#4a4b57; stroke-width:${STROKE_WIDTH}px;`);
nodes.selectAll("g.label").data(d => [d]).join("g").attr("class", "label").attr("transform", d => {
@@ -218,7 +218,7 @@ document.getElementById("zoom-to-fit-btn").addEventListener("click", () => {
const svg = d3.select("#graph-svg");
svg.call(zoom.transform, d3.zoomIdentity);
const mainRect = rect(".main-container");
const x0 = rect(".kernel-list-parent").right;
const x0 = rect(".ctx-list-parent").right;
const x1 = rect(".metadata-parent").left;
const pad = 16;
const R = { x: x0+pad, y: mainRect.top+pad, width: (x1>0 ? x1-x0 : mainRect.width)-2*pad, height: mainRect.height-2*pad };
@@ -265,25 +265,29 @@ hljs.registerLanguage("cpp", (hljs) => ({
var ret = [];
var cache = {};
var kernels = null;
var ctxs = null;
const evtSources = [];
const state = {currentKernel:-1, currentUOp:0, currentRewrite:0, expandKernel:false};
// VIZ displays graph rewrites in 3 levels, from bottom-up:
// rewrite: a single UOp transformation
// step: collection of rewrites
// context: collection of steps
const state = {currentCtx:-1, currentStep:0, currentRewrite:0, expandSteps:false};
function setState(ns) {
Object.assign(state, ns);
main();
}
async function main() {
const { currentKernel, currentUOp, currentRewrite, expandKernel } = state;
// ** left sidebar kernel list
if (kernels == null) {
kernels = await (await fetch("/kernels")).json();
setState({ currentKernel:-1 });
const { currentCtx, currentStep, currentRewrite, expandSteps } = state;
// ** left sidebar context list
if (ctxs == null) {
ctxs = await (await fetch("/ctxs")).json();
setState({ currentCtx:-1 });
}
const kernelList = document.querySelector(".kernel-list");
kernelList.innerHTML = "";
for (const [i,{name, steps}] of kernels.entries()) {
const ul = kernelList.appendChild(document.createElement("ul"));
if (i === currentKernel) {
const ctxList = document.querySelector(".ctx-list");
ctxList.innerHTML = "";
for (const [i,{name, steps}] of ctxs.entries()) {
const ul = ctxList.appendChild(document.createElement("ul"));
if (i === currentCtx) {
ul.className = "active";
requestAnimationFrame(() => ul.scrollIntoView({ behavior: "auto", block: "nearest" }));
}
@@ -293,27 +297,28 @@ async function main() {
return `<span style="${`color: color-mix(in srgb, ${colors[(parseInt(code)-30+60)%60]} 60%, white)`}">${st}</span>`;
});
p.onclick = () => {
setState(i === currentKernel ? { expandKernel:!expandKernel } : { expandKernel:true, currentKernel:i, currentUOp:0, currentRewrite:0 });
setState(i === currentCtx ? { expandSteps:!expandSteps } : { expandSteps:true, currentCtx:i, currentStep:0, currentRewrite:0 });
}
for (const [j,u] of steps.entries()) {
const inner = ul.appendChild(document.createElement("ul"));
if (i === currentKernel && j === currentUOp) {
if (i === currentCtx && j === currentStep) {
inner.className = "active";
requestAnimationFrame(() => inner.scrollIntoView({ behavior: "auto", block: "nearest" }));
}
inner.innerText = `${u.name ?? u.loc[0].replaceAll("\\", "/").split("/").pop()+':'+u.loc[1]} - ${u.match_count}`;
inner.style.marginLeft = `${8*u.depth}px`;
inner.style.display = i === currentKernel && expandKernel ? "block" : "none";
inner.style.display = i === currentCtx && expandSteps ? "block" : "none";
inner.onclick = (e) => {
e.stopPropagation();
setState({ currentUOp:j, currentKernel:i, currentRewrite:0 });
setState({ currentStep:j, currentCtx:i, currentRewrite:0 });
}
}
}
// ** center graph
if (currentKernel == -1) return;
const kernel = kernels[currentKernel].steps[currentUOp];
const cacheKey = `kernel=${currentKernel}&idx=${currentUOp}`;
if (currentCtx == -1) return;
const ctx = ctxs[currentCtx];
const step = ctx.steps[currentStep];
const cacheKey = `ctx=${currentCtx}&idx=${currentStep}`;
// close any pending event sources
let activeSrc = null;
for (const e of evtSources) {
@@ -323,11 +328,11 @@ async function main() {
if (cacheKey in cache) {
ret = cache[cacheKey];
}
// if we don't have a complete cache yet we start streaming this kernel
if (!(cacheKey in cache) || (cache[cacheKey].length !== kernel.match_count+1 && activeSrc == null)) {
// if we don't have a complete cache yet we start streaming rewrites in this step
if (!(cacheKey in cache) || (cache[cacheKey].length !== step.match_count+1 && activeSrc == null)) {
ret = [];
cache[cacheKey] = ret;
const eventSource = new EventSource(`/kernels?kernel=${currentKernel}&idx=${currentUOp}`);
const eventSource = new EventSource(`/ctxs?ctx=${currentCtx}&idx=${currentStep}`);
evtSources.push(eventSource);
eventSource.onmessage = (e) => {
if (e.data === "END") return eventSource.close();
@@ -341,20 +346,20 @@ async function main() {
};
}
if (ret.length === 0) return;
if (kernel.name == "View Memory Graph") {
if (ctx.name == "View Memory Graph") {
renderMemoryGraph(ret[currentRewrite].graph);
} else {
renderDag(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || [], recenter=currentRewrite === 0);
}
// ** right sidebar code blocks
const metadata = document.querySelector(".metadata");
const [code, lang] = kernels[currentKernel].kernel_code != null ? [kernels[currentKernel].kernel_code, "cpp"] : [ret[currentRewrite].uop, "python"];
metadata.replaceChildren(codeBlock(kernel.code_line, "python", { loc:kernel.loc, wrap:true }), codeBlock(code, lang, { wrap:false }));
const [code, lang] = ctx.kernel_code != null ? [ctx.kernel_code, "cpp"] : [ret[currentRewrite].uop, "python"];
metadata.replaceChildren(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }), codeBlock(code, lang, { wrap:false }));
// ** rewrite steps
if (kernel.match_count >= 1) {
if (step.match_count >= 1) {
const rewriteList = metadata.appendChild(document.createElement("div"));
rewriteList.className = "rewrite-list";
for (let s=0; s<=kernel.match_count; s++) {
for (let s=0; s<=step.match_count; s++) {
const ul = rewriteList.appendChild(document.createElement("ul"));
ul.innerText = s;
ul.id = `rewrite-${s}`;
@@ -408,36 +413,36 @@ function appendResizer(element, { minWidth, maxWidth }, left=false) {
}, { once: true });
});
}
appendResizer(document.querySelector(".kernel-list-parent"), { minWidth: 15, maxWidth: 50 }, left=true);
appendResizer(document.querySelector(".ctx-list-parent"), { minWidth: 15, maxWidth: 50 }, left=true);
appendResizer(document.querySelector(".metadata-parent"), { minWidth: 20, maxWidth: 50 });
// **** keyboard shortcuts
document.addEventListener("keydown", async function(event) {
const { currentKernel, currentUOp, currentRewrite, expandKernel } = state;
// up and down change the UOp or kernel from the list
const { currentCtx, currentStep, currentRewrite, expandSteps } = state;
// up and down change the step or context from the list
if (event.key == "ArrowUp") {
event.preventDefault();
if (expandKernel) {
return setState({ currentRewrite:0, currentUOp:Math.max(0, currentUOp-1) });
if (expandSteps) {
return setState({ currentRewrite:0, currentStep:Math.max(0, currentStep-1) });
}
return setState({ currentUOp:0, currentRewrite:0, currentKernel:Math.max(0, currentKernel-1) });
return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.max(0, currentCtx-1) });
}
if (event.key == "ArrowDown") {
event.preventDefault();
if (expandKernel) {
const totalUOps = kernels[currentKernel].steps.length-1;
return setState({ currentRewrite:0, currentUOp:Math.min(totalUOps, currentUOp+1) });
if (expandSteps) {
const totalUOps = ctxs[currentCtx].steps.length-1;
return setState({ currentRewrite:0, currentStep:Math.min(totalUOps, currentStep+1) });
}
return setState({ currentUOp:0, currentRewrite:0, currentKernel:Math.min(kernels.length-1, currentKernel+1) });
return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.min(ctxs.length-1, currentCtx+1) });
}
// enter toggles focus on a single rewrite stage
if (event.key == "Enter") {
event.preventDefault()
if (state.currentKernel === -1) {
return setState({ currentKernel:0, expandKernel:true });
if (currentCtx === -1) {
return setState({ currentCtx:0, expandSteps:true });
}
return setState({ currentUOp:0, currentRewrite:0, expandKernel:!expandKernel });
return setState({ currentStep:0, currentRewrite:0, expandSteps:!expandSteps });
}
// left and right go through rewrites in a single UOp
if (event.key == "ArrowLeft") {

View File

@@ -7,13 +7,13 @@ ctx.font = `${LINE_HEIGHT}px sans-serif`;
const ansiStrip = (st, tag) => st.replace(/\u001b\[(\d+)m(.*?)\u001b\[0m/g, (_,__,st) => st);
onmessage = (e) => {
const { graph, additions, kernels } = e.data;
const { graph, additions, ctxs } = e.data;
const g = new dagre.graphlib.Graph({ compound: true });
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
if (additions.length !== 0) g.setNode("addition", {label:"", style:"fill: rgba(26, 27, 38, 0.5);", padding:0});
for (let [k, {label, src, ref, ...rest }] of Object.entries(graph)) {
const idx = ref ? kernels.findIndex(k => k.ref === ref) : -1;
if (idx != -1) label += `\ncodegen@${ansiStrip(kernels[idx].name)}`;
const idx = ref ? ctxs.findIndex(k => k.ref === ref) : -1;
if (idx != -1) label += `\ncodegen@${ansiStrip(ctxs[idx].name)}`;
// adjust node dims by label size + add padding
let [width, height] = [0, 0];
for (line of label.split("\n")) {

View File

@@ -133,9 +133,9 @@ class Handler(BaseHTTPRequestHandler):
if url.path.endswith(".js"): content_type = "application/javascript"
if url.path.endswith(".css"): content_type = "text/css"
except FileNotFoundError: status_code = 404
elif url.path == "/kernels":
if "kernel" in (query:=parse_qs(url.query)):
kidx, ridx = int(query["kernel"][0]), int(query["idx"][0])
elif url.path == "/ctxs":
if "ctx" in (query:=parse_qs(url.query)):
kidx, ridx = int(query["ctx"][0]), int(query["idx"][0])
try:
# stream details
self.send_response(200)
@@ -149,7 +149,7 @@ class Handler(BaseHTTPRequestHandler):
return self.wfile.flush()
# pass if client closed connection
except (BrokenPipeError, ConnectionResetError): return
ret, content_type = json.dumps(kernels).encode(), "application/json"
ret, content_type = json.dumps(ctxs).encode(), "application/json"
elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"
else: status_code = 404
@@ -194,7 +194,7 @@ if __name__ == "__main__":
contexts, profile = load_pickle(args.kernels), load_pickle(args.profile)
# NOTE: this context is a tuple of list[keys] and list[values]
kernels = get_metadata(*contexts) if contexts is not None else []
ctxs = get_metadata(*contexts) if contexts is not None else []
perfetto_profile = to_perfetto(profile) if profile is not None else None