Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
geohotstan
2024-04-05 23:57:44 +08:00
committed by GitHub
parent 750ecf8fef
commit dafa42e864

View File

@@ -2,8 +2,8 @@ from __future__ import annotations
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
import importlib
import numpy as np
from tinygrad import Tensor, dtypes
from tinygrad.helpers import getenv, DEBUG
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import getenv, DEBUG, CI, OSX
from typing import List, Dict
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto # onnx 1.50 uses serialized file (see onnx/onnx-ml.proto) as descriptors
try:
@@ -24,6 +24,14 @@ def safe_numpy(t) -> np.ndarray:
numpy_cache[t] = tmp
return numpy_cache[t]
# copied from helpers.py
def is_dtype_supported(dtype, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16: return False
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
if dtype == dtypes.half: return not (CI and device in {"GPU", "LLVM", "CUDA"})
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
return True
# src: onnx/mapping.py
# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15
# NOTE: 17, 18, 19, 20 are float8, 10 is half
@@ -58,19 +66,13 @@ def get_run_onnx(onnx_model: ModelProto):
else: raise Exception(f"unknown attr: {attr}, {type_proto}")
def buffer_parse(inp: TensorProto) -> Tensor:
if inp.data_type in (1,10,6,7,5,11):
# TODO: this is shared with below
if len(inp.float_data) > 0:
ret = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
elif len(inp.int64_data) > 0:
ret = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False)
elif len(inp.int32_data) > 0:
ret = Tensor(np.array(inp.int32_data, dtype=np.int32).reshape(inp.dims), requires_grad=False)
else:
ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False)
else:
raise Exception(f"bad data type {inp.name} {inp.dims} {inp.data_type}")
return ret
if inp.data_type in (8,14,15): raise Exception(f"data type not supported {inp.name} {inp.dims} {inp.data_type}")
dtype = DTYPE_MAP[inp.data_type] if is_dtype_supported(DTYPE_MAP[inp.data_type]) else dtypes.float32
if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data):
return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims))
if len(inp.raw_data) > 0:
return Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(dtype.np).copy(), requires_grad=False).reshape(tuple(inp.dims))
return Tensor(None, requires_grad=False)
def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:
# TODO: this is not complete, see onnx/onnx_ml_pb2.pyi for a complete list
@@ -89,18 +91,7 @@ def get_run_onnx(onnx_model: ModelProto):
# get weights and biases
for inp in onnx_model.graph.initializer:
if len(inp.raw_data) > 0:
tensors[inp.name] = buffer_parse(inp)
elif len(inp.float_data) > 0:
tensors[inp.name] = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
elif len(inp.int64_data) > 0:
tensors[inp.name] = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False)
elif len(inp.raw_data) == 0:
tensors[inp.name] = Tensor(np.array([], dtype=np.float32), requires_grad=False)
else:
print(inp.name, inp.dims, inp.data_type, len(inp.raw_data))
print(inp)
raise Exception("no data")
tensors[inp.name] = buffer_parse(inp)
# preparse the attributes
attribute_dict = {}