当我们训练好一个网络模型后必不可少的就是对模型跑前向,看模型的实际性能如何。python绝对是最简单的环境,所以本文写一个python版本的前向测试。
import os
import cv2
import sys
import caffe
import glob
import argparse
from PIL import Image
import numpy as np
def parse_args():
parser = argparse.ArgumentParser(description='deblur arguments')
parser.add_argument('--image_root_dir', type=str, default='/.../data', # 存放数据集目录的路径
help='test image root dir')
parser.add_argument('--test_txt', type=str, default='/.../test.txt', # 生成的txt的相对路径
help='test txt path')
parser.add_argument('--caffe_model', type=str,
default='/.../xxx.caffemodel',
help='caffemodel path')
parser.add_argument('--deploy', type=str, default='/.../deploy.prototxt',
help='deploy path')
parser.add_argument('--num_cls', type=int, default='3', help='class number')
parser.add_argument('--input_size', type=int, default='96', help='net input size')
parser.add_argument('--save_dir', type=str, default='./results', help='test result dir')
parser.add_argument('--roc_name', type=str, default='roc.txt', help='test roc name')
parser.add_argument('--saveimg_flag', type=int, default='1', help='if 0, do not save img, else save img')
args = parser.parse_args()
return args
def main():
args = parse_args()
order = Test(args)
def Test(args):
if os.path.exists(args.save_dir) == False:
os.mkdir(args.save_dir)
roc_file = open(args.save_dir + '/' + args.roc_name, 'w+')
net = caffe.Net(args.deploy, args.caffe_model, caffe.TEST)
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2, 0, 1))
# transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))
transformer.set_raw_scale('data', 255)
# transformer.set_mean('data', np.array([104, 117, 123]))
transformer.set_channel_swap('data', (2, 1, 0))
f1 = open(args.test_txt, 'r')
datas = f1.readlines()
wrong = np.zeros(args.num_cls)
right = np.zeros(args.num_cls)
for data in datas:
tmp = data.split(' ')
imgname = tmp[0]
for i in range(1, len(tmp) - 1):
imgname += ' ' + tmp[i]
label = int(tmp[-1][0])
print(tmp, imgname, label)
im = caffe.io.load_image(args.image_root_dir + '/' + imgname)
im = cv2.resize(im, (args.input_size, args.input_size))
net.blobs['data'].data[...] = transformer.preprocess('data', im)
out = net.forward()
prob = net.blobs['prob'].data[0].flatten()
print(prob)
order = prob.argsort()[-1] # small-big
print(order, label)
if args.saveimg_flag != 0:
if os.path.exists(os.path.join(args.save_dir, str(label) + "-" + str(order))) == False:
os.mkdir(os.path.join(args.save_dir, str(label) + "-" + str(order)))
im = Image.open(args.image_root_dir + '/' + imgname).convert('RGB')
smallname = imgname.split('/')[-1]
if prob[order] >= 0.6:
im.save(os.path.join(args.save_dir, str(label) + "-" + str(order), str(prob[order]) + smallname))
for p in prob:
roc_file.write('%s ' % p)
roc_file.write('%s\n' % label)
if order != label:
wrong[label] += 1
else:
right[label] += 1
print(wrong, right)
roc_file.close()
if __name__ == '__main__':
main()
print("done")
生成的roc.txt可用于在下一篇画roc曲线。