From dbdefbbe540cb36ebd16ec1c6c795e8a15612e86 Mon Sep 17 00:00:00 2001 From: Friedrich Carl Eichenroth Date: Thu, 6 Mar 2025 01:34:18 +0000 Subject: [PATCH] Typed methods in tensor.py (#9356) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * types for tensor.py * x * more * remove some casts * more typing * fix linting issues * -1 line * add last type * cast 🤙🤙 --- tinygrad/nn/optim.py | 4 +- tinygrad/tensor.py | 210 ++++++++++++++++++++++++------------------- 2 files changed, 119 insertions(+), 95 deletions(-) diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index e5cfb4aa11..6baffe99c2 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -82,7 +82,7 @@ class LARS(Optimizer): if self.tcoef != 0: r1 = t.detach().square().sum().sqrt() r2 = g.square().sum().sqrt() - r = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0) + r:Tensor|float = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0) else: r = 1.0 g = g + self.wd * t.detach() # classic momentum does post learning rate update @@ -141,7 +141,7 @@ class LAMB(Optimizer): if not self.adam: r1 = t.detach().square().sum().sqrt() r2 = up.square().sum().sqrt() - r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0) + r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0) else: r = 1.0 t.assign((t.detach() - self.lr * r * up).cast(t.dtype)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8c50b23a9f..df2418706f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2,7 +2,7 @@ from __future__ import annotations import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref from contextlib import ContextDecorator -from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex +from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex, ParamSpec, TypeVar from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.dtype import _from_np_dtype, _to_np_dtype from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup @@ -46,11 +46,10 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None: # **** Tensor helper functions **** -def _metaop(op, shape:tuple[sint,...], dtype:DType, device:str|tuple[str, ...], arg=None): +def _metaop(op, shape:tuple[sint,...], dtype:DType, device:str|tuple[str, ...], arg=None) -> UOp: if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg) return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None) - def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821 ret = UOp.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY") # fake realize @@ -97,7 +96,7 @@ def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]: def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]: return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes))) -def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]): +def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]) -> Tensor: # apply mask to values (already broadcasted) and reduce such that if mask contains repeated indices the last one remains values = values * mask for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim))) @@ -293,8 +292,8 @@ class Tensor(SimpleMathTrait): if 0 in self.shape: return memoryview(bytearray(0)) # NOTE: this realizes on the object from as_buffer being a Python object cpu = self.cast(self.dtype.base).contiguous().to("CPU").realize() - buf = cast(UOp, cpu.lazydata).base.realized - assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized" + buf = cpu.lazydata.base.realized + assert buf is not None, f"{cpu.lazydata.base} was not realized" if self.device != "CPU": buf.options = BufferSpec(nolru=True) return buf.as_buffer(allow_zero_copy=True if self.device != "CPU" else False) @@ -310,7 +309,7 @@ class Tensor(SimpleMathTrait): assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}" assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e" - return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape)) + return self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape) def item(self) -> ConstType: """ @@ -376,7 +375,7 @@ class Tensor(SimpleMathTrait): if self.grad is not None: ret.grad = self.grad.to(device) return ret - def to_(self, device:str|tuple[str, ...]|None): + def to_(self, device:str|tuple[str, ...]|None) -> Tensor: """ Moves the tensor to the given device in place. """ @@ -398,7 +397,7 @@ class Tensor(SimpleMathTrait): mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None) return Tensor(mlb, device=devices, requires_grad=self.requires_grad) - def shard_(self, devices:tuple[str, ...], axis:int|None=None): + def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor: """ Shards the tensor across the given devices in place. """ @@ -415,7 +414,7 @@ class Tensor(SimpleMathTrait): # ***** creation entrypoint ***** @staticmethod - def _metaop(op, shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, arg=None, **kwargs): + def _metaop(op, shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, arg=None, **kwargs) -> Tensor: dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float if isinstance(device, tuple): return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None), @@ -423,7 +422,7 @@ class Tensor(SimpleMathTrait): return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs) @staticmethod - def empty(*shape, **kwargs): + def empty(*shape, **kwargs) -> Tensor: """ Creates an empty tensor with the given shape. @@ -468,7 +467,7 @@ class Tensor(SimpleMathTrait): _device_seeds: dict[str, Tensor] = {} _device_rng_counters: dict[str, Tensor] = {} @staticmethod - def manual_seed(seed=0): + def manual_seed(seed=0) -> None: """ Sets the seed for random operations. @@ -486,7 +485,7 @@ class Tensor(SimpleMathTrait): Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {} @staticmethod - def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor): + def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor) -> Tensor: x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64) x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64)) counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32) @@ -1069,7 +1068,7 @@ class Tensor(SimpleMathTrait): if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}") X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX) if mode == "constant": - def _constant(x:Tensor,px,v): + def _constant(x:Tensor,px,v) -> Tensor: return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v)) return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \ _constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value) @@ -1464,7 +1463,7 @@ class Tensor(SimpleMathTrait): order[dim0], order[dim1] = order[dim1], order[dim0] return self.permute(order) - def flatten(self, start_dim=0, end_dim=-1): + def flatten(self, start_dim=0, end_dim=-1) -> Tensor: """ Flattens the tensor by reshaping it into a one-dimensional tensor. If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened. @@ -1480,7 +1479,7 @@ class Tensor(SimpleMathTrait): start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim) return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:]) - def unflatten(self, dim:int, sizes:tuple[int,...]): + def unflatten(self, dim:int, sizes:tuple[int,...]) -> Tensor: """ Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function. @@ -1565,7 +1564,7 @@ class Tensor(SimpleMathTrait): ret = self._apply_uop(UOp.r, op=op, axis=axis) return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis)) - def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None): + def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None) -> Tensor: """ Returns the sum of the elements of the tensor along the specified axis or axes. @@ -1592,7 +1591,7 @@ class Tensor(SimpleMathTrait): ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim) return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret - def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None): + def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None) -> Tensor: """ Returns the product of the elements of the tensor along the specified axis or axes. @@ -1618,7 +1617,7 @@ class Tensor(SimpleMathTrait): """ return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim) - def max(self, axis:int|Sequence[int]|None=None, keepdim=False): + def max(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: """ Returns the maximum value of the tensor along the specified axis or axes. @@ -1641,9 +1640,9 @@ class Tensor(SimpleMathTrait): """ return self._reduce(Ops.MAX, axis, keepdim) - def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not() + def _inverse(self) -> Tensor: return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not() - def min(self, axis:int|Sequence[int]|None=None, keepdim=False): + def min(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: """ Returns the minimum value of the tensor along the specified axis or axes. @@ -1666,7 +1665,7 @@ class Tensor(SimpleMathTrait): """ return self._inverse().max(axis=axis, keepdim=keepdim)._inverse() - def any(self, axis:int|Sequence[int]|None=None, keepdim=False): + def any(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: """ Tests if any element evaluates to `True` along the specified axis or axes. @@ -1688,7 +1687,7 @@ class Tensor(SimpleMathTrait): """ return self.bool().max(axis, keepdim) - def all(self, axis:int|Sequence[int]|None=None, keepdim=False): + def all(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: """ Tests if all element evaluates to `True` along the specified axis or axes. @@ -1730,7 +1729,7 @@ class Tensor(SimpleMathTrait): is_nan_close = (self.isnan() & other.isnan()) & equal_nan return is_finite_close | is_infinite_close | is_nan_close - def mean(self, axis:int|Sequence[int]|None=None, keepdim=False): + def mean(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: """ Returns the mean value of the tensor along the specified axis or axes. @@ -1754,9 +1753,10 @@ class Tensor(SimpleMathTrait): """ output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32 numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim) - return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])).cast(output_dtype) + return numerator.div(prod([cast(int, si) for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])) \ + .cast(output_dtype) - def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1): + def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor: """ Returns the variance of the tensor along the specified axis or axes. @@ -1782,7 +1782,7 @@ class Tensor(SimpleMathTrait): n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)]) return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction])) - def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1): + def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]: """ Calculates the variance and mean over the dimensions specified by dim. Syntactic sugar around `Tensor.var` and `Tensor.mean` to match `torch.var_mean`. @@ -1799,7 +1799,7 @@ class Tensor(SimpleMathTrait): """ return self.var(axis, keepdim, correction), self.mean(axis, keepdim) - def std(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1): + def std(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor: """ Returns the standard deviation of the tensor along the specified axis or axes. @@ -1823,7 +1823,7 @@ class Tensor(SimpleMathTrait): """ return self.var(axis, keepdim, correction).sqrt() - def std_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1): + def std_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]: """ Calculates the standard deviation and mean over the dimensions specified by dim. Syntactic sugar around `Tensor.std` and `Tensor.mean` to match `torch.std_mean`. @@ -1840,13 +1840,13 @@ class Tensor(SimpleMathTrait): """ return self.std(axis, keepdim, correction), self.mean(axis, keepdim) - def _softmax(self, axis, dtype:DTypeLike|None=None): + def _softmax(self, axis, dtype:DTypeLike|None=None) -> tuple[Tensor, Tensor, Tensor]: m = self - self.max(axis=axis, keepdim=True).detach() if dtype is not None: m = m.cast(dtype) e = m.exp() return m, e, e.sum(axis=axis, keepdim=True) - def softmax(self, axis=-1, dtype:DTypeLike|None=None): + def softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor: """ Applies the softmax function to the tensor along the specified axis. @@ -1869,7 +1869,7 @@ class Tensor(SimpleMathTrait): _, e, ss = self._softmax(axis, dtype) return e.div(ss) - def log_softmax(self, axis=-1, dtype:DTypeLike|None=None): + def log_softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor: """ Applies the log-softmax function to the tensor along the specified axis. @@ -1892,7 +1892,7 @@ class Tensor(SimpleMathTrait): m, _, ss = self._softmax(axis, dtype) return m - ss.log() - def logsumexp(self, axis=None, keepdim=False): + def logsumexp(self, axis=None, keepdim=False) -> Tensor: """ Computes the log-sum-exp of the tensor along the specified axis or axes. @@ -1919,7 +1919,7 @@ class Tensor(SimpleMathTrait): m = self.max(axis=axis, keepdim=True) return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis) - def logcumsumexp(self, axis=0): + def logcumsumexp(self, axis=0) -> Tensor: """ Computes the log-cumsum-exp of the tensor along the specified axis or axes. @@ -1954,7 +1954,7 @@ class Tensor(SimpleMathTrait): ret = ((x_expand - x_cummax).exp() * mask).sum(-1).log() + x_cummax.squeeze(-1) return ret.reshape(*x.shape).transpose(-1, axis) - def argmax(self, axis=None, keepdim=False): + def argmax(self, axis=None, keepdim=False) -> Tensor: """ Returns the indices of the maximum value of the tensor along the specified axis. @@ -1981,7 +1981,7 @@ class Tensor(SimpleMathTrait): idx = m * Tensor.arange(self.shape[axis],0,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)).cast(dtypes.int32) - def argmin(self, axis=None, keepdim=False): + def argmin(self, axis=None, keepdim=False) -> Tensor: """ Returns the indices of the minimum value of the tensor along the specified axis. @@ -2353,7 +2353,7 @@ class Tensor(SimpleMathTrait): ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=identity_element(op, self.dtype)).unflatten(-1, (-1, SPLIT))._cumalu(-1, op) base = ret[..., -1]._cumalu(-1, op, _include_initial=True) base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1]) - def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1) + def fix(x: Tensor) -> Tensor: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1) return fix(ret) + fix(base) if op is Ops.ADD else fix(ret).maximum(fix(base)) def cumsum(self, axis:int=0) -> Tensor: @@ -2565,7 +2565,7 @@ class Tensor(SimpleMathTrait): # ***** unary ops ***** - def logical_not(self): + def logical_not(self) -> Tensor: """ Computes the logical NOT of the tensor element-wise. @@ -2574,7 +2574,8 @@ class Tensor(SimpleMathTrait): ``` """ return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True) - def neg(self): + + def neg(self) -> Tensor: """ Negates the tensor element-wise. @@ -2583,17 +2584,20 @@ class Tensor(SimpleMathTrait): ``` """ return self*-1 if self.dtype != dtypes.bool else self.logical_not() - def contiguous(self): + + def contiguous(self) -> Tensor: """ Returns a contiguous tensor. """ return self._apply_uop(UOp.contiguous) - def contiguous_backward(self): + + def contiguous_backward(self) -> Tensor: """ Inserts a contiguous operation in the backward pass. """ return self._apply_uop(UOp.contiguous_backward) - def log(self): + + def log(self) -> Tensor: """ Computes the natural logarithm element-wise. @@ -2604,7 +2608,8 @@ class Tensor(SimpleMathTrait): ``` """ return self.log2()*math.log(2) - def log2(self): + + def log2(self) -> Tensor: """ Computes the base-2 logarithm element-wise. @@ -2615,7 +2620,8 @@ class Tensor(SimpleMathTrait): ``` """ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2) - def exp(self): + + def exp(self) -> Tensor: """ Computes the exponential function element-wise. @@ -2626,7 +2632,8 @@ class Tensor(SimpleMathTrait): ``` """ return self.mul(1/math.log(2)).exp2() - def exp2(self): + + def exp2(self) -> Tensor: """ Computes the base-2 exponential function element-wise. @@ -2637,7 +2644,8 @@ class Tensor(SimpleMathTrait): ``` """ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2) - def relu(self): + + def relu(self) -> Tensor: """ Applies the Rectified Linear Unit (ReLU) function element-wise. @@ -2649,7 +2657,7 @@ class Tensor(SimpleMathTrait): """ return (self>0).where(self, 0) - def sigmoid(self): + def sigmoid(self) -> Tensor: """ Applies the Sigmoid function element-wise. @@ -2661,7 +2669,7 @@ class Tensor(SimpleMathTrait): """ return (1 + (self * (-1/math.log(2))).exp2()).reciprocal() - def hardsigmoid(self, alpha:float=1/6, beta:float=0.5): + def hardsigmoid(self, alpha:float=1/6, beta:float=0.5) -> Tensor: """ Applies the Hardsigmoid function element-wise. NOTE: default `alpha` and `beta` values is taken from torch @@ -2675,7 +2683,7 @@ class Tensor(SimpleMathTrait): """ return (alpha * self + beta).relu() - (alpha * self + beta - 1).relu() - def sqrt(self): + def sqrt(self) -> Tensor: """ Computes the square root of the tensor element-wise. @@ -2684,7 +2692,8 @@ class Tensor(SimpleMathTrait): ``` """ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt) - def rsqrt(self): + + def rsqrt(self) -> Tensor: """ Computes the reciprocal of the square root of the tensor element-wise. @@ -2693,7 +2702,8 @@ class Tensor(SimpleMathTrait): ``` """ return self.sqrt().reciprocal() - def sin(self): + + def sin(self) -> Tensor: """ Computes the sine of the tensor element-wise. @@ -2702,7 +2712,8 @@ class Tensor(SimpleMathTrait): ``` """ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin) - def cos(self): + + def cos(self) -> Tensor: """ Computes the cosine of the tensor element-wise. @@ -2711,7 +2722,8 @@ class Tensor(SimpleMathTrait): ``` """ return ((math.pi/2)-self).sin() - def tan(self): + + def tan(self) -> Tensor: """ Computes the tangent of the tensor element-wise. @@ -2721,7 +2733,7 @@ class Tensor(SimpleMathTrait): """ return self.sin() / self.cos() - def asin(self): + def asin(self) -> Tensor: """ Computes the inverse sine (arcsine) of the tensor element-wise. @@ -2734,7 +2746,7 @@ class Tensor(SimpleMathTrait): x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients) return self.sign() * x - def acos(self): + def acos(self) -> Tensor: """ Computes the inverse cosine (arccosine) of the tensor element-wise. @@ -2744,7 +2756,7 @@ class Tensor(SimpleMathTrait): """ return math.pi / 2 - self.asin() - def atan(self): + def atan(self) -> Tensor: """ Computes the inverse tangent (arctan) of the tensor element-wise. @@ -2765,6 +2777,7 @@ class Tensor(SimpleMathTrait): ``` """ return self.cast(dtypes.int32).cast(self.dtype) + def ceil(self: Tensor) -> Tensor: """ Rounds the tensor element-wise towards positive infinity. @@ -2774,6 +2787,7 @@ class Tensor(SimpleMathTrait): ``` """ return (self > (b := self.trunc())).where(b+1, b) + def floor(self: Tensor) -> Tensor: """ Rounds the tensor element-wise towards negative infinity. @@ -2783,6 +2797,7 @@ class Tensor(SimpleMathTrait): ``` """ return (self < (b := self.trunc())).where(b-1, b) + def round(self: Tensor) -> Tensor: """ Rounds the tensor element-wise with rounding half to even. @@ -2802,6 +2817,7 @@ class Tensor(SimpleMathTrait): ``` """ return (self == float("inf")) * detect_positive + (self == float("-inf")) * detect_negative + def isnan(self:Tensor) -> Tensor: """ Checks the tensor element-wise to return True where the element is NaN, otherwise returns False @@ -2811,6 +2827,7 @@ class Tensor(SimpleMathTrait): ``` """ return self != self + def isfinite(self:Tensor) -> Tensor: """ Checks the tensor element-wise to return True where the element is finite, otherwise returns False @@ -2834,7 +2851,7 @@ class Tensor(SimpleMathTrait): return (self+(((end - self).cast(dtypes.int8) * w_i + (1<> W_PREC)).cast(dtypes.uint8) return self + (end - self) * weight - def square(self): + def square(self) -> Tensor: """ Squares the tensor element-wise. Equivalent to `self*self`. @@ -2844,7 +2861,8 @@ class Tensor(SimpleMathTrait): ``` """ return self*self - def clamp(self, min_=None, max_=None): + + def clamp(self, min_=None, max_=None) -> Tensor: """ Clips (clamps) the values in the tensor between `min_` and `max_` element-wise. If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound. @@ -2856,12 +2874,14 @@ class Tensor(SimpleMathTrait): if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None") ret = self.maximum(min_) if min_ is not None else self return ret.minimum(max_) if max_ is not None else ret - def clip(self, min_=None, max_=None): + + def clip(self, min_=None, max_=None) -> Tensor: """ Alias for `Tensor.clamp`. """ return self.clamp(min_, max_) - def sign(self): + + def sign(self) -> Tensor: """ Returns the sign of the tensor element-wise. @@ -2870,7 +2890,8 @@ class Tensor(SimpleMathTrait): ``` """ return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0 - def abs(self): + + def abs(self) -> Tensor: """ Computes the absolute value of the tensor element-wise. @@ -2879,7 +2900,8 @@ class Tensor(SimpleMathTrait): ``` """ return self * self.sign() - def reciprocal(self): + + def reciprocal(self) -> Tensor: """ Compute `1/x` element-wise. @@ -2891,7 +2913,7 @@ class Tensor(SimpleMathTrait): # ***** activation functions ***** - def elu(self, alpha=1.0): + def elu(self, alpha=1.0) -> Tensor: """ Applies the Exponential Linear Unit (ELU) function element-wise. @@ -2904,7 +2926,7 @@ class Tensor(SimpleMathTrait): """ return self.relu() - alpha*(1-self.exp()).relu() - def celu(self, alpha=1.0): + def celu(self, alpha=1.0) -> Tensor: """ Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise. @@ -2917,7 +2939,7 @@ class Tensor(SimpleMathTrait): """ return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0) - def selu(self, alpha=1.67326, gamma=1.0507): + def selu(self, alpha=1.67326, gamma=1.0507) -> Tensor: """ Applies the Scaled Exponential Linear Unit (SELU) function element-wise. @@ -2930,7 +2952,7 @@ class Tensor(SimpleMathTrait): """ return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1)) - def swish(self): + def swish(self) -> Tensor: """ See `.silu()` @@ -2942,7 +2964,7 @@ class Tensor(SimpleMathTrait): """ return self * self.sigmoid() - def silu(self): + def silu(self) -> Tensor: """ Applies the Sigmoid Linear Unit (SiLU) function element-wise. @@ -2955,7 +2977,7 @@ class Tensor(SimpleMathTrait): """ return self.swish() # The SiLU function is also known as the swish function. - def relu6(self): + def relu6(self) -> Tensor: """ Applies the ReLU6 function element-wise. @@ -2968,7 +2990,7 @@ class Tensor(SimpleMathTrait): """ return self.relu() - (self-6).relu() - def hardswish(self): + def hardswish(self) -> Tensor: """ Applies the Hardswish function element-wise. @@ -2981,7 +3003,7 @@ class Tensor(SimpleMathTrait): """ return self * (self+3).relu6() * (1/6) - def tanh(self): + def tanh(self) -> Tensor: """ Applies the Hyperbolic Tangent (tanh) function element-wise. @@ -2993,7 +3015,7 @@ class Tensor(SimpleMathTrait): """ return 2.0 * ((2.0 * self).sigmoid()) - 1.0 - def sinh(self): + def sinh(self) -> Tensor: """ Applies the Hyperbolic Sine (sinh) function element-wise. @@ -3005,7 +3027,7 @@ class Tensor(SimpleMathTrait): """ return (self.exp() - self.neg().exp()) / 2 - def cosh(self): + def cosh(self) -> Tensor: """ Applies the Hyperbolic Cosine (cosh) function element-wise. @@ -3017,7 +3039,7 @@ class Tensor(SimpleMathTrait): """ return (self.exp() + self.neg().exp()) / 2 - def atanh(self): + def atanh(self) -> Tensor: """ Applies the Inverse Hyperbolic Tangent (atanh) function element-wise. @@ -3029,7 +3051,7 @@ class Tensor(SimpleMathTrait): """ return ((1 + self)/(1 - self)).log() / 2 - def asinh(self): + def asinh(self) -> Tensor: """ Applies the Inverse Hyperbolic Sine (asinh) function element-wise. @@ -3041,7 +3063,7 @@ class Tensor(SimpleMathTrait): """ return (self + (self.square() + 1).sqrt()).log() - def acosh(self): + def acosh(self) -> Tensor: """ Applies the Inverse Hyperbolic Cosine (acosh) function element-wise. @@ -3053,7 +3075,7 @@ class Tensor(SimpleMathTrait): """ return (self + (self.square() - 1).sqrt()).log() - def hardtanh(self, min_val=-1, max_val=1): + def hardtanh(self, min_val=-1, max_val=1) -> Tensor: """ Applies the Hardtanh function element-wise. @@ -3065,7 +3087,7 @@ class Tensor(SimpleMathTrait): """ return self.clip(min_val, max_val) - def erf(self): + def erf(self) -> Tensor: """ Applies error function element-wise. @@ -3079,7 +3101,7 @@ class Tensor(SimpleMathTrait): t = 1.0 / (1.0 + 0.3275911 * self.abs()) return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp()) - def gelu(self): + def gelu(self) -> Tensor: """ Applies the Gaussian Error Linear Unit (GELU) function element-wise. @@ -3092,7 +3114,7 @@ class Tensor(SimpleMathTrait): """ return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh()) - def quick_gelu(self): + def quick_gelu(self) -> Tensor: """ Applies the Sigmoid GELU approximation element-wise. @@ -3104,7 +3126,7 @@ class Tensor(SimpleMathTrait): """ return self * (self * 1.702).sigmoid() - def leaky_relu(self, neg_slope=0.01): + def leaky_relu(self, neg_slope=0.01) -> Tensor: """ Applies the Leaky ReLU function element-wise. @@ -3119,7 +3141,7 @@ class Tensor(SimpleMathTrait): """ return (self<0).where(neg_slope*self, self) - def mish(self): + def mish(self) -> Tensor: """ Applies the Mish function element-wise. @@ -3132,7 +3154,7 @@ class Tensor(SimpleMathTrait): """ return self * self.softplus().tanh() - def softplus(self, beta=1): + def softplus(self, beta=1) -> Tensor: """ Applies the Softplus function element-wise. @@ -3144,7 +3166,7 @@ class Tensor(SimpleMathTrait): """ return (1/beta) * (1 + (self*beta).exp()).log() - def softsign(self): + def softsign(self) -> Tensor: """ Applies the Softsign function element-wise. @@ -3360,7 +3382,7 @@ class Tensor(SimpleMathTrait): if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") return self.logical_not() if self.dtype == dtypes.bool else self ^ -1 - def lshift(self, x:int): + def lshift(self, x:int) -> Tensor: """ Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype. Equivalent to `self << x`. @@ -3372,7 +3394,7 @@ class Tensor(SimpleMathTrait): assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}" return self.mul(2 ** x) - def rshift(self, x:int): + def rshift(self, x:int) -> Tensor: """ Computes right arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype. Equivalent to `self >> x`. @@ -3434,7 +3456,7 @@ class Tensor(SimpleMathTrait): t, x = self._broadcasted(x) return t._inverse().maximum(x._inverse())._inverse() - def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint): + def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor: """ Return a tensor of elements selected from either `x` or `y`, depending on `self`. `output_i = x_i if self_i else y_i`. @@ -3458,7 +3480,7 @@ class Tensor(SimpleMathTrait): cond, y = cond._broadcasted(y, match_dtype=False) return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y)) - def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType): return mask.where(value, self) + def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType) -> Tensor: return mask.where(value, self) def copysign(self, other) -> Tensor: """ @@ -3503,7 +3525,7 @@ class Tensor(SimpleMathTrait): # ***** functional nn ops ***** - def linear(self, weight:Tensor, bias:Tensor|None=None): + def linear(self, weight:Tensor, bias:Tensor|None=None) -> Tensor: """ Applies a linear transformation to `self` using `weight` and `bias`. @@ -3519,7 +3541,7 @@ class Tensor(SimpleMathTrait): x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight) return x.add(bias) if bias is not None else x - def sequential(self, ll:list[Callable[[Tensor], Tensor]]): + def sequential(self, ll:list[Callable[[Tensor], Tensor]]) -> Tensor: """ Applies a sequence of functions to `self` chaining the output of each function to the input of the next. @@ -3592,7 +3614,7 @@ class Tensor(SimpleMathTrait): return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p) # helper function commonly used for indexing - def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1): + def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1) -> Tensor: if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}") offset = self.ndim - self._resolve_dim(dim) - 1 return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset) @@ -3821,7 +3843,7 @@ class Tensor(SimpleMathTrait): # ***** cast ops ***** - def llvm_bf16_cast(self, dtype:DTypeLike): + def llvm_bf16_cast(self, dtype:DTypeLike) -> Tensor: # hack for devices that don't support bfloat16 assert self.dtype == dtypes.bfloat16 return self.to("LLVM").cast(dtype) @@ -4011,8 +4033,10 @@ class Tensor(SimpleMathTrait): ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2) return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1)) -def _metadata_wrapper(fn): - def _wrapper(*args, **kwargs): +P = ParamSpec("P") +T = TypeVar("T") +def _metadata_wrapper(fn: Callable[P, T]) -> Callable[P, T]: + def _wrapper(*args: P.args, **kwargs: P.kwargs) -> T: if _METADATA.get() is not None: return fn(*args, **kwargs) if TRACEMETA >= 2: