Merge branch 'master' into shrink_in_render

This commit is contained in:
George Hotz
2026-05-29 01:37:25 -07:00
committed by GitHub
11 changed files with 51 additions and 37 deletions

View File

@@ -14,6 +14,10 @@ on:
pull_request:
workflow_dispatch:
concurrency:
group: test-${{ github.event_name }}-${{ github.event_name == 'pull_request' && github.event.pull_request.number || github.run_id }}
cancel-in-progress: ${{ github.event_name == 'pull_request' }}
jobs:
docs:
name: Docs
@@ -572,7 +576,7 @@ jobs:
fail-fast: false
matrix:
dev:
- 'CPU:CLANGJIT'
- 'CPU:CLANG'
- 'CPU:LLVM'
- 'CPU:LVP'
- 'CPU:X86'
@@ -591,12 +595,12 @@ jobs:
key: linux-${{ matrix.dev }}
deps: testing_unit
python-version: '3.12'
llvm: ${{ contains(matrix.dev, 'LLVM') || contains(matrix.dev, 'LVP') || contains(matrix.dev, 'CLANGJIT') }}
llvm: ${{ contains(matrix.dev, 'LLVM') || contains(matrix.dev, 'LVP') || contains(matrix.dev, 'CLANG') }}
mesa: ${{ contains(matrix.dev, 'LVP') && 'cpu' || 'false' }}
webgpu: ${{ matrix.dev == 'WEBGPU' }}
opencl: ${{ matrix.dev == 'CL' }}
- name: Set env
run: printf "DEV=${{ matrix.dev }}${{ matrix.dev == 'CPU:CLANGJIT' && '\nCPU_COUNT=2' || '' }}" >> $GITHUB_ENV
run: printf "DEV=${{ matrix.dev }}${{ matrix.dev == 'CPU:CLANG' && '\nCPU_COUNT=2' || '' }}" >> $GITHUB_ENV
- name: Check Device.DEFAULT and print some source
run: |
python -c "from tinygrad import Device; from tinygrad.helpers import Target; assert Device.DEFAULT == Target.parse('${{ matrix.dev }}').device"
@@ -889,7 +893,7 @@ jobs:
fail-fast: false
matrix:
dev:
- 'CPU:CLANGJIT'
- 'CPU:CLANG'
- 'CPU:LLVM'
- 'CPU:X86'
- 'WEBGPU'
@@ -908,7 +912,7 @@ jobs:
pydeps: ${{ matrix.dev == 'WEBGPU' && 'dawn-python' || '' }}
- name: Set env
shell: bash
run: printf "DEV=${{ matrix.dev }}${{ matrix.dev == 'CPU:CLANGJIT' && '\nCPU_COUNT=2' || '' }}" >> $GITHUB_ENV
run: printf "DEV=${{ matrix.dev }}${{ matrix.dev == 'CPU:CLANG' && '\nCPU_COUNT=2' || '' }}" >> $GITHUB_ENV
- name: Check Device.DEFAULT and print some source
shell: bash
run: |

View File

@@ -6,6 +6,7 @@ from tinygrad.uop.ops import UOp, Ops
STOCHASTIC_ROUND = getenv("STOCHASTIC_ROUND", 0)
MASTER_WEIGHTS = getenv("MASTER_WEIGHTS", 0)
FP8_AMAX_MARGIN = getenv("FP8_AMAX_MARGIN", 1.1)
def stochastic_round_bf16(x:Tensor) -> Tensor:
bits = x.bitcast(dtypes.uint32)
@@ -95,7 +96,7 @@ class GradAccClipAdamW(Optimizer):
scaled = (new_w * scale).clamp(-FP8_MAX, FP8_MAX)
ret = scaled.cast(t.dtype)
# update inv_scale for next step from quantized result
new_amax = (ret.float().abs().max(axis=tuple(range(1, ret.ndim))) * t._inv_scale).detach()
new_amax = (ret.float().abs().max(axis=tuple(range(1, ret.ndim))) * t._inv_scale * FP8_AMAX_MARGIN).detach()
new_inv = ((new_amax + 1e-8) / FP8_MAX).cast(t._inv_scale.dtype)
t._next_inv_scale.assign(new_inv.shard_like(t._next_inv_scale) if offloaded else new_inv)
return ret.shard_like(t) if offloaded else ret

View File

@@ -28,8 +28,8 @@ class TestDevice(unittest.TestCase):
def test_nonexistent_renderer(self):
with self.assertRaisesRegex(RuntimeError, "has no renderer"):
with Context(DEV="CPU:TYPO"): Device[Device.DEFAULT].renderer
with self.assertRaisesRegex(RuntimeError, "did you mean: 'CLANGJIT'"):
with Context(DEV="CPU:CLANG"): Device[Device.DEFAULT].renderer
with self.assertRaisesRegex(RuntimeError, "did you mean: 'CLANG'"):
with Context(DEV="CPU:CLANGJIT"): Device[Device.DEFAULT].renderer
@unittest.skipIf(Device.DEFAULT != "AMD", "only run on AMD")
def test_nonexistent_iface(self):
@@ -69,17 +69,17 @@ class TestDevice(unittest.TestCase):
@unittest.skipIf(WIN, "skipping windows test") # TODO: subprocess causes memory violation?
def test_env_overwrite_default_compiler(self):
if Device.DEFAULT == "CPU":
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler
try: _, _ = CPULLVMCompiler(), ClangJITCompiler()
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangCompiler
try: _, _ = CPULLVMCompiler(), ClangCompiler()
except Exception as e: self.skipTest(f"skipping compiler test: not all compilers: {e}")
imports = "from tinygrad import Device; from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler"
imports = "from tinygrad import Device; from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangCompiler"
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, CPULLVMCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "CPU:LLVM"})
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, ClangJITCompiler)"'],
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, ClangCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "CPU"})
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, ClangJITCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "CPU:CLANGJIT"})
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, ClangCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "CPU:CLANG"})
elif Device.DEFAULT == "AMD":
from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler
try: _, _ = HIPCompiler(Device[Device.DEFAULT].arch), AMDLLVMCompiler(Device[Device.DEFAULT].arch)
@@ -96,15 +96,15 @@ class TestDevice(unittest.TestCase):
@unittest.skipIf(WIN, "skipping windows test")
def test_env_online(self):
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler
try: _, _ = CPULLVMCompiler(), ClangJITCompiler()
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangCompiler
try: _, _ = CPULLVMCompiler(), ClangCompiler()
except Exception as e: self.skipTest(f"skipping compiler test: not all compilers: {e}")
with Context(DEV="CPU:LLVM"):
inst = Device["CPU"].compiler
self.assertIsInstance(Device["CPU"].compiler, CPULLVMCompiler)
with Context(DEV="CPU"):
self.assertIsInstance(Device["CPU"].compiler, ClangJITCompiler)
self.assertIsInstance(Device["CPU"].compiler, ClangCompiler)
with Context(DEV="CPU:LLVM"):
self.assertIsInstance(Device["CPU"].compiler, CPULLVMCompiler)
assert inst is Device["CPU"].compiler # cached
@@ -118,7 +118,7 @@ class TestDevice(unittest.TestCase):
dev = Device["CPU"]
dev.cached_renderer.clear()
with patch("tinygrad.renderer.cstyle.ClangJITRenderer.__init__", side_effect=RuntimeError("broken")):
with patch("tinygrad.renderer.cstyle.ClangRenderer.__init__", side_effect=RuntimeError("broken")):
self.assertIsInstance(dev.renderer.compiler, CPULLVMCompiler)
def test_dev_contextvar(self):

View File

@@ -1,5 +1,5 @@
import unittest, subprocess, platform
from tinygrad.runtime.support.compiler_cpu import ClangJITCompiler
from tinygrad.runtime.support.compiler_cpu import ClangCompiler
from tinygrad.runtime.support.elf import elf_loader
class TestElfLoader(unittest.TestCase):
@@ -23,7 +23,7 @@ class TestElfLoader(unittest.TestCase):
}
'''
with self.assertRaisesRegex(RuntimeError, 'evil_external_function'):
ClangJITCompiler([{'AMD64':'x86_64', 'aarch64':'arm64'}.get(m:=platform.machine(), m), "native"]).compile(src)
ClangCompiler([{'AMD64':'x86_64', 'aarch64':'arm64'}.get(m:=platform.machine(), m), "native"]).compile(src)
def test_link(self):
src = '''
float powf(float, float); // from libm

View File

@@ -5,13 +5,13 @@ from typing import Generator
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatcher, graph_rewrite, track_rewrites, profile_matches
from tinygrad.uop.symbolic import sym
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.helpers import colored, ansistrip, flatten, TracingKey, ProfileRangeEvent, ProfileEvent, Context, cpu_events, profile_marker
from tinygrad.helpers import cpu_profile, ProfilePointEvent, unwrap
from tinygrad.device import Buffer
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, active_group, _name_cnt, RewriteTrace
from tinygrad.viz.serve import load_rewrites, get_full_rewrite, uop_to_json, VizData, get_render
from tinygrad.viz.serve import load_rewrites, get_full_rewrite, uop_to_json, VizData, get_render, addrspace_colors
from tinygrad.codegen import to_program_cache
from tinygrad.codegen import to_program
@@ -348,6 +348,9 @@ class TestVizIntegration(unittest.TestCase):
self.assertEqual(lst[0]["name"], "Callify 1 Buffer n1")
self.assertEqual(lst[1]["name"], "Schedule 1 Kernel n1")
self.assertEqual(lst[2]["name"], prg.arg.name)
input_ast = next(viz.get_details(2, 0))["graph"].values()
for u in input_ast:
if u["label"].startswith("PARAM\n"): self.assertEqual(u["addrspace"], addrspace_colors[AddrSpace.GLOBAL])
# schedule graph CALL nodes have a link to jump to codegen
def test_link_sched_codegen(self):

View File

@@ -276,12 +276,11 @@ class ClangRenderer(CStyleLanguage):
def supported_dtypes(self):
return {d for d in super().supported_dtypes() if (d != dtypes.bfloat16 or self.target.arch.startswith(("x86", "arm"))) and d not in dtypes.fp8s}
class ClangJITRenderer(ClangRenderer):
def __init__(self, target:Target):
super().__init__(target)
from tinygrad.runtime.support.compiler_cpu import ClangJITCompiler
from tinygrad.runtime.support.compiler_cpu import ClangCompiler
if "AMX" in target.arch: self.tensor_cores = tc.amx
self.compiler = ClangJITCompiler([x for x in target.arch.split(",") if x != "AMX"])
self.compiler = ClangCompiler([x for x in target.arch.split(",") if x != "AMX"])
class OpenCLRenderer(CStyleLanguage):
has_aux = True

View File

@@ -4,7 +4,7 @@ from tinygrad.helpers import to_mv, OSX, WIN, mv_address, suppress_finalizing, u
from tinygrad.device import BufferSpec
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface
from tinygrad.runtime.support.hcq import CLikeArgsState
from tinygrad.renderer.cstyle import ClangJITRenderer
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.renderer.llvmir import CPULLVMRenderer
from tinygrad.renderer.nir import LVPRenderer
from tinygrad.renderer.isa.x86 import X86Renderer
@@ -138,5 +138,5 @@ class CPUDevice(HCQCompiled):
def __init__(self, device:str=""):
self.tasks:queue.Queue = queue.Queue()
CPUWorker(self, self.tasks, thread_id=0).start()
super().__init__(device, CPUAllocator(self), [ClangJITRenderer, CPULLVMRenderer, LVPRenderer, X86Renderer], functools.partial(CPUProgram, self),
super().__init__(device, CPUAllocator(self), [ClangRenderer, CPULLVMRenderer, LVPRenderer, X86Renderer], functools.partial(CPUProgram, self),
CPUSignal, CPUComputeQueue, arch={'amd64':'x86_64', 'aarch64':'arm64'}.get(m:=platform.machine().lower(), m)+",native")

View File

@@ -4,7 +4,7 @@ from tinygrad.helpers import getenv, capstone_flatdump, DEBUG, unwrap
from tinygrad.runtime.support.elf import jit_loader
from tinygrad.runtime.autogen import llvm
class ClangJITCompiler(Compiler):
class ClangCompiler(Compiler):
def __init__(self, arch:list[str], cachekey="compile_clang_jit"):
assert len(arch) >= 2, f"invalid arch string: {','.join(arch)!r}, expected '<arch>,<cpu>,[<feats>]' (eg. 'x86_64,znver2')"
self.arch, cpu, *feats = arch
@@ -98,7 +98,7 @@ class CPULLVMCompiler(LLVMCompiler):
if cpu == "native":
cpu = ctypes.string_at(llvm.LLVMGetHostCPUName()).decode()
featstr = (featstr + "," if featstr else "") + ctypes.string_at(llvm.LLVMGetHostCPUFeatures()).decode()
# +reserve-x18 here does the same thing as -ffixed-x18 in ClangJITCompiler, see comments there for why it's needed on arm osx
# +reserve-x18 here does the same thing as -ffixed-x18 in ClangCompiler, see comments there for why it's needed on arm osx
super().__init__(self.arch, cpu, ('+reserve-x18,' if self.arch == "arm64" else '') + featstr, cache_key)
def disassemble(self, lib:bytes): capstone_flatdump(lib, self.arch)

View File

@@ -57,9 +57,9 @@ function intersectRect(r1, r2) {
}
function addTags(root, path) {
root.selectAll("circle").data(d => d.rect ? [] : [d]).join("circle").attr("r", 5).style("fill", d => d.fill ?? null);
root.selectAll("circle").data(d => d.rect ? [] : [d]).join("circle").attr("r", 5).style("fill", d => d.fill ?? null).style("stroke", d => d.stroke ?? null);
root.selectAll("rect").data(d => d.rect ? [d] : []).join("rect").attr("x", d => -d.width/2).attr("y", d => -d.height/2)
.attr("width", d => d.width).attr("height", d => d.height).style("fill", d => d.fill ?? null);
.attr("width", d => d.width).attr("height", d => d.height).style("fill", d => d.fill ?? null).style("stroke", d => d.stroke ?? null);
if (path != null) root.selectAll("path").data(d => [d]).join("path").attr("d", path);
else root.selectAll("text").data(d => [d]).join("text").text(d => d.text).attr("dy", "0.35em");
}
@@ -115,6 +115,8 @@ const drawGraph = (data) => {
});
addTags(nodes.selectAll("g.tag").data(d => d.tag != null ? [d] : []).join("g").attr("class", "tag")
.attr("transform", d => `translate(${-d.width/2+8}, ${-d.height/2+8})`).datum(e => ({ text:e.tag })));
addTags(nodes.selectAll("g.addrspace").data(d => d.addrspace != null ? [d] : []).join("g").attr("class", "tag addrspace")
.attr("transform", d => `translate(${d.width/2-8}, ${-d.height/2+8})`).datum(e => ({ rect:true, width:10, height:10, fill:e.addrspace, stroke:"none" })));
const CALL_TAG_WIDTH = 14;
addTags(nodes.selectAll("g.type").data(d => d.collapsible ? [d] : []).join("g").attr("class", d => `tag clickable ${d.collapsed ? 'collapsed' : 'expanded'}`)
.attr("transform", d => d.callNode ? `translate(${CALL_TAG_WIDTH/2-d.width/2}, ${0})` : `translate(${-d.width/2}, ${0})`)

View File

@@ -31,7 +31,7 @@ const layoutCfg = (g, { blocks, paths, pc_tokens }) => {
width = Math.max(width, ctx.measureText(tokens.map((t) => t.st).join("")).width);
height += lineHeight;
}
g.setNode(lead, { ...rectDims(width, height), label, id:lead, color:"#1a1b26" });
g.setNode(lead, { ...rectDims(width, height), label, labelX:0, id:lead, color:"#1a1b26", addrspace:null });
}
// paths become edges between basic blocks
const pathColors = {0:"#3f7564", 1:"#7a4540", 2:"#3b5f7e"};
@@ -45,9 +45,9 @@ const layoutUOp = (g, { graph, change }, opts) => {
const lineHeight = 14;
g.setGraph({ rankdir: "LR", font:"sans-serif", lh:lineHeight });
ctx.font = `350 ${lineHeight}px ${g.graph().font}`;
if (change?.length) g.setNode("overlay", {label:"", labelWidth:0, labelHeight:0, className:"overlay"});
if (change?.length) g.setNode("overlay", {label:"", labelWidth:0, labelHeight:0, labelX:0, className:"overlay"});
let callCount = 0;
for (const [k, {label, src, ref, color, tag, exclude }] of Object.entries(graph)) {
for (const [k, {label, src, ref, color, tag, exclude, addrspace}] of Object.entries(graph)) {
// adjust node dims by label size (excluding escape codes) + add padding
let [width, height] = [0, 0];
for (line of label.replace(/\u001B\[(?:K|.*?m)/g, "").split("\n")) {
@@ -56,7 +56,7 @@ const layoutUOp = (g, { graph, change }, opts) => {
}
const callNode = label.startsWith("CALL\n") || label.startsWith("FUNCTION\n");
if (callNode) callCount++;
g.setNode(k, {...rectDims(width, height), label, labelX:0, ref, id:k, color, tag, callNode, exclude});
g.setNode(k, {...rectDims(width, height), label, labelX:0, ref, id:k, color, tag, callNode, exclude, addrspace});
// add edges
const edgeCounts = {};
for (const [_, s] of src) edgeCounts[s] = (edgeCounts[s] || 0)+1;

View File

@@ -43,7 +43,7 @@ from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupO
from tinygrad.uop.ops import KernelInfo
from tinygrad.uop.render import print_uops, pyrender
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, ProfileProgramEvent
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, AddrSpace
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.SHAPED_WMMA: "#FF5B5B",
@@ -56,6 +56,8 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
Ops.STAGE: "#AC640D", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
addrspace_colors = {AddrSpace.REG:"#e68181", AddrSpace.LOCAL:"#e7c86a", AddrSpace.GLOBAL:"#75bd7b"}
# VIZ API
# A step is a lightweight descriptor for a trace entry
@@ -160,8 +162,11 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
# limit SOURCE labels line count
if u.op is Ops.SOURCE and len(lines:=label.split("\n")) > 40:
label = "\n".join(lines[:30]) + "\n..."
try: addrspace = u.addrspace
except Exception: addrspace = None
graph[id(u)] = {"label":label, "src":[(i,id(x)) for i,x in enumerate(u.src)], "exclude":u in excluded, "color":uops_colors.get(u.op, "#ffffff"),
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None}
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None,
"addrspace":addrspace_colors.get(addrspace, None) if addrspace is not None else None}
return graph
def _reconstruct(data:VizData, a:int, depth:int|None=None):