diff --git a/extra/onnx.py b/extra/onnx.py index 4d0a587f88..f90f0ede1f 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -38,13 +38,16 @@ def is_dtype_supported(dtype, device: str = Device.DEFAULT): if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU") return True -# src: onnx/mapping.py -# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15 -# NOTE: 17, 18, 19, 20 are float8, 10 is half -DTYPE_MAP = {1:dtypes.float, 2:dtypes.uint8, 3:dtypes.int8, 4:dtypes.uint16, 5:dtypes.int16, 6:dtypes.int32, 7:dtypes.int64, - 9:dtypes.bool, 10:dtypes.float, 11:dtypes.double, 12:dtypes.uint32, 13:dtypes.uint64, 16:dtypes.bfloat16, - 17:dtypes.float, 18:dtypes.float, 19:dtypes.float, 20:dtypes.float} -# TODO: fix buffer_parse to use this and fix get_weight_and_biases to only use buffer_parse +# src: onnx/mapping.py https://onnx.ai/onnx/api/mapping.html#l-mod-onnx-mapping +# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15, UINT4 = 21, INT4 = 22 +# TODO: use dtypes.float16 for FLOAT16 +DTYPE_MAP = { + TensorProto.FLOAT:dtypes.float, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8, TensorProto.UINT16:dtypes.uint16, + TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64, TensorProto.BOOL:dtypes.bool, + TensorProto.FLOAT16:dtypes.float, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32, TensorProto.UINT64:dtypes.uint64, + TensorProto.BFLOAT16:dtypes.bfloat16, TensorProto.FLOAT8E4M3FN:dtypes.float, TensorProto.FLOAT8E4M3FNUZ:dtypes.float, + TensorProto.FLOAT8E5M2:dtypes.float, TensorProto.FLOAT8E5M2FNUZ:dtypes.float +} onnx_ops = importlib.import_module('extra.onnx_ops') @@ -72,7 +75,8 @@ def get_run_onnx(onnx_model: ModelProto): else: raise Exception(f"unknown attr: {attr}, {type_proto}") def buffer_parse(inp: TensorProto) -> Tensor: - if inp.data_type in (8,14,15): raise Exception(f"data type not supported {inp.name} {inp.dims} {inp.data_type}") + if inp.data_type not in DTYPE_MAP: + raise NotImplementedError(f"data type not supported {inp.name} {inp.dims} {inp.data_type}") dtype = DTYPE_MAP[inp.data_type] if is_dtype_supported(DTYPE_MAP[inp.data_type]) else dtypes.float32 if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data): return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims))