From 3fb79bb43aa86c098777ea7ffedacfbafeec9a42 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 12 Aug 2025 22:05:19 -0700 Subject: [PATCH] minor onnx cleanups (#11642) --- extra/onnx.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/extra/onnx.py b/extra/onnx.py index f7e56d4579..da623743b0 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -74,7 +74,7 @@ class OnnxNode: # ***** protobuf parsing ****** class PBBufferedReader(BufferedReader): def __init__(self, tensor: Tensor): - assert tensor.dtype is dtypes.uint8, tensor + assert tensor.dtype == dtypes.uint8, tensor super().__init__(TensorIO(tensor)) self.len = tensor.nbytes() @@ -109,9 +109,8 @@ class PBBufferedReader(BufferedReader): total_bytes_len = self.decode_varint() old_pos = self.tell() values = [] - while self.tell() < total_bytes_len + old_pos: - val = self.decode_varint() # need copy here because packed ints are varint - values.append(val - 2**64 if val & (1 << 63) else val) + # need copy here because packed ints are varint + while self.tell() < total_bytes_len + old_pos: values.append(self.read_int64()) return values def skip_field(self, wire_type: WireType) -> None: @@ -223,8 +222,8 @@ class OnnxPBParser: location, length, offset = None, None, 0 for kv in obj["external_data"]: if kv["key"] == "location": location = kv["value"] - if kv["key"] == "offset": offset = int(kv["value"]) - if kv["key"] == "length": length = int(kv["value"]) + elif kv["key"] == "offset": offset = int(kv["value"]) + elif kv["key"] == "length": length = int(kv["value"]) if location is None: raise ValueError("no location in external_data") if self.file_path is None: @@ -247,12 +246,12 @@ class OnnxPBParser: if not isinstance(data, Tensor): obj["parsed_tensor"] = Tensor(data, dtype=to_dtype).reshape(shape) return obj - assert isinstance(data, Tensor) and data.dtype is dtypes.uint8, data + assert isinstance(data, Tensor) and data.dtype == dtypes.uint8, data data = data.bitcast(true_dtype).reshape(shape) data = data.to(Device.DEFAULT) if true_dtype is to_dtype else data.to("cpu").cast(to_dtype).to(Device.DEFAULT) # const folding if shape == (): - if data.dtype is dtypes.float16 and sys.version_info < (3, 12): data = data.cast(dtypes.float32) + if data.dtype == dtypes.float16 and sys.version_info < (3, 12): data = data.cast(dtypes.float32) data = Tensor(data.item(), dtype=to_dtype).reshape(shape) obj["parsed_tensor"] = data return obj @@ -375,7 +374,7 @@ required_input_python_consts: dict[str, tuple[int, ...]] = { cache_misses = 0 @functools.cache def _cached_to_python_const(t:Tensor): - if t.dtype is dtypes.uint8: return t.data().tobytes() + if t.dtype == dtypes.uint8: return t.data().tobytes() if 0 in t.shape: return [] return t.tolist() @@ -551,7 +550,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False) if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False) if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False) - if value_string is not None or value_strings is not None and sparse_value is not None: + if value_string is not None or value_strings is not None or sparse_value is not None: raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value') def Range(start:float|int|list[float|int], limit:float|int|list[float|int], delta:float|int|list[float|int]): @@ -615,9 +614,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT def BitwiseOr(x:Tensor,y:Tensor): return x | y def BitwiseXor(x:Tensor,y:Tensor): return x ^ y def BitwiseNot(x:Tensor): return ~x - def Mod(x:Tensor,y:Tensor,fmod=0): - if fmod: return x - x.div(y, rounding_mode="trunc") * y - return x % y + def Mod(x:Tensor,y:Tensor,fmod=0): return x - x.div(y, rounding_mode="trunc") * y if fmod else x % y # ***** Casting Ops ***** # TODO: saturate @@ -1139,7 +1136,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT if axis is None: inp = inp.flatten() axis = 0 - if axis < 0: axis += inp.ndim + axis = inp._resolve_dim(axis) con = Tensor([i for i,cond in enumerate(condition) if cond]) # compress in python return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]