Revert "combine get inputs and type_parse function in onnx (#8069)" (#8079)

This reverts commit 074a67a6eb.
This commit is contained in:
chenyu
2024-12-06 08:04:21 -05:00
committed by GitHub
parent c8313a3669
commit b73d9a7d24
2 changed files with 46 additions and 43 deletions

View File

@@ -1,12 +1,12 @@
from __future__ import annotations
from typing import List, Dict, Union, Callable, Any, Sequence
from typing import List, Dict, Union, Callable, Any
import importlib, functools
import numpy as np
from tinygrad import Tensor, dtypes
from tinygrad.helpers import getenv, DEBUG, all_same
from tinygrad.helpers import getenv, DEBUG
from tinygrad.dtype import DType, ConstType
from tinygrad.device import is_dtype_supported
from onnx import AttributeProto, ModelProto, TensorProto, ValueInfoProto
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto
try:
from onnx.helper import tensor_dtype_to_np_dtype
except ImportError:
@@ -68,12 +68,31 @@ onnx_ops = importlib.import_module('extra.onnx_ops')
ONNXLIMIT = getenv("ONNXLIMIT", -1)
def get_run_onnx(onnx_model: ModelProto):
# model initialization data
def type_parse(type_proto: TypeProto):
ret = []
while True:
attr = type_proto.WhichOneof('value')
if attr == 'tensor_type':
if "dim_value" not in type_proto.tensor_type.shape.dim.__dir__(): return () # variable type, unable to determine shape
elif not ret:
return tuple([x.dim_value for x in type_proto.tensor_type.shape.dim])
else:
ret.extend([(x.dim_value,) for x in type_proto.tensor_type.shape.dim])
return tuple(ret)
elif attr == 'sequence_type':
type_proto = getattr(type_proto, attr).elem_type
ret.append(1)
elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}")
elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}")
elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}")
else: raise AttributeError(f"unknown attr: {attr}, {type_proto}")
# initialization data
model_parameters = {inp.name:buffer_parse(inp) for inp in onnx_model.graph.initializer}
model_attributes = {num:{x.name:attribute_parse(x) for x in n.attribute} for num,n in enumerate(onnx_model.graph.node)}
# model descriptions
# TODO: need a better way of controlling training vs non-training
# model specs
is_onnx_preview_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in onnx_model.graph.node)
onnx_model_version = onnx_model.opset_import[0].version
@@ -84,42 +103,32 @@ def get_run_onnx(onnx_model: ModelProto):
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf")
}
# src: https://onnx.ai/onnx/repo-docs/IR.html#input-output-data-types
# parses and validates inputs based on their shape and dtype specified by model
def prepare_input(user_input:Any, model_input:ValueInfoProto):
type_proto = model_input.type
if type_proto.HasField("optional_type"):
if user_input is None: return Tensor(None)
type_proto = type_proto.optional_type.elem_type
if type_proto.HasField("sequence_type"):
if not isinstance(user_input, Sequence): raise RuntimeError(f"{model_input.name} received {user_input}, expected sequence type")
dtype = dtype_parse(type_proto.sequence_type.elem_type.tensor_type.elem_type)
sequence = [Tensor(i, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(i, Tensor) else i for i in user_input]
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"shapes for {model_input.name} must be homogeneous")
# TODO: need true float16 for dtype checking
# if not all(t.dtype is dtype for t in sequence): raise RuntimeError(f"{model_input.name} received wrong dtype, expected {dtype}")
return sequence
if type_proto.HasField("tensor_type"):
dtype = dtype_parse(type_proto.tensor_type.elem_type)
tensor = Tensor(user_input, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(user_input, Tensor) else user_input
# TODO: need true float16 for dtype checking
# if dtype is not tensor.dtype: raise RuntimeError(f"{model_input.name} received dtype {inp.dtype}, expected {dtype}")
for d,onnx_dim in enumerate(type_proto.tensor_type.shape.dim):
# NOTE: `dim_value` is a variable when `dim_value` is not specified and `dim_param` is, e.g. dim {dim_param: "N"}
if onnx_dim.dim_value is not None and onnx_dim.dim_value != user_input.shape[d]:
raise RuntimeError(f"{model_input.name} received value {user_input.shape[d]} on dim {d}, expected {onnx_dim.dim_value}")
return tensor
type_field_names = [field.name for field,_ in type_proto.ListFields()]
raise NotImplementedError(f"{model_input.name} with {type_field_names=} is not supported")
def run_onnx(inputs={}, debug=0):
debug = getenv("DEBUGONNX") or debug
input_tensors: Dict[str,Tensor|List[Tensor]] = {}
intermediate_tensors: Dict[str,Tensor] = {}
input_tensors: Dict[str, Tensor | List[Tensor]] = {}
# get inputs
for model_input in onnx_model.graph.input:
if model_input.name in inputs: input_tensors[model_input.name] = prepare_input(inputs[model_input.name], model_input)
elif model_input.name not in model_parameters: raise RuntimeError(f"Please provide input data for {model_input.name}")
name = model_input.name
if name in model_parameters: continue
shape = type_parse(model_input.type)
if name in inputs:
if isinstance(inputs[name], Tensor):
input_tensors[name] = inputs[name]
elif isinstance(inputs[name], list):
input_tensors[name] = [Tensor(i, requires_grad=False) for i in inputs[name]]
# TODO: this is just to make training tests pass, need a principled way to handle training vs non-training
elif is_onnx_preview_training:
input_tensors[name] = Tensor(inputs[name], requires_grad=True)
else:
input_tensors[name] = Tensor(inputs[name], requires_grad=False)
if shape: # if only input_tensor is not variable type
ts = input_tensors[name]
input_shape = ts.shape if isinstance(ts, Tensor) else (1, *[i.shape for i in ts])
assert input_shape == shape, f"wrong shape for input {name}, {input_shape} isn't {shape}"
else:
raise RuntimeError(f"no data for {name} with shape {shape}")
def fetch_tensor(x: str):
if x in model_parameters: return model_parameters[x]

View File

@@ -70,17 +70,11 @@ backend_test.exclude('BFLOAT16') # not supported in numpy
# TODO: fix these with true onnx float16
backend_test.exclude('to_FLOAT16')
backend_test.exclude('cast_no_saturate')
backend_test.exclude('test_dequantizelinear_e4m3fn_float16_cpu')
backend_test.exclude('test_max_float16_cpu')
backend_test.exclude('test_min_float16_cpu')
backend_test.exclude('test_pow_types_int*')
backend_test.exclude('test_convinteger_*')
backend_test.exclude('test_matmulinteger_*')
backend_test.exclude('test_dequantizelinear_int4_cpu')
backend_test.exclude('test_dequantizelinear_uint4_cpu')
# we don't support indexes
backend_test.exclude('test_nonzero_*')