diff --git a/spec/README.md b/spec/README.md new file mode 100644 index 0000000000..f316625abb --- /dev/null +++ b/spec/README.md @@ -0,0 +1 @@ +Run `./render.sh` whenever you update tinyspec.tex to regenerate tinyspec.pdf. diff --git a/spec/render.sh b/spec/render.sh new file mode 100755 index 0000000000..23a9a3b1f3 --- /dev/null +++ b/spec/render.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -e + +if ! command -v tectonic &>/dev/null; then + echo "tectonic not found, installing..." + sudo pacman -S --noconfirm tectonic +fi + +tectonic tinyspec.tex +echo "done: tinyspec.pdf" diff --git a/spec/tinyspec.pdf b/spec/tinyspec.pdf new file mode 100644 index 0000000000..249befa073 Binary files /dev/null and b/spec/tinyspec.pdf differ diff --git a/spec/tinyspec.tex b/spec/tinyspec.tex new file mode 100644 index 0000000000..28dc5a9e86 --- /dev/null +++ b/spec/tinyspec.tex @@ -0,0 +1,450 @@ +\documentclass[10pt,letterpaper]{article} + +\usepackage[margin=0.75in]{geometry} +\usepackage{amsmath,amssymb} +\usepackage{booktabs} +\usepackage{array} +\usepackage[dvipsnames]{xcolor} +\usepackage{enumitem} +\usepackage{listings} +\lstset{language=Python, basicstyle=\ttfamily\small, columns=fullflexible, keepspaces=true} + +\newcommand{\op}[1]{\textsc{#1}} + +\definecolor{movgreen}{HTML}{2E7D32} +\definecolor{reducered}{HTML}{C62828} +\definecolor{elwyellow}{HTML}{F9A825} +\definecolor{callblue}{HTML}{1565C0} +\definecolor{assignbrown}{HTML}{795548} +\definecolor{multipurple}{HTML}{7B1FA2} +\definecolor{markerorange}{HTML}{E65100} +% AxisType colors (from tinygrad) +\definecolor{axblue}{HTML}{1565C0} % GLOBAL +\definecolor{axcyan}{HTML}{00838F} % LOCAL +\definecolor{axbrcyan}{HTML}{00ACC1} % WARP +\definecolor{axbrblue}{HTML}{42A5F5} % THREAD +\definecolor{axwhite}{HTML}{616161} % LOOP (gray on white paper) +\definecolor{axred}{HTML}{C62828} % REDUCE +\definecolor{axbrred}{HTML}{E53935} % GROUP_REDUCE +\definecolor{axyellow}{HTML}{F9A825} % UPCAST +\definecolor{axmagenta}{HTML}{7B1FA2} % UNROLL + +\title{tinygrad: a single dialect from Tensor programs to Command Buffers} +\author{tinygrad, Corp. \\ \texttt{research@tinygrad.org}} +\date{} + +\begin{document} +\maketitle +\thispagestyle{empty} + +\section*{UOps} + +All nodes in the tinygrad graph are \textbf{UOps}. A UOp is a tuple $(\mathrm{op},\;\mathrm{src},\;\mathrm{arg},\;\mathrm{tag})$ where $\mathrm{op}$ is from the set below, $\mathrm{src}$ is a tuple of input UOps, $\mathrm{arg}$ is op-dependent, and $\mathrm{tag}$ is for temporary processing. The full program is a DAG of UOps. Each UOp has five derived properties --- \textbf{dtype}, \textbf{shape}, \textbf{device}, \textbf{min\_max}, and \textbf{axis} --- determined by the rules at the end of this document. + +%% ============================================================ +\subsection*{Source Ops \normalfont\small--- leaf nodes} + +\begin{tabular}{@{}l p{3.2cm} p{3.0cm} p{6.2cm}@{}} +\toprule +\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ +\midrule +\op{Buffer} & () & size, dtype, device, addrspace & + Shape $(n \cdot \textit{size},)$ if device is $n$-tuple, else $(\textit{size},)$. \\ +\op{BufferView} & (buf,) & size, dtype, offset & + Typed access into a buffer. Zero-copy $(\textit{size},)$ slice at offset; inherits addrspace. \\ +\op{Param} & $(\mathbf{s})$ or $(\mathbf{s}, \text{min}, \text{max})$ & slot, dtype, device? & + Placeholder with shape $\mathbf{s}$. Substituted in \op{Function}. \\[4pt] +\op{Const} & () & value, dtype & + A scalar constant with shape $(\ )$. \\ +\op{Vconst} & () & values, dtype & + A vector constant with shape $(n,)$. \\ +\bottomrule +\end{tabular} + +\smallskip +A \op{Buffer}'s \textbf{addrspace} is \texttt{GLOBAL}, \texttt{LOCAL}, or \texttt{REG}. + +%% ============================================================ +\subsection*{{\color{movgreen}Movement Ops} \normalfont\small--- no arithmetic, shapes are $(k,)$-shaped UOps with dtype \texttt{index} in src} + +\begin{tabular}{@{}l l l l@{}} +\toprule +\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ +\midrule +\op{Permute} & $(T,)$ & axis order $\pi$ & Reorder axes. $\pi = (1,0)$ is transpose. \\ +\op{Flip} & $(T,)$ & bools $\mathbf{f}$ & Reverse along flagged axes. \\ +\op{Reshape} & $(T, \mathbf{s'})$ & --- & Reinterpret in row-major order. $\prod s_k = \prod s'_k$. \\ +\op{Expand} & $(T, \mathbf{s'})$ & --- & Broadcast size-1 axes. $s_k \in \{1, s'_k\}$. \\ +\op{Pad} & $(T, \mathbf{b}, \mathbf{e})$ & --- & Pad with $0$s: $b_k$ before, $e_k$ after each axis. \\ +\op{Shrink} & $(T, \mathbf{b}, \mathbf{e})$ & --- & Keep $[b_k, e_k)$ per axis. Inverse of \op{Pad}. \\ +\op{Index} & $(T, i_0, i_1, \ldots)$ & --- & Index from left. $()$-shaped $i$ removes dim; $(k,)$-shaped makes it $k$. \\ +\op{Stack} & $(T_0, T_1, \ldots)$ & --- & Join along a newly created leading axis. All shapes must match. \\ +\op{Replicated} & $(T,)$ & axes & Mark $T$ as replicated along axes. Collapse axes to $1$. \\ +\bottomrule +\end{tabular} + +%% ============================================================ +\subsection*{{\color{reducered}Reduce Ops} \normalfont\small--- collapse axes to size $1$} + +\begin{tabular}{@{}l l l l@{}} +\toprule +\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ +\midrule +\op{Reduce} & $(T,)$ & op, axes & Reduce $T$ along axes. Op is \op{Add}, \op{Max}, or \op{Mul}. \\ +\bottomrule +\end{tabular} + +%% ============================================================ +\subsection*{{\color{callblue}Call Ops} \normalfont\small--- function abstraction, like the lambda calculus} + +\begin{tabular}{@{}l l l l@{}} +\toprule +\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ +\midrule +\op{Function} & (body, $a_0$, $a_1$, \ldots) & --- & Substitute each \op{Param} $k$ in \op{Tuple} body with $a_k$. Gradient-able. \\ +\op{Call} & (body, $a_0$, $a_1$, \ldots) & --- & Opaque invocation of a compiled kernel or custom function. \\ +\op{Tuple} & $(v_0, v_1, \ldots)$ & --- & Pack values; required as \op{Function} body to return a value. \\ +\op{GetTuple} & $(T,)$ & idx & Extract element at idx from a \op{Tuple}. \\ +\bottomrule +\end{tabular} + +%% ============================================================ +\subsection*{{\color{multipurple}Store Ops} \normalfont\small--- side effects} + +\begin{tabular}{@{}l l l l@{}} +\toprule +\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ +\midrule +\op{Store} & (buf, val, gate?) & --- & Write val into buf. buf.shape $=$ val.shape. \\ + & & & If gate is present, write only when gate is true. Output is void. \\ +\bottomrule +\end{tabular} + +%% ============================================================ +\subsection*{{\color{assignbrown}Ordering Ops} \normalfont\small--- execution order} + +\begin{tabular}{@{}l l l p{6.0cm}@{}} +\toprule +\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ +\midrule +\op{Range} & $(\text{bound},)$ & type & Iterator from $0$ to bound. \\ +\op{End} & (body, range) & --- & Close a \op{Range} loop. \\ +\op{After} & (buf, deps\ldots) & --- & Passthrough of buf; guarantees deps execute first. \\ +\op{Group} & $(u_0, u_1, \ldots)$ & --- & Void no-op that merges multiple \op{Store}s into one node, unordered. \\ +\op{Sink} & $(s_0, s_1, \ldots)$ & --- & Collect side effects into a single root node. \\ +\op{Linear} & (uops\ldots) & --- & Linearized (toposorted) instruction sequence. \\ +\bottomrule +\end{tabular} + +\smallskip +Assign is \op{Store} followed by \op{After}: write the value, then return the buffer with an ordering dependency. + +%% ============================================================ +\subsection*{{\color{elwyellow}Elementwise Ops} \normalfont\small--- all inputs same shape, output same shape, applied per-element} + +\begin{tabular}{@{}l l l l@{}} +\toprule +\textbf{Arity} & \textbf{src} & \textbf{Op} & \textbf{Semantics} \\ +\midrule +Unary & $(T,)$ + & \op{Recip} + & $1/x$ \\ +& & \op{Trunc} + & $\mathrm{trunc}(x)$: round toward zero. \\ +& & \op{Cast} + & Convert to target dtype (specified in arg). \\ +& & \op{Bitcast} + & Reinterpret bits as target dtype. Must be same size. \\[4pt] +Binary & $(A, B)$ + & \op{Add}, \op{Mul}, \op{Max}, \op{Mod}, \op{Idiv} + & $a+b$, $a \cdot b$, $\max(a,b)$, $a \bmod b$, $\lfloor a/b \rfloor$ \\ +& & \op{CmpLt}, \op{CmpNe} + & $[a < b]$, $[a \ne b]$ \\ +& & \op{Xor}, \op{Or}, \op{And}, \op{Shr}, \op{Shl} + & $a \oplus b$, $a \mid b$, $a \mathbin{\&} b$, $a \gg b$, $a \ll b$ \\[4pt] +Ternary & $(P, A, B)$ + & \op{Where} + & $A[\mathbf{i}]$ if $P[\mathbf{i}] \ne 0$, else $B[\mathbf{i}]$ \\ + +\bottomrule +\end{tabular} + +\medskip +\textbf{Decomposed elementwise ops} --- defined in terms of the primitives above. + +\smallskip +\begin{tabular}{@{}l l l@{}} +\toprule +\textbf{Op} & \textbf{Decomposition} & \textbf{Semantics} \\ +\midrule +\op{Neg} & \op{Mul}($A$, $-1$) & $-x$ \\ +\op{Sub} & \op{Add}($A$, \op{Neg}($B$)) & $a - b$ \\ +\op{Div} & \op{Mul}($A$, \op{Recip}($B$)) & $a / b$ \\ +\op{CmpGt} & \op{CmpLt}($B$, $A$) & $[a > b]$ \\ +\op{CmpGe} & \op{CmpNe}(\op{CmpLt}($A$, $B$),\, $1$) & $[a \ge b]$ \\ +\op{CmpLe} & \op{CmpNe}(\op{CmpLt}($B$, $A$),\, $1$) & $[a \le b]$ \\ +\op{CmpEq} & \op{CmpNe}(\op{CmpNe}($A$, $B$),\, $1$) & $[a = b]$ \\ +\op{Not} & \op{CmpNe}($A$, $1$) & $\lnot a$ \\[4pt] +\op{Exp2} & polynomial approx + \op{Mul}, \op{Add} & $2^x$ \\ +\op{Log2} & exponent extract + polynomial approx & $\log_2 x$ \\ +\op{Sin} & argument reduction + polynomial approx & $\sin x$ \\ +\op{Sqrt} & \op{Exp2}($0.5 \cdot$ \op{Log2}($A$)) & $\sqrt{x}$ \\ +\op{Pow} & \op{Exp2}(\op{Log2}($A$) $\cdot\, B$) & $a^b$ \\ +\op{Mulacc} & \op{Add}(\op{Mul}($A$, $B$),\, $C$) & $a \cdot b + c$ \\ +\op{Threefry} & 5 rounds of add-rotate-xor (ARX) & Threefry 2x32 PRNG \\ +\bottomrule +\end{tabular} + +%% ============================================================ +\subsection*{{\color{markerorange}Marker Ops} \normalfont\small--- identity on data} + +\begin{tabular}{@{}l l l l@{}} +\toprule +\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ +\midrule +\op{Contiguous} & $(T,)$ & --- & Force contiguous memory layout. \\ +\op{ContiguousBackward} & $(T,)$ & --- & Force contiguous in backward pass. \\ +\op{Detach} & $(T,)$ & --- & Stops gradient propagation. \\ +\op{Copy} & $(T,)$ & device & Copy to target device. \\ +\bottomrule +\end{tabular} + +%% ============================================================ +\subsection*{Codegen Ops \normalfont\small--- generated code primitives, these do not appear in the main graph} + +\begin{tabular}{@{}l l l l@{}} +\toprule +\textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ +\midrule +\op{Load} & (idx,alt?,gate?) & --- & Dereference: read element at index from buffer. \\ + & & & All loads will be replaced by \op{Store}. \\ +\op{Barrier} & (deps\ldots) & --- & Synchronize threads within a workgroup. \\ +\op{Ins} & \ldots & \ldots & A single machine instruction (e.g.\ AMD ISA). \\ +\op{Special} & (bound,) & name & GPU thread/workgroup index (e.g.\ \texttt{gidx0}, \texttt{lidx1}). \\ +\op{If} & (gate,) & --- & Begin conditional execution block. \\ +\op{Endif} & (if,) & --- & End conditional execution block. \\ +\op{Wmma} & (A, B, acc) & config & Warp matrix multiply-accumulate (tensor cores). \\ +\op{Custom} & (args\ldots) & fmt & Inject custom code string into generated source. \\ +\op{AtomicAdd} & (idx, val) & --- & Atomic read-modify-write: \texttt{buf[idx] += val}. \\[4pt] +\op{CustomFunction} & (meta\ldots) & name & Opaque device function (e.g.\ HW decode). Via \op{Call}. \\ +\op{Program} & (linear, source, binary) & --- & Compiled kernel: instructions, source, and machine code. \\ +\op{Source} & () & str & Human-readable rendered source code. \\ +\op{Binary} & () & bytes & Compiled machine code. \\ +\bottomrule +\end{tabular} + +\smallskip +These ops are not part of the core specification and are subject to change. + +%% ============================================================ +\subsection*{Derived Properties} + +Every UOp has a \textbf{dtype}, \textbf{shape}, \textbf{device}, \textbf{min\_max}, and \textbf{axis}, derived from its op, src, and arg: + +\medskip +\begin{tabular}{@{}l l l l l@{}} +\toprule +\textbf{Op} & \textbf{dtype} & \textbf{shape} & \textbf{device} & \textbf{min\_max} \\ +\midrule +\op{Buffer} & from arg & $(\text{size},)$ from arg & from arg & dtype range \\ +\op{Const} & from arg & $()$ & \textsc{null} & $[v, v]$ \\ +\op{Param} & from arg & from $\mathrm{src}[0]$ & from arg & from src or dtype range \\[3pt] +Movement ops & $\mathrm{src}[0].\mathrm{dtype}$ & (see op) & $\mathrm{src}[0].\mathrm{device}$ & $\mathrm{src}[0]$ \\ +\op{Reduce} & $\mathrm{src}[0].\mathrm{dtype}$ & collapse axes to $1$ & $\mathrm{src}[0].\mathrm{device}$ & dtype range \\[3pt] +\op{Cast} & from arg & $\mathrm{src}[0].\mathrm{shape}$ & $\mathrm{src}[0].\mathrm{device}$ & clamped to dtype \\ +\op{Bitcast} & from arg & $\mathrm{src}[0].\mathrm{shape}$ & $\mathrm{src}[0].\mathrm{device}$ & dtype range \\ +\op{Copy} & $\mathrm{src}[0].\mathrm{dtype}$ & $\mathrm{src}[0].\mathrm{shape}$ & from arg & $\mathrm{src}[0]$ \\ +ALU unary & $\mathrm{src}[0].\mathrm{dtype}$ & $\mathrm{src}[0].\mathrm{shape}$ & $\mathrm{src}[0].\mathrm{device}$ & dtype range \\ +\op{Add} & $\mathrm{src}[0].\mathrm{dtype}$ & broadcast & $\mathrm{src}[0].\mathrm{device}$ & $[a+b,\, A+B]$ \\ +\op{Mul} & $\mathrm{src}[0].\mathrm{dtype}$ & broadcast & $\mathrm{src}[0].\mathrm{device}$ & $[\min,\max]$ of products \\ +\op{Max} & $\mathrm{src}[0].\mathrm{dtype}$ & broadcast & $\mathrm{src}[0].\mathrm{device}$ & $[\max(a,b),\, \max(A,B)]$ \\ +Other binary & $\mathrm{src}[0].\mathrm{dtype}$ & broadcast & $\mathrm{src}[0].\mathrm{device}$ & dtype range \\ +\op{CmpLt}, \op{CmpNe} & bool & broadcast & $\mathrm{src}[0].\mathrm{device}$ & from intervals \\ +\op{Where} & $\mathrm{src}[1].\mathrm{dtype}$ & broadcast & $\mathrm{src}[0].\mathrm{device}$ & $[\min(b,c),\, \max(B,C)]$ \\[3pt] +\op{Function}, \op{Call} & $\mathrm{src}[0].\mathrm{dtype}$ & substitute \op{Param} shapes & $\mathrm{src}[1].\mathrm{device}$ & dtype range \\ +\op{Range} & index & $()$ & \textsc{null} & $[0,\, n{-}1]$ \\ +\op{Index} & $\mathrm{src}[0].\mathrm{dtype}$ & remaining dims & $\mathrm{src}[0].\mathrm{device}$ & $\mathrm{src}[0]$ \\ +\op{Store} & void & $()$ & $\mathrm{src}[0].\mathrm{device}$ & --- \\ +\op{After} & $\mathrm{src}[0].\mathrm{dtype}$ & $\mathrm{src}[0].\mathrm{shape}$ & $\mathrm{src}[0].\mathrm{device}$ & $\mathrm{src}[0]$ \\ +\bottomrule +\end{tabular} + +\smallskip +$\mathrm{broadcast}$: right-align shapes, element-wise max; each axis must be equal or $1$. +$[a,A]$, $[b,B]$, $[c,C]$ denote min\_max of $\mathrm{src}[0]$, $\mathrm{src}[1]$, $\mathrm{src}[2]$. +Default \emph{dtype range}: $[\mathrm{dtype\_min},\, \mathrm{dtype\_max}]$. + +\medskip +\textbf{axis} tracks the multi-device sharding dimension. \op{Buffer} with $n$-tuple device: axis $= 0$ (device dim). +\op{Reshape} remaps axis to preserve the shard boundary. \op{Permute} follows the permutation. +\op{Reduce} on the shard axis $\to$ \textsc{null}. \op{Replicated} on the shard axis $\to$ \textsc{null}. \op{Copy} $\to$ \textsc{null}. ALU ops inherit from sources. Default: \textsc{null}. + +%% ============================================================ +\subsection*{Kernel Optimizations (OptOps) \normalfont\small--- schedule-level transforms on kernel ranges} + +Each kernel's iteration space is a set of \op{Range} axes. Every range has an \textbf{AxisType}: + +\medskip +\begin{tabular}{@{}l l l l l@{}} +\toprule +\textbf{AxisType} & \textbf{Letter} & \textbf{Split from} & \textbf{Direction} & \textbf{Semantics} \\ +\midrule +{\color{axblue}\texttt{GLOBAL}} & \texttt{g} & --- & --- & GPU global workgroup dimension. \\ +{\color{axcyan}\texttt{LOCAL}} & \texttt{l} & g, L & inner & Workgroup local dimension (shared memory). \\ +{\color{axbrcyan}\texttt{WARP}} & \texttt{w} & \multicolumn{2}{l}{(created by \op{TC})} & Warp-level lanes for tensor cores. \\ +{\color{axbrblue}\texttt{THREAD}} & \texttt{t} & g & outer & CPU thread parallelism. \\ +{\color{axwhite}\texttt{LOOP}} & \texttt{L} & --- & --- & Generic sequential loop (initial state). \\ +{\color{axred}\texttt{REDUCE}} & \texttt{R} & --- & --- & Reduction axis. \\ +{\color{axbrred}\texttt{GROUP\_REDUCE}} & \texttt{G} & R & inner/outer & Shared-memory group reduction. \\ +{\color{axyellow}\texttt{UPCAST}} & \texttt{u} & g, l, L & inner & Register-level vectorization. \\ +{\color{axmagenta}\texttt{UNROLL}} & \texttt{r} & R, G & inner & Fully unrolled loop. \\ +\bottomrule +\end{tabular} + +\medskip +An optimization is a triple $(\mathrm{op},\;\mathrm{axis},\;\mathrm{arg})$: + +\smallskip +\begin{tabular}{@{}l l l p{6.5cm}@{}} +\toprule +\textbf{OptOp} & \textbf{axis} & \textbf{arg} & \textbf{Semantics} \\ +\midrule +\op{Split} & any & (factor $k$, target, top?) & + Split axis $n$ by $k$ into $(n/k, k)$ or $(k, n/k)$ if top. New sub-axis gets target AxisType (see table above). \\ +\op{Padto} & any & multiple $m$ & + Pad axis to next multiple of $m$ with validity masks. \\[4pt] +\op{Swap} & axis$_i$ & axis$_j$ & + Swap two axes $i \leftrightarrow j$. \\ +\op{Nolocals} & --- & --- & + Disable local memory; no workgroup dims emitted. \\ +\op{TC} & reduce idx & (tc, opt, mode) & + Apply tensor core \op{Wmma}: split reduce/output axes into \texttt{WARP}, \texttt{UPCAST}, and \texttt{UNROLL} dims. \\ +\bottomrule +\end{tabular} + +\smallskip +Optimizations compose left-to-right. \op{TC} must be first. The search space is explored by BEAM search or hand-coded heuristics. + +%% ============================================================ +\subsection*{Common Ops as Compositions} + +All high-level tensor operations decompose into the primitives above. + +\begin{lstlisting} +# gemm: C[M,N] = A[M,K] @ B[K,N] +def gemm(A, B): + M,K = A.shape; _,N = B.shape + return (A.reshape(M,K,1) * B.reshape(1,K,N)).sum(1) + +# prefix_sum: cumulative sum via repeat+reshape sliding window trick +def prefix_sum(T): + n = T.shape[0] + x = T.pad((n-1, 0)) # (2n-1,) + x = x.reshape(1,2*n-1).expand(n+1,2*n-1) # tile + x = x.reshape((n+1)*(2*n-1)).shrink_to(2*n*n) # trim + x = x.reshape(n,2*n).shrink_to(n,n) # windows + return x.sum(-1) # reduce + +# arange: prefix_sum of all 1s gives [1,2,...,n], subtract 1 for [0,1,...,n-1] +def arange(n): + return prefix_sum(Tensor(1).reshape(1).expand(n)) - 1 + +# gather: out[i] = T[idx[i]]. one-hot mask along gather axis, then reduce +def gather(T, idx): + K = T.shape[0] + pos = arange(K).reshape(K, 1) # (K, 1) + mask = (pos == idx.reshape(1, -1)).cast(T.dtype) # (K, D) + return (T.reshape(K, 1) * mask).sum(0) # (D,) + +# scatter_add: T[idx[i]] += val[i] +def scatter_add(T, idx, val): + K, D = T.shape[0], idx.shape[0] + pos = arange(K).reshape(K, 1) # (K, 1) + mask = (pos == idx.reshape(1, D)).cast(T.dtype) # (K, D) + return T + (mask * val.reshape(1, D)).sum(1) # (K,) +\end{lstlisting} + +%% ============================================================ +\subsection*{{\color{multipurple}Multi-Device Collectives} \normalfont\small--- derived from primitives} + +Let $D = (d_0, \ldots, d_{n-1})$ be an $n$-tuple device. +\op{Copy} to an $n$-tuple device reshards with axis $= 0$. \op{Copy} never changes shape. + +\begin{lstlisting} +# T has shape (s,) on a single device. + +# broadcast: replicate T to all n devices +def broadcast(T): + return T.reshape(1, s).expand(n, s).copy(D).replicated(0) # (s,) on D, axis=null + +# scatter: split T into n chunks, one per device +def scatter(T): + return T.copy(D) # (s,) on D, axis=0 + +# T has shape (n*s,) on D with axis=0, so each device holds (s,) elements. + +# gather: collect all shards onto one device +def gather(T): + return T.copy(D[0]) # (n*s,) on D[0], axis=null + +# reduce: gather + sum +def reduce(T): + return gather(T).reshape(n, s).sum(0) # (s,) on D[0], axis=null + +# allgather: collect all shards, replicate to all devices +def allgather(T): + return T.reshape(1, n*s).expand(n, n*s).copy(D).replicated(0) # (n*s,) on D, axis=null + +# reduce_scatter: reduce across devices, scatter result +def reduce_scatter(T): + return T.reshape(n, n, s//n).permute(1, 0, 2).copy(D).sum(1).reshape(s) # (s,) on D, axis=0 + +# allreduce: reduce_scatter + allgather +def allreduce(T): + return allgather(reduce_scatter(T)) # (s,) on D, axis=null +\end{lstlisting} + +%% ============================================================ +\subsection*{{\color{callblue}The \texttt{@function} Decorator} \normalfont\small--- graph capture via tracing} + +The \texttt{@function} decorator transforms a Python function on Tensors into a single \op{Function} node. + +\begin{lstlisting} +@function +def f(a: Tensor, b: Tensor) -> Tensor: + return a + b +\end{lstlisting} + +When \texttt{f(x, y)} is called, the decorator: + +\begin{enumerate}[leftmargin=1.5em, itemsep=2pt] + \item \textbf{Extracts inputs}: walks all arguments to find every Tensor, deduplicates by identity. + \item \textbf{Runs the function} lazily (no device execution), building a UOp graph from the result. + \item \textbf{Parameterizes}: replaces each input UOp with a \op{Param}$(k)$ placeholder. + \item \textbf{Wraps the body} in a \op{Tuple} (even for single returns) and creates\\ + \op{Function}(\op{Tuple}(body), $x$, $y$). + \item \textbf{Returns} the result via \op{GetTuple}$(0)$, or one \op{GetTuple} per element for tuple returns. +\end{enumerate} + +The result is a reusable graph fragment: the body contains only \op{Param} references, not concrete buffers. At schedule time, the \op{Function} is resolved by substituting each \op{Param}$(k)$ back with its corresponding argument $a_k$, or lowered into an opaque \op{Call} if it is to be compiled as a reusable kernel. + +%% ============================================================ +\subsection*{Lowering Pipeline \normalfont\small--- from Tensor graph to machine code} + +\begin{tabular}{@{}l p{9.7cm}@{}} +\toprule +\textbf{Stage} & \textbf{Semantics} \\ +\midrule +\textbf{Callify} & Transform the Tensor graph into a single stateless function. \\ +\textbf{Rangeify} & Determine the kernel split of the function. Break everything down to shape () \\ +\textbf{Optimize} & Insert local buffers. Swap and split ranges, and determine which axes are parallel and which are serial. \\ +\textbf{Expand} & Expand the parallel ranges into shape. \\ +\textbf{Instruction Selection} & Select target instructions, including WMMA and devectorization. \\ +\textbf{Linearize} & Topologically sort the graph and determine execution order. \\ +\textbf{Register/Memory Plan} & Allocate and reuse \texttt{GLOBAL}, \texttt{LOCAL}, and \texttt{REG} storage for values with non-overlapping lifetimes. \\ +\textbf{Render} & Output the machine code. \\ +\bottomrule +\end{tabular} + + +\end{document}