mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
add Tensor.clamp and fix bool loading (#6069)
This commit is contained in:
@@ -167,7 +167,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
||||
def __setstate__(self, state): self.tensor = state[0]
|
||||
|
||||
deserialized_objects: Dict[str, Any] = {}
|
||||
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32,
|
||||
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
|
||||
"IntStorage": dtypes.int32, "BoolStorage": dtypes.bool,
|
||||
"LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
|
||||
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
|
||||
class Dummy: pass
|
||||
|
||||
@@ -2173,7 +2173,7 @@ class Tensor:
|
||||
```
|
||||
"""
|
||||
return self*self
|
||||
def clip(self, min_=None, max_=None):
|
||||
def clamp(self, min_=None, max_=None):
|
||||
"""
|
||||
Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
|
||||
If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound.
|
||||
@@ -2185,6 +2185,11 @@ class Tensor:
|
||||
if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None")
|
||||
ret = self.maximum(min_) if min_ is not None else self
|
||||
return ret.minimum(max_) if max_ is not None else ret
|
||||
def clip(self, min_=None, max_=None):
|
||||
"""
|
||||
Alias for `Tensor.clamp`.
|
||||
"""
|
||||
return self.clamp(min_, max_)
|
||||
def sign(self):
|
||||
"""
|
||||
Returns the sign of the tensor element-wise.
|
||||
|
||||
Reference in New Issue
Block a user