pytorch 语义分割前后处理代码详解

发布时间:2024年01月19日

1. 语义分割模型预测流程

  • (1) 读入图片
  • (2) Letterbox处理
  • (3) 归一化、HWC 转 CHW,并expand维度到NCHW
  • (4) 前向推理
  • (5) softmax 计算像素类别概率
  • (6) 截取灰条部分,并resize到原图尺寸
  • (7) 利用argmax,计算像素类别
  • (8) 分割效果可视化输出

2. 代码详解

2.1 输入图片预处理

2.1.1 读取图片,并转为RGB
img_path = "test.png"
image = Image.open(img_path)
image = cvtColor(image)  

其中cvtColor的函数实现如下:

#---------------------------------------------------------#
#   将图像转换成RGB图像,防止灰度图在预测时报错。
#   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
def cvtColor(image):
    if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
        # print("bgr------------------------------")
        return image 
    else:
        # print("rgb-------------------------------")
        image = image.convert('RGB')
        return image 

对原始备份,用于可视化绘图

old_img     = copy.deepcopy(image)
orininal_h  = np.array(image).shape[0]
orininal_w  = np.array(image).shape[1]
2.1.2 Letterbox 处理

letterbox的作用是: 给图像增加灰条,实现不失真的resize,经过letterbox后获得指定input_size大小的不失真图

 image_data, nw, nh  = resize_image(image, (self.input_shape[1],self.input_shape[0]))

self.input_shape: 模型输入指定的图片大小, 其中通过resize_image实现letterbox功能, resize_image代码如下:

def resize_image(image, size):
    iw, ih  = image.size    # image size
    w, h    = size          # input_size

    scale   = min(w/iw, h/ih)
    nw      = int(iw*scale)
    nh      = int(ih*scale)

    image   = image.resize((nw,nh), Image.BICUBIC)
    new_image = Image.new('RGB', size, (128,128,128))
    new_image.paste(image, ((w-nw)//2, (h-nh)//2))

    return new_image, nw, nh
  • (1) 以原图长边为准,将长边缩放到input_size大小;对应的短边, 进行等比缩放
scale   = min(w/iw, h/ih)
nw      = int(iw*scale)
nh      = int(ih*scale)
image   = image.resize((nw,nh), Image.BICUBIC)

nw,nh为原图经过缩放的大小

  • (2) 对缩放的图片居中对齐上下左右填充
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
  1. 首先生成一个input_size(size),大小的空图片,图片填充像素值(128,128,128)
  2. 将缩放后的image, 从new_image的左上角((w-nw)//2, (h-nh)//2))处粘贴到new_image, 这就相当于将缩放后的图片,前后左右padding, padding 的像素为(128,128,128),然后得到padding后的new_image
  3. 函数返回,填充后的new_image, 以及原图image经缩放后的尺寸: nw, nh
2.1.3 归一化、HWC 转 CHW,并expand维度到NCHW
 image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
  • (1) 对图像归一化
def preprocess_input(image):
    image /= 255.0
    return image
  • (2) HWC 转 CHW
np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1))
  • (3) 扩展batch维度
    因为模型输入的Tensor的大小,必须为[N,C,H,W], 因此利用np.expand_dims在第0维度扩充一个batch维度
 image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)

2.2 模型推理

with torch.no_grad():
      images = torch.from_numpy(image_data)
      if self.cuda:
          images = images.cuda()
          
      #---------------------------------------------------#
      #   图片传入网络进行预测
      #---------------------------------------------------#
      pr = self.net(images)[0]
  • 输入图片进行推理: self.net(images), 因为是单张图片预测,所以输出的大小为 (1,c,h,w)
  • 通过取索引[0],去掉batch维度,方便后处理,,因此pr的输出为c,h,w

2.3 输出结果后处理

(1) 利用softmax计算像素的类别概率

 pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
  • 首先将c,h,w的tensor, 利用torch.permute转换为(h,w,c)
  • 然后,对最后一维channel维度,进行softmax计算,获得各个像素每个类别的预测概率

(2) 将预测的输出resize到原图尺寸

  • 首先将输出的padding灰条部分截取掉

因为输入图片上下左右四周都进行了padding, 所以预测结果部分,四周的padding部分就需要截取掉, 上下左右padding的大小为: padding 大小 = int((self.input_shape[0] - nh) // 2) , 截取灰条代码如下:

pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
     int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
  • 将图片resize到原图尺寸
    将截取了灰条部分的预测输出pr, resize到原图大小
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)

(3) 计算每个像素的类别
利用argmax获得每个像素的类别预测概率最大的索引,也就是获得对应类别的索引

   pr = pr.argmax(axis=-1)

2.4 分割结果可视化

将输出的分割图与原始图进行混合输出

  • (1) 定义各个类别的rgb颜色值
if self.num_classes <= 21:
     self.colors = [ (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128),
                     (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),
                     (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128),
                     (128, 64, 12)]
     # self.colors = [ (0, 0, 0),(0, 0, 0), (0, 1, 0)]

 else:
     hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
     self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
     self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
  • (2) 为每个像素赋予颜色

通过以下方式(1)方式(2)都可以实现,方式(1)的写法非常简洁值得借鉴。

方式1:

seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
  • 经过argmax后,pr大小为(h,w), 每个位置的值为像素索引
  • 通过np.reshape(pr, [-1]), 将pr的shape由(h,w), reshape为一维: (h*w, ) , 每个值为类别索引
  • np.array(self.colors, np.uint8) 对应的值如下, shape大小为(22,3):
array([[  0,   0,   0],
       [128,   0,   0],
       [  0, 128,   0],
       [128, 128,   0],
       [  0,   0, 128],
       [128,   0, 128],
       [  0, 128, 128],
       [128, 128, 128],
       [ 64,   0,   0],
       [192,   0,   0],
       [ 64, 128,   0],
       [192, 128,   0],
       [ 64,   0, 128],
       [192,   0, 128],
       [ 64, 128, 128],
       [192, 128, 128],
       [  0,  64,   0],
       [128,  64,   0],
       [  0, 192,   0],
       [128, 192,   0],
       [  0,  64, 128],
       [128,  64,  12]], dtype=uint8)
  • 利用np.array(self.colors, np.uint8)[np.reshape(pr, [-1])] ,计算得到每个像素的颜色值array, 大小为(h*w,3)

    • 其中[np.reshape(pr, [-1])]为h*w个行索引值,根据行索引值,然后从np.array(self.colors, np.uint8),取对应的,总共有h*w个索引, 因此经过np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], 得到(h*w,3)大小的array,存储了每个像素的rgb颜色。
  • 然后通过reshape,将大小为(h*w,3), reshape为(h,w,3), 得到分割后的3通道的图片seg_img

方式2:

seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
for c in range(self.num_classes):
    seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
    seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
    seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')

遍历所有类别,为各个类别位置,赋予对应的像素值。

  • (3) 分割图与原图按比率融合
image   = Image.fromarray(np.uint8(seg_img))
image   = Image.blend(old_img, image, 0.7)

2.5 保留预测的像素索引图

pr = np.uint8(pr)
image = pr
image = Image.fromarray(pr)

2.6 仅扣去背景,仅保留原图中的目标

seg_img = (np.expand_dims(pr != 0, -1) * np.array(old_img, np.float32)).astype('uint8')
#------------------------------------------------#
#   将新图片转换成Image的形式
#------------------------------------------------#
image = Image.fromarray(np.uint8(seg_img))
  • 像素为0,对应的是背景,pr != 0即去掉了背景区域,非背景区域都为True,背景区域为False
  • 然后与原图old_img做乘积,就可以扣去背景,仅保留图中的目标。
  • 因为pr的shape为(h,w); 而old_imgshape为(h,w,c), 无法直接相乘,所以通过np.expand_dims(pr != 0, -1),将pr扩展到shape大小为(h,w,1)

2.7 统计每张图各类像素占比

if count:
    classes_nums        = np.zeros([self.num_classes])
    total_points_num    = orininal_h * orininal_w
    print('-' * 63)
    print("|%25s | %15s | %15s|"%("Key", "Value", "Ratio"))
    print('-' * 63)
    for i in range(self.num_classes):
        num     = np.sum(pr == i)
        ratio   = num / total_points_num * 100
        if num > 0:
            print("|%25s | %15s | %14.2f%%|"%(str(name_classes[i]), str(num), ratio))
            print('-' * 63)
        classes_nums[i] = num
    print("classes_nums:", classes_nums)

3. 完整代码

import colorsys
import copy
import time

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn

from nets.deeplabv3_plus import DeepLab
from utils.utils import cvtColor, preprocess_input, resize_image, show_config


#-----------------------------------------------------------------------------------#
#   使用自己训练好的模型预测需要修改3个参数
#   model_path、backbone和num_classes都需要修改!
#   如果出现shape不匹配,一定要注意训练时的model_path、backbone和num_classes的修改
#-----------------------------------------------------------------------------------#
class DeeplabV3(object):
    _defaults = {
        #-------------------------------------------------------------------#
        #   model_path指向logs文件夹下的权值文件
        #   训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
        #   验证集损失较低不代表miou较高,仅代表该权值在验证集上泛化性能较好。
        #-------------------------------------------------------------------#
        "model_path"        : 'logs/best_epoch_weights1213_40_10.pth',
        # "model_path"        : 'logs/best_epoch_weights_prune1117_40_10.pth',
        # "model_path"        : 'logs/best_epoch_weights_1127_15fps.pth',
       # "model_path"         : 'model_data/best_11--3---.pth',
        #----------------------------------------#
        #   所需要区分的类的个数+1
        #----------------------------------------#
        "num_classes"       : 2,
        #----------------------------------------#
        #   所使用的的主干网络:
        #   mobilenet
        #   xception    
        #----------------------------------------#
        # "backbone"          : "CSPnet",
        "backbone":  "mobilenet",  #  "RestNet18", #
        #----------------------------------------#
        #   输入图片的大小
        #----------------------------------------#
      #  "input_shape"       : [384, 640],
        "input_shape"       : [192, 320],
        # "input_shape"       : [256, 480],
        # "input_shape"       : [512, 512],
        #----------------------------------------#
        #   下采样的倍数,一般可选的为8和16
        #   与训练时设置的一样即可
        #----------------------------------------#
        "downsample_factor" : 16,
        #-------------------------------------------------#
        #   mix_type参数用于控制检测结果的可视化方式
        #
        #   mix_type = 0的时候代表原图与生成的图进行混合
        #   mix_type = 1的时候代表仅保留生成的图
        #   mix_type = 2的时候代表仅扣去背景,仅保留原图中的目标
        #-------------------------------------------------#
        "mix_type"          : 0,
        #-------------------------------#
        #   是否使用Cuda
        #   没有GPU可以设置成False
        #-------------------------------#
        "cuda"              : True,
    }

    #---------------------------------------------------#
    #   初始化Deeplab
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)
        #---------------------------------------------------#
        #   画框设置不同的颜色
        #---------------------------------------------------#
        if self.num_classes <= 21:
            self.colors = [ (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128),
                            (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),
                            (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128),
                            (128, 64, 12)]
            # self.colors = [ (0, 0, 0),(0, 0, 0), (0, 1, 0)]

        else:
            hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
            self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
            self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
        #---------------------------------------------------#
        #   获得模型
        #---------------------------------------------------#
        self.generate()
        
        show_config(**self._defaults)
                    
    #---------------------------------------------------#
    #   获得所有的分类
    #---------------------------------------------------#
    def generate(self, onnx=False):
        #-------------------------------#
        #   载入模型与权值
        #-------------------------------#
        self.net = DeepLab(num_classes=self.num_classes, backbone=self.backbone, downsample_factor=self.downsample_factor, pretrained=False)

        device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # self.net.load_state_dict(torch.load(self.model_path, map_location=device))
        self.net = torch.load(self.model_path, map_location=device)
        
        # self.net    = self.net.eval()
        print('{} model, and classes loaded.'.format(self.model_path))
        if not onnx:
            if self.cuda:
                self.net = nn.DataParallel(self.net)
                self.net = self.net.cuda()

    #---------------------------------------------------#
    #   检测图片
    #---------------------------------------------------#
    def detect_image(self, image, count=False, name_classes=None):
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        #---------------------------------------------------#
        #   对输入图像进行一个备份,后面用于绘图
        #---------------------------------------------------#
        old_img     = copy.deepcopy(image)
        orininal_h  = np.array(image).shape[0]
        orininal_w  = np.array(image).shape[1]
        #---------------------------------------------------------#
        #   给图像增加灰条,实现不失真的resize
        #   也可以直接resize进行识别
        #---------------------------------------------------------#
        image_data, nw, nh  = resize_image(image, (self.input_shape[1],self.input_shape[0]))
        #---------------------------------------------------------#
        #   添加上batch_size维度
        #---------------------------------------------------------#
        image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)

        with torch.no_grad():
            images = torch.from_numpy(image_data)
            if self.cuda:
                images = images.cuda()
                
            #---------------------------------------------------#
            #   图片传入网络进行预测
            #---------------------------------------------------#
            pr = self.net(images)[0]
            #---------------------------------------------------#
            #   取出每一个像素点的种类
            #---------------------------------------------------#
            pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
            #--------------------------------------#
            #   将灰条部分截取掉
            #--------------------------------------#
            pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
                    int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
            #---------------------------------------------------#
            #   进行图片的resize
            #---------------------------------------------------#
            pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
            #---------------------------------------------------#
            #   取出每一个像素点的种类
            #---------------------------------------------------#
            pr = pr.argmax(axis=-1)
            # print(pr)

            # file.write("/")
            # import scipy.io as io

            # pr = np.arange(pr)
            # result1 = np.array(pr)
            # np.savetxt('npresult1.txt', result1)

            # io.savemat('save.mat', {'result1': result1})
        
        #---------------------------------------------------------#
        #   计数
        #---------------------------------------------------------#
        if count:
            classes_nums        = np.zeros([self.num_classes])
            total_points_num    = orininal_h * orininal_w
            print('-' * 63)
            print("|%25s | %15s | %15s|"%("Key", "Value", "Ratio"))
            print('-' * 63)
            for i in range(self.num_classes):
                num     = np.sum(pr == i)
                ratio   = num / total_points_num * 100
                if num > 0:
                    print("|%25s | %15s | %14.2f%%|"%(str(name_classes[i]), str(num), ratio))
                    print('-' * 63)
                classes_nums[i] = num
            print("classes_nums:", classes_nums)
    
        if self.mix_type == 0:
            # seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
            # for c in range(self.num_classes):
            #     seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
            #     seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
            #     seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
            seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
            #------------------------------------------------#
            #   将新图片转换成Image的形式
            #------------------------------------------------#
            image   = Image.fromarray(np.uint8(seg_img))
            #------------------------------------------------#
            #   将新图与原图及进行混合
            #------------------------------------------------#
            image   = Image.blend(old_img, image, 0.7)

        elif self.mix_type == 1:
            seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
            for c in range(self.num_classes):
                #seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
                #seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
                #seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
                seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
                seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
                seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')

            seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
            #------------------------------------------------#
            #   将新图片转换成Image的形式
            #------------------------------------------------#

            #image   = Image.fromarray(np.uint8(seg_img))

            # 没有梯形的二值图
            pr = np.uint8(pr)
            image = pr
            image = Image.fromarray(pr)


        elif self.mix_type == 2:
            seg_img = (np.expand_dims(pr != 0, -1) * np.array(old_img, np.float32)).astype('uint8')
            #------------------------------------------------#
            #   将新图片转换成Image的形式
            #------------------------------------------------#
            image = Image.fromarray(np.uint8(seg_img))
        
        return image

    def get_FPS(self, image, test_interval):
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        #---------------------------------------------------------#
        #   给图像增加灰条,实现不失真的resize
        #   也可以直接resize进行识别
        #---------------------------------------------------------#
        image_data, nw, nh  = resize_image(image, (self.input_shape[1],self.input_shape[0]))
        #---------------------------------------------------------#
        #   添加上batch_size维度
        #---------------------------------------------------------#
        image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)

        with torch.no_grad():
            images = torch.from_numpy(image_data)
            if self.cuda:
                images = images.cuda()
                
            #---------------------------------------------------#
            #   图片传入网络进行预测
            #---------------------------------------------------#
            pr = self.net(images)[0]
            #---------------------------------------------------#
            #   取出每一个像素点的种类
            #---------------------------------------------------#
            pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1)
            #--------------------------------------#
            #   将灰条部分截取掉
            #--------------------------------------#
            pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
                    int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]

        t1 = time.time()
        for _ in range(test_interval):
            with torch.no_grad():
                #---------------------------------------------------#
                #   图片传入网络进行预测
                #---------------------------------------------------#
                pr = self.net(images)[0]
                #---------------------------------------------------#
                #   取出每一个像素点的种类
                #---------------------------------------------------#
                pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1)
                #--------------------------------------#
                #   将灰条部分截取掉
                #--------------------------------------#
                pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
                        int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
        t2 = time.time()
        tact_time = (t2 - t1) / test_interval
        return tact_time

    def convert_to_onnx(self, simplify, model_path):
        import onnx
        self.generate(onnx=True)

        im                  = torch.zeros(1, 3, *self.input_shape).to('cuda')  # image size(1, 3, 512, 512) BCHW
        input_layer_names   = ["images"]
        output_layer_names  = ["output"]
        
        # Export the model
        print(f'Starting export with onnx {onnx.__version__}.')
        torch.onnx.export(self.net,
                        im,
                        f               = model_path,
                        verbose         = False,
                        opset_version   = 11,  # 12,
                        training        = torch.onnx.TrainingMode.EVAL,
                        do_constant_folding = True,
                        input_names     = input_layer_names,
                        output_names    = output_layer_names,
                        dynamic_axes    = None)

        # Checks
        model_onnx = onnx.load(model_path)  # load onnx model
        onnx.checker.check_model(model_onnx)  # check onnx model

        # Simplify onnx
        if simplify:
            import onnxsim
            print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.')
            model_onnx, check = onnxsim.simplify(
                model_onnx,
                dynamic_input_shape=False,
                input_shapes=None)
            assert check, 'assert check failed'
            onnx.save(model_onnx, model_path)

        print('Onnx model save as {}'.format(model_path))
    
    def get_miou_png(self, image):
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        orininal_h  = np.array(image).shape[0]
        orininal_w  = np.array(image).shape[1]
        #---------------------------------------------------------#
        #   给图像增加灰条,实现不失真的resize
        #   也可以直接resize进行识别
        #---------------------------------------------------------#
        image_data, nw, nh  = resize_image(image, (self.input_shape[1],self.input_shape[0]))
        #---------------------------------------------------------#
        #   添加上batch_size维度
        #---------------------------------------------------------#
        image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)

        with torch.no_grad():
            images = torch.from_numpy(image_data)
            if self.cuda:
                images = images.cuda()
                
            #---------------------------------------------------#
            #   图片传入网络进行预测
            #---------------------------------------------------#
            pr = self.net(images)[0]
            #---------------------------------------------------#
            #   取出每一个像素点的种类
            #---------------------------------------------------#
            pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
            #--------------------------------------#
            #   将灰条部分截取掉
            #--------------------------------------#
            pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
                    int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
            #---------------------------------------------------#
            #   进行图片的resize
            #---------------------------------------------------#
            pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
            #---------------------------------------------------#
            #   取出每一个像素点的种类
            #---------------------------------------------------#
            pr = pr.argmax(axis=-1)
    
        image = Image.fromarray(np.uint8(pr))
        return image

参考

代码地址:https://github.com/bubbliiiing/deeplabv3-plus-pytorch/blob/main/deeplab.py

文章来源:https://blog.csdn.net/weixin_38346042/article/details/135695165
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。