Typed methods in tensor.py (#9356)

* types for tensor.py

* x

* more

* remove some casts

* more typing

* fix linting issues

* -1 line

* add last type

* cast 🤙🤙
This commit is contained in:
Friedrich Carl Eichenroth
2025-03-06 01:34:18 +00:00
committed by GitHub
parent 77f7ddf62a
commit dbdefbbe54
2 changed files with 119 additions and 95 deletions

View File

@@ -82,7 +82,7 @@ class LARS(Optimizer):
if self.tcoef != 0:
r1 = t.detach().square().sum().sqrt()
r2 = g.square().sum().sqrt()
r = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
r:Tensor|float = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
else: r = 1.0
g = g + self.wd * t.detach()
# classic momentum does post learning rate update
@@ -141,7 +141,7 @@ class LAMB(Optimizer):
if not self.adam:
r1 = t.detach().square().sum().sqrt()
r2 = up.square().sum().sqrt()
r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
else:
r = 1.0
t.assign((t.detach() - self.lr * r * up).cast(t.dtype))

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
from contextlib import ContextDecorator
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex, ParamSpec, TypeVar
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
@@ -46,11 +46,10 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
# **** Tensor helper functions ****
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:str|tuple[str, ...], arg=None):
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:str|tuple[str, ...], arg=None) -> UOp:
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821
ret = UOp.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
# fake realize
@@ -97,7 +96,7 @@ def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]):
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]) -> Tensor:
# apply mask to values (already broadcasted) and reduce such that if mask contains repeated indices the last one remains
values = values * mask
for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
@@ -293,8 +292,8 @@ class Tensor(SimpleMathTrait):
if 0 in self.shape: return memoryview(bytearray(0))
# NOTE: this realizes on the object from as_buffer being a Python object
cpu = self.cast(self.dtype.base).contiguous().to("CPU").realize()
buf = cast(UOp, cpu.lazydata).base.realized
assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized"
buf = cpu.lazydata.base.realized
assert buf is not None, f"{cpu.lazydata.base} was not realized"
if self.device != "CPU": buf.options = BufferSpec(nolru=True)
return buf.as_buffer(allow_zero_copy=True if self.device != "CPU" else False)
@@ -310,7 +309,7 @@ class Tensor(SimpleMathTrait):
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape))
return self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape)
def item(self) -> ConstType:
"""
@@ -376,7 +375,7 @@ class Tensor(SimpleMathTrait):
if self.grad is not None: ret.grad = self.grad.to(device)
return ret
def to_(self, device:str|tuple[str, ...]|None):
def to_(self, device:str|tuple[str, ...]|None) -> Tensor:
"""
Moves the tensor to the given device in place.
"""
@@ -398,7 +397,7 @@ class Tensor(SimpleMathTrait):
mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None)
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
def shard_(self, devices:tuple[str, ...], axis:int|None=None):
def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor:
"""
Shards the tensor across the given devices in place.
"""
@@ -415,7 +414,7 @@ class Tensor(SimpleMathTrait):
# ***** creation entrypoint *****
@staticmethod
def _metaop(op, shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, arg=None, **kwargs):
def _metaop(op, shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, arg=None, **kwargs) -> Tensor:
dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
if isinstance(device, tuple):
return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None),
@@ -423,7 +422,7 @@ class Tensor(SimpleMathTrait):
return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
@staticmethod
def empty(*shape, **kwargs):
def empty(*shape, **kwargs) -> Tensor:
"""
Creates an empty tensor with the given shape.
@@ -468,7 +467,7 @@ class Tensor(SimpleMathTrait):
_device_seeds: dict[str, Tensor] = {}
_device_rng_counters: dict[str, Tensor] = {}
@staticmethod
def manual_seed(seed=0):
def manual_seed(seed=0) -> None:
"""
Sets the seed for random operations.
@@ -486,7 +485,7 @@ class Tensor(SimpleMathTrait):
Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}
@staticmethod
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor) -> Tensor:
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
@@ -1069,7 +1068,7 @@ class Tensor(SimpleMathTrait):
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
if mode == "constant":
def _constant(x:Tensor,px,v):
def _constant(x:Tensor,px,v) -> Tensor:
return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
_constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
@@ -1464,7 +1463,7 @@ class Tensor(SimpleMathTrait):
order[dim0], order[dim1] = order[dim1], order[dim0]
return self.permute(order)
def flatten(self, start_dim=0, end_dim=-1):
def flatten(self, start_dim=0, end_dim=-1) -> Tensor:
"""
Flattens the tensor by reshaping it into a one-dimensional tensor.
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
@@ -1480,7 +1479,7 @@ class Tensor(SimpleMathTrait):
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
def unflatten(self, dim:int, sizes:tuple[int,...]):
def unflatten(self, dim:int, sizes:tuple[int,...]) -> Tensor:
"""
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
@@ -1565,7 +1564,7 @@ class Tensor(SimpleMathTrait):
ret = self._apply_uop(UOp.r, op=op, axis=axis)
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None):
def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None) -> Tensor:
"""
Returns the sum of the elements of the tensor along the specified axis or axes.
@@ -1592,7 +1591,7 @@ class Tensor(SimpleMathTrait):
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim)
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None):
def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None) -> Tensor:
"""
Returns the product of the elements of the tensor along the specified axis or axes.
@@ -1618,7 +1617,7 @@ class Tensor(SimpleMathTrait):
"""
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
def max(self, axis:int|Sequence[int]|None=None, keepdim=False):
def max(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
"""
Returns the maximum value of the tensor along the specified axis or axes.
@@ -1641,9 +1640,9 @@ class Tensor(SimpleMathTrait):
"""
return self._reduce(Ops.MAX, axis, keepdim)
def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
def _inverse(self) -> Tensor: return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
def min(self, axis:int|Sequence[int]|None=None, keepdim=False):
def min(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
"""
Returns the minimum value of the tensor along the specified axis or axes.
@@ -1666,7 +1665,7 @@ class Tensor(SimpleMathTrait):
"""
return self._inverse().max(axis=axis, keepdim=keepdim)._inverse()
def any(self, axis:int|Sequence[int]|None=None, keepdim=False):
def any(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
"""
Tests if any element evaluates to `True` along the specified axis or axes.
@@ -1688,7 +1687,7 @@ class Tensor(SimpleMathTrait):
"""
return self.bool().max(axis, keepdim)
def all(self, axis:int|Sequence[int]|None=None, keepdim=False):
def all(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
"""
Tests if all element evaluates to `True` along the specified axis or axes.
@@ -1730,7 +1729,7 @@ class Tensor(SimpleMathTrait):
is_nan_close = (self.isnan() & other.isnan()) & equal_nan
return is_finite_close | is_infinite_close | is_nan_close
def mean(self, axis:int|Sequence[int]|None=None, keepdim=False):
def mean(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
"""
Returns the mean value of the tensor along the specified axis or axes.
@@ -1754,9 +1753,10 @@ class Tensor(SimpleMathTrait):
"""
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])).cast(output_dtype)
return numerator.div(prod([cast(int, si) for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])) \
.cast(output_dtype)
def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor:
"""
Returns the variance of the tensor along the specified axis or axes.
@@ -1782,7 +1782,7 @@ class Tensor(SimpleMathTrait):
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction]))
def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]:
"""
Calculates the variance and mean over the dimensions specified by dim.
Syntactic sugar around `Tensor.var` and `Tensor.mean` to match `torch.var_mean`.
@@ -1799,7 +1799,7 @@ class Tensor(SimpleMathTrait):
"""
return self.var(axis, keepdim, correction), self.mean(axis, keepdim)
def std(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
def std(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor:
"""
Returns the standard deviation of the tensor along the specified axis or axes.
@@ -1823,7 +1823,7 @@ class Tensor(SimpleMathTrait):
"""
return self.var(axis, keepdim, correction).sqrt()
def std_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
def std_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]:
"""
Calculates the standard deviation and mean over the dimensions specified by dim.
Syntactic sugar around `Tensor.std` and `Tensor.mean` to match `torch.std_mean`.
@@ -1840,13 +1840,13 @@ class Tensor(SimpleMathTrait):
"""
return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
def _softmax(self, axis, dtype:DTypeLike|None=None):
def _softmax(self, axis, dtype:DTypeLike|None=None) -> tuple[Tensor, Tensor, Tensor]:
m = self - self.max(axis=axis, keepdim=True).detach()
if dtype is not None: m = m.cast(dtype)
e = m.exp()
return m, e, e.sum(axis=axis, keepdim=True)
def softmax(self, axis=-1, dtype:DTypeLike|None=None):
def softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
"""
Applies the softmax function to the tensor along the specified axis.
@@ -1869,7 +1869,7 @@ class Tensor(SimpleMathTrait):
_, e, ss = self._softmax(axis, dtype)
return e.div(ss)
def log_softmax(self, axis=-1, dtype:DTypeLike|None=None):
def log_softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
"""
Applies the log-softmax function to the tensor along the specified axis.
@@ -1892,7 +1892,7 @@ class Tensor(SimpleMathTrait):
m, _, ss = self._softmax(axis, dtype)
return m - ss.log()
def logsumexp(self, axis=None, keepdim=False):
def logsumexp(self, axis=None, keepdim=False) -> Tensor:
"""
Computes the log-sum-exp of the tensor along the specified axis or axes.
@@ -1919,7 +1919,7 @@ class Tensor(SimpleMathTrait):
m = self.max(axis=axis, keepdim=True)
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
def logcumsumexp(self, axis=0):
def logcumsumexp(self, axis=0) -> Tensor:
"""
Computes the log-cumsum-exp of the tensor along the specified axis or axes.
@@ -1954,7 +1954,7 @@ class Tensor(SimpleMathTrait):
ret = ((x_expand - x_cummax).exp() * mask).sum(-1).log() + x_cummax.squeeze(-1)
return ret.reshape(*x.shape).transpose(-1, axis)
def argmax(self, axis=None, keepdim=False):
def argmax(self, axis=None, keepdim=False) -> Tensor:
"""
Returns the indices of the maximum value of the tensor along the specified axis.
@@ -1981,7 +1981,7 @@ class Tensor(SimpleMathTrait):
idx = m * Tensor.arange(self.shape[axis],0,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)).cast(dtypes.int32)
def argmin(self, axis=None, keepdim=False):
def argmin(self, axis=None, keepdim=False) -> Tensor:
"""
Returns the indices of the minimum value of the tensor along the specified axis.
@@ -2353,7 +2353,7 @@ class Tensor(SimpleMathTrait):
ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=identity_element(op, self.dtype)).unflatten(-1, (-1, SPLIT))._cumalu(-1, op)
base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
def fix(x: Tensor) -> Tensor: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
return fix(ret) + fix(base) if op is Ops.ADD else fix(ret).maximum(fix(base))
def cumsum(self, axis:int=0) -> Tensor:
@@ -2565,7 +2565,7 @@ class Tensor(SimpleMathTrait):
# ***** unary ops *****
def logical_not(self):
def logical_not(self) -> Tensor:
"""
Computes the logical NOT of the tensor element-wise.
@@ -2574,7 +2574,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True)
def neg(self):
def neg(self) -> Tensor:
"""
Negates the tensor element-wise.
@@ -2583,17 +2584,20 @@ class Tensor(SimpleMathTrait):
```
"""
return self*-1 if self.dtype != dtypes.bool else self.logical_not()
def contiguous(self):
def contiguous(self) -> Tensor:
"""
Returns a contiguous tensor.
"""
return self._apply_uop(UOp.contiguous)
def contiguous_backward(self):
def contiguous_backward(self) -> Tensor:
"""
Inserts a contiguous operation in the backward pass.
"""
return self._apply_uop(UOp.contiguous_backward)
def log(self):
def log(self) -> Tensor:
"""
Computes the natural logarithm element-wise.
@@ -2604,7 +2608,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.log2()*math.log(2)
def log2(self):
def log2(self) -> Tensor:
"""
Computes the base-2 logarithm element-wise.
@@ -2615,7 +2620,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2)
def exp(self):
def exp(self) -> Tensor:
"""
Computes the exponential function element-wise.
@@ -2626,7 +2632,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.mul(1/math.log(2)).exp2()
def exp2(self):
def exp2(self) -> Tensor:
"""
Computes the base-2 exponential function element-wise.
@@ -2637,7 +2644,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2)
def relu(self):
def relu(self) -> Tensor:
"""
Applies the Rectified Linear Unit (ReLU) function element-wise.
@@ -2649,7 +2657,7 @@ class Tensor(SimpleMathTrait):
"""
return (self>0).where(self, 0)
def sigmoid(self):
def sigmoid(self) -> Tensor:
"""
Applies the Sigmoid function element-wise.
@@ -2661,7 +2669,7 @@ class Tensor(SimpleMathTrait):
"""
return (1 + (self * (-1/math.log(2))).exp2()).reciprocal()
def hardsigmoid(self, alpha:float=1/6, beta:float=0.5):
def hardsigmoid(self, alpha:float=1/6, beta:float=0.5) -> Tensor:
"""
Applies the Hardsigmoid function element-wise.
NOTE: default `alpha` and `beta` values is taken from torch
@@ -2675,7 +2683,7 @@ class Tensor(SimpleMathTrait):
"""
return (alpha * self + beta).relu() - (alpha * self + beta - 1).relu()
def sqrt(self):
def sqrt(self) -> Tensor:
"""
Computes the square root of the tensor element-wise.
@@ -2684,7 +2692,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt)
def rsqrt(self):
def rsqrt(self) -> Tensor:
"""
Computes the reciprocal of the square root of the tensor element-wise.
@@ -2693,7 +2702,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.sqrt().reciprocal()
def sin(self):
def sin(self) -> Tensor:
"""
Computes the sine of the tensor element-wise.
@@ -2702,7 +2712,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin)
def cos(self):
def cos(self) -> Tensor:
"""
Computes the cosine of the tensor element-wise.
@@ -2711,7 +2722,8 @@ class Tensor(SimpleMathTrait):
```
"""
return ((math.pi/2)-self).sin()
def tan(self):
def tan(self) -> Tensor:
"""
Computes the tangent of the tensor element-wise.
@@ -2721,7 +2733,7 @@ class Tensor(SimpleMathTrait):
"""
return self.sin() / self.cos()
def asin(self):
def asin(self) -> Tensor:
"""
Computes the inverse sine (arcsine) of the tensor element-wise.
@@ -2734,7 +2746,7 @@ class Tensor(SimpleMathTrait):
x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients)
return self.sign() * x
def acos(self):
def acos(self) -> Tensor:
"""
Computes the inverse cosine (arccosine) of the tensor element-wise.
@@ -2744,7 +2756,7 @@ class Tensor(SimpleMathTrait):
"""
return math.pi / 2 - self.asin()
def atan(self):
def atan(self) -> Tensor:
"""
Computes the inverse tangent (arctan) of the tensor element-wise.
@@ -2765,6 +2777,7 @@ class Tensor(SimpleMathTrait):
```
"""
return self.cast(dtypes.int32).cast(self.dtype)
def ceil(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise towards positive infinity.
@@ -2774,6 +2787,7 @@ class Tensor(SimpleMathTrait):
```
"""
return (self > (b := self.trunc())).where(b+1, b)
def floor(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise towards negative infinity.
@@ -2783,6 +2797,7 @@ class Tensor(SimpleMathTrait):
```
"""
return (self < (b := self.trunc())).where(b-1, b)
def round(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise with rounding half to even.
@@ -2802,6 +2817,7 @@ class Tensor(SimpleMathTrait):
```
"""
return (self == float("inf")) * detect_positive + (self == float("-inf")) * detect_negative
def isnan(self:Tensor) -> Tensor:
"""
Checks the tensor element-wise to return True where the element is NaN, otherwise returns False
@@ -2811,6 +2827,7 @@ class Tensor(SimpleMathTrait):
```
"""
return self != self
def isfinite(self:Tensor) -> Tensor:
"""
Checks the tensor element-wise to return True where the element is finite, otherwise returns False
@@ -2834,7 +2851,7 @@ class Tensor(SimpleMathTrait):
return (self+(((end - self).cast(dtypes.int8) * w_i + (1<<W_PREC-1)).cast(dtypes.uint16) >> W_PREC)).cast(dtypes.uint8)
return self + (end - self) * weight
def square(self):
def square(self) -> Tensor:
"""
Squares the tensor element-wise.
Equivalent to `self*self`.
@@ -2844,7 +2861,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self*self
def clamp(self, min_=None, max_=None):
def clamp(self, min_=None, max_=None) -> Tensor:
"""
Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound.
@@ -2856,12 +2874,14 @@ class Tensor(SimpleMathTrait):
if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None")
ret = self.maximum(min_) if min_ is not None else self
return ret.minimum(max_) if max_ is not None else ret
def clip(self, min_=None, max_=None):
def clip(self, min_=None, max_=None) -> Tensor:
"""
Alias for `Tensor.clamp`.
"""
return self.clamp(min_, max_)
def sign(self):
def sign(self) -> Tensor:
"""
Returns the sign of the tensor element-wise.
@@ -2870,7 +2890,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0
def abs(self):
def abs(self) -> Tensor:
"""
Computes the absolute value of the tensor element-wise.
@@ -2879,7 +2900,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self * self.sign()
def reciprocal(self):
def reciprocal(self) -> Tensor:
"""
Compute `1/x` element-wise.
@@ -2891,7 +2913,7 @@ class Tensor(SimpleMathTrait):
# ***** activation functions *****
def elu(self, alpha=1.0):
def elu(self, alpha=1.0) -> Tensor:
"""
Applies the Exponential Linear Unit (ELU) function element-wise.
@@ -2904,7 +2926,7 @@ class Tensor(SimpleMathTrait):
"""
return self.relu() - alpha*(1-self.exp()).relu()
def celu(self, alpha=1.0):
def celu(self, alpha=1.0) -> Tensor:
"""
Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise.
@@ -2917,7 +2939,7 @@ class Tensor(SimpleMathTrait):
"""
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
def selu(self, alpha=1.67326, gamma=1.0507):
def selu(self, alpha=1.67326, gamma=1.0507) -> Tensor:
"""
Applies the Scaled Exponential Linear Unit (SELU) function element-wise.
@@ -2930,7 +2952,7 @@ class Tensor(SimpleMathTrait):
"""
return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))
def swish(self):
def swish(self) -> Tensor:
"""
See `.silu()`
@@ -2942,7 +2964,7 @@ class Tensor(SimpleMathTrait):
"""
return self * self.sigmoid()
def silu(self):
def silu(self) -> Tensor:
"""
Applies the Sigmoid Linear Unit (SiLU) function element-wise.
@@ -2955,7 +2977,7 @@ class Tensor(SimpleMathTrait):
"""
return self.swish() # The SiLU function is also known as the swish function.
def relu6(self):
def relu6(self) -> Tensor:
"""
Applies the ReLU6 function element-wise.
@@ -2968,7 +2990,7 @@ class Tensor(SimpleMathTrait):
"""
return self.relu() - (self-6).relu()
def hardswish(self):
def hardswish(self) -> Tensor:
"""
Applies the Hardswish function element-wise.
@@ -2981,7 +3003,7 @@ class Tensor(SimpleMathTrait):
"""
return self * (self+3).relu6() * (1/6)
def tanh(self):
def tanh(self) -> Tensor:
"""
Applies the Hyperbolic Tangent (tanh) function element-wise.
@@ -2993,7 +3015,7 @@ class Tensor(SimpleMathTrait):
"""
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
def sinh(self):
def sinh(self) -> Tensor:
"""
Applies the Hyperbolic Sine (sinh) function element-wise.
@@ -3005,7 +3027,7 @@ class Tensor(SimpleMathTrait):
"""
return (self.exp() - self.neg().exp()) / 2
def cosh(self):
def cosh(self) -> Tensor:
"""
Applies the Hyperbolic Cosine (cosh) function element-wise.
@@ -3017,7 +3039,7 @@ class Tensor(SimpleMathTrait):
"""
return (self.exp() + self.neg().exp()) / 2
def atanh(self):
def atanh(self) -> Tensor:
"""
Applies the Inverse Hyperbolic Tangent (atanh) function element-wise.
@@ -3029,7 +3051,7 @@ class Tensor(SimpleMathTrait):
"""
return ((1 + self)/(1 - self)).log() / 2
def asinh(self):
def asinh(self) -> Tensor:
"""
Applies the Inverse Hyperbolic Sine (asinh) function element-wise.
@@ -3041,7 +3063,7 @@ class Tensor(SimpleMathTrait):
"""
return (self + (self.square() + 1).sqrt()).log()
def acosh(self):
def acosh(self) -> Tensor:
"""
Applies the Inverse Hyperbolic Cosine (acosh) function element-wise.
@@ -3053,7 +3075,7 @@ class Tensor(SimpleMathTrait):
"""
return (self + (self.square() - 1).sqrt()).log()
def hardtanh(self, min_val=-1, max_val=1):
def hardtanh(self, min_val=-1, max_val=1) -> Tensor:
"""
Applies the Hardtanh function element-wise.
@@ -3065,7 +3087,7 @@ class Tensor(SimpleMathTrait):
"""
return self.clip(min_val, max_val)
def erf(self):
def erf(self) -> Tensor:
"""
Applies error function element-wise.
@@ -3079,7 +3101,7 @@ class Tensor(SimpleMathTrait):
t = 1.0 / (1.0 + 0.3275911 * self.abs())
return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp())
def gelu(self):
def gelu(self) -> Tensor:
"""
Applies the Gaussian Error Linear Unit (GELU) function element-wise.
@@ -3092,7 +3114,7 @@ class Tensor(SimpleMathTrait):
"""
return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh())
def quick_gelu(self):
def quick_gelu(self) -> Tensor:
"""
Applies the Sigmoid GELU approximation element-wise.
@@ -3104,7 +3126,7 @@ class Tensor(SimpleMathTrait):
"""
return self * (self * 1.702).sigmoid()
def leaky_relu(self, neg_slope=0.01):
def leaky_relu(self, neg_slope=0.01) -> Tensor:
"""
Applies the Leaky ReLU function element-wise.
@@ -3119,7 +3141,7 @@ class Tensor(SimpleMathTrait):
"""
return (self<0).where(neg_slope*self, self)
def mish(self):
def mish(self) -> Tensor:
"""
Applies the Mish function element-wise.
@@ -3132,7 +3154,7 @@ class Tensor(SimpleMathTrait):
"""
return self * self.softplus().tanh()
def softplus(self, beta=1):
def softplus(self, beta=1) -> Tensor:
"""
Applies the Softplus function element-wise.
@@ -3144,7 +3166,7 @@ class Tensor(SimpleMathTrait):
"""
return (1/beta) * (1 + (self*beta).exp()).log()
def softsign(self):
def softsign(self) -> Tensor:
"""
Applies the Softsign function element-wise.
@@ -3360,7 +3382,7 @@ class Tensor(SimpleMathTrait):
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
def lshift(self, x:int):
def lshift(self, x:int) -> Tensor:
"""
Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
Equivalent to `self << x`.
@@ -3372,7 +3394,7 @@ class Tensor(SimpleMathTrait):
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
return self.mul(2 ** x)
def rshift(self, x:int):
def rshift(self, x:int) -> Tensor:
"""
Computes right arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
Equivalent to `self >> x`.
@@ -3434,7 +3456,7 @@ class Tensor(SimpleMathTrait):
t, x = self._broadcasted(x)
return t._inverse().maximum(x._inverse())._inverse()
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint):
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
"""
Return a tensor of elements selected from either `x` or `y`, depending on `self`.
`output_i = x_i if self_i else y_i`.
@@ -3458,7 +3480,7 @@ class Tensor(SimpleMathTrait):
cond, y = cond._broadcasted(y, match_dtype=False)
return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y))
def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType): return mask.where(value, self)
def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType) -> Tensor: return mask.where(value, self)
def copysign(self, other) -> Tensor:
"""
@@ -3503,7 +3525,7 @@ class Tensor(SimpleMathTrait):
# ***** functional nn ops *****
def linear(self, weight:Tensor, bias:Tensor|None=None):
def linear(self, weight:Tensor, bias:Tensor|None=None) -> Tensor:
"""
Applies a linear transformation to `self` using `weight` and `bias`.
@@ -3519,7 +3541,7 @@ class Tensor(SimpleMathTrait):
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
return x.add(bias) if bias is not None else x
def sequential(self, ll:list[Callable[[Tensor], Tensor]]):
def sequential(self, ll:list[Callable[[Tensor], Tensor]]) -> Tensor:
"""
Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
@@ -3592,7 +3614,7 @@ class Tensor(SimpleMathTrait):
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
# helper function commonly used for indexing
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1) -> Tensor:
if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}")
offset = self.ndim - self._resolve_dim(dim) - 1
return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
@@ -3821,7 +3843,7 @@ class Tensor(SimpleMathTrait):
# ***** cast ops *****
def llvm_bf16_cast(self, dtype:DTypeLike):
def llvm_bf16_cast(self, dtype:DTypeLike) -> Tensor:
# hack for devices that don't support bfloat16
assert self.dtype == dtypes.bfloat16
return self.to("LLVM").cast(dtype)
@@ -4011,8 +4033,10 @@ class Tensor(SimpleMathTrait):
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
def _metadata_wrapper(fn):
def _wrapper(*args, **kwargs):
P = ParamSpec("P")
T = TypeVar("T")
def _metadata_wrapper(fn: Callable[P, T]) -> Callable[P, T]:
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if _METADATA.get() is not None: return fn(*args, **kwargs)
if TRACEMETA >= 2: