AMX in arch, better docs (#15871)

This commit is contained in:
Christopher Milan
2026-04-22 14:25:18 -07:00
committed by GitHub
parent e5891acab2
commit b0dc95a390
13 changed files with 45 additions and 28 deletions

View File

@@ -57,6 +57,8 @@ AMD:LLVM | use the AMD device with the LLVM renderer
NV:CUDA:sm_70 | use the NV device with the CUDA renderer targetting sm_70
AMD::gfx950 | use the AMD device targetting gfx950
USB+AMD | use the AMD device over the USB interface
CPU:LLVM | use the CPU device with the LLVM renderer
CPU:LLVM:x86_64,znver2,avx2,-avx512f | use the CPU device with the LLVM renderer, with [additional arch flags](runtime.md#cpu-arch)
### Debug breakdown

View File

@@ -10,7 +10,7 @@ tinygrad supports various runtimes, enabling your code to scale across a wide ra
| [METAL](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_metal.py) | Utilizes Metal for acceleration on Apple devices | - | M1+ Macs; Metal 3.0+ for `bfloat` support |
| [CUDA](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cuda.py) | Utilizes CUDA for acceleration on NVIDIA GPUs | nvrtc (default)<br> PTX (`DEV=CUDA:PTX`) | NVIDIA GPU with CUDA support |
| [CL](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cl.py) | Accelerates computations using OpenCL on GPUs | - | OpenCL 2.0 compatible device |
| [CPU](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cpu.py) | Runs on CPU using the clang or llvm compiler | Clang JIT (default)<br>LLVM IR (`DEV=CPU:LLVM`) | `clang` compiler in system `PATH` |
| [CPU](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cpu.py) | Runs on CPU using the clang or llvm compiler | Clang JIT (default)<br>LLVM IR (`DEV=CPU:LLVM`) | `clang` compiler in system `PATH`<br>You can specify additional arch parameters via [the `DEV` variable](env_vars.md#dev-variable). See [CPU arch](#cpu-arch) for details. |
| [WEBGPU](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_webgpu.py) | Runs on GPU using the Dawn WebGPU engine (used in Google Chrome) | - | Dawn library installed and discoverable. Binaries: [pydawn v0.3.0](https://github.com/wpmed92/pydawn/releases/tag/v0.3.0) |
@@ -79,3 +79,13 @@ NV backend supports several interfaces for communicating with devices:
* `NVK`: uses the nvidia driver
* `PCI`: uses the [NV driver](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/support/nv/nvdev.py)
## CPU Arch
The CPU renderers may be additionally configured using the arch component of [the `DEV` environment variable](env_vars.md#dev-variable).
CPU arch should be specified as a comma-separated list of parameters, and must contain at least two values: the architecture family (ie. x86_64, arm64, or riscv64) and the cpu type (as accepted by `clang`'s `-march`).
If native is specified as the cpu type, tinygrad (or delegate compiler) will query the host cpu type. Additional comma-separated values may be specified as follows:
* `AMX`: emit Apple silicon AMX instructions
All other additional values are interpreted as cpu feature flags. When a value is preceded by a `-` character, the corresponding feature flag will be disabled, otherwise the flag will be enabled.
Note that enabled feature flags should not be preceded by a `+`.

View File

@@ -132,6 +132,7 @@ class TestDevVar(unittest.TestCase):
for d, t in [("AMD", Target(device="AMD", renderer="")), ("AMD:LLVM", Target(device="AMD", renderer="LLVM")),
(":LLVM", Target(device="", renderer="LLVM")), ("AMD::gfx1100", Target(device="AMD", arch="gfx1100")),
("AMD:LLVM:gfx1100", Target(device="AMD", renderer="LLVM", arch="gfx1100")), ("::gfx1100", Target(arch="gfx1100")),
("CPU:LLVM:arm64,native,AMX", Target(device="CPU", renderer="LLVM", arch="arm64,native,AMX")),
("USB+", Target(interface="USB")), ("USB+AMD", Target(device="AMD", interface="USB")),
("PCI:0+AMD", Target(device="AMD", interface="PCI", indices="0")), (":0+AMD", Target(device="AMD", indices="0")),
("PCI:0,1+AMD", Target(device="AMD", interface="PCI", indices="0,1")),

View File

@@ -23,7 +23,7 @@ class TestElfLoader(unittest.TestCase):
}
'''
with self.assertRaisesRegex(RuntimeError, 'evil_external_function'):
ClangJITCompiler({'AMD64':'x86_64', 'aarch64':'arm64'}.get(m:=platform.machine(), m)+",native").compile(src)
ClangJITCompiler([{'AMD64':'x86_64', 'aarch64':'arm64'}.get(m:=platform.machine(), m), "native"]).compile(src)
def test_link(self):
src = '''
float powf(float, float); // from libm

View File

@@ -3,9 +3,11 @@ from tinygrad import Device, Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.engine.realize import get_program
from tinygrad.helpers import AMX
from tinygrad.helpers import DEV
from test.helpers import replace_opts
AMX = "AMX" in DEV.arch
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4")
class TestFloat4(unittest.TestCase):
@staticmethod

View File

@@ -7,7 +7,7 @@ from tinygrad.tensor import _to_np_dtype
from tinygrad.uop.ops import Ops
from tinygrad.dtype import DType
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import AMX, DEV, Context
from tinygrad.helpers import DEV, Context
from test.helpers import slow, replace_opts
from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
@@ -18,6 +18,8 @@ from test.backend.test_linearizer import helper_realized_ast, helper_linearizer_
# NOTE: get_program always passes in Device[Device.DEFAULT].renderer explicitly for process_replay!!!
AMX = "AMX" in DEV.arch
def helper_tc_ensure_uops_and_opts_count(N: int, M:int, K:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0,
ensure_triggered:bool=True):
a, b = Tensor.rand(M, K, dtype=dtype_in), Tensor.rand(K, N, dtype=dtype_in)

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid, PtrDType
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, identity_element
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, invalid_gate
from tinygrad.helpers import getenv, flatten, AMX, prod
from tinygrad.helpers import getenv, flatten, prod
from tinygrad.renderer import Renderer
# ***** image load valid simplification *****
@@ -171,7 +171,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if "AMX" in ctx.target.arch else [4,2])
lengths.append(1) # worst case, it's not folded
# filter fold lengths that don't divide

View File

@@ -1,6 +1,6 @@
import itertools
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
from tinygrad.helpers import getenv, DEBUG, prod, NOLOCALS, TC_OPT, TC_SELECT, USE_TC, AMX, IMAGE
from tinygrad.helpers import getenv, DEBUG, prod, NOLOCALS, TC_OPT, TC_SELECT, USE_TC, IMAGE
from tinygrad.dtype import PtrDType, ImageDType
from tinygrad.uop.ops import Ops, resolve, AxisType
from tinygrad.codegen.opt.postrange import Scheduler
@@ -34,7 +34,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
except KernelOptError:
pass
# skip hand-coded TC opts if AMX, upcasting will make kernel slower
if good_tc_opt and not AMX:
if good_tc_opt and "AMX" not in k.ren.target.arch:
if rngs is not None:
for tc_dim in [1,0]: # attempt to upcast M and N
szs = [sz for sz in [5,4,3,2] if rngs[tc_dim].src[0].divides(sz) is not None]

View File

@@ -232,7 +232,7 @@ DEV, DEBUG, BEAM, NOOPT = _DEV("DEV", ""), ContextVar("DEBUG", 0), ContextVar("B
IMAGE, FLOAT16, OPENPILOT_HACKS = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0), ContextVar("OPENPILOT_HACKS", 0)
JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32)
WINO, CAPTURING, TRACEMETA, NO_COLOR = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1), ContextVar("NO_COLOR", 0)
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
USE_TC, TC_SELECT, TC_OPT = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0)
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1)
RING, ALL2ALL, ALLREDUCE_CAST = ContextVar("RING", 1), ContextVar("ALL2ALL", 0), ContextVar("ALLREDUCE_CAST", 1)

View File

@@ -3,7 +3,7 @@ import math, sys, struct
from collections import defaultdict, Counter
from tinygrad.codegen.opt import tc
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str, axis_letters
from tinygrad.helpers import strip_parens, getenv, prod, dedup, Target, AMX, CPU_COUNT
from tinygrad.helpers import strip_parens, getenv, prod, dedup, Target, CPU_COUNT
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate, float_to_bf16
from tinygrad.renderer import Renderer
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
@@ -226,7 +226,6 @@ class ClangRenderer(CStyleLanguage):
global_max = (CPU_COUNT.value, 0, 0)
infinity = "__builtin_inff()"
nan = '__builtin_nanf("")'
if AMX: tensor_cores = tc.amx
# language options
buffer_suffix = " restrict"
@@ -280,7 +279,8 @@ class ClangJITRenderer(ClangRenderer):
def __init__(self, target:Target):
super().__init__(target)
from tinygrad.runtime.support.compiler_cpu import ClangJITCompiler
self.compiler = ClangJITCompiler(target.arch)
if "AMX" in target.arch: self.tensor_cores = tc.amx
self.compiler = ClangJITCompiler([x for x in target.arch.split(",") if x != "AMX"])
class OpenCLRenderer(CStyleLanguage):
has_aux = True

View File

@@ -6,7 +6,7 @@ from tinygrad.renderer.cstyle import HIPRenderer, create_non_native_float_pats,
from tinygrad.uop.decompositions import xexp2, xlog2
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, range_str
from tinygrad.dtype import dtypes, float_to_fp8, DType, PtrDType, truncate
from tinygrad.helpers import prod, Target, AMX, CPU_COUNT, getenv
from tinygrad.helpers import prod, Target, CPU_COUNT, getenv
def ldt(dt:DType):
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
@@ -134,7 +134,6 @@ class LLVMRenderer(Renderer):
abi: str | None
string_rewrite: PatternMatcher
code_for_op = {k:lambda:None for v in lop.values() for k in v.keys()}
if AMX: tensor_cores = tc.amx
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) + pm_manual_bf16_cast
def _render_fn(self, name:str, args:list[tuple[str,DType]], kernel:list[str], prefix:list[str]|None=None) -> str:
@@ -149,7 +148,7 @@ class LLVMRenderer(Renderer):
local_args: list[str] = []
for u in uops:
if AMX and u.op is Ops.WMMA: # prealloc aux buffers as AMX can only load from memory
if self.tensor_cores == tc.amx and u.op is Ops.WMMA: # prealloc aux buffers as AMX can only load from memory
vc += 1
r[u] = f"%wmma{vc}"
for i, dtype in enumerate(u.arg[2].vec(sz) for sz in [prod(size for _, size in upcast) for upcast in u.arg[6]]):
@@ -204,7 +203,8 @@ class CPULLVMRenderer(LLVMRenderer):
def __init__(self, target:Target):
super().__init__(target)
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler
self.compiler = CPULLVMCompiler(target.arch)
if "AMX" in target.arch: self.tensor_cores = tc.amx
self.compiler = CPULLVMCompiler([x for x in target.arch.split(",") if x != "AMX"])
barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n'
code_for_workitem = {"g": lambda x: f"tail call i32 @llvm.amdgcn.workgroup.id.{chr(120+int(x))}()",

View File

@@ -5,17 +5,17 @@ from tinygrad.runtime.support.elf import jit_loader
from tinygrad.runtime.autogen import llvm
class ClangJITCompiler(Compiler):
def __init__(self, arch, cachekey="compile_clang_jit"):
self.arch, cpu, feats = (sp:=arch.split(',', 2)) + [""] * (3 - len(sp))
def __init__(self, arch:list[str], cachekey="compile_clang_jit"):
self.arch, cpu, *feats = arch
assert self.arch and cpu, f"invalid arch string: {arch!r}, expected '<arch>,<cpu>,[<feats>]' (eg. 'x86_64,znver2')"
match self.arch:
case "x86_64": self.args = [f"-march={cpu}"] + [f"-mno{f}" if f.startswith("-") else f"-m{f}" for f in feats.split(',') if f]
case "x86_64": self.args = [f"-march={cpu}"] + [f"-mno{f}" if f.startswith("-") else f"-m{f}" for f in feats]
# on arm march means "runs on this arch and superset" instead of "optimize for this arch". x86 march == arm mcpu
# x18 is a reserved platform register. It is clobbered on context switch in macos and is used to store TEB pointer in windows on arm
case "arm64": self.args = ["-ffixed-x18", "-mcpu=" + "+".join([cpu] + ["no"+f[1:] if f.startswith("-") else f for f in feats.split(',') if f])]
case "riscv64": self.args = ["-march=" + "_".join(["rv64g" if cpu == "native" else cpu] + [f for f in feats.split(',') if f])]
case "arm64": self.args = ["-ffixed-x18", "-mcpu=" + "+".join([cpu] + ["no"+f[1:] if f.startswith("-") else f for f in feats])]
case "riscv64": self.args = ["-march=" + "_".join(["rv64g" if cpu == "native" else cpu] + feats)]
case _: raise RuntimeError(f"unsupported arch: {self.arch!r}")
super().__init__(f"{cachekey}_{arch}")
super().__init__(f"{cachekey}_{'_'.join(arch)}")
def compile_to_obj(self, src:str) -> bytes:
"""Compile C source to ELF object file (before linking)."""
@@ -91,14 +91,14 @@ class LLVMCompiler(Compiler):
class CPULLVMCompiler(LLVMCompiler):
def __init__(self, arch, cache_key=None):
self.arch, cpu, feats = (sp:=arch.split(',', 2)) + [""] * (3 - len(sp))
def __init__(self, arch:list[str], cache_key=None):
self.arch, cpu, *feats = arch
assert self.arch and cpu, f"invalid arch string: {arch!r}, expected '<arch>,<cpu>,[<feats>]' (eg. 'x86_64,znver2')"
feats = ','.join(f if f.startswith('-') else '+'+f for f in feats.split(',') if f)
featstr = ','.join(f if f.startswith('-') else '+'+f for f in feats)
if cpu == "native":
cpu = ctypes.string_at(llvm.LLVMGetHostCPUName()).decode()
feats = (feats + "," if feats else "") + ctypes.string_at(llvm.LLVMGetHostCPUFeatures()).decode()
featstr = (featstr + "," if featstr else "") + ctypes.string_at(llvm.LLVMGetHostCPUFeatures()).decode()
# +reserve-x18 here does the same thing as -ffixed-x18 in ClangJITCompiler, see comments there for why it's needed on arm osx
super().__init__(self.arch, cpu, ('+reserve-x18,' if self.arch == "arm64" else '') + feats, cache_key)
super().__init__(self.arch, cpu, ('+reserve-x18,' if self.arch == "arm64" else '') + featstr, cache_key)
def disassemble(self, lib:bytes): capstone_flatdump(lib, self.arch)

View File

@@ -17,7 +17,7 @@ def deserialize(enc_src, opts):
return mesa.nir_deserialize(None, ctypes.cast(opts, ctypes.POINTER(mesa.nir_shader_compiler_options)), blobreader)
class LVPCompiler(CPULLVMCompiler):
def __init__(self, arch): CPULLVMCompiler.__init__(self, arch, cache_key="compile_lvp")
def __init__(self, arch): CPULLVMCompiler.__init__(self, arch.split(","), cache_key="compile_lvp")
def compile(self, src) -> bytes:
shader, ctx = deserialize(src, mesa.lvp_nir_options), llvm.LLVMGetGlobalContext()