原始的YoloV8封装的层次太高,想要为我们所用可能需要阅读很多API,下面给出比较简单的使用方式
os
:用于操作文件系统。cv2
(OpenCV):用于图像处理。numpy
:提供数学运算,特别是对数组的操作。ultralytics.YOLO
:一个现成的YOLO模型实现,用于对象检测。torch
:PyTorch深度学习框架,用于处理模型。serial
:用于串口通信。time
:用于时间相关的操作。init_serial
init_serial
函数用于初始化串口通信。try-except
结构来捕获异常。load_model
load_model
函数用于加载YOLO模型。YOLO
类从给定路径加载模型,并将其传输到指定的计算设备(CPU或GPU)。process_images
process_images
函数负责处理指定路径下的图像。.jpg
图像。q
退出循环。main
process_images
函数处理图像。if __name__ == "__main__":
确保在直接运行脚本时执行主函数。import os
import cv2
import numpy as np
from ultralytics import YOLO
import torch
import serial
import time
def init_serial(port, baudrate, bytesize, parity, stopbits):
try:
ser = serial.Serial(port, baudrate, bytesize, parity, stopbits)
time.sleep(1) # 等待串口初始化
return ser
except serial.SerialException as e:
print(f"Error opening serial port: {e}")
exit()
def load_model(weights_path, device):
if not os.path.exists(weights_path):
print("Model weights not found!")
exit()
model = YOLO(weights_path).to(device)
model.fuse()
model.info(verbose=False)
return model
def process_images(path, model, serial_port):
if not os.path.exists(path):
print(f"Path {path} does not exist!")
exit()
for img_file in os.listdir(path):
if not img_file.endswith(".jpg"):
continue
img_path = os.path.join(path, img_file)
img = cv2.imread(img_path)
if img is None:
print(f"Failed to load image {img_path}")
continue
mask = img.copy()
result = model(img)
cls, xywh = result[0].boxes.cls, result[0].boxes.xywh
cls_, xywh_ = cls.detach().cpu().numpy(), xywh.detach().cpu().numpy()
for pos, cls_value in zip(xywh_, cls_):
pt1, pt2 = (np.int_([pos[0] - pos[2] / 2, pos[1] - pos[3] / 2]),
np.int_([pos[0] + pos[2] / 2, pos[1] + pos[3] / 2]))
color = [0, 0, 255] if cls_value == 0 else [0, 255, 0]
cv2.rectangle(mask, tuple(pt1), tuple(pt2), color, 2)
res_ = "Yes" if np.any(cls_ == 1) else "No"
print(res_)
serial_port.write((res_ + "\r\n").encode())
cv2.imshow("result", mask)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cv2.destroyAllWindows()
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
serial_port = init_serial("/dev/ttyTHS1", 115200, serial.EIGHTBITS, serial.PARITY_NONE, serial.STOPBITS_ONE)
model = load_model("./weights/best.pt", device)
process_images("./datasets/pre/", model, serial_port)
if __name__ == "__main__":
main()