Tensor.erf (#7419)

the same one used in onnx and the one in bert.
This commit is contained in:
chenyu
2024-10-30 18:12:28 -04:00
committed by GitHub
parent e955aa1bee
commit fb694a63eb
5 changed files with 25 additions and 20 deletions

View File

@@ -42,6 +42,7 @@ Elementwise ops operate on a per element basis. They don't change the shape of t
::: tinygrad.Tensor.asinh
::: tinygrad.Tensor.acosh
::: tinygrad.Tensor.hardtanh
::: tinygrad.Tensor.erf
::: tinygrad.Tensor.gelu
::: tinygrad.Tensor.quick_gelu
::: tinygrad.Tensor.leakyrelu

View File

@@ -231,12 +231,7 @@ class BertOutput:
return hidden_states
def gelu(x):
return x * 0.5 * (1.0 + erf(x / 1.41421))
# approximation of the error function
def erf(x):
t = (1 + 0.3275911 * x.abs()).reciprocal()
return x.sign() * (1 - ((((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736) * t + 0.254829592) * t * (-(x.square())).exp())
return x * 0.5 * (1.0 + (x / 1.41421).erf())
class BertIntermediate:
def __init__(self, hidden_size, intermediate_size):

View File

@@ -8,7 +8,7 @@ import numpy as np
tensor_methods = {"Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "MatMul",
"Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh",
"Elu", "Celu", "Xor", "Round"}
"Elu", "Celu", "Xor", "Round", "Erf"}
# **************** Free Ops ****************
@@ -43,7 +43,7 @@ def Constant(value:Optional[Tensor]=None, value_float=None, value_floats=None, v
if value_string is not None or value_strings is not None: raise NotImplementedError('value_string or value_strings not implemented for Constant op')
def HardSigmoid(x: Tensor, alpha=0.2, beta=0.5): return (alpha*x + beta).clip(0, 1)
def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + Erf(x/math.sqrt(2)))
def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
def PRelu(X:Tensor, slope:Tensor):
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
@@ -505,17 +505,6 @@ def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1):
cond = indices[:,None] == Tensor.arange(depth).reshape((1,) * len(ls) + (depth,) + (1,) * len(rs))
return cond.where(values[1], values[0])
def Erf(x: Tensor):
t = 1.0 / (1.0 + 0.3275911 * x.abs())
term1 = 0.254829592 * t
term2 = -0.284496736 * t ** 2
term3 = 1.421413741 * t ** 3
term4 = -1.453152027 * t ** 4
term5 = 1.061405429 * t ** 5
y = (term1 + term2 + term3 + term4 + term5)
z = 1.0 - y * (-x * x).exp()
return (x > 0).where(z, -z)
def Compress(inp: Tensor, condition: Tensor, axis=None):
if axis is None:
inp = inp.flatten()

View File

@@ -647,6 +647,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6)
helper_test_op([()], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6)
def test_erf(self):
helper_test_op([(45,65)], torch.erf, Tensor.erf)
helper_test_op([(45,65)], torch.erf, Tensor.erf, low=300, high=400)
helper_test_op([(45,65)], torch.erf, Tensor.erf, low=-400, high=-300)
helper_test_op([()], torch.erf, Tensor.erf)
def test_gelu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu)
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=300, high=400)

View File

@@ -7,7 +7,7 @@ from collections import defaultdict
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN
from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, UOps, BinaryOps, sint, Variable, SimpleMathTrait
from tinygrad.device import Device, Buffer, BufferOptions
@@ -2651,6 +2651,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
"""
return self.clip(min_val, max_val)
def erf(self):
"""
Applies error function element-wise.
- Described: https://en.wikipedia.org/wiki/Error_function
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).erf().numpy())
```
"""
# https://personal.math.ubc.ca/~cbm/aands/page_299.htm 7.1.26
t = 1.0 / (1.0 + 0.3275911 * self.abs())
return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp())
def gelu(self):
"""
Applies the Gaussian Error Linear Unit (GELU) function element-wise.