可根据具体需要修改
import glob
import os
import shutil
import time
import cv2
import onnxruntime
import numpy as np
import tqdm
# Load ONNX model
onnx_file_name = "./save_weights/best_model.onnx"
session = onnxruntime.InferenceSession(onnx_file_name)
start_time = time.time()
paths = glob.glob("./UNet/Dataset/test/*.*")
result_path = './onnx'
if os.path.exists(result_path):
shutil.rmtree(result_path)
os.mkdir(result_path)
image_size = 224
# Load golden images
for path in tqdm.tqdm(paths):
if 'image' not in path and 'implant' in path:
continue
# Compute ONNX model output
origin_image = cv2.imread(path)
original_height, original_width, i = origin_image.shape
image = cv2.cvtColor(origin_image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (image_size, image_size), interpolation=cv2.INTER_CUBIC)
image = np.array(image, np.float32)
image = image / 127.5 - 1
image = np.expand_dims(image, 0)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
result = session.run([output_name], {input_name: image})
# Compare ONNX model output with golden image
output_image = result[0].argmax(3).squeeze(0)
output_image[output_image == 1] = 129
output_image[output_image == 2] = 192
output_image[output_image == 3] = 255
prediction = output_image.astype(np.uint8)
predict_result = cv2.resize(prediction, (original_width, original_height), interpolation=cv2.INTER_NEAREST)
cv2.imwrite(os.path.join(result_path, os.path.splitext(os.path.basename(path))[0] + "_origin.png"), origin_image)
cv2.imwrite(os.path.join(result_path, os.path.splitext(os.path.basename(path))[0] + "_predict.png"), predict_result)
total_time = time.time() - start_time
print("time {}s, fps {}".format(total_time, 100 / total_time))