diff --git a/extra/onnx.py b/extra/onnx.py index 00cd77caea..0b4493db2b 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -1,12 +1,12 @@ from __future__ import annotations -from typing import List, Dict, Union, Callable, Any, Sequence +from typing import List, Dict, Union, Callable, Any import importlib, functools import numpy as np from tinygrad import Tensor, dtypes -from tinygrad.helpers import getenv, DEBUG, all_same +from tinygrad.helpers import getenv, DEBUG from tinygrad.dtype import DType, ConstType from tinygrad.device import is_dtype_supported -from onnx import AttributeProto, ModelProto, TensorProto, ValueInfoProto +from onnx import AttributeProto, ModelProto, TensorProto, TypeProto try: from onnx.helper import tensor_dtype_to_np_dtype except ImportError: @@ -68,12 +68,31 @@ onnx_ops = importlib.import_module('extra.onnx_ops') ONNXLIMIT = getenv("ONNXLIMIT", -1) def get_run_onnx(onnx_model: ModelProto): - # model initialization data + def type_parse(type_proto: TypeProto): + ret = [] + while True: + attr = type_proto.WhichOneof('value') + if attr == 'tensor_type': + if "dim_value" not in type_proto.tensor_type.shape.dim.__dir__(): return () # variable type, unable to determine shape + elif not ret: + return tuple([x.dim_value for x in type_proto.tensor_type.shape.dim]) + else: + ret.extend([(x.dim_value,) for x in type_proto.tensor_type.shape.dim]) + return tuple(ret) + elif attr == 'sequence_type': + type_proto = getattr(type_proto, attr).elem_type + ret.append(1) + elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type + elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}") + elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}") + elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}") + else: raise AttributeError(f"unknown attr: {attr}, {type_proto}") + + # initialization data model_parameters = {inp.name:buffer_parse(inp) for inp in onnx_model.graph.initializer} model_attributes = {num:{x.name:attribute_parse(x) for x in n.attribute} for num,n in enumerate(onnx_model.graph.node)} - # model descriptions - # TODO: need a better way of controlling training vs non-training + # model specs is_onnx_preview_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in onnx_model.graph.node) onnx_model_version = onnx_model.opset_import[0].version @@ -84,42 +103,32 @@ def get_run_onnx(onnx_model: ModelProto): "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf") } - # src: https://onnx.ai/onnx/repo-docs/IR.html#input-output-data-types - # parses and validates inputs based on their shape and dtype specified by model - def prepare_input(user_input:Any, model_input:ValueInfoProto): - type_proto = model_input.type - if type_proto.HasField("optional_type"): - if user_input is None: return Tensor(None) - type_proto = type_proto.optional_type.elem_type - if type_proto.HasField("sequence_type"): - if not isinstance(user_input, Sequence): raise RuntimeError(f"{model_input.name} received {user_input}, expected sequence type") - dtype = dtype_parse(type_proto.sequence_type.elem_type.tensor_type.elem_type) - sequence = [Tensor(i, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(i, Tensor) else i for i in user_input] - if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"shapes for {model_input.name} must be homogeneous") - # TODO: need true float16 for dtype checking - # if not all(t.dtype is dtype for t in sequence): raise RuntimeError(f"{model_input.name} received wrong dtype, expected {dtype}") - return sequence - if type_proto.HasField("tensor_type"): - dtype = dtype_parse(type_proto.tensor_type.elem_type) - tensor = Tensor(user_input, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(user_input, Tensor) else user_input - # TODO: need true float16 for dtype checking - # if dtype is not tensor.dtype: raise RuntimeError(f"{model_input.name} received dtype {inp.dtype}, expected {dtype}") - for d,onnx_dim in enumerate(type_proto.tensor_type.shape.dim): - # NOTE: `dim_value` is a variable when `dim_value` is not specified and `dim_param` is, e.g. dim {dim_param: "N"} - if onnx_dim.dim_value is not None and onnx_dim.dim_value != user_input.shape[d]: - raise RuntimeError(f"{model_input.name} received value {user_input.shape[d]} on dim {d}, expected {onnx_dim.dim_value}") - return tensor - type_field_names = [field.name for field,_ in type_proto.ListFields()] - raise NotImplementedError(f"{model_input.name} with {type_field_names=} is not supported") - def run_onnx(inputs={}, debug=0): debug = getenv("DEBUGONNX") or debug + input_tensors: Dict[str,Tensor|List[Tensor]] = {} intermediate_tensors: Dict[str,Tensor] = {} - input_tensors: Dict[str, Tensor | List[Tensor]] = {} + # get inputs for model_input in onnx_model.graph.input: - if model_input.name in inputs: input_tensors[model_input.name] = prepare_input(inputs[model_input.name], model_input) - elif model_input.name not in model_parameters: raise RuntimeError(f"Please provide input data for {model_input.name}") + name = model_input.name + if name in model_parameters: continue + shape = type_parse(model_input.type) + if name in inputs: + if isinstance(inputs[name], Tensor): + input_tensors[name] = inputs[name] + elif isinstance(inputs[name], list): + input_tensors[name] = [Tensor(i, requires_grad=False) for i in inputs[name]] + # TODO: this is just to make training tests pass, need a principled way to handle training vs non-training + elif is_onnx_preview_training: + input_tensors[name] = Tensor(inputs[name], requires_grad=True) + else: + input_tensors[name] = Tensor(inputs[name], requires_grad=False) + if shape: # if only input_tensor is not variable type + ts = input_tensors[name] + input_shape = ts.shape if isinstance(ts, Tensor) else (1, *[i.shape for i in ts]) + assert input_shape == shape, f"wrong shape for input {name}, {input_shape} isn't {shape}" + else: + raise RuntimeError(f"no data for {name} with shape {shape}") def fetch_tensor(x: str): if x in model_parameters: return model_parameters[x] diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index b3f6ae86ee..0cb704d48e 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -70,17 +70,11 @@ backend_test.exclude('BFLOAT16') # not supported in numpy # TODO: fix these with true onnx float16 backend_test.exclude('to_FLOAT16') backend_test.exclude('cast_no_saturate') -backend_test.exclude('test_dequantizelinear_e4m3fn_float16_cpu') -backend_test.exclude('test_max_float16_cpu') -backend_test.exclude('test_min_float16_cpu') backend_test.exclude('test_pow_types_int*') backend_test.exclude('test_convinteger_*') backend_test.exclude('test_matmulinteger_*') -backend_test.exclude('test_dequantizelinear_int4_cpu') -backend_test.exclude('test_dequantizelinear_uint4_cpu') - # we don't support indexes backend_test.exclude('test_nonzero_*')