mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-10 15:06:18 +08:00
@@ -42,6 +42,7 @@ Elementwise ops operate on a per element basis. They don't change the shape of t
|
||||
::: tinygrad.Tensor.asinh
|
||||
::: tinygrad.Tensor.acosh
|
||||
::: tinygrad.Tensor.hardtanh
|
||||
::: tinygrad.Tensor.erf
|
||||
::: tinygrad.Tensor.gelu
|
||||
::: tinygrad.Tensor.quick_gelu
|
||||
::: tinygrad.Tensor.leakyrelu
|
||||
|
||||
@@ -231,12 +231,7 @@ class BertOutput:
|
||||
return hidden_states
|
||||
|
||||
def gelu(x):
|
||||
return x * 0.5 * (1.0 + erf(x / 1.41421))
|
||||
|
||||
# approximation of the error function
|
||||
def erf(x):
|
||||
t = (1 + 0.3275911 * x.abs()).reciprocal()
|
||||
return x.sign() * (1 - ((((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736) * t + 0.254829592) * t * (-(x.square())).exp())
|
||||
return x * 0.5 * (1.0 + (x / 1.41421).erf())
|
||||
|
||||
class BertIntermediate:
|
||||
def __init__(self, hidden_size, intermediate_size):
|
||||
|
||||
@@ -8,7 +8,7 @@ import numpy as np
|
||||
|
||||
tensor_methods = {"Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "MatMul",
|
||||
"Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh",
|
||||
"Elu", "Celu", "Xor", "Round"}
|
||||
"Elu", "Celu", "Xor", "Round", "Erf"}
|
||||
|
||||
# **************** Free Ops ****************
|
||||
|
||||
@@ -43,7 +43,7 @@ def Constant(value:Optional[Tensor]=None, value_float=None, value_floats=None, v
|
||||
if value_string is not None or value_strings is not None: raise NotImplementedError('value_string or value_strings not implemented for Constant op')
|
||||
|
||||
def HardSigmoid(x: Tensor, alpha=0.2, beta=0.5): return (alpha*x + beta).clip(0, 1)
|
||||
def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + Erf(x/math.sqrt(2)))
|
||||
def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
|
||||
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
|
||||
def PRelu(X:Tensor, slope:Tensor):
|
||||
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
|
||||
@@ -505,17 +505,6 @@ def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1):
|
||||
cond = indices[:,None] == Tensor.arange(depth).reshape((1,) * len(ls) + (depth,) + (1,) * len(rs))
|
||||
return cond.where(values[1], values[0])
|
||||
|
||||
def Erf(x: Tensor):
|
||||
t = 1.0 / (1.0 + 0.3275911 * x.abs())
|
||||
term1 = 0.254829592 * t
|
||||
term2 = -0.284496736 * t ** 2
|
||||
term3 = 1.421413741 * t ** 3
|
||||
term4 = -1.453152027 * t ** 4
|
||||
term5 = 1.061405429 * t ** 5
|
||||
y = (term1 + term2 + term3 + term4 + term5)
|
||||
z = 1.0 - y * (-x * x).exp()
|
||||
return (x > 0).where(z, -z)
|
||||
|
||||
def Compress(inp: Tensor, condition: Tensor, axis=None):
|
||||
if axis is None:
|
||||
inp = inp.flatten()
|
||||
|
||||
@@ -647,6 +647,12 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6)
|
||||
helper_test_op([()], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6)
|
||||
|
||||
def test_erf(self):
|
||||
helper_test_op([(45,65)], torch.erf, Tensor.erf)
|
||||
helper_test_op([(45,65)], torch.erf, Tensor.erf, low=300, high=400)
|
||||
helper_test_op([(45,65)], torch.erf, Tensor.erf, low=-400, high=-300)
|
||||
helper_test_op([()], torch.erf, Tensor.erf)
|
||||
|
||||
def test_gelu(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu)
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=300, high=400)
|
||||
|
||||
@@ -7,7 +7,7 @@ from collections import defaultdict
|
||||
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, UOps, BinaryOps, sint, Variable, SimpleMathTrait
|
||||
from tinygrad.device import Device, Buffer, BufferOptions
|
||||
@@ -2651,6 +2651,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
"""
|
||||
return self.clip(min_val, max_val)
|
||||
|
||||
def erf(self):
|
||||
"""
|
||||
Applies error function element-wise.
|
||||
|
||||
- Described: https://en.wikipedia.org/wiki/Error_function
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).erf().numpy())
|
||||
```
|
||||
"""
|
||||
# https://personal.math.ubc.ca/~cbm/aands/page_299.htm 7.1.26
|
||||
t = 1.0 / (1.0 + 0.3275911 * self.abs())
|
||||
return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp())
|
||||
|
||||
def gelu(self):
|
||||
"""
|
||||
Applies the Gaussian Error Linear Unit (GELU) function element-wise.
|
||||
|
||||
Reference in New Issue
Block a user