diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index d7ff8f9114..a44e113e8e 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -82,10 +82,10 @@ position: relative; height: 100%; } - .metadata > * + *, .rewrite-container > * + *, .kernel-list > * + * { + .metadata > * + *, .rewrite-container > * + *, .ctx-list > * + * { margin-top: 12px; } - .kernel-list > ul > * + * { + .ctx-list > ul > * + * { margin-top: 4px; } .graph { @@ -93,12 +93,12 @@ inset: 0; z-index: 1; } - .kernel-list-parent { + .ctx-list-parent { width: 15%; padding-top: 50px; border-right: 1px solid #4a4b56; } - .kernel-list { + .ctx-list { width: 100%; height: 100%; overflow-y: auto; @@ -193,7 +193,7 @@ -
+
Rendering new layout...
diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index d54c2b69d1..a2042e7529 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -28,7 +28,7 @@ async function renderDag(graph, additions, recenter=false) { if (timeout != null) clearTimeout(timeout); const progressMessage = document.querySelector(".progress-message"); timeout = setTimeout(() => {progressMessage.style.display = "block"}, 2000); - worker.postMessage({graph, additions, kernels}); + worker.postMessage({graph, additions, ctxs}); worker.onmessage = (e) => { progressMessage.style.display = "none"; clearTimeout(timeout); @@ -38,7 +38,7 @@ async function renderDag(graph, additions, recenter=false) { const STROKE_WIDTH = 1.4; const nodes = d3.select("#nodes").selectAll("g").data(g.nodes().map(id => g.node(id)), d => d).join("g") .attr("transform", d => `translate(${d.x},${d.y})`).classed("clickable", d => d.ref != null) - .on("click", (_,d) => d.ref != null && setState({ expandKernel: true, currentKernel:d.ref, currentUOp:0, currentRewrite:0 })); + .on("click", (_,d) => d.ref != null && setState({ expandSteps: true, currentCtx:d.ref, currentStep:0, currentRewrite:0 })); nodes.selectAll("rect").data(d => [d]).join("rect").attr("width", d => d.width).attr("height", d => d.height).attr("fill", d => d.color) .attr("x", d => -d.width/2).attr("y", d => -d.height/2).attr("style", d => d.style ?? `stroke:#4a4b57; stroke-width:${STROKE_WIDTH}px;`); nodes.selectAll("g.label").data(d => [d]).join("g").attr("class", "label").attr("transform", d => { @@ -218,7 +218,7 @@ document.getElementById("zoom-to-fit-btn").addEventListener("click", () => { const svg = d3.select("#graph-svg"); svg.call(zoom.transform, d3.zoomIdentity); const mainRect = rect(".main-container"); - const x0 = rect(".kernel-list-parent").right; + const x0 = rect(".ctx-list-parent").right; const x1 = rect(".metadata-parent").left; const pad = 16; const R = { x: x0+pad, y: mainRect.top+pad, width: (x1>0 ? x1-x0 : mainRect.width)-2*pad, height: mainRect.height-2*pad }; @@ -265,25 +265,29 @@ hljs.registerLanguage("cpp", (hljs) => ({ var ret = []; var cache = {}; -var kernels = null; +var ctxs = null; const evtSources = []; -const state = {currentKernel:-1, currentUOp:0, currentRewrite:0, expandKernel:false}; +// VIZ displays graph rewrites in 3 levels, from bottom-up: +// rewrite: a single UOp transformation +// step: collection of rewrites +// context: collection of steps +const state = {currentCtx:-1, currentStep:0, currentRewrite:0, expandSteps:false}; function setState(ns) { Object.assign(state, ns); main(); } async function main() { - const { currentKernel, currentUOp, currentRewrite, expandKernel } = state; - // ** left sidebar kernel list - if (kernels == null) { - kernels = await (await fetch("/kernels")).json(); - setState({ currentKernel:-1 }); + const { currentCtx, currentStep, currentRewrite, expandSteps } = state; + // ** left sidebar context list + if (ctxs == null) { + ctxs = await (await fetch("/ctxs")).json(); + setState({ currentCtx:-1 }); } - const kernelList = document.querySelector(".kernel-list"); - kernelList.innerHTML = ""; - for (const [i,{name, steps}] of kernels.entries()) { - const ul = kernelList.appendChild(document.createElement("ul")); - if (i === currentKernel) { + const ctxList = document.querySelector(".ctx-list"); + ctxList.innerHTML = ""; + for (const [i,{name, steps}] of ctxs.entries()) { + const ul = ctxList.appendChild(document.createElement("ul")); + if (i === currentCtx) { ul.className = "active"; requestAnimationFrame(() => ul.scrollIntoView({ behavior: "auto", block: "nearest" })); } @@ -293,27 +297,28 @@ async function main() { return `${st}`; }); p.onclick = () => { - setState(i === currentKernel ? { expandKernel:!expandKernel } : { expandKernel:true, currentKernel:i, currentUOp:0, currentRewrite:0 }); + setState(i === currentCtx ? { expandSteps:!expandSteps } : { expandSteps:true, currentCtx:i, currentStep:0, currentRewrite:0 }); } for (const [j,u] of steps.entries()) { const inner = ul.appendChild(document.createElement("ul")); - if (i === currentKernel && j === currentUOp) { + if (i === currentCtx && j === currentStep) { inner.className = "active"; requestAnimationFrame(() => inner.scrollIntoView({ behavior: "auto", block: "nearest" })); } inner.innerText = `${u.name ?? u.loc[0].replaceAll("\\", "/").split("/").pop()+':'+u.loc[1]} - ${u.match_count}`; inner.style.marginLeft = `${8*u.depth}px`; - inner.style.display = i === currentKernel && expandKernel ? "block" : "none"; + inner.style.display = i === currentCtx && expandSteps ? "block" : "none"; inner.onclick = (e) => { e.stopPropagation(); - setState({ currentUOp:j, currentKernel:i, currentRewrite:0 }); + setState({ currentStep:j, currentCtx:i, currentRewrite:0 }); } } } // ** center graph - if (currentKernel == -1) return; - const kernel = kernels[currentKernel].steps[currentUOp]; - const cacheKey = `kernel=${currentKernel}&idx=${currentUOp}`; + if (currentCtx == -1) return; + const ctx = ctxs[currentCtx]; + const step = ctx.steps[currentStep]; + const cacheKey = `ctx=${currentCtx}&idx=${currentStep}`; // close any pending event sources let activeSrc = null; for (const e of evtSources) { @@ -323,11 +328,11 @@ async function main() { if (cacheKey in cache) { ret = cache[cacheKey]; } - // if we don't have a complete cache yet we start streaming this kernel - if (!(cacheKey in cache) || (cache[cacheKey].length !== kernel.match_count+1 && activeSrc == null)) { + // if we don't have a complete cache yet we start streaming rewrites in this step + if (!(cacheKey in cache) || (cache[cacheKey].length !== step.match_count+1 && activeSrc == null)) { ret = []; cache[cacheKey] = ret; - const eventSource = new EventSource(`/kernels?kernel=${currentKernel}&idx=${currentUOp}`); + const eventSource = new EventSource(`/ctxs?ctx=${currentCtx}&idx=${currentStep}`); evtSources.push(eventSource); eventSource.onmessage = (e) => { if (e.data === "END") return eventSource.close(); @@ -341,20 +346,20 @@ async function main() { }; } if (ret.length === 0) return; - if (kernel.name == "View Memory Graph") { + if (ctx.name == "View Memory Graph") { renderMemoryGraph(ret[currentRewrite].graph); } else { renderDag(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || [], recenter=currentRewrite === 0); } // ** right sidebar code blocks const metadata = document.querySelector(".metadata"); - const [code, lang] = kernels[currentKernel].kernel_code != null ? [kernels[currentKernel].kernel_code, "cpp"] : [ret[currentRewrite].uop, "python"]; - metadata.replaceChildren(codeBlock(kernel.code_line, "python", { loc:kernel.loc, wrap:true }), codeBlock(code, lang, { wrap:false })); + const [code, lang] = ctx.kernel_code != null ? [ctx.kernel_code, "cpp"] : [ret[currentRewrite].uop, "python"]; + metadata.replaceChildren(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }), codeBlock(code, lang, { wrap:false })); // ** rewrite steps - if (kernel.match_count >= 1) { + if (step.match_count >= 1) { const rewriteList = metadata.appendChild(document.createElement("div")); rewriteList.className = "rewrite-list"; - for (let s=0; s<=kernel.match_count; s++) { + for (let s=0; s<=step.match_count; s++) { const ul = rewriteList.appendChild(document.createElement("ul")); ul.innerText = s; ul.id = `rewrite-${s}`; @@ -408,36 +413,36 @@ function appendResizer(element, { minWidth, maxWidth }, left=false) { }, { once: true }); }); } -appendResizer(document.querySelector(".kernel-list-parent"), { minWidth: 15, maxWidth: 50 }, left=true); +appendResizer(document.querySelector(".ctx-list-parent"), { minWidth: 15, maxWidth: 50 }, left=true); appendResizer(document.querySelector(".metadata-parent"), { minWidth: 20, maxWidth: 50 }); // **** keyboard shortcuts document.addEventListener("keydown", async function(event) { - const { currentKernel, currentUOp, currentRewrite, expandKernel } = state; - // up and down change the UOp or kernel from the list + const { currentCtx, currentStep, currentRewrite, expandSteps } = state; + // up and down change the step or context from the list if (event.key == "ArrowUp") { event.preventDefault(); - if (expandKernel) { - return setState({ currentRewrite:0, currentUOp:Math.max(0, currentUOp-1) }); + if (expandSteps) { + return setState({ currentRewrite:0, currentStep:Math.max(0, currentStep-1) }); } - return setState({ currentUOp:0, currentRewrite:0, currentKernel:Math.max(0, currentKernel-1) }); + return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.max(0, currentCtx-1) }); } if (event.key == "ArrowDown") { event.preventDefault(); - if (expandKernel) { - const totalUOps = kernels[currentKernel].steps.length-1; - return setState({ currentRewrite:0, currentUOp:Math.min(totalUOps, currentUOp+1) }); + if (expandSteps) { + const totalUOps = ctxs[currentCtx].steps.length-1; + return setState({ currentRewrite:0, currentStep:Math.min(totalUOps, currentStep+1) }); } - return setState({ currentUOp:0, currentRewrite:0, currentKernel:Math.min(kernels.length-1, currentKernel+1) }); + return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.min(ctxs.length-1, currentCtx+1) }); } // enter toggles focus on a single rewrite stage if (event.key == "Enter") { event.preventDefault() - if (state.currentKernel === -1) { - return setState({ currentKernel:0, expandKernel:true }); + if (currentCtx === -1) { + return setState({ currentCtx:0, expandSteps:true }); } - return setState({ currentUOp:0, currentRewrite:0, expandKernel:!expandKernel }); + return setState({ currentStep:0, currentRewrite:0, expandSteps:!expandSteps }); } // left and right go through rewrites in a single UOp if (event.key == "ArrowLeft") { diff --git a/tinygrad/viz/js/worker.js b/tinygrad/viz/js/worker.js index 0f6774b0d0..31b7ba10bb 100644 --- a/tinygrad/viz/js/worker.js +++ b/tinygrad/viz/js/worker.js @@ -7,13 +7,13 @@ ctx.font = `${LINE_HEIGHT}px sans-serif`; const ansiStrip = (st, tag) => st.replace(/\u001b\[(\d+)m(.*?)\u001b\[0m/g, (_,__,st) => st); onmessage = (e) => { - const { graph, additions, kernels } = e.data; + const { graph, additions, ctxs } = e.data; const g = new dagre.graphlib.Graph({ compound: true }); g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; }); if (additions.length !== 0) g.setNode("addition", {label:"", style:"fill: rgba(26, 27, 38, 0.5);", padding:0}); for (let [k, {label, src, ref, ...rest }] of Object.entries(graph)) { - const idx = ref ? kernels.findIndex(k => k.ref === ref) : -1; - if (idx != -1) label += `\ncodegen@${ansiStrip(kernels[idx].name)}`; + const idx = ref ? ctxs.findIndex(k => k.ref === ref) : -1; + if (idx != -1) label += `\ncodegen@${ansiStrip(ctxs[idx].name)}`; // adjust node dims by label size + add padding let [width, height] = [0, 0]; for (line of label.split("\n")) { diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 050ea0f355..df7102f941 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -133,9 +133,9 @@ class Handler(BaseHTTPRequestHandler): if url.path.endswith(".js"): content_type = "application/javascript" if url.path.endswith(".css"): content_type = "text/css" except FileNotFoundError: status_code = 404 - elif url.path == "/kernels": - if "kernel" in (query:=parse_qs(url.query)): - kidx, ridx = int(query["kernel"][0]), int(query["idx"][0]) + elif url.path == "/ctxs": + if "ctx" in (query:=parse_qs(url.query)): + kidx, ridx = int(query["ctx"][0]), int(query["idx"][0]) try: # stream details self.send_response(200) @@ -149,7 +149,7 @@ class Handler(BaseHTTPRequestHandler): return self.wfile.flush() # pass if client closed connection except (BrokenPipeError, ConnectionResetError): return - ret, content_type = json.dumps(kernels).encode(), "application/json" + ret, content_type = json.dumps(ctxs).encode(), "application/json" elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json" else: status_code = 404 @@ -194,7 +194,7 @@ if __name__ == "__main__": contexts, profile = load_pickle(args.kernels), load_pickle(args.profile) # NOTE: this context is a tuple of list[keys] and list[values] - kernels = get_metadata(*contexts) if contexts is not None else [] + ctxs = get_metadata(*contexts) if contexts is not None else [] perfetto_profile = to_perfetto(profile) if profile is not None else None