diff --git a/extra/onnx.py b/extra/onnx.py index c1d5bf0163..0d809440d2 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -2,8 +2,8 @@ from __future__ import annotations from google.protobuf.internal.containers import RepeatedCompositeFieldContainer import importlib import numpy as np -from tinygrad import Tensor, dtypes -from tinygrad.helpers import getenv, DEBUG +from tinygrad import Tensor, dtypes, Device +from tinygrad.helpers import getenv, DEBUG, CI, OSX from typing import List, Dict from onnx import AttributeProto, ModelProto, TensorProto, TypeProto # onnx 1.50 uses serialized file (see onnx/onnx-ml.proto) as descriptors try: @@ -24,6 +24,14 @@ def safe_numpy(t) -> np.ndarray: numpy_cache[t] = tmp return numpy_cache[t] +# copied from helpers.py +def is_dtype_supported(dtype, device: str = Device.DEFAULT): + if dtype == dtypes.bfloat16: return False + if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32] + if dtype == dtypes.half: return not (CI and device in {"GPU", "LLVM", "CUDA"}) + if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU") + return True + # src: onnx/mapping.py # not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15 # NOTE: 17, 18, 19, 20 are float8, 10 is half @@ -58,19 +66,13 @@ def get_run_onnx(onnx_model: ModelProto): else: raise Exception(f"unknown attr: {attr}, {type_proto}") def buffer_parse(inp: TensorProto) -> Tensor: - if inp.data_type in (1,10,6,7,5,11): - # TODO: this is shared with below - if len(inp.float_data) > 0: - ret = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False) - elif len(inp.int64_data) > 0: - ret = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False) - elif len(inp.int32_data) > 0: - ret = Tensor(np.array(inp.int32_data, dtype=np.int32).reshape(inp.dims), requires_grad=False) - else: - ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False) - else: - raise Exception(f"bad data type {inp.name} {inp.dims} {inp.data_type}") - return ret + if inp.data_type in (8,14,15): raise Exception(f"data type not supported {inp.name} {inp.dims} {inp.data_type}") + dtype = DTYPE_MAP[inp.data_type] if is_dtype_supported(DTYPE_MAP[inp.data_type]) else dtypes.float32 + if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data): + return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims)) + if len(inp.raw_data) > 0: + return Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(dtype.np).copy(), requires_grad=False).reshape(tuple(inp.dims)) + return Tensor(None, requires_grad=False) def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]: # TODO: this is not complete, see onnx/onnx_ml_pb2.pyi for a complete list @@ -89,18 +91,7 @@ def get_run_onnx(onnx_model: ModelProto): # get weights and biases for inp in onnx_model.graph.initializer: - if len(inp.raw_data) > 0: - tensors[inp.name] = buffer_parse(inp) - elif len(inp.float_data) > 0: - tensors[inp.name] = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False) - elif len(inp.int64_data) > 0: - tensors[inp.name] = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False) - elif len(inp.raw_data) == 0: - tensors[inp.name] = Tensor(np.array([], dtype=np.float32), requires_grad=False) - else: - print(inp.name, inp.dims, inp.data_type, len(inp.raw_data)) - print(inp) - raise Exception("no data") + tensors[inp.name] = buffer_parse(inp) # preparse the attributes attribute_dict = {}