diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 3a29735056..0ddd673caa 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -167,7 +167,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]: def __setstate__(self, state): self.tensor = state[0] deserialized_objects: Dict[str, Any] = {} - intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, + intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, + "IntStorage": dtypes.int32, "BoolStorage": dtypes.bool, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter} whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed class Dummy: pass diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index cffff7340f..ce79c630e6 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2173,7 +2173,7 @@ class Tensor: ``` """ return self*self - def clip(self, min_=None, max_=None): + def clamp(self, min_=None, max_=None): """ Clips (clamps) the values in the tensor between `min_` and `max_` element-wise. If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound. @@ -2185,6 +2185,11 @@ class Tensor: if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None") ret = self.maximum(min_) if min_ is not None else self return ret.minimum(max_) if max_ is not None else ret + def clip(self, min_=None, max_=None): + """ + Alias for `Tensor.clamp`. + """ + return self.clamp(min_, max_) def sign(self): """ Returns the sign of the tensor element-wise.