mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
minor onnx cleanups (#11642)
This commit is contained in:
@@ -74,7 +74,7 @@ class OnnxNode:
|
||||
# ***** protobuf parsing ******
|
||||
class PBBufferedReader(BufferedReader):
|
||||
def __init__(self, tensor: Tensor):
|
||||
assert tensor.dtype is dtypes.uint8, tensor
|
||||
assert tensor.dtype == dtypes.uint8, tensor
|
||||
super().__init__(TensorIO(tensor))
|
||||
self.len = tensor.nbytes()
|
||||
|
||||
@@ -109,9 +109,8 @@ class PBBufferedReader(BufferedReader):
|
||||
total_bytes_len = self.decode_varint()
|
||||
old_pos = self.tell()
|
||||
values = []
|
||||
while self.tell() < total_bytes_len + old_pos:
|
||||
val = self.decode_varint() # need copy here because packed ints are varint
|
||||
values.append(val - 2**64 if val & (1 << 63) else val)
|
||||
# need copy here because packed ints are varint
|
||||
while self.tell() < total_bytes_len + old_pos: values.append(self.read_int64())
|
||||
return values
|
||||
|
||||
def skip_field(self, wire_type: WireType) -> None:
|
||||
@@ -223,8 +222,8 @@ class OnnxPBParser:
|
||||
location, length, offset = None, None, 0
|
||||
for kv in obj["external_data"]:
|
||||
if kv["key"] == "location": location = kv["value"]
|
||||
if kv["key"] == "offset": offset = int(kv["value"])
|
||||
if kv["key"] == "length": length = int(kv["value"])
|
||||
elif kv["key"] == "offset": offset = int(kv["value"])
|
||||
elif kv["key"] == "length": length = int(kv["value"])
|
||||
if location is None: raise ValueError("no location in external_data")
|
||||
|
||||
if self.file_path is None:
|
||||
@@ -247,12 +246,12 @@ class OnnxPBParser:
|
||||
if not isinstance(data, Tensor):
|
||||
obj["parsed_tensor"] = Tensor(data, dtype=to_dtype).reshape(shape)
|
||||
return obj
|
||||
assert isinstance(data, Tensor) and data.dtype is dtypes.uint8, data
|
||||
assert isinstance(data, Tensor) and data.dtype == dtypes.uint8, data
|
||||
data = data.bitcast(true_dtype).reshape(shape)
|
||||
data = data.to(Device.DEFAULT) if true_dtype is to_dtype else data.to("cpu").cast(to_dtype).to(Device.DEFAULT)
|
||||
# const folding
|
||||
if shape == ():
|
||||
if data.dtype is dtypes.float16 and sys.version_info < (3, 12): data = data.cast(dtypes.float32)
|
||||
if data.dtype == dtypes.float16 and sys.version_info < (3, 12): data = data.cast(dtypes.float32)
|
||||
data = Tensor(data.item(), dtype=to_dtype).reshape(shape)
|
||||
obj["parsed_tensor"] = data
|
||||
return obj
|
||||
@@ -375,7 +374,7 @@ required_input_python_consts: dict[str, tuple[int, ...]] = {
|
||||
cache_misses = 0
|
||||
@functools.cache
|
||||
def _cached_to_python_const(t:Tensor):
|
||||
if t.dtype is dtypes.uint8: return t.data().tobytes()
|
||||
if t.dtype == dtypes.uint8: return t.data().tobytes()
|
||||
if 0 in t.shape: return []
|
||||
return t.tolist()
|
||||
|
||||
@@ -551,7 +550,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
|
||||
if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
|
||||
if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
|
||||
if value_string is not None or value_strings is not None and sparse_value is not None:
|
||||
if value_string is not None or value_strings is not None or sparse_value is not None:
|
||||
raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value')
|
||||
|
||||
def Range(start:float|int|list[float|int], limit:float|int|list[float|int], delta:float|int|list[float|int]):
|
||||
@@ -615,9 +614,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
def BitwiseOr(x:Tensor,y:Tensor): return x | y
|
||||
def BitwiseXor(x:Tensor,y:Tensor): return x ^ y
|
||||
def BitwiseNot(x:Tensor): return ~x
|
||||
def Mod(x:Tensor,y:Tensor,fmod=0):
|
||||
if fmod: return x - x.div(y, rounding_mode="trunc") * y
|
||||
return x % y
|
||||
def Mod(x:Tensor,y:Tensor,fmod=0): return x - x.div(y, rounding_mode="trunc") * y if fmod else x % y
|
||||
|
||||
# ***** Casting Ops *****
|
||||
# TODO: saturate
|
||||
@@ -1139,7 +1136,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
if axis is None:
|
||||
inp = inp.flatten()
|
||||
axis = 0
|
||||
if axis < 0: axis += inp.ndim
|
||||
axis = inp._resolve_dim(axis)
|
||||
con = Tensor([i for i,cond in enumerate(condition) if cond]) # compress in python
|
||||
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user