Files
tinygrad/spec/tinyspec.tex
2026-06-04 10:12:31 -07:00

458 lines
22 KiB
TeX

\documentclass[10pt,letterpaper]{article}
\usepackage[margin=0.75in]{geometry}
\usepackage{amsmath,amssymb}
\usepackage{booktabs}
\usepackage{array}
\usepackage[dvipsnames]{xcolor}
\usepackage{enumitem}
\usepackage{listings}
\lstset{language=Python, basicstyle=\ttfamily\small, columns=fullflexible, keepspaces=true}
\newcommand{\op}[1]{\textsc{#1}}
\definecolor{movgreen}{HTML}{2E7D32}
\definecolor{reducered}{HTML}{C62828}
\definecolor{elwyellow}{HTML}{F9A825}
\definecolor{callblue}{HTML}{1565C0}
\definecolor{assignbrown}{HTML}{795548}
\definecolor{loadred}{HTML}{c08080}
\definecolor{multipurple}{HTML}{7B1FA2}
\definecolor{markerorange}{HTML}{E65100}
% AxisType colors (from tinygrad)
\definecolor{axblue}{HTML}{1565C0} % GLOBAL
\definecolor{axcyan}{HTML}{00838F} % LOCAL
\definecolor{axbrcyan}{HTML}{00ACC1} % WARP
\definecolor{axbrblue}{HTML}{42A5F5} % THREAD
\definecolor{axwhite}{HTML}{616161} % LOOP (gray on white paper)
\definecolor{axred}{HTML}{C62828} % REDUCE
\definecolor{axbrred}{HTML}{E53935} % GROUP_REDUCE
\definecolor{axyellow}{HTML}{F9A825} % UPCAST
\definecolor{axmagenta}{HTML}{7B1FA2} % UNROLL
\title{tinygrad: a single dialect from Tensor programs to Command Buffers}
\author{tinygrad, Corp. \\ \texttt{research@tinygrad.org}}
\date{}
\begin{document}
\maketitle
\thispagestyle{empty}
\section*{UOps}
All nodes in the tinygrad graph are \textbf{UOps}. A UOp is a tuple $(\mathrm{op},\;\mathrm{src},\;\mathrm{arg},\;\mathrm{tag})$ where $\mathrm{op}$ is from the set below, $\mathrm{src}$ is a tuple of input UOps, $\mathrm{arg}$ is op-dependent, and $\mathrm{tag}$ is for temporary processing. The full program is a DAG of UOps. Each UOp has five derived properties --- \textbf{dtype}, \textbf{shape}, \textbf{device}, \textbf{addrspace}, and \textbf{min\_max} --- determined by the rules at the end of this document.
%% ============================================================
\subsection*{Source Ops \normalfont\small--- leaf nodes}
\begin{tabular}{@{}l p{3.2cm} p{3.0cm} p{6.2cm}@{}}
\toprule
\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\
\midrule
\op{Param} & $(\mathbf{s})$ & slot, dtype, device?, addrspace? &
Placeholder with shape $\mathbf{s}$. Substituted in \op{Function}. \\[4pt]
\op{Buffer} & $(\mathbf{s})$ & slot, dtype, device, addrspace &
Concrete buffer slot with shape $\mathbf{s}$. If device is a tuple, it creates the fully sized buffer across multiple devices. \\
\op{Const} & () & value, dtype &
A scalar constant with shape $(\ )$. \\
& & & Form vector consts with \op{Stack} \\
\op{Binary} & () & data & Raw binary data, has dtype uint8 and shape len($data$) \\
\bottomrule
\end{tabular}
\smallskip
\textbf{addrspace} is \texttt{GLOBAL}, \texttt{LOCAL}, or \texttt{REG}.
%% ============================================================
\subsection*{{\color{movgreen}Movement Ops} \normalfont\small--- no arithmetic; view, indexing, and reinterpretation only}
\begin{tabular}{@{}l l l l@{}}
\toprule
\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\
\midrule
\op{Permute} & $(T,)$ & axis order $\pi$ & Reorder axes. $\pi = (1,0)$ is transpose. \\
\op{Flip} & $(T,)$ & bools $\mathbf{f}$ & Reverse along flagged axes. \\
\op{Reshape} & $(T, \mathbf{s'})$ & --- & Reinterpret in row-major order. $\prod s_k = \prod s'_k$. \\
\op{Expand} & $(T, \mathbf{s'})$ & --- & Broadcast size-1 axes. $s_k \in \{1, s'_k\}$. \\
\op{Pad} & $(T, \mathbf{o}, \mathbf{s'})$ & --- & Place $T$ at offset $o_k$ in an invalid-filled output of shape $s'_k$. \\
\op{Shrink} & $(T, \mathbf{o}, \mathbf{s'})$ & --- & Keep $s'_k$ elements starting at offset $o_k$ per axis. Inverse of \op{Pad}. \\
\op{Index} & $(T, i_0, i_1, \ldots)$ & --- & Index from left. $()$-shaped $i$ removes dim; $(k,)$-shaped makes it $k$. \\
\op{Stack} & $(T_0, T_1, \ldots)$ & --- & Join along a newly created leading axis. All shapes must match. \\
\op{Bitcast} & $(T,)$ & dtype & Reinterpret storage as target dtype; preserve total bytes. \\
\bottomrule
\end{tabular}
%% ============================================================
\subsection*{{\color{reducered}Reduce Ops} \normalfont\small--- collapse axes to size $1$}
\begin{tabular}{@{}l l l l@{}}
\toprule
\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\
\midrule
\op{Reduce} & ($T$, $r_0$, $r_1$, \ldots) & op, axes & Reduce $T$ along axes or ranges. Op is \op{Add}, \op{Max}, or \op{Mul}. \\
\bottomrule
\end{tabular}
%% ============================================================
\subsection*{{\color{callblue}Call Ops} \normalfont\small--- function abstraction, like the lambda calculus}
\begin{tabular}{@{}l l l l@{}}
\toprule
\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\
\midrule
\op{Function} & (body, $a_0$, $a_1$, \ldots) & --- & Substitute each \op{Param} $k$ in \op{Tuple} body with $a_k$. Gradient-able. \\
\op{Call} & (body, $a_0$, $a_1$, \ldots) & --- & Opaque invocation of a compiled kernel or custom function. \\
\op{Tuple} & $(v_0, v_1, \ldots)$ & --- & Pack values; required as \op{Function} body to return a value. \\
\op{GetTuple} & $(T,)$ & idx & Extract element at idx from a \op{Tuple}. \\
\bottomrule
\end{tabular}
%% ============================================================
\subsection*{{\color{loadred}Load Ops} \normalfont\small--- can change device or addrspace}
\begin{tabular}{@{}l l l l@{}}
\toprule
\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\
\midrule
\op{Load} & (buf, alt?, gate?) & device, addrspace & Read (pull) from buffer into a new anonymous buffer. \\
& & & Note: this replaces \op{Copy} and \op{Contiguous}. \\
\bottomrule
\end{tabular}
%% ============================================================
\subsection*{{\color{multipurple}Store Ops} \normalfont\small--- the only op with observable side effects}
\begin{tabular}{@{}l l l l@{}}
\toprule
\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\
\midrule
\op{Store} & (buf, val, gate?) & --- & Write (push) val into buf. buf.shape $=$ val.shape. \\
& & & If gate is present, write only when gate is true. Output is void. \\
\bottomrule
\end{tabular}
%% ============================================================
\subsection*{{\color{assignbrown}Ordering Ops} \normalfont\small--- execution order}
\begin{tabular}{@{}l l l p{6.0cm}@{}}
\toprule
\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\
\midrule
\op{Range} & $(\text{bound},)$ & type & Iterator from $0$ to bound. \\
\op{End} & (body, range) & --- & Close a \op{Range} loop. \\
\op{After} & (buf, deps\ldots) & --- & Passthrough of buf; guarantees deps execute first. \\
\op{Group} & $(u_0, u_1, \ldots)$ & --- & Void no-op that merges multiple \op{Store}s into one node, unordered. \\
\op{Sink} & $(s_0, s_1, \ldots)$ & --- & Collect side effects into a single root node. \\
\op{Linear} & (uops\ldots) & --- & Linearized (toposorted) instruction sequence. \\
\bottomrule
\end{tabular}
\smallskip
Assign is \op{Store} followed by \op{After}: write the value, then return the buffer with an ordering dependency.
%% ============================================================
\subsection*{{\color{elwyellow}Elementwise Ops} \normalfont\small--- all inputs same shape, output same shape, applied per-element}
\begin{tabular}{@{}l l l l@{}}
\toprule
\textbf{Arity} & \textbf{src} & \textbf{Op} & \textbf{Semantics} \\
\midrule
Unary & $(T,)$
& \op{Recip}
& $1/x$ \\
& & \op{Trunc}
& $\mathrm{trunc}(x)$: round toward zero. \\
& & \op{Cast}
& Convert to target dtype (specified in arg). \\
\\[4pt]
Binary & $(A, B)$
& \op{Add}, \op{Mul}, \op{Max}, \op{Mod}, \op{Idiv}
& $a+b$, $a \cdot b$, $\max(a,b)$, $a \bmod b$, $\lfloor a/b \rfloor$ \\
& & \op{CmpLt}, \op{CmpNe}
& $[a < b]$, $[a \ne b]$ \\
& & \op{Xor}, \op{Or}, \op{And}, \op{Shr}, \op{Shl}
& $a \oplus b$, $a \mid b$, $a \mathbin{\&} b$, $a \gg b$, $a \ll b$ \\[4pt]
Ternary & $(P, A, B)$
& \op{Where}
& $A[\mathbf{i}]$ if $P[\mathbf{i}] \ne 0$, else $B[\mathbf{i}]$ \\
\bottomrule
\end{tabular}
\medskip
\textbf{Decomposed elementwise ops} --- defined in terms of the primitives above.
\smallskip
\begin{tabular}{@{}l l l@{}}
\toprule
\textbf{Op} & \textbf{Decomposition} & \textbf{Semantics} \\
\midrule
\op{Neg} & \op{Mul}($A$, $-1$) & $-x$ \\
\op{Sub} & \op{Add}($A$, \op{Neg}($B$)) & $a - b$ \\
\op{Div} & \op{Mul}($A$, \op{Recip}($B$)) & $a / b$ \\
\op{CmpGt} & \op{CmpLt}($B$, $A$) & $[a > b]$ \\
\op{CmpGe} & \op{CmpNe}(\op{CmpLt}($A$, $B$),\, $1$) & $[a \ge b]$ \\
\op{CmpLe} & \op{CmpNe}(\op{CmpLt}($B$, $A$),\, $1$) & $[a \le b]$ \\
\op{CmpEq} & \op{CmpNe}(\op{CmpNe}($A$, $B$),\, $1$) & $[a = b]$ \\
\op{Not} & \op{CmpNe}($A$, $1$) & $\lnot a$ \\[4pt]
\op{Exp2} & polynomial approx + \op{Mul}, \op{Add} & $2^x$ \\
\op{Log2} & exponent extract + polynomial approx & $\log_2 x$ \\
\op{Sin} & argument reduction + polynomial approx & $\sin x$ \\
\op{Sqrt} & \op{Exp2}($0.5 \cdot$ \op{Log2}($A$)) & $\sqrt{x}$ \\
\op{Pow} & \op{Exp2}(\op{Log2}($A$) $\cdot\, B$) & $a^b$ \\
\op{Mulacc} & \op{Add}(\op{Mul}($A$, $B$),\, $C$) & $a \cdot b + c$ \\
\op{Threefry} & 5 rounds of add-rotate-xor (ARX) & Threefry 2x32 PRNG \\
\bottomrule
\end{tabular}
%% ============================================================
\subsection*{{\color{markerorange}Marker Ops} \normalfont\small--- identity on data}
\begin{tabular}{@{}l l l l@{}}
\toprule
\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\
\midrule
\op{Contiguous} & $(T,)$ & --- & Force contiguous memory layout. \\
\op{ContiguousBackward} & $(T,)$ & --- & Force contiguous in backward pass. \\
\op{Detach} & $(T,)$ & --- & Stops gradient propagation. \\
\bottomrule
\end{tabular}
%% ============================================================
\subsection*{Codegen Ops \normalfont\small--- generated code primitives, these do not appear in the main graph}
\begin{tabular}{@{}l l l l@{}}
\toprule
\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\
\midrule
\op{Barrier} & (deps\ldots) & --- & Synchronize threads within a workgroup. \\
\op{Ins} & \ldots & \ldots & A single machine instruction (e.g.\ AMD ISA). \\
\op{Special} & (bound,) & name & GPU thread/workgroup index (e.g.\ \texttt{gidx0}, \texttt{lidx1}). \\
\op{If} & (gate,) & --- & Begin conditional execution block. \\
\op{Endif} & (if,) & --- & End conditional execution block. \\
\op{Wmma} & (A, B, acc) & config & Warp matrix multiply-accumulate (tensor cores). \\
\op{Custom} & (args\ldots) & fmt & Inject custom code string into generated source. \\
\op{AtomicAdd} & (idx, val) & --- & Atomic read-modify-write: \texttt{buf[idx] += val}. \\[4pt]
\op{CustomFunction} & (meta\ldots) & name & Opaque device function (e.g.\ HW decode). Via \op{Call}. \\
\op{Program} & (linear, source, binary) & --- & Compiled kernel: instructions, source, and machine code. \\
\op{Source} & () & str & Human-readable rendered source code. \\
\op{Binary} & () & bytes & Compiled machine code. \\
\bottomrule
\end{tabular}
\smallskip
These ops are not part of the core specification and are subject to change.
%% ============================================================
\subsection*{Derived Properties}
Every UOp has a \textbf{dtype}, \textbf{shape}, \textbf{device}, \textbf{addrspace}, and \textbf{min\_max}, derived from its op, src, and arg:
\medskip
\begin{tabular}{@{}l l l l l@{}}
\toprule
\textbf{Op} & \textbf{dtype} & \textbf{shape} & \textbf{device} & \textbf{min\_max} \\
\midrule
\op{Buffer} & from arg & from $\mathrm{src}[0]$ & from arg & dtype range \\
\op{Const} & from arg & $()$ & \textsc{null} & $[v, v]$ \\
\op{Param} & from arg & from $\mathrm{src}[0]$ & from arg & from src or dtype range \\[3pt]
Movement ops & $\mathrm{src}[0].\mathrm{dtype}$ & (see op) & $\mathrm{src}[0].\mathrm{device}$ & $\mathrm{src}[0]$ \\
\op{Reduce} & $\mathrm{src}[0].\mathrm{dtype}$ & collapse axes to $1$ & $\mathrm{src}[0].\mathrm{device}$ & dtype range \\[3pt]
\op{Cast} & from arg & $\mathrm{src}[0].\mathrm{shape}$ & $\mathrm{src}[0].\mathrm{device}$ & clamped to dtype \\
\op{Bitcast} & from arg & $\mathrm{src}[0].\mathrm{shape}$ & $\mathrm{src}[0].\mathrm{device}$ & dtype range \\
\op{Copy} & $\mathrm{src}[0].\mathrm{dtype}$ & $\mathrm{src}[0].\mathrm{shape}$ & from arg & $\mathrm{src}[0]$ \\
ALU unary & $\mathrm{src}[0].\mathrm{dtype}$ & $\mathrm{src}[0].\mathrm{shape}$ & $\mathrm{src}[0].\mathrm{device}$ & dtype range \\
\op{Add} & $\mathrm{src}[0].\mathrm{dtype}$ & broadcast & $\mathrm{src}[0].\mathrm{device}$ & $[a+b,\, A+B]$ \\
\op{Mul} & $\mathrm{src}[0].\mathrm{dtype}$ & broadcast & $\mathrm{src}[0].\mathrm{device}$ & $[\min,\max]$ of products \\
\op{Max} & $\mathrm{src}[0].\mathrm{dtype}$ & broadcast & $\mathrm{src}[0].\mathrm{device}$ & $[\max(a,b),\, \max(A,B)]$ \\
Other binary & $\mathrm{src}[0].\mathrm{dtype}$ & broadcast & $\mathrm{src}[0].\mathrm{device}$ & dtype range \\
\op{CmpLt}, \op{CmpNe} & bool & broadcast & $\mathrm{src}[0].\mathrm{device}$ & from intervals \\
\op{Where} & $\mathrm{src}[1].\mathrm{dtype}$ & broadcast & $\mathrm{src}[0].\mathrm{device}$ & $[\min(b,c),\, \max(B,C)]$ \\[3pt]
\op{Function}, \op{Call} & $\mathrm{src}[0].\mathrm{dtype}$ & substitute \op{Param} shapes & $\mathrm{src}[1].\mathrm{device}$ & dtype range \\
\op{Range} & index & $()$ & \textsc{null} & $[0,\, n{-}1]$ \\
\op{Index} & $\mathrm{src}[0].\mathrm{dtype}$ & remaining dims & $\mathrm{src}[0].\mathrm{device}$ & $\mathrm{src}[0]$ \\
\op{Store} & void & $()$ & $\mathrm{src}[0].\mathrm{device}$ & --- \\
\op{After} & $\mathrm{src}[0].\mathrm{dtype}$ & $\mathrm{src}[0].\mathrm{shape}$ & $\mathrm{src}[0].\mathrm{device}$ & $\mathrm{src}[0]$ \\
\bottomrule
\end{tabular}
\smallskip
$\mathrm{broadcast}$: right-align shapes, element-wise max; each axis must be equal or $1$.
$[a,A]$, $[b,B]$, $[c,C]$ denote min\_max of $\mathrm{src}[0]$, $\mathrm{src}[1]$, $\mathrm{src}[2]$.
Default \emph{dtype range}: $[\mathrm{dtype\_min},\, \mathrm{dtype\_max}]$.
\medskip
\textbf{axis} tracks the multi-device sharding dimension. \op{Buffer} with $n$-tuple device: axis $= 0$ (device dim).
\op{Reshape} remaps axis to preserve the shard boundary. \op{Permute} follows the permutation.
\op{Reduce} on the shard axis $\to$ \textsc{null}. \op{Replicated} on the shard axis $\to$ \textsc{null}. \op{Copy} $\to$ \textsc{null}. ALU ops inherit from sources. Default: \textsc{null}.
%% ============================================================
\subsection*{Kernel Optimizations (OptOps) \normalfont\small--- schedule-level transforms on kernel ranges}
Each kernel's iteration space is a set of \op{Range} axes. Every range has an \textbf{AxisType}:
\medskip
\begin{tabular}{@{}l l l l l@{}}
\toprule
\textbf{AxisType} & \textbf{Letter} & \textbf{Split from} & \textbf{Direction} & \textbf{Semantics} \\
\midrule
{\color{axblue}\texttt{GLOBAL}} & \texttt{g} & --- & --- & GPU global workgroup dimension. \\
{\color{axcyan}\texttt{LOCAL}} & \texttt{l} & g, L & inner & Workgroup local dimension (shared memory). \\
{\color{axbrcyan}\texttt{WARP}} & \texttt{w} & \multicolumn{2}{l}{(created by \op{TC})} & Warp-level lanes for tensor cores. \\
{\color{axbrblue}\texttt{THREAD}} & \texttt{t} & g & outer & CPU thread parallelism. \\
{\color{axwhite}\texttt{LOOP}} & \texttt{L} & --- & --- & Generic sequential loop (initial state). \\
{\color{axred}\texttt{REDUCE}} & \texttt{R} & --- & --- & Reduction axis. \\
{\color{axbrred}\texttt{GROUP\_REDUCE}} & \texttt{G} & R & inner/outer & Shared-memory group reduction. \\
{\color{axyellow}\texttt{UPCAST}} & \texttt{u} & g, l, L & inner & Register-level vectorization. \\
{\color{axmagenta}\texttt{UNROLL}} & \texttt{r} & R, G & inner & Fully unrolled loop. \\
\bottomrule
\end{tabular}
\medskip
An optimization is a triple $(\mathrm{op},\;\mathrm{axis},\;\mathrm{arg})$:
\smallskip
\begin{tabular}{@{}l l l p{6.5cm}@{}}
\toprule
\textbf{OptOp} & \textbf{axis} & \textbf{arg} & \textbf{Semantics} \\
\midrule
\op{Split} & any & (factor $k$, target, top?) &
Split axis $n$ by $k$ into $(n/k, k)$ or $(k, n/k)$ if top. New sub-axis gets target AxisType (see table above). \\
\op{Padto} & any & multiple $m$ &
Pad axis to next multiple of $m$ with validity masks. \\[4pt]
\op{Swap} & axis$_i$ & axis$_j$ &
Swap two axes $i \leftrightarrow j$. \\
\op{Nolocals} & --- & --- &
Disable local memory; no workgroup dims emitted. \\
\op{TC} & reduce idx & (tc, opt, mode) &
Apply tensor core \op{Wmma}: split reduce/output axes into \texttt{WARP}, \texttt{UPCAST}, and \texttt{UNROLL} dims. \\
\bottomrule
\end{tabular}
\smallskip
Optimizations compose left-to-right. \op{TC} must be first. The search space is explored by BEAM search or hand-coded heuristics.
%% ============================================================
\subsection*{Common Ops as Compositions}
All high-level tensor operations decompose into the primitives above.
\begin{lstlisting}
# gemm: C[M,N] = A[M,K] @ B[K,N]
def gemm(A, B):
M,K = A.shape; _,N = B.shape
return (A.reshape(M,K,1) * B.reshape(1,K,N)).sum(1)
# prefix_sum: cumulative sum via repeat+reshape sliding window trick
def prefix_sum(T):
n = T.shape[0]
x = T.pad((n-1, 0)) # (2n-1,)
x = x.reshape(1,2*n-1).expand(n+1,2*n-1) # tile
x = x.reshape((n+1)*(2*n-1)).shrink_to(2*n*n) # trim
x = x.reshape(n,2*n).shrink_to(n,n) # windows
return x.sum(-1) # reduce
# arange: prefix_sum of all 1s gives [1,2,...,n], subtract 1 for [0,1,...,n-1]
def arange(n):
return prefix_sum(Tensor(1).reshape(1).expand(n)) - 1
# gather: out[i] = T[idx[i]]. one-hot mask along gather axis, then reduce
def gather(T, idx):
K = T.shape[0]
pos = arange(K).reshape(K, 1) # (K, 1)
mask = (pos == idx.reshape(1, -1)).cast(T.dtype) # (K, D)
return (T.reshape(K, 1) * mask).sum(0) # (D,)
# scatter_add: T[idx[i]] += val[i]
def scatter_add(T, idx, val):
K, D = T.shape[0], idx.shape[0]
pos = arange(K).reshape(K, 1) # (K, 1)
mask = (pos == idx.reshape(1, D)).cast(T.dtype) # (K, D)
return T + (mask * val.reshape(1, D)).sum(1) # (K,)
\end{lstlisting}
%% ============================================================
\subsection*{{\color{multipurple}Multi-Device Collectives} \normalfont\small--- derived from primitives}
Let $D = (d_0, \ldots, d_{n-1})$ be an $n$-tuple device.
\op{Copy} to an $n$-tuple device reshards with axis $= 0$. \op{Copy} never changes shape.
\begin{lstlisting}
# T has shape (s,) on a single device.
# broadcast: replicate T to all n devices
def broadcast(T):
return T.reshape(1, s).expand(n, s).copy(D).replicated(0) # (s,) on D, axis=null
# scatter: split T into n chunks, one per device
def scatter(T):
return T.copy(D) # (s,) on D, axis=0
# T has shape (n*s,) on D with axis=0, so each device holds (s,) elements.
# gather: collect all shards onto one device
def gather(T):
return T.copy(D[0]) # (n*s,) on D[0], axis=null
# reduce: gather + sum
def reduce(T):
return gather(T).reshape(n, s).sum(0) # (s,) on D[0], axis=null
# allgather: collect all shards, replicate to all devices
def allgather(T):
return T.reshape(1, n*s).expand(n, n*s).copy(D).replicated(0) # (n*s,) on D, axis=null
# reduce_scatter: reduce across devices, scatter result
def reduce_scatter(T):
return T.reshape(n, n, s//n).permute(1, 0, 2).copy(D).sum(1).reshape(s) # (s,) on D, axis=0
# allreduce: reduce_scatter + allgather
def allreduce(T):
return allgather(reduce_scatter(T)) # (s,) on D, axis=null
\end{lstlisting}
%% ============================================================
\subsection*{{\color{callblue}The \texttt{@function} Decorator} \normalfont\small--- graph capture via tracing}
The \texttt{@function} decorator transforms a Python function on Tensors into a single \op{Function} node.
\begin{lstlisting}
@function
def f(a: Tensor, b: Tensor) -> Tensor:
return a + b
\end{lstlisting}
When \texttt{f(x, y)} is called, the decorator:
\begin{enumerate}[leftmargin=1.5em, itemsep=2pt]
\item \textbf{Extracts inputs}: walks all arguments to find every Tensor, deduplicates by identity.
\item \textbf{Runs the function} lazily (no device execution), building a UOp graph from the result.
\item \textbf{Parameterizes}: replaces each input UOp with a \op{Param}$(k)$ placeholder.
\item \textbf{Wraps the body} in a \op{Tuple} (even for single returns) and creates\\
\op{Function}(\op{Tuple}(body), $x$, $y$).
\item \textbf{Returns} the result via \op{GetTuple}$(0)$, or one \op{GetTuple} per element for tuple returns.
\end{enumerate}
The result is a reusable graph fragment: the body contains only \op{Param} references, not concrete buffers. At schedule time, the \op{Function} is resolved by substituting each \op{Param}$(k)$ back with its corresponding argument $a_k$, or lowered into an opaque \op{Call} if it is to be compiled as a reusable kernel.
%% ============================================================
\subsection*{Lowering Pipeline \normalfont\small--- from Tensor graph to machine code}
\begin{tabular}{@{}l p{9.7cm}@{}}
\toprule
\textbf{Stage} & \textbf{Semantics} \\
\midrule
\textbf{Callify} & Transform the Tensor graph into a single stateless function. \\
\textbf{Rangeify} & Determine the kernel split of the function. Break everything down to shape () \\
\textbf{Optimize} & Insert local buffers. Swap and split ranges, and determine which axes are parallel and which are serial. \\
\textbf{Expand} & Expand the parallel ranges into shape. \\
\textbf{Instruction Selection} & Select target instructions, including WMMA and devectorization. \\
\textbf{Linearize} & Topologically sort the graph and determine execution order. \\
\textbf{Register/Memory Plan} & Allocate and reuse \texttt{GLOBAL}, \texttt{LOCAL}, and \texttt{REG} storage for values with non-overlapping lifetimes. \\
\textbf{Render} & Output the machine code. \\
\bottomrule
\end{tabular}
\end{document}