diff --git a/extra/onnx.py b/extra/onnx.py index f89bdbb256..af4cef92c1 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -1,6 +1,6 @@ from types import SimpleNamespace from typing import Any, Sequence, cast, Literal, Callable -import dataclasses, functools, io, math, types +import dataclasses, functools, io, math, types, warnings from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort from tinygrad.dtype import DType, ConstType, dtypes, ImageDType @@ -14,7 +14,7 @@ def has_field(onnx_type: TypeProto|SimpleNamespace, field): if isinstance(onnx_type, TypeProto): return onnx_type.HasField(field) return hasattr(onnx_type, field) -def dtype_parse(onnx_dtype: int) -> DType: +def dtype_parse(onnx_dtype: int, fallback_context: str | None = None) -> DType: supported: dict[int, DType] = { TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8, TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64, @@ -26,7 +26,13 @@ def dtype_parse(onnx_dtype: int) -> DType: TensorProto.FLOAT8E5M2, TensorProto.FLOAT8E5M2FNUZ, TensorProto.UINT4, TensorProto.INT4 } if onnx_dtype in unsupported: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported") - return supported[onnx_dtype] if is_dtype_supported(supported[onnx_dtype]) else dtypes.float + if is_dtype_supported(dtype := supported[onnx_dtype]): return dtype + # if fallback_context is provided, we can fall back to a default dtype + if fallback_context is not None: + default_dtype = dtypes.float + warnings.warn(f"dtype {dtype} on {Device.DEFAULT} from {fallback_context} is not supported, falling back to {default_dtype}") + return default_dtype + raise RuntimeError(f"dtype {dtype} on device {Device.DEFAULT} is not supported") def attribute_parse(onnx_attribute: AttributeProto): supported: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = { @@ -46,7 +52,7 @@ def attribute_parse(onnx_attribute: AttributeProto): def buffer_parse(onnx_tensor: TensorProto) -> Tensor: if onnx_tensor.string_data: raise NotImplementedError("Parsing for buffer with string data is not implemented.") - dtype, shape = dtype_parse(onnx_tensor.data_type), tuple(onnx_tensor.dims) + dtype, shape = dtype_parse(onnx_tensor.data_type, "buffer parse"), tuple(onnx_tensor.dims) data = None if len(onnx_tensor.float_data): data = onnx_tensor.float_data elif len(onnx_tensor.int32_data): data = onnx_tensor.int32_data @@ -76,7 +82,7 @@ def type_parse(onnx_type: TypeProto): if has_field(elem_type, "tensor_type"): shape = tuple(getattr(d, "dim_param", None) or getattr(d, "dim_value") for d in elem_type.tensor_type.shape.dim) \ if has_field(elem_type.tensor_type, "shape") else None # test_identity_sequence_cpu - dtype = dtype_parse(elem_type.tensor_type.elem_type) + dtype = dtype_parse(elem_type.tensor_type.elem_type, "input type spec parse") return OnnxValue(shape, dtype, is_optional, is_sequence) raise RuntimeError(f"TypeProto was not parsed properly: {onnx_type=}") @@ -145,15 +151,15 @@ class OnnxRunner: if spec.is_optional and value is None: return None # TODO: need true float16 for dtype checking if spec.is_sequence: - if not isinstance(value, Sequence): raise RuntimeError(f"{name} received {value}, expected a sequence type") + if not isinstance(value, Sequence): raise RuntimeError(f"input {name} received {value}, expected a sequence type") sequence = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value] - if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for {name} sequence must be homogeneous") + if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for input {name} sequence must be homogeneous") return sequence tensor = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(value, Tensor) else value for dim, (onnx_dim, user_dim_input) in enumerate(zip(spec.shape, tensor.shape, strict=True)): if isinstance(onnx_dim, str): onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input)) - if user_dim_input != onnx_dim: raise RuntimeError(f"{name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.") + if user_dim_input != onnx_dim: raise RuntimeError(f"input {name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.") return tensor def _dispatch_op(self, op, inps, opts): @@ -284,7 +290,7 @@ def get_onnx_ops(): raise ValueError(f"pixel_format={pixel_format!r} is not supported.") def EyeLike(x:Tensor, dtype:int|None=None, k:int=0): - ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype) + ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype, "EyeLike op") if dtype is not None else x.dtype) return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape)) def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0) @@ -338,7 +344,7 @@ def get_onnx_ops(): # ***** Casting Ops ***** # TODO: saturate - def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to)) + def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to, "Cast op")) def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype) # ***** Reduce Ops ***** @@ -731,7 +737,9 @@ def get_onnx_ops(): # ***** Quantization Ops ***** def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1): - out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8 + if isinstance(y_zero_point, Tensor): out_dtype = y_zero_point.dtype + elif output_dtype != 0: out_dtype = dtype_parse(output_dtype, "QuantizeLinear op") + else: out_dtype = dtypes.uint8 y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size) if out_dtype == dtypes.uchar: # this appears to work in practice, at least for uchar out_dtype. it folds with the quantize stuff