mirror of
https://github.com/firestar5683/StarPilot.git
synced 2026-06-27 17:42:04 +08:00
load model before calling convert_fp16_to_fp32
This commit is contained in:
@@ -10,8 +10,7 @@ def attributeproto_fp16_to_fp32(attr):
|
||||
attr.data_type = 1
|
||||
attr.raw_data = float32_list.astype(np.float32).tobytes()
|
||||
|
||||
def convert_fp16_to_fp32(onnx_path):
|
||||
model = onnx.load(onnx_path)
|
||||
def convert_fp16_to_fp32(model):
|
||||
for i in model.graph.initializer:
|
||||
if i.data_type == 10:
|
||||
attributeproto_fp16_to_fp32(i)
|
||||
@@ -33,6 +32,6 @@ def make_onnx_cpu_runner(model_path):
|
||||
options.intra_op_num_threads = 4
|
||||
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
model_data = convert_fp16_to_fp32(model_path)
|
||||
model_data = convert_fp16_to_fp32(onnx.load(model_path))
|
||||
return ort.InferenceSession(model_data, options, providers=['CPUExecutionProvider'])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user