onnx: deduplicate simple proto parsers (#16085)

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
June
2026-05-07 18:44:27 -07:00
committed by GitHub
parent c106c73e51
commit 83eaefcd0f
2 changed files with 22 additions and 62 deletions

View File

@@ -144,7 +144,7 @@ class MetadataOnnxPBParser(OnnxPBParser):
for fid, wire_type in self._parse_message(self.reader.len):
match fid:
case 7: obj["graph"] = self._parse_GraphProto()
case 14: obj["metadata_props"].append(self._parse_StringStringEntryProto())
case 14: obj["metadata_props"].append(self._parse_proto(self._SIMPLE_PROTOS["StringStringEntryProto"]))
case _: self.reader.skip_field(wire_type)
return obj

View File

@@ -160,7 +160,7 @@ class OnnxPBParser:
case 4: obj["domain"] = self.reader.read_string()
case 5: obj["model_version"] = self.reader.read_int64()
case 7: obj["graph"] = self._parse_GraphProto()
case 8: obj["opset_import"].append(self._parse_OperatorSetIdProto())
case 8: obj["opset_import"].append(self._parse_proto(self._SIMPLE_PROTOS["OperatorSetIdProto"]))
case _: self.reader.skip_field(wire_type)
# update opset version
@@ -214,7 +214,7 @@ class OnnxPBParser:
case 9: obj["raw_data"] = self.reader.read_bytes()
case 10: obj["double_data"] = self.reader.read_packed_floats()
case 11: obj["uint64_data"] = self.reader.read_packed_int64s()
case 13: obj.setdefault("external_data", []).append(self._parse_StringStringEntryProto())
case 13: obj.setdefault("external_data", []).append(self._parse_proto(self._SIMPLE_PROTOS["StringStringEntryProto"]))
case 14: obj["data_location"] = self.reader.read_int64()
case _: self.reader.skip_field(wire_type)
@@ -281,7 +281,7 @@ class OnnxPBParser:
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: obj["name"] = self.reader.read_string()
case 2: obj["type"] = self._parse_TypeProto()
case 2: obj["type"] = self._parse_proto(self._SIMPLE_PROTOS["TypeProto"])
case _: self.reader.skip_field(wire_type)
# parse type
@@ -295,66 +295,26 @@ class OnnxPBParser:
OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence)
return obj
def _parse_TypeProto(self) -> dict:
_SIMPLE_PROTOS: dict[str, dict[int, tuple[str, str]]] = {
"TypeProto": {1: ("tensor_type", "TypeProtoTensor"), 4: ("sequence_type", "TypeProtoWrapper"),
9: ("optional_type", "TypeProtoWrapper")},
"TypeProtoTensor": {1: ("elem_type", "read_int64"), 2: ("shape", "TensorShapeProto")},
"TypeProtoWrapper": {1: ("elem_type", "TypeProto")},
"TensorShapeProto": {1: ("+dim", "TensorShapeProtoDimension")},
"TensorShapeProtoDimension": {1: ("dim_value", "read_int64"), 2: ("dim_param", "read_string")},
"StringStringEntryProto": {1: ("key", "read_string"), 2: ("value", "read_string")},
"OperatorSetIdProto": {1: ("domain", "read_string"), 2: ("version", "read_int64")},
}
def _parse_proto(self, fields: dict[int, tuple[str, str]]) -> dict:
obj: dict[str, Any] = {}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: obj["tensor_type"] = self._parse_TypeProtoTensor()
case 4: obj["sequence_type"] = self._parse_TypeProtoWrapper()
case 9: obj["optional_type"] = self._parse_TypeProtoWrapper()
case _: self.reader.skip_field(wire_type)
return obj
def _parse_TypeProtoTensor(self) -> dict:
obj: dict[str, Any] = {}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: obj["elem_type"] = self.reader.read_int64()
case 2: obj["shape"] = self._parse_TensorShapeProto()
case _: self.reader.skip_field(wire_type)
return obj
def _parse_TypeProtoWrapper(self) -> dict:
obj = {}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: obj["elem_type"] = self._parse_TypeProto()
case _: self.reader.skip_field(wire_type)
return obj
def _parse_TensorShapeProto(self) -> dict:
obj: dict[str, Any] = {"dim": []}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: obj["dim"].append(self._parse_TensorShapeProtoDimension())
case _: self.reader.skip_field(wire_type)
return obj
def _parse_TensorShapeProtoDimension(self) -> dict:
obj: dict[str, Any] = {}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: obj["dim_value"] = self.reader.read_int64()
case 2: obj["dim_param"] = self.reader.read_string()
case _: self.reader.skip_field(wire_type)
return obj
def _parse_StringStringEntryProto(self) -> dict:
obj: dict[str, Any] = {}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: obj["key"] = self.reader.read_string()
case 2: obj["value"] = self.reader.read_string()
case _: self.reader.skip_field(wire_type)
return obj
def _parse_OperatorSetIdProto(self) -> dict:
obj: dict[str, Any] = {}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: obj["domain"] = self.reader.read_string()
case 2: obj["version"] = self.reader.read_int64()
case _: self.reader.skip_field(wire_type)
if fid not in fields:
self.reader.skip_field(wire_type)
continue
name, action = fields[fid]
value = self._parse_proto(self._SIMPLE_PROTOS[action]) if action in self._SIMPLE_PROTOS else getattr(self.reader, action)()
if name[0] == "+": obj.setdefault(name[1:], []).append(value)
else: obj[name] = value
return obj
# ***** python const *****