From 69aa2054f60335b16418422841852522e9aad989 Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Thu, 28 May 2026 19:41:58 -0700 Subject: [PATCH 1/5] rename clangjit to clang (#16423) --- .github/workflows/test.yml | 10 +++++----- test/null/test_device.py | 24 ++++++++++++------------ test/null/test_elf.py | 4 ++-- tinygrad/renderer/cstyle.py | 5 ++--- tinygrad/runtime/ops_cpu.py | 4 ++-- tinygrad/runtime/support/compiler_cpu.py | 4 ++-- 6 files changed, 25 insertions(+), 26 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3e58a0eee8..6f6c8e7fd9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -572,7 +572,7 @@ jobs: fail-fast: false matrix: dev: - - 'CPU:CLANGJIT' + - 'CPU:CLANG' - 'CPU:LLVM' - 'CPU:LVP' - 'CPU:X86' @@ -591,12 +591,12 @@ jobs: key: linux-${{ matrix.dev }} deps: testing_unit python-version: '3.12' - llvm: ${{ contains(matrix.dev, 'LLVM') || contains(matrix.dev, 'LVP') || contains(matrix.dev, 'CLANGJIT') }} + llvm: ${{ contains(matrix.dev, 'LLVM') || contains(matrix.dev, 'LVP') || contains(matrix.dev, 'CLANG') }} mesa: ${{ contains(matrix.dev, 'LVP') && 'cpu' || 'false' }} webgpu: ${{ matrix.dev == 'WEBGPU' }} opencl: ${{ matrix.dev == 'CL' }} - name: Set env - run: printf "DEV=${{ matrix.dev }}${{ matrix.dev == 'CPU:CLANGJIT' && '\nCPU_COUNT=2' || '' }}" >> $GITHUB_ENV + run: printf "DEV=${{ matrix.dev }}${{ matrix.dev == 'CPU:CLANG' && '\nCPU_COUNT=2' || '' }}" >> $GITHUB_ENV - name: Check Device.DEFAULT and print some source run: | python -c "from tinygrad import Device; from tinygrad.helpers import Target; assert Device.DEFAULT == Target.parse('${{ matrix.dev }}').device" @@ -889,7 +889,7 @@ jobs: fail-fast: false matrix: dev: - - 'CPU:CLANGJIT' + - 'CPU:CLANG' - 'CPU:LLVM' - 'CPU:X86' - 'WEBGPU' @@ -908,7 +908,7 @@ jobs: pydeps: ${{ matrix.dev == 'WEBGPU' && 'dawn-python' || '' }} - name: Set env shell: bash - run: printf "DEV=${{ matrix.dev }}${{ matrix.dev == 'CPU:CLANGJIT' && '\nCPU_COUNT=2' || '' }}" >> $GITHUB_ENV + run: printf "DEV=${{ matrix.dev }}${{ matrix.dev == 'CPU:CLANG' && '\nCPU_COUNT=2' || '' }}" >> $GITHUB_ENV - name: Check Device.DEFAULT and print some source shell: bash run: | diff --git a/test/null/test_device.py b/test/null/test_device.py index b31be94173..ec62d5eaa6 100644 --- a/test/null/test_device.py +++ b/test/null/test_device.py @@ -28,8 +28,8 @@ class TestDevice(unittest.TestCase): def test_nonexistent_renderer(self): with self.assertRaisesRegex(RuntimeError, "has no renderer"): with Context(DEV="CPU:TYPO"): Device[Device.DEFAULT].renderer - with self.assertRaisesRegex(RuntimeError, "did you mean: 'CLANGJIT'"): - with Context(DEV="CPU:CLANG"): Device[Device.DEFAULT].renderer + with self.assertRaisesRegex(RuntimeError, "did you mean: 'CLANG'"): + with Context(DEV="CPU:CLANGJIT"): Device[Device.DEFAULT].renderer @unittest.skipIf(Device.DEFAULT != "AMD", "only run on AMD") def test_nonexistent_iface(self): @@ -69,17 +69,17 @@ class TestDevice(unittest.TestCase): @unittest.skipIf(WIN, "skipping windows test") # TODO: subprocess causes memory violation? def test_env_overwrite_default_compiler(self): if Device.DEFAULT == "CPU": - from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler - try: _, _ = CPULLVMCompiler(), ClangJITCompiler() + from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangCompiler + try: _, _ = CPULLVMCompiler(), ClangCompiler() except Exception as e: self.skipTest(f"skipping compiler test: not all compilers: {e}") - imports = "from tinygrad import Device; from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler" + imports = "from tinygrad import Device; from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangCompiler" subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, CPULLVMCompiler)"'], shell=True, check=True, env={**os.environ, "DEV": "CPU:LLVM"}) - subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, ClangJITCompiler)"'], + subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, ClangCompiler)"'], shell=True, check=True, env={**os.environ, "DEV": "CPU"}) - subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, ClangJITCompiler)"'], - shell=True, check=True, env={**os.environ, "DEV": "CPU:CLANGJIT"}) + subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, ClangCompiler)"'], + shell=True, check=True, env={**os.environ, "DEV": "CPU:CLANG"}) elif Device.DEFAULT == "AMD": from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler try: _, _ = HIPCompiler(Device[Device.DEFAULT].arch), AMDLLVMCompiler(Device[Device.DEFAULT].arch) @@ -96,15 +96,15 @@ class TestDevice(unittest.TestCase): @unittest.skipIf(WIN, "skipping windows test") def test_env_online(self): - from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler - try: _, _ = CPULLVMCompiler(), ClangJITCompiler() + from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangCompiler + try: _, _ = CPULLVMCompiler(), ClangCompiler() except Exception as e: self.skipTest(f"skipping compiler test: not all compilers: {e}") with Context(DEV="CPU:LLVM"): inst = Device["CPU"].compiler self.assertIsInstance(Device["CPU"].compiler, CPULLVMCompiler) with Context(DEV="CPU"): - self.assertIsInstance(Device["CPU"].compiler, ClangJITCompiler) + self.assertIsInstance(Device["CPU"].compiler, ClangCompiler) with Context(DEV="CPU:LLVM"): self.assertIsInstance(Device["CPU"].compiler, CPULLVMCompiler) assert inst is Device["CPU"].compiler # cached @@ -118,7 +118,7 @@ class TestDevice(unittest.TestCase): dev = Device["CPU"] dev.cached_renderer.clear() - with patch("tinygrad.renderer.cstyle.ClangJITRenderer.__init__", side_effect=RuntimeError("broken")): + with patch("tinygrad.renderer.cstyle.ClangRenderer.__init__", side_effect=RuntimeError("broken")): self.assertIsInstance(dev.renderer.compiler, CPULLVMCompiler) def test_dev_contextvar(self): diff --git a/test/null/test_elf.py b/test/null/test_elf.py index f7d350bd34..0f9a3d48f0 100644 --- a/test/null/test_elf.py +++ b/test/null/test_elf.py @@ -1,5 +1,5 @@ import unittest, subprocess, platform -from tinygrad.runtime.support.compiler_cpu import ClangJITCompiler +from tinygrad.runtime.support.compiler_cpu import ClangCompiler from tinygrad.runtime.support.elf import elf_loader class TestElfLoader(unittest.TestCase): @@ -23,7 +23,7 @@ class TestElfLoader(unittest.TestCase): } ''' with self.assertRaisesRegex(RuntimeError, 'evil_external_function'): - ClangJITCompiler([{'AMD64':'x86_64', 'aarch64':'arm64'}.get(m:=platform.machine(), m), "native"]).compile(src) + ClangCompiler([{'AMD64':'x86_64', 'aarch64':'arm64'}.get(m:=platform.machine(), m), "native"]).compile(src) def test_link(self): src = ''' float powf(float, float); // from libm diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index f50623df75..dc0801e7a9 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -275,12 +275,11 @@ class ClangRenderer(CStyleLanguage): def supported_dtypes(self): return {d for d in super().supported_dtypes() if (d != dtypes.bfloat16 or self.target.arch.startswith(("x86", "arm"))) and d not in dtypes.fp8s} -class ClangJITRenderer(ClangRenderer): def __init__(self, target:Target): super().__init__(target) - from tinygrad.runtime.support.compiler_cpu import ClangJITCompiler + from tinygrad.runtime.support.compiler_cpu import ClangCompiler if "AMX" in target.arch: self.tensor_cores = tc.amx - self.compiler = ClangJITCompiler([x for x in target.arch.split(",") if x != "AMX"]) + self.compiler = ClangCompiler([x for x in target.arch.split(",") if x != "AMX"]) class OpenCLRenderer(CStyleLanguage): has_aux = True diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index ca4e442c94..37c59c9338 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -4,7 +4,7 @@ from tinygrad.helpers import to_mv, OSX, WIN, mv_address, suppress_finalizing, u from tinygrad.device import BufferSpec from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface from tinygrad.runtime.support.hcq import CLikeArgsState -from tinygrad.renderer.cstyle import ClangJITRenderer +from tinygrad.renderer.cstyle import ClangRenderer from tinygrad.renderer.llvmir import CPULLVMRenderer from tinygrad.renderer.nir import LVPRenderer from tinygrad.renderer.isa.x86 import X86Renderer @@ -138,5 +138,5 @@ class CPUDevice(HCQCompiled): def __init__(self, device:str=""): self.tasks:queue.Queue = queue.Queue() CPUWorker(self, self.tasks, thread_id=0).start() - super().__init__(device, CPUAllocator(self), [ClangJITRenderer, CPULLVMRenderer, LVPRenderer, X86Renderer], functools.partial(CPUProgram, self), + super().__init__(device, CPUAllocator(self), [ClangRenderer, CPULLVMRenderer, LVPRenderer, X86Renderer], functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue, arch={'amd64':'x86_64', 'aarch64':'arm64'}.get(m:=platform.machine().lower(), m)+",native") diff --git a/tinygrad/runtime/support/compiler_cpu.py b/tinygrad/runtime/support/compiler_cpu.py index 5b6121ebb7..2fc5ede229 100644 --- a/tinygrad/runtime/support/compiler_cpu.py +++ b/tinygrad/runtime/support/compiler_cpu.py @@ -4,7 +4,7 @@ from tinygrad.helpers import getenv, capstone_flatdump, DEBUG, unwrap from tinygrad.runtime.support.elf import jit_loader from tinygrad.runtime.autogen import llvm -class ClangJITCompiler(Compiler): +class ClangCompiler(Compiler): def __init__(self, arch:list[str], cachekey="compile_clang_jit"): assert len(arch) >= 2, f"invalid arch string: {','.join(arch)!r}, expected ',,[]' (eg. 'x86_64,znver2')" self.arch, cpu, *feats = arch @@ -98,7 +98,7 @@ class CPULLVMCompiler(LLVMCompiler): if cpu == "native": cpu = ctypes.string_at(llvm.LLVMGetHostCPUName()).decode() featstr = (featstr + "," if featstr else "") + ctypes.string_at(llvm.LLVMGetHostCPUFeatures()).decode() - # +reserve-x18 here does the same thing as -ffixed-x18 in ClangJITCompiler, see comments there for why it's needed on arm osx + # +reserve-x18 here does the same thing as -ffixed-x18 in ClangCompiler, see comments there for why it's needed on arm osx super().__init__(self.arch, cpu, ('+reserve-x18,' if self.arch == "arm64" else '') + featstr, cache_key) def disassemble(self, lib:bytes): capstone_flatdump(lib, self.arch) From 6e0d5262dc8834e19e0413abf7172529b54631eb Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Thu, 28 May 2026 20:14:35 -0700 Subject: [PATCH 2/5] ci: autocancel outdated pr jobs (#16424) --- .github/workflows/test.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6f6c8e7fd9..84632d77ad 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,6 +14,10 @@ on: pull_request: workflow_dispatch: +concurrency: + group: test-${{ github.event_name }}-${{ github.event_name == 'pull_request' && github.event.pull_request.number || github.run_id }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + jobs: docs: name: Docs From f86966af56446de96dc802dde2495c7ca0aefe24 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Thu, 28 May 2026 23:18:11 -0400 Subject: [PATCH 3/5] llama: optim amax margin (#16425) --- examples/mlperf/optim.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/mlperf/optim.py b/examples/mlperf/optim.py index a1e20c99f7..2c83c35329 100644 --- a/examples/mlperf/optim.py +++ b/examples/mlperf/optim.py @@ -6,6 +6,7 @@ from tinygrad.uop.ops import UOp, Ops STOCHASTIC_ROUND = getenv("STOCHASTIC_ROUND", 0) MASTER_WEIGHTS = getenv("MASTER_WEIGHTS", 0) +FP8_AMAX_MARGIN = getenv("FP8_AMAX_MARGIN", 1.1) def stochastic_round_bf16(x:Tensor) -> Tensor: bits = x.bitcast(dtypes.uint32) @@ -95,7 +96,7 @@ class GradAccClipAdamW(Optimizer): scaled = (new_w * scale).clamp(-FP8_MAX, FP8_MAX) ret = scaled.cast(t.dtype) # update inv_scale for next step from quantized result - new_amax = (ret.float().abs().max(axis=tuple(range(1, ret.ndim))) * t._inv_scale).detach() + new_amax = (ret.float().abs().max(axis=tuple(range(1, ret.ndim))) * t._inv_scale * FP8_AMAX_MARGIN).detach() new_inv = ((new_amax + 1e-8) / FP8_MAX).cast(t._inv_scale.dtype) t._next_inv_scale.assign(new_inv.shard_like(t._next_inv_scale) if offloaded else new_inv) return ret.shard_like(t) if offloaded else ret From 814d414f41401273fe76da2bb19b3ed5c08471b5 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 29 May 2026 07:16:34 +0300 Subject: [PATCH 4/5] viz: set label offset for asm (#16426) --- tinygrad/viz/js/worker.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/viz/js/worker.js b/tinygrad/viz/js/worker.js index d7e7532e5d..3e677f79ed 100644 --- a/tinygrad/viz/js/worker.js +++ b/tinygrad/viz/js/worker.js @@ -31,7 +31,7 @@ const layoutCfg = (g, { blocks, paths, pc_tokens }) => { width = Math.max(width, ctx.measureText(tokens.map((t) => t.st).join("")).width); height += lineHeight; } - g.setNode(lead, { ...rectDims(width, height), label, id:lead, color:"#1a1b26" }); + g.setNode(lead, { ...rectDims(width, height), label, labelX:0, id:lead, color:"#1a1b26" }); } // paths become edges between basic blocks const pathColors = {0:"#3f7564", 1:"#7a4540", 2:"#3b5f7e"}; From 54cfb794b8c304650f0990a4633cd7860618b328 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 29 May 2026 11:25:07 +0300 Subject: [PATCH 5/5] viz: addrspace little colored box (#16427) * return addrspace * layout * render * addrspace encodes color * update colors * in input_ast all are params are green * update stroke --- test/null/test_viz.py | 7 +++++-- tinygrad/viz/js/index.js | 6 ++++-- tinygrad/viz/js/worker.js | 8 ++++---- tinygrad/viz/serve.py | 9 +++++++-- 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/test/null/test_viz.py b/test/null/test_viz.py index a7a76d4737..42f03d45a9 100644 --- a/test/null/test_viz.py +++ b/test/null/test_viz.py @@ -5,13 +5,13 @@ from typing import Generator from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatcher, graph_rewrite, track_rewrites, profile_matches from tinygrad.uop.symbolic import sym -from tinygrad.dtype import dtypes +from tinygrad.dtype import dtypes, AddrSpace from tinygrad.helpers import colored, ansistrip, flatten, TracingKey, ProfileRangeEvent, ProfileEvent, Context, cpu_events, profile_marker from tinygrad.helpers import cpu_profile, ProfilePointEvent, unwrap from tinygrad.device import Buffer from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, active_group, _name_cnt, RewriteTrace -from tinygrad.viz.serve import load_rewrites, get_full_rewrite, uop_to_json, VizData, get_render +from tinygrad.viz.serve import load_rewrites, get_full_rewrite, uop_to_json, VizData, get_render, addrspace_colors from tinygrad.codegen import to_program_cache from tinygrad.codegen import to_program @@ -348,6 +348,9 @@ class TestVizIntegration(unittest.TestCase): self.assertEqual(lst[0]["name"], "Callify 1 Buffer n1") self.assertEqual(lst[1]["name"], "Schedule 1 Kernel n1") self.assertEqual(lst[2]["name"], prg.arg.name) + input_ast = next(viz.get_details(2, 0))["graph"].values() + for u in input_ast: + if u["label"].startswith("PARAM\n"): self.assertEqual(u["addrspace"], addrspace_colors[AddrSpace.GLOBAL]) # schedule graph CALL nodes have a link to jump to codegen def test_link_sched_codegen(self): diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 3d1b566fa8..b47a9012e3 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -57,9 +57,9 @@ function intersectRect(r1, r2) { } function addTags(root, path) { - root.selectAll("circle").data(d => d.rect ? [] : [d]).join("circle").attr("r", 5).style("fill", d => d.fill ?? null); + root.selectAll("circle").data(d => d.rect ? [] : [d]).join("circle").attr("r", 5).style("fill", d => d.fill ?? null).style("stroke", d => d.stroke ?? null); root.selectAll("rect").data(d => d.rect ? [d] : []).join("rect").attr("x", d => -d.width/2).attr("y", d => -d.height/2) - .attr("width", d => d.width).attr("height", d => d.height).style("fill", d => d.fill ?? null); + .attr("width", d => d.width).attr("height", d => d.height).style("fill", d => d.fill ?? null).style("stroke", d => d.stroke ?? null); if (path != null) root.selectAll("path").data(d => [d]).join("path").attr("d", path); else root.selectAll("text").data(d => [d]).join("text").text(d => d.text).attr("dy", "0.35em"); } @@ -115,6 +115,8 @@ const drawGraph = (data) => { }); addTags(nodes.selectAll("g.tag").data(d => d.tag != null ? [d] : []).join("g").attr("class", "tag") .attr("transform", d => `translate(${-d.width/2+8}, ${-d.height/2+8})`).datum(e => ({ text:e.tag }))); + addTags(nodes.selectAll("g.addrspace").data(d => d.addrspace != null ? [d] : []).join("g").attr("class", "tag addrspace") + .attr("transform", d => `translate(${d.width/2-8}, ${-d.height/2+8})`).datum(e => ({ rect:true, width:10, height:10, fill:e.addrspace, stroke:"none" }))); const CALL_TAG_WIDTH = 14; addTags(nodes.selectAll("g.type").data(d => d.collapsible ? [d] : []).join("g").attr("class", d => `tag clickable ${d.collapsed ? 'collapsed' : 'expanded'}`) .attr("transform", d => d.callNode ? `translate(${CALL_TAG_WIDTH/2-d.width/2}, ${0})` : `translate(${-d.width/2}, ${0})`) diff --git a/tinygrad/viz/js/worker.js b/tinygrad/viz/js/worker.js index 3e677f79ed..a809de89ec 100644 --- a/tinygrad/viz/js/worker.js +++ b/tinygrad/viz/js/worker.js @@ -31,7 +31,7 @@ const layoutCfg = (g, { blocks, paths, pc_tokens }) => { width = Math.max(width, ctx.measureText(tokens.map((t) => t.st).join("")).width); height += lineHeight; } - g.setNode(lead, { ...rectDims(width, height), label, labelX:0, id:lead, color:"#1a1b26" }); + g.setNode(lead, { ...rectDims(width, height), label, labelX:0, id:lead, color:"#1a1b26", addrspace:null }); } // paths become edges between basic blocks const pathColors = {0:"#3f7564", 1:"#7a4540", 2:"#3b5f7e"}; @@ -45,9 +45,9 @@ const layoutUOp = (g, { graph, change }, opts) => { const lineHeight = 14; g.setGraph({ rankdir: "LR", font:"sans-serif", lh:lineHeight }); ctx.font = `350 ${lineHeight}px ${g.graph().font}`; - if (change?.length) g.setNode("overlay", {label:"", labelWidth:0, labelHeight:0, className:"overlay"}); + if (change?.length) g.setNode("overlay", {label:"", labelWidth:0, labelHeight:0, labelX:0, className:"overlay"}); let callCount = 0; - for (const [k, {label, src, ref, color, tag, exclude }] of Object.entries(graph)) { + for (const [k, {label, src, ref, color, tag, exclude, addrspace}] of Object.entries(graph)) { // adjust node dims by label size (excluding escape codes) + add padding let [width, height] = [0, 0]; for (line of label.replace(/\u001B\[(?:K|.*?m)/g, "").split("\n")) { @@ -56,7 +56,7 @@ const layoutUOp = (g, { graph, change }, opts) => { } const callNode = label.startsWith("CALL\n") || label.startsWith("FUNCTION\n"); if (callNode) callCount++; - g.setNode(k, {...rectDims(width, height), label, labelX:0, ref, id:k, color, tag, callNode, exclude}); + g.setNode(k, {...rectDims(width, height), label, labelX:0, ref, id:k, color, tag, callNode, exclude, addrspace}); // add edges const edgeCounts = {}; for (const [_, s] of src) edgeCounts[s] = (edgeCounts[s] || 0)+1; diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 753910c4b3..6608f25362 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -43,7 +43,7 @@ from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupO from tinygrad.uop.ops import KernelInfo from tinygrad.uop.render import print_uops, pyrender from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, ProfileProgramEvent -from tinygrad.dtype import dtypes +from tinygrad.dtype import dtypes, AddrSpace uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.SHAPED_WMMA: "#FF5B5B", @@ -56,6 +56,8 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", Ops.STAGE: "#AC640D", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"} +addrspace_colors = {AddrSpace.REG:"#e68181", AddrSpace.LOCAL:"#e7c86a", AddrSpace.GLOBAL:"#75bd7b"} + # VIZ API # A step is a lightweight descriptor for a trace entry @@ -160,8 +162,11 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]: # limit SOURCE labels line count if u.op is Ops.SOURCE and len(lines:=label.split("\n")) > 40: label = "\n".join(lines[:30]) + "\n..." + try: addrspace = u.addrspace + except Exception: addrspace = None graph[id(u)] = {"label":label, "src":[(i,id(x)) for i,x in enumerate(u.src)], "exclude":u in excluded, "color":uops_colors.get(u.op, "#ffffff"), - "ref":ref, "tag":repr(u.tag) if u.tag is not None else None} + "ref":ref, "tag":repr(u.tag) if u.tag is not None else None, + "addrspace":addrspace_colors.get(addrspace, None) if addrspace is not None else None} return graph def _reconstruct(data:VizData, a:int, depth:int|None=None):