Files
dragonpilot/tinygrad_repo/tinygrad/mixin/movement.py
Vehicle Researcher 6928314c89 openpilot v0.10.3 release
date: 2025-12-18T23:23:16
master commit: 154c2334110373950bac1c36fc6e943cb1208326
2025-12-18 23:23:21 -08:00

377 lines
16 KiB
Python

# mixins add syntactic sugar to Tensor and UOp
import functools
from typing import TypeAlias, TYPE_CHECKING, Self
from tinygrad.uop import Ops
from tinygrad.helpers import prod, argfix, flatten, dedup, make_tuple, ceildiv
from tinygrad.uop.ops import resolve, smax
if TYPE_CHECKING:
from tinygrad.uop.ops import UOp
sint: TypeAlias = "UOp | int"
def _align_left(*shapes: tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
# unsqueeze left to make every shape same length
max_dim = max(len(shape) for shape in shapes)
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
class MovementMixin:
# required to implement
def _mop(self, op: Ops, arg) -> Self:
raise NotImplementedError
@property
def shape(self) -> tuple[sint, ...]:
raise NotImplementedError
# great functions you get!
@property
def ndim(self) -> int:
"""
Returns the number of dimensions in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2], [3, 4]])
print(t.ndim)
```
"""
return len(self.shape)
def numel(self) -> sint:
"""
Returns the total number of elements in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(t.numel())
```
"""
return prod(self.shape)
def _resolve_dim(self, dim: int, *, extra: bool = False) -> int:
total = self.ndim + int(extra)
if not -max(1, total) <= dim <= max(1, total) - 1:
raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total) - 1]}")
return dim + total if dim < 0 else dim
def _broadcast_to(self, new_shape: tuple[sint, ...]) -> Self:
if self.shape == new_shape:
return self
if self.ndim > len(new_shape):
raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}")
# first unsqueeze left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
shape, _ = _align_left(self.shape, new_shape)
# for each dimension, check either dim is 1, or it does not change
if not all(s == ns or s == 1 for s, ns in zip(shape, new_shape)):
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
reshaped = self.reshape(shape)
ret = reshaped._mop(Ops.EXPAND, arg=new_shape)
return reshaped if ret.shape == reshaped.shape else ret
def expand(self, shape, *args) -> Self:
"""
Returns a tensor that is expanded to the shape that is specified.
Expand can also increase the number of dimensions that a tensor has.
Passing a `-1` or `None` to a dimension means that its size will not be changed.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t.expand(4, -1).numpy())
```
"""
new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args)))))
return self._broadcast_to(new_shape)
def reshape(self, shape, *args) -> Self:
"""
Returns a tensor with the same data as the original tensor but with a different shape.
`shape` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6)
print(t.reshape(2, 3).numpy())
```
"""
# resolve None and args
new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))])
# resolve -1
if (c := new_shape.count(-1)) > 1:
raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
if c:
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
if prod(self.shape) != prod(new_shape):
raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
ret = self._mop(Ops.RESHAPE, arg=new_shape)
return self if ret.shape == self.shape else ret
def shrink(self, arg: tuple[tuple[sint, sint] | None, ...]) -> Self:
"""
Returns a tensor that shrinks the each axis based on input arg.
`arg` must have the same length as `self.ndim`.
For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(3, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.shrink(((None, (1, 3)))).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.shrink((((0, 2), (0, 2)))).numpy())
```
"""
if self.ndim != len(arg):
raise ValueError(f"{self.ndim=} != {len(arg)=}")
ret = self._mop(Ops.SHRINK, arg=[x if x is not None else (0, s) for x, s in zip(arg, self.shape)])
return self if ret.shape == self.shape else ret
def permute(self, order, *args) -> Self:
"""
Returns a tensor that is a permutation of the original tensor.
The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
`order` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.empty(2, 3, 5)
print(t.shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(t.permute(2, 0, 1).shape)
```
"""
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
if sorted(order_arg) != list(range(self.ndim)):
raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
return self._mop(Ops.PERMUTE, arg=order_arg) if order_arg != tuple(range(self.ndim)) else self
def flip(self, axis, *args) -> Self:
"""
Returns a tensor that reverses the order of the original tensor along given `axis`.
`axis` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flip(0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flip((0, 1)).numpy())
```
"""
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
assert all(not isinstance(x, bool) and x >= 0 and x < self.ndim for x in axis_arg), f"flip args must be axis ints {axis_arg}"
if len(axis_arg) != len(dedup(axis_arg)):
raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
flip_arg = tuple([i in axis_arg for i in range(len(self.shape))])
return self._mop(Ops.FLIP, arg=flip_arg) if any(flip_arg) else self
# **** high level ****
def shrink_to(self, shape, *args) -> Self:
return self.shrink(tuple([None if ns is None else (0, ns) for ns in argfix(shape, *args)]))
def view(self, shape, *args) -> Self:
"""`.view` is an alias for `.reshape`."""
return self.reshape(shape, *args)
def squeeze(self, dim: int | None = None) -> Self:
"""
Returns a tensor with specified dimensions of input of size 1 removed.
If `dim` is not specified, all dimensions with size 1 are removed.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.zeros(2, 1, 2, 1, 2)
print(t.squeeze().shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(t.squeeze(0).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(t.squeeze(1).shape)
```
"""
if dim is None:
return self.reshape(tuple(dim for dim in self.shape if dim != 1))
dim = self._resolve_dim(dim)
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim + 1 :])
def unsqueeze(self, dim: int) -> Self:
"""
Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3, 4])
print(t.unsqueeze(0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.unsqueeze(1).numpy())
```
"""
dim = self._resolve_dim(dim, extra=True)
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
@property
def T(self) -> Self:
"""`.T` is an alias for `.transpose()`."""
return self.transpose()
def transpose(self, dim0=1, dim1=0) -> Self:
"""
Returns a tensor that is a transposed version of the original tensor.
The given dimensions `dim0` and `dim1` are swapped.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.transpose(0, 1).numpy())
```
"""
order = list(range(self.ndim))
order[dim0], order[dim1] = order[dim1], order[dim0]
return self.permute(order)
def flatten(self, start_dim=0, end_dim=-1) -> Self:
"""
Flattens the tensor by reshaping it into a one-dimensional tensor.
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(8).reshape(2, 2, 2)
print(t.flatten().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flatten(start_dim=1).numpy())
```
"""
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim : end_dim + 1]),) + self.shape[end_dim + 1 :])
def unflatten(self, dim: int, sizes: tuple[int, ...]) -> Self:
"""
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
```
"""
dim = self._resolve_dim(dim)
return self.reshape(self.shape[:dim] + sizes + self.shape[dim + 1 :])
def rearrange(self, formula: str, **sizes) -> Self:
"""
Rearranges input according to formula
See: https://einops.rocks/api/rearrange/
```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
```
"""
def parse_formula(formula: str):
tokens = f" {formula} ".replace("", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
pairs = list(zip(lparens, rparens))
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
return [name for name in tokens if name not in ("(", ")")], [(s - 2 * i, e - 1 - 2 * i) for i, (s, e) in enumerate(pairs)]
assert formula.count("->") == 1, 'need exactly one "->" in formula'
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
for name in sizes:
assert name in lhs, f"axis {name} is not used in transform"
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
for name in flatten((lhs, rhs)):
assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
# resolve ellipsis
if "..." in lhs:
ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
lhs, rhs = map(lambda l: l[: (i := l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1 :] if "..." in l else l, (lhs, rhs))
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
for i, name in enumerate(lhs):
assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
t = t.permute([lhs.index(name) for name in rhs])
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0] < dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
# *** movement ops with expand ***
def repeat_interleave(self, repeats: int, dim: int | None = None) -> Self:
"""
Repeats elements of a tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t.repeat_interleave(2).numpy())
```
"""
x, dim = (self.flatten(), 0) if dim is None else (self, self._resolve_dim(dim))
shp = x.shape
x = x.reshape(*shp[: dim + 1], 1, *shp[dim + 1 :])
x = x.expand(*shp[: dim + 1], repeats, *shp[dim + 1 :])
x = x.reshape(*shp[:dim], shp[dim] * repeats, *shp[dim + 1 :])
return x
def repeat(self, repeats, *args) -> Self:
"""
Repeats tensor number of times along each dimension specified by `repeats`.
`repeats` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t.repeat(4, 2).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.repeat(4, 2, 1).shape)
```
"""
repeats = argfix(repeats, *args)
base_shape = _align_left(self.shape, repeats)[0]
unsqueezed_shape = flatten([[s] if r == 1 else [1, s] for r, s in zip(repeats, base_shape)])
expanded_shape = flatten([[s] if r == 1 else [r, s] for r, s in zip(repeats, base_shape)])
final_shape = [r * s for r, s in zip(repeats, base_shape)]
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
# **** pool level ****
def _pool(self, k_: tuple[sint, ...], stride: int | tuple[int, ...] = 1, dilation: int | tuple[int, ...] = 1) -> Self:
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
noop, i_ = [None] * (self.ndim - len(k_)), self.shape[-len(k_) :]
assert all(resolve(d * (k - 1) + 1 <= i) for k, d, i in zip(k_, d_, i_)), "kernel size cannot be greater than actual input size"
o_ = [ceildiv(i - d * (k - 1), s) for i, d, k, s in zip(i_, d_, k_, s_)]
# input size scaling factor to make sure shrink for stride is possible
f_ = [smax(1, ceildiv(o * s - d, i)) for o, s, i, d in zip(o_, s_, i_, d_)]
# repeats such that we don't need padding
x = self.repeat([1] * len(noop) + [ceildiv(k * (i * f + d), i) for k, i, d, f in zip(k_, i_, d_, f_)])
# handle dilation
x = x.shrink_to(noop + [k * (i * f + d) for k, i, d, f in zip(k_, i_, d_, f_)])
x = x.reshape(noop + flatten((k, (i * f + d)) for k, i, d, f in zip(k_, i_, d_, f_)))
# handle stride
x = x.shrink_to(noop + flatten((k, o * s) for k, o, s in zip(k_, o_, s_))).reshape(noop + flatten((k, o, s) for k, o, s in zip(k_, o_, s_)))
x = x.shrink_to(noop + flatten((k, o, 1) for k, o in zip(k_, o_))).reshape(noop + flatten((k, o) for k, o in zip(k_, o_)))
# permute to move reduce to the end
return x.permute(*range(len(noop)), *[len(noop) + i * 2 + 1 for i in range(len(i_))], *[len(noop) + i * 2 for i in range(len(i_))])