在前面一篇python版的caffe前向中,生成了一个用于画ROC曲线的txt文件,作为本代码的输入:
# -*- coding:utf-8
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve,auc
import numpy as np
colors = ['r', 'g', 'b', 'yellow', 'pink', 'black', 'purple', 'lime']
def get_output_file1(output_file, classNumber):
prod_all=[]
label_all=[]
for line in open(output_file):
x = line.split()
prod=[]
label=[]
for i in range(int(classNumber)):
prod.append(float(x[i]))
tag = int(x[classNumber])#int(x[classNumber])
for j in range(classNumber):
if (j == tag):
label.append(1)
else:
label.append(0)
prod_all.append(prod)
label_all.append(label)
return prod_all,label_all
def ROC(prod_all, label_all, classLabel, output_txtname, rgb="r", leged="line"):
y_true = np.array(label_all)
y_predict = np.array(prod_all)
fpr, tpr, thr = roc_curve(y_true[:, classLabel], y_predict[:, classLabel])
fid = open(output_txtname, 'a+')
fid.writelines(str(classLabel)+"\n"+" fpr tpr thr"+"\n")
for i in range(len(fpr)):
fid.writelines( str(fpr[i])+" "+str(tpr[i])+" "+str(thr[i])+"\n")
AUC=auc(fpr, tpr)
plt.plot(fpr, tpr, clip_on=False,color=rgb,label=leged+'-'+str(AUC)[0:6])
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.legend(loc='best')
return AUC
if __name__ == '__main__':
classNumber = 3
classes = ['eye','nose', 'ear']
input_txtname = '/.../roc.txt'
output_txtname = './roc1_th.txt'
output_imgname = './roc8.jpg'
prod_all, label_all = get_output_file1(input_txtname, classNumber)
AUC = []
for i in range(classNumber):
AUC.append(ROC(prod_all, label_all, i, output_txtname, rgb=colors[i], leged=classes[i]))
#plt.title("AUC")
#plt.show()
plt.savefig(output_imgname)
这里的colors为每个类的曲线的颜色,颜色的数量要多于类别个数否则不够用。