caffe模型的python前向测试

发布时间:2023年12月17日

当我们训练好一个网络模型后必不可少的就是对模型跑前向,看模型的实际性能如何。python绝对是最简单的环境,所以本文写一个python版本的前向测试。

import os
import cv2
import sys
import caffe
import glob
import argparse
from PIL import Image
import numpy as np


def parse_args():
    parser = argparse.ArgumentParser(description='deblur arguments')
    parser.add_argument('--image_root_dir', type=str, default='/.../data',  # 存放数据集目录的路径
                        help='test image root dir')
    parser.add_argument('--test_txt', type=str, default='/.../test.txt',  # 生成的txt的相对路径
                        help='test txt path')
    parser.add_argument('--caffe_model', type=str,
                        default='/.../xxx.caffemodel',
                        help='caffemodel path')
    parser.add_argument('--deploy', type=str, default='/.../deploy.prototxt',
                        help='deploy path')
    parser.add_argument('--num_cls', type=int, default='3', help='class number')
    parser.add_argument('--input_size', type=int, default='96', help='net input size')
    parser.add_argument('--save_dir', type=str, default='./results', help='test result dir')
    parser.add_argument('--roc_name', type=str, default='roc.txt', help='test roc name')
    parser.add_argument('--saveimg_flag', type=int, default='1', help='if 0, do not save img, else save img')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    order = Test(args)


def Test(args):
    if os.path.exists(args.save_dir) == False:
        os.mkdir(args.save_dir)

    roc_file = open(args.save_dir + '/' + args.roc_name, 'w+')

    net = caffe.Net(args.deploy, args.caffe_model, caffe.TEST)
    transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
    transformer.set_transpose('data', (2, 0, 1))
    # transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))
    transformer.set_raw_scale('data', 255)
    # transformer.set_mean('data', np.array([104, 117, 123]))
    transformer.set_channel_swap('data', (2, 1, 0))

    f1 = open(args.test_txt, 'r')
    datas = f1.readlines()

    wrong = np.zeros(args.num_cls)
    right = np.zeros(args.num_cls)
    for data in datas:
        tmp = data.split(' ')

        imgname = tmp[0]
        for i in range(1, len(tmp) - 1):
            imgname += ' ' + tmp[i]
        label = int(tmp[-1][0])
        print(tmp, imgname, label)

        im = caffe.io.load_image(args.image_root_dir + '/' + imgname)
        im = cv2.resize(im, (args.input_size, args.input_size))
        net.blobs['data'].data[...] = transformer.preprocess('data', im)
        out = net.forward()
        prob = net.blobs['prob'].data[0].flatten()
        print(prob)
        order = prob.argsort()[-1]  # small-big
        print(order, label)

        if args.saveimg_flag != 0:
            if os.path.exists(os.path.join(args.save_dir, str(label) + "-" + str(order))) == False:
                os.mkdir(os.path.join(args.save_dir, str(label) + "-" + str(order)))
            im = Image.open(args.image_root_dir + '/' + imgname).convert('RGB')
            smallname = imgname.split('/')[-1]
            if prob[order] >= 0.6:
                im.save(os.path.join(args.save_dir, str(label) + "-" + str(order), str(prob[order]) + smallname))

        for p in prob:
            roc_file.write('%s ' % p)
        roc_file.write('%s\n' % label)

        if order != label:
            wrong[label] += 1
        else:
            right[label] += 1
    print(wrong, right)
    roc_file.close()


if __name__ == '__main__':
    main()
    print("done") 

生成的roc.txt可用于在下一篇画roc曲线。

文章来源:https://blog.csdn.net/weixin_45354497/article/details/135033713
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。