\documentclass[10pt,letterpaper]{article} \usepackage[margin=0.75in]{geometry} \usepackage{amsmath,amssymb} \usepackage{booktabs} \usepackage{array} \usepackage[dvipsnames]{xcolor} \usepackage{enumitem} \usepackage{listings} \lstset{language=Python, basicstyle=\ttfamily\small, columns=fullflexible, keepspaces=true} \newcommand{\op}[1]{\textsc{#1}} \definecolor{movgreen}{HTML}{2E7D32} \definecolor{reducered}{HTML}{C62828} \definecolor{elwyellow}{HTML}{F9A825} \definecolor{callblue}{HTML}{1565C0} \definecolor{assignbrown}{HTML}{795548} \definecolor{loadred}{HTML}{c08080} \definecolor{multipurple}{HTML}{7B1FA2} \definecolor{markerorange}{HTML}{E65100} % AxisType colors (from tinygrad) \definecolor{axblue}{HTML}{1565C0} % GLOBAL \definecolor{axcyan}{HTML}{00838F} % LOCAL \definecolor{axbrcyan}{HTML}{00ACC1} % WARP \definecolor{axbrblue}{HTML}{42A5F5} % THREAD \definecolor{axwhite}{HTML}{616161} % LOOP (gray on white paper) \definecolor{axred}{HTML}{C62828} % REDUCE \definecolor{axbrred}{HTML}{E53935} % GROUP_REDUCE \definecolor{axyellow}{HTML}{F9A825} % UPCAST \definecolor{axmagenta}{HTML}{7B1FA2} % UNROLL \title{tinygrad: a single dialect from Tensor programs to Command Buffers} \author{tinygrad, Corp. \\ \texttt{research@tinygrad.org}} \date{} \begin{document} \maketitle \thispagestyle{empty} \section*{UOps} All nodes in the tinygrad graph are \textbf{UOps}. A UOp is a tuple $(\mathrm{op},\;\mathrm{src},\;\mathrm{arg},\;\mathrm{tag})$ where $\mathrm{op}$ is from the set below, $\mathrm{src}$ is a tuple of input UOps, $\mathrm{arg}$ is op-dependent, and $\mathrm{tag}$ is for temporary processing. The full program is a DAG of UOps. Each UOp has five derived properties --- \textbf{dtype}, \textbf{shape}, \textbf{device}, \textbf{min\_max}, and \textbf{axis} --- determined by the rules at the end of this document. %% ============================================================ \subsection*{Source Ops \normalfont\small--- leaf nodes} \begin{tabular}{@{}l p{3.2cm} p{3.0cm} p{6.2cm}@{}} \toprule \textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ \midrule \op{Param} & $(\mathbf{s})$ & slot, dtype, device?, addrspace? & Placeholder with shape $\mathbf{s}$. Substituted in \op{Function}. \\[4pt] \op{Buffer} & () & size, dtype, device, addrspace & Shape $(n \cdot \textit{size},)$ if device is $n$-tuple, else $(\textit{size},)$. \\ \op{Const} & () & value, dtype & A scalar constant with shape $(\ )$. \\ & & & Form vector consts with \op{Stack} \\ \op{Binary} & () & data & Raw binary data, has dtype uint8 and shape len($data$) \\ \bottomrule \end{tabular} \smallskip A \op{Buffer}'s \textbf{addrspace} is \texttt{GLOBAL}, \texttt{LOCAL}, or \texttt{REG}. %% ============================================================ \subsection*{{\color{movgreen}Movement Ops} \normalfont\small--- no arithmetic; view, indexing, and reinterpretation only} \begin{tabular}{@{}l l l l@{}} \toprule \textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ \midrule \op{Permute} & $(T,)$ & axis order $\pi$ & Reorder axes. $\pi = (1,0)$ is transpose. \\ \op{Flip} & $(T,)$ & bools $\mathbf{f}$ & Reverse along flagged axes. \\ \op{Reshape} & $(T, \mathbf{s'})$ & --- & Reinterpret in row-major order. $\prod s_k = \prod s'_k$. \\ \op{Expand} & $(T, \mathbf{s'})$ & --- & Broadcast size-1 axes. $s_k \in \{1, s'_k\}$. \\ \op{Pad} & $(T, \mathbf{o}, \mathbf{s'})$ & --- & Place $T$ at offset $o_k$ in a zero-filled output of shape $s'_k$. \\ \op{Shrink} & $(T, \mathbf{o}, \mathbf{s'})$ & --- & Keep $s'_k$ elements starting at offset $o_k$ per axis. Inverse of \op{Pad}. \\ \op{Index} & $(T, i_0, i_1, \ldots)$ & --- & Index from left. $()$-shaped $i$ removes dim; $(k,)$-shaped makes it $k$. \\ \op{Stack} & $(T_0, T_1, \ldots)$ & --- & Join along a newly created leading axis. All shapes must match. \\ \op{Replicated} & $(T,)$ & axes & Mark $T$ as replicated along axes. Collapse axes to $1$. \\ \op{Slice} & $(T, \mathrm{offset})$ & size, dtype & Zero-copy \textit{size} elems of dtype; offset is elems of $T$ dtype. \\ \op{Bitcast} & $(T,)$ & dtype & Reinterpret storage as target dtype; preserve total bytes. \\ \bottomrule \end{tabular} %% ============================================================ \subsection*{{\color{reducered}Reduce Ops} \normalfont\small--- collapse axes to size $1$} \begin{tabular}{@{}l l l l@{}} \toprule \textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ \midrule \op{Reduce} & ($T$, $r_0$, $r_1$, \ldots) & op, axes & Reduce $T$ along axes or ranges. Op is \op{Add}, \op{Max}, or \op{Mul}. \\ \bottomrule \end{tabular} %% ============================================================ \subsection*{{\color{callblue}Call Ops} \normalfont\small--- function abstraction, like the lambda calculus} \begin{tabular}{@{}l l l l@{}} \toprule \textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ \midrule \op{Function} & (body, $a_0$, $a_1$, \ldots) & --- & Substitute each \op{Param} $k$ in \op{Tuple} body with $a_k$. Gradient-able. \\ \op{Call} & (body, $a_0$, $a_1$, \ldots) & --- & Opaque invocation of a compiled kernel or custom function. \\ \op{Tuple} & $(v_0, v_1, \ldots)$ & --- & Pack values; required as \op{Function} body to return a value. \\ \op{GetTuple} & $(T,)$ & idx & Extract element at idx from a \op{Tuple}. \\ \bottomrule \end{tabular} %% ============================================================ \subsection*{{\color{loadred}Load Ops} \normalfont\small--- can change device or addrspace} \begin{tabular}{@{}l l l l@{}} \toprule \textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ \midrule \op{Load} & (buf, alt?, gate?) & device, addrspace & Read (pull) from buffer into a new anonymous buffer. \\ & & & Note: this replaces \op{Copy} and \op{Contiguous}. \\ \bottomrule \end{tabular} %% ============================================================ \subsection*{{\color{multipurple}Store Ops} \normalfont\small--- the only op with observable side effects} \begin{tabular}{@{}l l l l@{}} \toprule \textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ \midrule \op{Store} & (buf, val, gate?) & --- & Write (push) val into buf. buf.shape $=$ val.shape. \\ & & & If gate is present, write only when gate is true. Output is void. \\ \bottomrule \end{tabular} %% ============================================================ \subsection*{{\color{assignbrown}Ordering Ops} \normalfont\small--- execution order} \begin{tabular}{@{}l l l p{6.0cm}@{}} \toprule \textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ \midrule \op{Range} & $(\text{bound},)$ & type & Iterator from $0$ to bound. \\ \op{End} & (body, range) & --- & Close a \op{Range} loop. \\ \op{After} & (buf, deps\ldots) & --- & Passthrough of buf; guarantees deps execute first. \\ \op{Group} & $(u_0, u_1, \ldots)$ & --- & Void no-op that merges multiple \op{Store}s into one node, unordered. \\ \op{Sink} & $(s_0, s_1, \ldots)$ & --- & Collect side effects into a single root node. \\ \op{Linear} & (uops\ldots) & --- & Linearized (toposorted) instruction sequence. \\ \bottomrule \end{tabular} \smallskip Assign is \op{Store} followed by \op{After}: write the value, then return the buffer with an ordering dependency. %% ============================================================ \subsection*{{\color{elwyellow}Elementwise Ops} \normalfont\small--- all inputs same shape, output same shape, applied per-element} \begin{tabular}{@{}l l l l@{}} \toprule \textbf{Arity} & \textbf{src} & \textbf{Op} & \textbf{Semantics} \\ \midrule Unary & $(T,)$ & \op{Recip} & $1/x$ \\ & & \op{Trunc} & $\mathrm{trunc}(x)$: round toward zero. \\ & & \op{Cast} & Convert to target dtype (specified in arg). \\ \\[4pt] Binary & $(A, B)$ & \op{Add}, \op{Mul}, \op{Max}, \op{Mod}, \op{Idiv} & $a+b$, $a \cdot b$, $\max(a,b)$, $a \bmod b$, $\lfloor a/b \rfloor$ \\ & & \op{CmpLt}, \op{CmpNe} & $[a < b]$, $[a \ne b]$ \\ & & \op{Xor}, \op{Or}, \op{And}, \op{Shr}, \op{Shl} & $a \oplus b$, $a \mid b$, $a \mathbin{\&} b$, $a \gg b$, $a \ll b$ \\[4pt] Ternary & $(P, A, B)$ & \op{Where} & $A[\mathbf{i}]$ if $P[\mathbf{i}] \ne 0$, else $B[\mathbf{i}]$ \\ \bottomrule \end{tabular} \medskip \textbf{Decomposed elementwise ops} --- defined in terms of the primitives above. \smallskip \begin{tabular}{@{}l l l@{}} \toprule \textbf{Op} & \textbf{Decomposition} & \textbf{Semantics} \\ \midrule \op{Neg} & \op{Mul}($A$, $-1$) & $-x$ \\ \op{Sub} & \op{Add}($A$, \op{Neg}($B$)) & $a - b$ \\ \op{Div} & \op{Mul}($A$, \op{Recip}($B$)) & $a / b$ \\ \op{CmpGt} & \op{CmpLt}($B$, $A$) & $[a > b]$ \\ \op{CmpGe} & \op{CmpNe}(\op{CmpLt}($A$, $B$),\, $1$) & $[a \ge b]$ \\ \op{CmpLe} & \op{CmpNe}(\op{CmpLt}($B$, $A$),\, $1$) & $[a \le b]$ \\ \op{CmpEq} & \op{CmpNe}(\op{CmpNe}($A$, $B$),\, $1$) & $[a = b]$ \\ \op{Not} & \op{CmpNe}($A$, $1$) & $\lnot a$ \\[4pt] \op{Exp2} & polynomial approx + \op{Mul}, \op{Add} & $2^x$ \\ \op{Log2} & exponent extract + polynomial approx & $\log_2 x$ \\ \op{Sin} & argument reduction + polynomial approx & $\sin x$ \\ \op{Sqrt} & \op{Exp2}($0.5 \cdot$ \op{Log2}($A$)) & $\sqrt{x}$ \\ \op{Pow} & \op{Exp2}(\op{Log2}($A$) $\cdot\, B$) & $a^b$ \\ \op{Mulacc} & \op{Add}(\op{Mul}($A$, $B$),\, $C$) & $a \cdot b + c$ \\ \op{Threefry} & 5 rounds of add-rotate-xor (ARX) & Threefry 2x32 PRNG \\ \bottomrule \end{tabular} %% ============================================================ \subsection*{{\color{markerorange}Marker Ops} \normalfont\small--- identity on data} \begin{tabular}{@{}l l l l@{}} \toprule \textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ \midrule \op{Contiguous} & $(T,)$ & --- & Force contiguous memory layout. \\ \op{ContiguousBackward} & $(T,)$ & --- & Force contiguous in backward pass. \\ \op{Detach} & $(T,)$ & --- & Stops gradient propagation. \\ \bottomrule \end{tabular} %% ============================================================ \subsection*{Codegen Ops \normalfont\small--- generated code primitives, these do not appear in the main graph} \begin{tabular}{@{}l l l l@{}} \toprule \textbf{Op} & \textbf{src} & \textbf{arg} & \textbf{Semantics} \\ \midrule \op{Barrier} & (deps\ldots) & --- & Synchronize threads within a workgroup. \\ \op{Ins} & \ldots & \ldots & A single machine instruction (e.g.\ AMD ISA). \\ \op{Special} & (bound,) & name & GPU thread/workgroup index (e.g.\ \texttt{gidx0}, \texttt{lidx1}). \\ \op{If} & (gate,) & --- & Begin conditional execution block. \\ \op{Endif} & (if,) & --- & End conditional execution block. \\ \op{Wmma} & (A, B, acc) & config & Warp matrix multiply-accumulate (tensor cores). \\ \op{Custom} & (args\ldots) & fmt & Inject custom code string into generated source. \\ \op{AtomicAdd} & (idx, val) & --- & Atomic read-modify-write: \texttt{buf[idx] += val}. \\[4pt] \op{CustomFunction} & (meta\ldots) & name & Opaque device function (e.g.\ HW decode). Via \op{Call}. \\ \op{Program} & (linear, source, binary) & --- & Compiled kernel: instructions, source, and machine code. \\ \op{Source} & () & str & Human-readable rendered source code. \\ \op{Binary} & () & bytes & Compiled machine code. \\ \bottomrule \end{tabular} \smallskip These ops are not part of the core specification and are subject to change. %% ============================================================ \subsection*{Derived Properties} Every UOp has a \textbf{dtype}, \textbf{shape}, \textbf{device}, \textbf{min\_max}, and \textbf{axis}, derived from its op, src, and arg: \medskip \begin{tabular}{@{}l l l l l@{}} \toprule \textbf{Op} & \textbf{dtype} & \textbf{shape} & \textbf{device} & \textbf{min\_max} \\ \midrule \op{Buffer} & from arg & $(\text{size},)$ from arg & from arg & dtype range \\ \op{Const} & from arg & $()$ & \textsc{null} & $[v, v]$ \\ \op{Param} & from arg & from $\mathrm{src}[0]$ & from arg & from src or dtype range \\[3pt] Movement ops & $\mathrm{src}[0].\mathrm{dtype}$ & (see op) & $\mathrm{src}[0].\mathrm{device}$ & $\mathrm{src}[0]$ \\ \op{Reduce} & $\mathrm{src}[0].\mathrm{dtype}$ & collapse axes to $1$ & $\mathrm{src}[0].\mathrm{device}$ & dtype range \\[3pt] \op{Cast} & from arg & $\mathrm{src}[0].\mathrm{shape}$ & $\mathrm{src}[0].\mathrm{device}$ & clamped to dtype \\ \op{Bitcast} & from arg & $\mathrm{src}[0].\mathrm{shape}$ & $\mathrm{src}[0].\mathrm{device}$ & dtype range \\ \op{Copy} & $\mathrm{src}[0].\mathrm{dtype}$ & $\mathrm{src}[0].\mathrm{shape}$ & from arg & $\mathrm{src}[0]$ \\ ALU unary & $\mathrm{src}[0].\mathrm{dtype}$ & $\mathrm{src}[0].\mathrm{shape}$ & $\mathrm{src}[0].\mathrm{device}$ & dtype range \\ \op{Add} & $\mathrm{src}[0].\mathrm{dtype}$ & broadcast & $\mathrm{src}[0].\mathrm{device}$ & $[a+b,\, A+B]$ \\ \op{Mul} & $\mathrm{src}[0].\mathrm{dtype}$ & broadcast & $\mathrm{src}[0].\mathrm{device}$ & $[\min,\max]$ of products \\ \op{Max} & $\mathrm{src}[0].\mathrm{dtype}$ & broadcast & $\mathrm{src}[0].\mathrm{device}$ & $[\max(a,b),\, \max(A,B)]$ \\ Other binary & $\mathrm{src}[0].\mathrm{dtype}$ & broadcast & $\mathrm{src}[0].\mathrm{device}$ & dtype range \\ \op{CmpLt}, \op{CmpNe} & bool & broadcast & $\mathrm{src}[0].\mathrm{device}$ & from intervals \\ \op{Where} & $\mathrm{src}[1].\mathrm{dtype}$ & broadcast & $\mathrm{src}[0].\mathrm{device}$ & $[\min(b,c),\, \max(B,C)]$ \\[3pt] \op{Function}, \op{Call} & $\mathrm{src}[0].\mathrm{dtype}$ & substitute \op{Param} shapes & $\mathrm{src}[1].\mathrm{device}$ & dtype range \\ \op{Range} & index & $()$ & \textsc{null} & $[0,\, n{-}1]$ \\ \op{Index} & $\mathrm{src}[0].\mathrm{dtype}$ & remaining dims & $\mathrm{src}[0].\mathrm{device}$ & $\mathrm{src}[0]$ \\ \op{Store} & void & $()$ & $\mathrm{src}[0].\mathrm{device}$ & --- \\ \op{After} & $\mathrm{src}[0].\mathrm{dtype}$ & $\mathrm{src}[0].\mathrm{shape}$ & $\mathrm{src}[0].\mathrm{device}$ & $\mathrm{src}[0]$ \\ \bottomrule \end{tabular} \smallskip $\mathrm{broadcast}$: right-align shapes, element-wise max; each axis must be equal or $1$. $[a,A]$, $[b,B]$, $[c,C]$ denote min\_max of $\mathrm{src}[0]$, $\mathrm{src}[1]$, $\mathrm{src}[2]$. Default \emph{dtype range}: $[\mathrm{dtype\_min},\, \mathrm{dtype\_max}]$. \medskip \textbf{axis} tracks the multi-device sharding dimension. \op{Buffer} with $n$-tuple device: axis $= 0$ (device dim). \op{Reshape} remaps axis to preserve the shard boundary. \op{Permute} follows the permutation. \op{Reduce} on the shard axis $\to$ \textsc{null}. \op{Replicated} on the shard axis $\to$ \textsc{null}. \op{Copy} $\to$ \textsc{null}. ALU ops inherit from sources. Default: \textsc{null}. %% ============================================================ \subsection*{Kernel Optimizations (OptOps) \normalfont\small--- schedule-level transforms on kernel ranges} Each kernel's iteration space is a set of \op{Range} axes. Every range has an \textbf{AxisType}: \medskip \begin{tabular}{@{}l l l l l@{}} \toprule \textbf{AxisType} & \textbf{Letter} & \textbf{Split from} & \textbf{Direction} & \textbf{Semantics} \\ \midrule {\color{axblue}\texttt{GLOBAL}} & \texttt{g} & --- & --- & GPU global workgroup dimension. \\ {\color{axcyan}\texttt{LOCAL}} & \texttt{l} & g, L & inner & Workgroup local dimension (shared memory). \\ {\color{axbrcyan}\texttt{WARP}} & \texttt{w} & \multicolumn{2}{l}{(created by \op{TC})} & Warp-level lanes for tensor cores. \\ {\color{axbrblue}\texttt{THREAD}} & \texttt{t} & g & outer & CPU thread parallelism. \\ {\color{axwhite}\texttt{LOOP}} & \texttt{L} & --- & --- & Generic sequential loop (initial state). \\ {\color{axred}\texttt{REDUCE}} & \texttt{R} & --- & --- & Reduction axis. \\ {\color{axbrred}\texttt{GROUP\_REDUCE}} & \texttt{G} & R & inner/outer & Shared-memory group reduction. \\ {\color{axyellow}\texttt{UPCAST}} & \texttt{u} & g, l, L & inner & Register-level vectorization. \\ {\color{axmagenta}\texttt{UNROLL}} & \texttt{r} & R, G & inner & Fully unrolled loop. \\ \bottomrule \end{tabular} \medskip An optimization is a triple $(\mathrm{op},\;\mathrm{axis},\;\mathrm{arg})$: \smallskip \begin{tabular}{@{}l l l p{6.5cm}@{}} \toprule \textbf{OptOp} & \textbf{axis} & \textbf{arg} & \textbf{Semantics} \\ \midrule \op{Split} & any & (factor $k$, target, top?) & Split axis $n$ by $k$ into $(n/k, k)$ or $(k, n/k)$ if top. New sub-axis gets target AxisType (see table above). \\ \op{Padto} & any & multiple $m$ & Pad axis to next multiple of $m$ with validity masks. \\[4pt] \op{Swap} & axis$_i$ & axis$_j$ & Swap two axes $i \leftrightarrow j$. \\ \op{Nolocals} & --- & --- & Disable local memory; no workgroup dims emitted. \\ \op{TC} & reduce idx & (tc, opt, mode) & Apply tensor core \op{Wmma}: split reduce/output axes into \texttt{WARP}, \texttt{UPCAST}, and \texttt{UNROLL} dims. \\ \bottomrule \end{tabular} \smallskip Optimizations compose left-to-right. \op{TC} must be first. The search space is explored by BEAM search or hand-coded heuristics. %% ============================================================ \subsection*{Common Ops as Compositions} All high-level tensor operations decompose into the primitives above. \begin{lstlisting} # gemm: C[M,N] = A[M,K] @ B[K,N] def gemm(A, B): M,K = A.shape; _,N = B.shape return (A.reshape(M,K,1) * B.reshape(1,K,N)).sum(1) # prefix_sum: cumulative sum via repeat+reshape sliding window trick def prefix_sum(T): n = T.shape[0] x = T.pad((n-1, 0)) # (2n-1,) x = x.reshape(1,2*n-1).expand(n+1,2*n-1) # tile x = x.reshape((n+1)*(2*n-1)).shrink_to(2*n*n) # trim x = x.reshape(n,2*n).shrink_to(n,n) # windows return x.sum(-1) # reduce # arange: prefix_sum of all 1s gives [1,2,...,n], subtract 1 for [0,1,...,n-1] def arange(n): return prefix_sum(Tensor(1).reshape(1).expand(n)) - 1 # gather: out[i] = T[idx[i]]. one-hot mask along gather axis, then reduce def gather(T, idx): K = T.shape[0] pos = arange(K).reshape(K, 1) # (K, 1) mask = (pos == idx.reshape(1, -1)).cast(T.dtype) # (K, D) return (T.reshape(K, 1) * mask).sum(0) # (D,) # scatter_add: T[idx[i]] += val[i] def scatter_add(T, idx, val): K, D = T.shape[0], idx.shape[0] pos = arange(K).reshape(K, 1) # (K, 1) mask = (pos == idx.reshape(1, D)).cast(T.dtype) # (K, D) return T + (mask * val.reshape(1, D)).sum(1) # (K,) \end{lstlisting} %% ============================================================ \subsection*{{\color{multipurple}Multi-Device Collectives} \normalfont\small--- derived from primitives} Let $D = (d_0, \ldots, d_{n-1})$ be an $n$-tuple device. \op{Copy} to an $n$-tuple device reshards with axis $= 0$. \op{Copy} never changes shape. \begin{lstlisting} # T has shape (s,) on a single device. # broadcast: replicate T to all n devices def broadcast(T): return T.reshape(1, s).expand(n, s).copy(D).replicated(0) # (s,) on D, axis=null # scatter: split T into n chunks, one per device def scatter(T): return T.copy(D) # (s,) on D, axis=0 # T has shape (n*s,) on D with axis=0, so each device holds (s,) elements. # gather: collect all shards onto one device def gather(T): return T.copy(D[0]) # (n*s,) on D[0], axis=null # reduce: gather + sum def reduce(T): return gather(T).reshape(n, s).sum(0) # (s,) on D[0], axis=null # allgather: collect all shards, replicate to all devices def allgather(T): return T.reshape(1, n*s).expand(n, n*s).copy(D).replicated(0) # (n*s,) on D, axis=null # reduce_scatter: reduce across devices, scatter result def reduce_scatter(T): return T.reshape(n, n, s//n).permute(1, 0, 2).copy(D).sum(1).reshape(s) # (s,) on D, axis=0 # allreduce: reduce_scatter + allgather def allreduce(T): return allgather(reduce_scatter(T)) # (s,) on D, axis=null \end{lstlisting} %% ============================================================ \subsection*{{\color{callblue}The \texttt{@function} Decorator} \normalfont\small--- graph capture via tracing} The \texttt{@function} decorator transforms a Python function on Tensors into a single \op{Function} node. \begin{lstlisting} @function def f(a: Tensor, b: Tensor) -> Tensor: return a + b \end{lstlisting} When \texttt{f(x, y)} is called, the decorator: \begin{enumerate}[leftmargin=1.5em, itemsep=2pt] \item \textbf{Extracts inputs}: walks all arguments to find every Tensor, deduplicates by identity. \item \textbf{Runs the function} lazily (no device execution), building a UOp graph from the result. \item \textbf{Parameterizes}: replaces each input UOp with a \op{Param}$(k)$ placeholder. \item \textbf{Wraps the body} in a \op{Tuple} (even for single returns) and creates\\ \op{Function}(\op{Tuple}(body), $x$, $y$). \item \textbf{Returns} the result via \op{GetTuple}$(0)$, or one \op{GetTuple} per element for tuple returns. \end{enumerate} The result is a reusable graph fragment: the body contains only \op{Param} references, not concrete buffers. At schedule time, the \op{Function} is resolved by substituting each \op{Param}$(k)$ back with its corresponding argument $a_k$, or lowered into an opaque \op{Call} if it is to be compiled as a reusable kernel. %% ============================================================ \subsection*{Lowering Pipeline \normalfont\small--- from Tensor graph to machine code} \begin{tabular}{@{}l p{9.7cm}@{}} \toprule \textbf{Stage} & \textbf{Semantics} \\ \midrule \textbf{Callify} & Transform the Tensor graph into a single stateless function. \\ \textbf{Rangeify} & Determine the kernel split of the function. Break everything down to shape () \\ \textbf{Optimize} & Insert local buffers. Swap and split ranges, and determine which axes are parallel and which are serial. \\ \textbf{Expand} & Expand the parallel ranges into shape. \\ \textbf{Instruction Selection} & Select target instructions, including WMMA and devectorization. \\ \textbf{Linearize} & Topologically sort the graph and determine execution order. \\ \textbf{Register/Memory Plan} & Allocate and reuse \texttt{GLOBAL}, \texttt{LOCAL}, and \texttt{REG} storage for values with non-overlapping lifetimes. \\ \textbf{Render} & Output the machine code. \\ \bottomrule \end{tabular} \end{document}