mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
@@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
|
||||
import importlib
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.helpers import getenv, DEBUG, CI, OSX
|
||||
from typing import List, Dict
|
||||
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto # onnx 1.50 uses serialized file (see onnx/onnx-ml.proto) as descriptors
|
||||
try:
|
||||
@@ -24,6 +24,14 @@ def safe_numpy(t) -> np.ndarray:
|
||||
numpy_cache[t] = tmp
|
||||
return numpy_cache[t]
|
||||
|
||||
# copied from helpers.py
|
||||
def is_dtype_supported(dtype, device: str = Device.DEFAULT):
|
||||
if dtype == dtypes.bfloat16: return False
|
||||
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
|
||||
if dtype == dtypes.half: return not (CI and device in {"GPU", "LLVM", "CUDA"})
|
||||
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
|
||||
return True
|
||||
|
||||
# src: onnx/mapping.py
|
||||
# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15
|
||||
# NOTE: 17, 18, 19, 20 are float8, 10 is half
|
||||
@@ -58,19 +66,13 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
else: raise Exception(f"unknown attr: {attr}, {type_proto}")
|
||||
|
||||
def buffer_parse(inp: TensorProto) -> Tensor:
|
||||
if inp.data_type in (1,10,6,7,5,11):
|
||||
# TODO: this is shared with below
|
||||
if len(inp.float_data) > 0:
|
||||
ret = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.int64_data) > 0:
|
||||
ret = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.int32_data) > 0:
|
||||
ret = Tensor(np.array(inp.int32_data, dtype=np.int32).reshape(inp.dims), requires_grad=False)
|
||||
else:
|
||||
ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False)
|
||||
else:
|
||||
raise Exception(f"bad data type {inp.name} {inp.dims} {inp.data_type}")
|
||||
return ret
|
||||
if inp.data_type in (8,14,15): raise Exception(f"data type not supported {inp.name} {inp.dims} {inp.data_type}")
|
||||
dtype = DTYPE_MAP[inp.data_type] if is_dtype_supported(DTYPE_MAP[inp.data_type]) else dtypes.float32
|
||||
if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data):
|
||||
return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims))
|
||||
if len(inp.raw_data) > 0:
|
||||
return Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(dtype.np).copy(), requires_grad=False).reshape(tuple(inp.dims))
|
||||
return Tensor(None, requires_grad=False)
|
||||
|
||||
def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:
|
||||
# TODO: this is not complete, see onnx/onnx_ml_pb2.pyi for a complete list
|
||||
@@ -89,18 +91,7 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
|
||||
# get weights and biases
|
||||
for inp in onnx_model.graph.initializer:
|
||||
if len(inp.raw_data) > 0:
|
||||
tensors[inp.name] = buffer_parse(inp)
|
||||
elif len(inp.float_data) > 0:
|
||||
tensors[inp.name] = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.int64_data) > 0:
|
||||
tensors[inp.name] = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.raw_data) == 0:
|
||||
tensors[inp.name] = Tensor(np.array([], dtype=np.float32), requires_grad=False)
|
||||
else:
|
||||
print(inp.name, inp.dims, inp.data_type, len(inp.raw_data))
|
||||
print(inp)
|
||||
raise Exception("no data")
|
||||
tensors[inp.name] = buffer_parse(inp)
|
||||
|
||||
# preparse the attributes
|
||||
attribute_dict = {}
|
||||
|
||||
Reference in New Issue
Block a user