add Tensor.clamp and fix bool loading (#6069)

This commit is contained in:
George Hotz
2024-08-13 15:26:40 -07:00
committed by GitHub
parent 1782e4f64d
commit e039b2a920
2 changed files with 8 additions and 2 deletions

View File

@@ -167,7 +167,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
def __setstate__(self, state): self.tensor = state[0]
deserialized_objects: Dict[str, Any] = {}
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32,
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
"IntStorage": dtypes.int32, "BoolStorage": dtypes.bool,
"LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
class Dummy: pass

View File

@@ -2173,7 +2173,7 @@ class Tensor:
```
"""
return self*self
def clip(self, min_=None, max_=None):
def clamp(self, min_=None, max_=None):
"""
Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound.
@@ -2185,6 +2185,11 @@ class Tensor:
if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None")
ret = self.maximum(min_) if min_ is not None else self
return ret.minimum(max_) if max_ is not None else ret
def clip(self, min_=None, max_=None):
"""
Alias for `Tensor.clamp`.
"""
return self.clamp(min_, max_)
def sign(self):
"""
Returns the sign of the tensor element-wise.