mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-10 15:06:18 +08:00
onnx: deduplicate simple proto parsers (#16085)
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
2
test/external/external_test_onnx_runner.py
vendored
2
test/external/external_test_onnx_runner.py
vendored
@@ -144,7 +144,7 @@ class MetadataOnnxPBParser(OnnxPBParser):
|
||||
for fid, wire_type in self._parse_message(self.reader.len):
|
||||
match fid:
|
||||
case 7: obj["graph"] = self._parse_GraphProto()
|
||||
case 14: obj["metadata_props"].append(self._parse_StringStringEntryProto())
|
||||
case 14: obj["metadata_props"].append(self._parse_proto(self._SIMPLE_PROTOS["StringStringEntryProto"]))
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
return obj
|
||||
|
||||
|
||||
@@ -160,7 +160,7 @@ class OnnxPBParser:
|
||||
case 4: obj["domain"] = self.reader.read_string()
|
||||
case 5: obj["model_version"] = self.reader.read_int64()
|
||||
case 7: obj["graph"] = self._parse_GraphProto()
|
||||
case 8: obj["opset_import"].append(self._parse_OperatorSetIdProto())
|
||||
case 8: obj["opset_import"].append(self._parse_proto(self._SIMPLE_PROTOS["OperatorSetIdProto"]))
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
|
||||
# update opset version
|
||||
@@ -214,7 +214,7 @@ class OnnxPBParser:
|
||||
case 9: obj["raw_data"] = self.reader.read_bytes()
|
||||
case 10: obj["double_data"] = self.reader.read_packed_floats()
|
||||
case 11: obj["uint64_data"] = self.reader.read_packed_int64s()
|
||||
case 13: obj.setdefault("external_data", []).append(self._parse_StringStringEntryProto())
|
||||
case 13: obj.setdefault("external_data", []).append(self._parse_proto(self._SIMPLE_PROTOS["StringStringEntryProto"]))
|
||||
case 14: obj["data_location"] = self.reader.read_int64()
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
|
||||
@@ -281,7 +281,7 @@ class OnnxPBParser:
|
||||
for fid, wire_type in self._parse_message(self._decode_end_pos()):
|
||||
match fid:
|
||||
case 1: obj["name"] = self.reader.read_string()
|
||||
case 2: obj["type"] = self._parse_TypeProto()
|
||||
case 2: obj["type"] = self._parse_proto(self._SIMPLE_PROTOS["TypeProto"])
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
|
||||
# parse type
|
||||
@@ -295,66 +295,26 @@ class OnnxPBParser:
|
||||
OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence)
|
||||
return obj
|
||||
|
||||
def _parse_TypeProto(self) -> dict:
|
||||
_SIMPLE_PROTOS: dict[str, dict[int, tuple[str, str]]] = {
|
||||
"TypeProto": {1: ("tensor_type", "TypeProtoTensor"), 4: ("sequence_type", "TypeProtoWrapper"),
|
||||
9: ("optional_type", "TypeProtoWrapper")},
|
||||
"TypeProtoTensor": {1: ("elem_type", "read_int64"), 2: ("shape", "TensorShapeProto")},
|
||||
"TypeProtoWrapper": {1: ("elem_type", "TypeProto")},
|
||||
"TensorShapeProto": {1: ("+dim", "TensorShapeProtoDimension")},
|
||||
"TensorShapeProtoDimension": {1: ("dim_value", "read_int64"), 2: ("dim_param", "read_string")},
|
||||
"StringStringEntryProto": {1: ("key", "read_string"), 2: ("value", "read_string")},
|
||||
"OperatorSetIdProto": {1: ("domain", "read_string"), 2: ("version", "read_int64")},
|
||||
}
|
||||
def _parse_proto(self, fields: dict[int, tuple[str, str]]) -> dict:
|
||||
obj: dict[str, Any] = {}
|
||||
for fid, wire_type in self._parse_message(self._decode_end_pos()):
|
||||
match fid:
|
||||
case 1: obj["tensor_type"] = self._parse_TypeProtoTensor()
|
||||
case 4: obj["sequence_type"] = self._parse_TypeProtoWrapper()
|
||||
case 9: obj["optional_type"] = self._parse_TypeProtoWrapper()
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
return obj
|
||||
|
||||
def _parse_TypeProtoTensor(self) -> dict:
|
||||
obj: dict[str, Any] = {}
|
||||
for fid, wire_type in self._parse_message(self._decode_end_pos()):
|
||||
match fid:
|
||||
case 1: obj["elem_type"] = self.reader.read_int64()
|
||||
case 2: obj["shape"] = self._parse_TensorShapeProto()
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
return obj
|
||||
|
||||
def _parse_TypeProtoWrapper(self) -> dict:
|
||||
obj = {}
|
||||
for fid, wire_type in self._parse_message(self._decode_end_pos()):
|
||||
match fid:
|
||||
case 1: obj["elem_type"] = self._parse_TypeProto()
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
return obj
|
||||
|
||||
def _parse_TensorShapeProto(self) -> dict:
|
||||
obj: dict[str, Any] = {"dim": []}
|
||||
for fid, wire_type in self._parse_message(self._decode_end_pos()):
|
||||
match fid:
|
||||
case 1: obj["dim"].append(self._parse_TensorShapeProtoDimension())
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
return obj
|
||||
|
||||
def _parse_TensorShapeProtoDimension(self) -> dict:
|
||||
obj: dict[str, Any] = {}
|
||||
for fid, wire_type in self._parse_message(self._decode_end_pos()):
|
||||
match fid:
|
||||
case 1: obj["dim_value"] = self.reader.read_int64()
|
||||
case 2: obj["dim_param"] = self.reader.read_string()
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
return obj
|
||||
|
||||
def _parse_StringStringEntryProto(self) -> dict:
|
||||
obj: dict[str, Any] = {}
|
||||
for fid, wire_type in self._parse_message(self._decode_end_pos()):
|
||||
match fid:
|
||||
case 1: obj["key"] = self.reader.read_string()
|
||||
case 2: obj["value"] = self.reader.read_string()
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
return obj
|
||||
|
||||
def _parse_OperatorSetIdProto(self) -> dict:
|
||||
obj: dict[str, Any] = {}
|
||||
for fid, wire_type in self._parse_message(self._decode_end_pos()):
|
||||
match fid:
|
||||
case 1: obj["domain"] = self.reader.read_string()
|
||||
case 2: obj["version"] = self.reader.read_int64()
|
||||
case _: self.reader.skip_field(wire_type)
|
||||
if fid not in fields:
|
||||
self.reader.skip_field(wire_type)
|
||||
continue
|
||||
name, action = fields[fid]
|
||||
value = self._parse_proto(self._SIMPLE_PROTOS[action]) if action in self._SIMPLE_PROTOS else getattr(self.reader, action)()
|
||||
if name[0] == "+": obj.setdefault(name[1:], []).append(value)
|
||||
else: obj[name] = value
|
||||
return obj
|
||||
|
||||
# ***** python const *****
|
||||
|
||||
Reference in New Issue
Block a user