mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
Tensor.isfinite (#9316)
This commit is contained in:
@@ -404,6 +404,9 @@ class TestOps(unittest.TestCase):
|
||||
def test_isnan(self):
|
||||
helper_test_op(None, torch.isnan, Tensor.isnan, vals=[[float('-inf'), 0., float('inf'), float('nan'), 1.1]], forward_only=True)
|
||||
|
||||
def test_isfinite(self):
|
||||
helper_test_op(None, torch.isfinite, Tensor.isfinite, vals=[[float('-inf'), 0., float('inf'), float('nan'), 1.1]], forward_only=True)
|
||||
|
||||
def test_lerp(self):
|
||||
helper_test_op([(45,35), (45,35), (45,35)], lambda x,y,z: x.lerp(y,z))
|
||||
helper_test_op(None, lambda x,y,z: x.lerp(y,z), vals=[[1.,2.,3.], [4.,5.,6.], 0.5])
|
||||
|
||||
@@ -1725,9 +1725,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(Tensor([float('nan')]).isclose(Tensor([float('nan')]), equal_nan=True).numpy())
|
||||
```
|
||||
"""
|
||||
# TODO: Tensor.isfinite
|
||||
def isfinite(t): return (t.isinf()|t.isnan()).logical_not()
|
||||
is_finite_close = isfinite(self) & isfinite(other) & ((self - other).abs() <= atol + rtol * other.abs())
|
||||
is_finite_close = self.isfinite() & other.isfinite() & ((self - other).abs() <= atol + rtol * other.abs())
|
||||
is_infinite_close = (self.isinf() | other.isinf()) & (self == other)
|
||||
is_nan_close = (self.isnan() & other.isnan()) & equal_nan
|
||||
return is_finite_close | is_infinite_close | is_nan_close
|
||||
@@ -2793,7 +2791,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
|
||||
|
||||
def isinf(self:Tensor, detect_positive:bool=True, detect_negative:bool=True):
|
||||
def isinf(self:Tensor, detect_positive:bool=True, detect_negative:bool=True) -> Tensor:
|
||||
"""
|
||||
Checks the tensor element-wise to return True where the element is infinity, otherwise returns False
|
||||
|
||||
@@ -2802,7 +2800,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return (self == float("inf")) * detect_positive + (self == float("-inf")) * detect_negative
|
||||
def isnan(self:Tensor):
|
||||
def isnan(self:Tensor) -> Tensor:
|
||||
"""
|
||||
Checks the tensor element-wise to return True where the element is NaN, otherwise returns False
|
||||
|
||||
@@ -2811,6 +2809,15 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return self != self
|
||||
def isfinite(self:Tensor) -> Tensor:
|
||||
"""
|
||||
Checks the tensor element-wise to return True where the element is finite, otherwise returns False
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isfinite().numpy())
|
||||
```
|
||||
"""
|
||||
return (self.isinf()|self.isnan()).logical_not()
|
||||
|
||||
def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user