ONNX分割模型推理

发布时间:2024年01月18日

可根据具体需要修改

import glob
import os
import shutil
import time
import cv2

import onnxruntime
import numpy as np
import tqdm

# Load ONNX model
onnx_file_name = "./save_weights/best_model.onnx"
session = onnxruntime.InferenceSession(onnx_file_name)
start_time = time.time()
paths = glob.glob("./UNet/Dataset/test/*.*")

result_path = './onnx'

if os.path.exists(result_path):
    shutil.rmtree(result_path)
os.mkdir(result_path)

image_size = 224
# Load golden images
for path in tqdm.tqdm(paths):
    if 'image' not in path and 'implant' in path:
        continue
    # Compute ONNX model output
    origin_image = cv2.imread(path)
    original_height, original_width, i = origin_image.shape
    image = cv2.cvtColor(origin_image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (image_size, image_size), interpolation=cv2.INTER_CUBIC)
    image = np.array(image, np.float32)
    image = image / 127.5 - 1

    image = np.expand_dims(image, 0)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    result = session.run([output_name], {input_name: image})

    # Compare ONNX model output with golden image
    output_image = result[0].argmax(3).squeeze(0)
    output_image[output_image == 1] = 129
    output_image[output_image == 2] = 192
    output_image[output_image == 3] = 255
    prediction = output_image.astype(np.uint8)

    predict_result = cv2.resize(prediction, (original_width, original_height), interpolation=cv2.INTER_NEAREST)
    cv2.imwrite(os.path.join(result_path, os.path.splitext(os.path.basename(path))[0] + "_origin.png"), origin_image)
    cv2.imwrite(os.path.join(result_path, os.path.splitext(os.path.basename(path))[0] + "_predict.png"), predict_result)

total_time = time.time() - start_time
print("time {}s, fps {}".format(total_time, 100 / total_time))

文章来源:https://blog.csdn.net/qq_45152711/article/details/135668836
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。