From 039d84ff021bd5ee7ef8902bb60ea09c7b54f513 Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Mon, 11 May 2026 18:45:17 -0700 Subject: [PATCH] Revert "onnx: deduplicate simple proto parsers" (#16148) This reverts commit 83eaefcd0f6568237c8b1d66e2a4363bbdea48a1. --- test/external/external_test_onnx_runner.py | 2 +- tinygrad/nn/onnx.py | 82 ++++++++++++++++------ 2 files changed, 62 insertions(+), 22 deletions(-) diff --git a/test/external/external_test_onnx_runner.py b/test/external/external_test_onnx_runner.py index 77959ef45d..d91a2aec59 100644 --- a/test/external/external_test_onnx_runner.py +++ b/test/external/external_test_onnx_runner.py @@ -144,7 +144,7 @@ class MetadataOnnxPBParser(OnnxPBParser): for fid, wire_type in self._parse_message(self.reader.len): match fid: case 7: obj["graph"] = self._parse_GraphProto() - case 14: obj["metadata_props"].append(self._parse_proto(self._SIMPLE_PROTOS["StringStringEntryProto"])) + case 14: obj["metadata_props"].append(self._parse_StringStringEntryProto()) case _: self.reader.skip_field(wire_type) return obj diff --git a/tinygrad/nn/onnx.py b/tinygrad/nn/onnx.py index fdffb8db34..058d251bb2 100644 --- a/tinygrad/nn/onnx.py +++ b/tinygrad/nn/onnx.py @@ -160,7 +160,7 @@ class OnnxPBParser: case 4: obj["domain"] = self.reader.read_string() case 5: obj["model_version"] = self.reader.read_int64() case 7: obj["graph"] = self._parse_GraphProto() - case 8: obj["opset_import"].append(self._parse_proto(self._SIMPLE_PROTOS["OperatorSetIdProto"])) + case 8: obj["opset_import"].append(self._parse_OperatorSetIdProto()) case _: self.reader.skip_field(wire_type) # update opset version @@ -214,7 +214,7 @@ class OnnxPBParser: case 9: obj["raw_data"] = self.reader.read_bytes() case 10: obj["double_data"] = self.reader.read_packed_floats() case 11: obj["uint64_data"] = self.reader.read_packed_int64s() - case 13: obj.setdefault("external_data", []).append(self._parse_proto(self._SIMPLE_PROTOS["StringStringEntryProto"])) + case 13: obj.setdefault("external_data", []).append(self._parse_StringStringEntryProto()) case 14: obj["data_location"] = self.reader.read_int64() case _: self.reader.skip_field(wire_type) @@ -281,7 +281,7 @@ class OnnxPBParser: for fid, wire_type in self._parse_message(self._decode_end_pos()): match fid: case 1: obj["name"] = self.reader.read_string() - case 2: obj["type"] = self._parse_proto(self._SIMPLE_PROTOS["TypeProto"]) + case 2: obj["type"] = self._parse_TypeProto() case _: self.reader.skip_field(wire_type) # parse type @@ -295,26 +295,66 @@ class OnnxPBParser: OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence) return obj - _SIMPLE_PROTOS: dict[str, dict[int, tuple[str, str]]] = { - "TypeProto": {1: ("tensor_type", "TypeProtoTensor"), 4: ("sequence_type", "TypeProtoWrapper"), - 9: ("optional_type", "TypeProtoWrapper")}, - "TypeProtoTensor": {1: ("elem_type", "read_int64"), 2: ("shape", "TensorShapeProto")}, - "TypeProtoWrapper": {1: ("elem_type", "TypeProto")}, - "TensorShapeProto": {1: ("+dim", "TensorShapeProtoDimension")}, - "TensorShapeProtoDimension": {1: ("dim_value", "read_int64"), 2: ("dim_param", "read_string")}, - "StringStringEntryProto": {1: ("key", "read_string"), 2: ("value", "read_string")}, - "OperatorSetIdProto": {1: ("domain", "read_string"), 2: ("version", "read_int64")}, - } - def _parse_proto(self, fields: dict[int, tuple[str, str]]) -> dict: + def _parse_TypeProto(self) -> dict: obj: dict[str, Any] = {} for fid, wire_type in self._parse_message(self._decode_end_pos()): - if fid not in fields: - self.reader.skip_field(wire_type) - continue - name, action = fields[fid] - value = self._parse_proto(self._SIMPLE_PROTOS[action]) if action in self._SIMPLE_PROTOS else getattr(self.reader, action)() - if name[0] == "+": obj.setdefault(name[1:], []).append(value) - else: obj[name] = value + match fid: + case 1: obj["tensor_type"] = self._parse_TypeProtoTensor() + case 4: obj["sequence_type"] = self._parse_TypeProtoWrapper() + case 9: obj["optional_type"] = self._parse_TypeProtoWrapper() + case _: self.reader.skip_field(wire_type) + return obj + + def _parse_TypeProtoTensor(self) -> dict: + obj: dict[str, Any] = {} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["elem_type"] = self.reader.read_int64() + case 2: obj["shape"] = self._parse_TensorShapeProto() + case _: self.reader.skip_field(wire_type) + return obj + + def _parse_TypeProtoWrapper(self) -> dict: + obj = {} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["elem_type"] = self._parse_TypeProto() + case _: self.reader.skip_field(wire_type) + return obj + + def _parse_TensorShapeProto(self) -> dict: + obj: dict[str, Any] = {"dim": []} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["dim"].append(self._parse_TensorShapeProtoDimension()) + case _: self.reader.skip_field(wire_type) + return obj + + def _parse_TensorShapeProtoDimension(self) -> dict: + obj: dict[str, Any] = {} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["dim_value"] = self.reader.read_int64() + case 2: obj["dim_param"] = self.reader.read_string() + case _: self.reader.skip_field(wire_type) + return obj + + def _parse_StringStringEntryProto(self) -> dict: + obj: dict[str, Any] = {} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["key"] = self.reader.read_string() + case 2: obj["value"] = self.reader.read_string() + case _: self.reader.skip_field(wire_type) + return obj + + def _parse_OperatorSetIdProto(self) -> dict: + obj: dict[str, Any] = {} + for fid, wire_type in self._parse_message(self._decode_end_pos()): + match fid: + case 1: obj["domain"] = self.reader.read_string() + case 2: obj["version"] = self.reader.read_int64() + case _: self.reader.skip_field(wire_type) return obj # ***** python const *****