minor onnx cleanups (#11642)

This commit is contained in:
chenyu
2025-08-12 22:05:19 -07:00
committed by GitHub
parent e9e5a08a04
commit 3fb79bb43a

View File

@@ -74,7 +74,7 @@ class OnnxNode:
# ***** protobuf parsing ******
class PBBufferedReader(BufferedReader):
def __init__(self, tensor: Tensor):
assert tensor.dtype is dtypes.uint8, tensor
assert tensor.dtype == dtypes.uint8, tensor
super().__init__(TensorIO(tensor))
self.len = tensor.nbytes()
@@ -109,9 +109,8 @@ class PBBufferedReader(BufferedReader):
total_bytes_len = self.decode_varint()
old_pos = self.tell()
values = []
while self.tell() < total_bytes_len + old_pos:
val = self.decode_varint() # need copy here because packed ints are varint
values.append(val - 2**64 if val & (1 << 63) else val)
# need copy here because packed ints are varint
while self.tell() < total_bytes_len + old_pos: values.append(self.read_int64())
return values
def skip_field(self, wire_type: WireType) -> None:
@@ -223,8 +222,8 @@ class OnnxPBParser:
location, length, offset = None, None, 0
for kv in obj["external_data"]:
if kv["key"] == "location": location = kv["value"]
if kv["key"] == "offset": offset = int(kv["value"])
if kv["key"] == "length": length = int(kv["value"])
elif kv["key"] == "offset": offset = int(kv["value"])
elif kv["key"] == "length": length = int(kv["value"])
if location is None: raise ValueError("no location in external_data")
if self.file_path is None:
@@ -247,12 +246,12 @@ class OnnxPBParser:
if not isinstance(data, Tensor):
obj["parsed_tensor"] = Tensor(data, dtype=to_dtype).reshape(shape)
return obj
assert isinstance(data, Tensor) and data.dtype is dtypes.uint8, data
assert isinstance(data, Tensor) and data.dtype == dtypes.uint8, data
data = data.bitcast(true_dtype).reshape(shape)
data = data.to(Device.DEFAULT) if true_dtype is to_dtype else data.to("cpu").cast(to_dtype).to(Device.DEFAULT)
# const folding
if shape == ():
if data.dtype is dtypes.float16 and sys.version_info < (3, 12): data = data.cast(dtypes.float32)
if data.dtype == dtypes.float16 and sys.version_info < (3, 12): data = data.cast(dtypes.float32)
data = Tensor(data.item(), dtype=to_dtype).reshape(shape)
obj["parsed_tensor"] = data
return obj
@@ -375,7 +374,7 @@ required_input_python_consts: dict[str, tuple[int, ...]] = {
cache_misses = 0
@functools.cache
def _cached_to_python_const(t:Tensor):
if t.dtype is dtypes.uint8: return t.data().tobytes()
if t.dtype == dtypes.uint8: return t.data().tobytes()
if 0 in t.shape: return []
return t.tolist()
@@ -551,7 +550,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
if value_string is not None or value_strings is not None and sparse_value is not None:
if value_string is not None or value_strings is not None or sparse_value is not None:
raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value')
def Range(start:float|int|list[float|int], limit:float|int|list[float|int], delta:float|int|list[float|int]):
@@ -615,9 +614,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
def BitwiseOr(x:Tensor,y:Tensor): return x | y
def BitwiseXor(x:Tensor,y:Tensor): return x ^ y
def BitwiseNot(x:Tensor): return ~x
def Mod(x:Tensor,y:Tensor,fmod=0):
if fmod: return x - x.div(y, rounding_mode="trunc") * y
return x % y
def Mod(x:Tensor,y:Tensor,fmod=0): return x - x.div(y, rounding_mode="trunc") * y if fmod else x % y
# ***** Casting Ops *****
# TODO: saturate
@@ -1139,7 +1136,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
if axis is None:
inp = inp.flatten()
axis = 0
if axis < 0: axis += inp.ndim
axis = inp._resolve_dim(axis)
con = Tensor([i for i,cond in enumerate(condition) if cond]) # compress in python
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]