diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index f13c5f23d5..d20428ac10 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -108,6 +108,14 @@ class TestSafetensors(unittest.TestCase): import json assert json.loads(dat[8:8+sz])['__metadata__']['hello'] == 'world' + def test_save_all_dtypes(self): + for dtype in dtypes.fields().values(): + if dtype in [dtypes.bfloat16, dtypes._arg_int32]: continue # not supported in numpy + path = temp("ones.safetensors") + ones = Tensor.rand((10,10), dtype=dtype) + safe_save(get_state_dict(ones), path) + assert ones == list(safe_load(path).values())[0] + def helper_test_disk_tensor(fn, data, np_fxn, tinygrad_fxn=None): if tinygrad_fxn is None: tinygrad_fxn = np_fxn pathlib.Path(temp(fn)).unlink(missing_ok=True) diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 8350c5e91f..bc97ea83be 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -5,7 +5,7 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI, unwrap from tinygrad.shape.view import strides_for_shape -safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64} +safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64, "F64": dtypes.double, "B": dtypes.bool, "I16": dtypes.short, "U16": dtypes.ushort, "UI": dtypes.uint, "UL": dtypes.ulong} inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()} def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]: