mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
This reverts commit 074a67a6eb.
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Union, Callable, Any, Sequence
|
||||
from typing import List, Dict, Union, Callable, Any
|
||||
import importlib, functools
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import getenv, DEBUG, all_same
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
from tinygrad.dtype import DType, ConstType
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from onnx import AttributeProto, ModelProto, TensorProto, ValueInfoProto
|
||||
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto
|
||||
try:
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
except ImportError:
|
||||
@@ -68,12 +68,31 @@ onnx_ops = importlib.import_module('extra.onnx_ops')
|
||||
ONNXLIMIT = getenv("ONNXLIMIT", -1)
|
||||
|
||||
def get_run_onnx(onnx_model: ModelProto):
|
||||
# model initialization data
|
||||
def type_parse(type_proto: TypeProto):
|
||||
ret = []
|
||||
while True:
|
||||
attr = type_proto.WhichOneof('value')
|
||||
if attr == 'tensor_type':
|
||||
if "dim_value" not in type_proto.tensor_type.shape.dim.__dir__(): return () # variable type, unable to determine shape
|
||||
elif not ret:
|
||||
return tuple([x.dim_value for x in type_proto.tensor_type.shape.dim])
|
||||
else:
|
||||
ret.extend([(x.dim_value,) for x in type_proto.tensor_type.shape.dim])
|
||||
return tuple(ret)
|
||||
elif attr == 'sequence_type':
|
||||
type_proto = getattr(type_proto, attr).elem_type
|
||||
ret.append(1)
|
||||
elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
|
||||
elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}")
|
||||
elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}")
|
||||
elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}")
|
||||
else: raise AttributeError(f"unknown attr: {attr}, {type_proto}")
|
||||
|
||||
# initialization data
|
||||
model_parameters = {inp.name:buffer_parse(inp) for inp in onnx_model.graph.initializer}
|
||||
model_attributes = {num:{x.name:attribute_parse(x) for x in n.attribute} for num,n in enumerate(onnx_model.graph.node)}
|
||||
|
||||
# model descriptions
|
||||
# TODO: need a better way of controlling training vs non-training
|
||||
# model specs
|
||||
is_onnx_preview_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in onnx_model.graph.node)
|
||||
onnx_model_version = onnx_model.opset_import[0].version
|
||||
|
||||
@@ -84,42 +103,32 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf")
|
||||
}
|
||||
|
||||
# src: https://onnx.ai/onnx/repo-docs/IR.html#input-output-data-types
|
||||
# parses and validates inputs based on their shape and dtype specified by model
|
||||
def prepare_input(user_input:Any, model_input:ValueInfoProto):
|
||||
type_proto = model_input.type
|
||||
if type_proto.HasField("optional_type"):
|
||||
if user_input is None: return Tensor(None)
|
||||
type_proto = type_proto.optional_type.elem_type
|
||||
if type_proto.HasField("sequence_type"):
|
||||
if not isinstance(user_input, Sequence): raise RuntimeError(f"{model_input.name} received {user_input}, expected sequence type")
|
||||
dtype = dtype_parse(type_proto.sequence_type.elem_type.tensor_type.elem_type)
|
||||
sequence = [Tensor(i, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(i, Tensor) else i for i in user_input]
|
||||
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"shapes for {model_input.name} must be homogeneous")
|
||||
# TODO: need true float16 for dtype checking
|
||||
# if not all(t.dtype is dtype for t in sequence): raise RuntimeError(f"{model_input.name} received wrong dtype, expected {dtype}")
|
||||
return sequence
|
||||
if type_proto.HasField("tensor_type"):
|
||||
dtype = dtype_parse(type_proto.tensor_type.elem_type)
|
||||
tensor = Tensor(user_input, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(user_input, Tensor) else user_input
|
||||
# TODO: need true float16 for dtype checking
|
||||
# if dtype is not tensor.dtype: raise RuntimeError(f"{model_input.name} received dtype {inp.dtype}, expected {dtype}")
|
||||
for d,onnx_dim in enumerate(type_proto.tensor_type.shape.dim):
|
||||
# NOTE: `dim_value` is a variable when `dim_value` is not specified and `dim_param` is, e.g. dim {dim_param: "N"}
|
||||
if onnx_dim.dim_value is not None and onnx_dim.dim_value != user_input.shape[d]:
|
||||
raise RuntimeError(f"{model_input.name} received value {user_input.shape[d]} on dim {d}, expected {onnx_dim.dim_value}")
|
||||
return tensor
|
||||
type_field_names = [field.name for field,_ in type_proto.ListFields()]
|
||||
raise NotImplementedError(f"{model_input.name} with {type_field_names=} is not supported")
|
||||
|
||||
def run_onnx(inputs={}, debug=0):
|
||||
debug = getenv("DEBUGONNX") or debug
|
||||
input_tensors: Dict[str,Tensor|List[Tensor]] = {}
|
||||
intermediate_tensors: Dict[str,Tensor] = {}
|
||||
|
||||
input_tensors: Dict[str, Tensor | List[Tensor]] = {}
|
||||
# get inputs
|
||||
for model_input in onnx_model.graph.input:
|
||||
if model_input.name in inputs: input_tensors[model_input.name] = prepare_input(inputs[model_input.name], model_input)
|
||||
elif model_input.name not in model_parameters: raise RuntimeError(f"Please provide input data for {model_input.name}")
|
||||
name = model_input.name
|
||||
if name in model_parameters: continue
|
||||
shape = type_parse(model_input.type)
|
||||
if name in inputs:
|
||||
if isinstance(inputs[name], Tensor):
|
||||
input_tensors[name] = inputs[name]
|
||||
elif isinstance(inputs[name], list):
|
||||
input_tensors[name] = [Tensor(i, requires_grad=False) for i in inputs[name]]
|
||||
# TODO: this is just to make training tests pass, need a principled way to handle training vs non-training
|
||||
elif is_onnx_preview_training:
|
||||
input_tensors[name] = Tensor(inputs[name], requires_grad=True)
|
||||
else:
|
||||
input_tensors[name] = Tensor(inputs[name], requires_grad=False)
|
||||
if shape: # if only input_tensor is not variable type
|
||||
ts = input_tensors[name]
|
||||
input_shape = ts.shape if isinstance(ts, Tensor) else (1, *[i.shape for i in ts])
|
||||
assert input_shape == shape, f"wrong shape for input {name}, {input_shape} isn't {shape}"
|
||||
else:
|
||||
raise RuntimeError(f"no data for {name} with shape {shape}")
|
||||
|
||||
def fetch_tensor(x: str):
|
||||
if x in model_parameters: return model_parameters[x]
|
||||
|
||||
6
test/external/external_test_onnx_backend.py
vendored
6
test/external/external_test_onnx_backend.py
vendored
@@ -70,17 +70,11 @@ backend_test.exclude('BFLOAT16') # not supported in numpy
|
||||
# TODO: fix these with true onnx float16
|
||||
backend_test.exclude('to_FLOAT16')
|
||||
backend_test.exclude('cast_no_saturate')
|
||||
backend_test.exclude('test_dequantizelinear_e4m3fn_float16_cpu')
|
||||
backend_test.exclude('test_max_float16_cpu')
|
||||
backend_test.exclude('test_min_float16_cpu')
|
||||
|
||||
backend_test.exclude('test_pow_types_int*')
|
||||
backend_test.exclude('test_convinteger_*')
|
||||
backend_test.exclude('test_matmulinteger_*')
|
||||
|
||||
backend_test.exclude('test_dequantizelinear_int4_cpu')
|
||||
backend_test.exclude('test_dequantizelinear_uint4_cpu')
|
||||
|
||||
# we don't support indexes
|
||||
backend_test.exclude('test_nonzero_*')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user