目录
文章性质:学习笔记 📖
视频教程:FCN源码解析(Pytorch)- 4 通过混淆矩阵计算评价指标
主要内容:根据 视频教程 中提供的 FCN 源代码(PyTorch),讲解了如何通过混淆矩阵计算评价指标。
在 train_and_val.py 文件中的?evaluate?函数代码如下:
def evaluate(model, data_loader, device, num_classes):
model.eval()
confmat = utils.ConfusionMatrix(num_classes)
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, 100, header):
image, target = image.to(device), target.to(device)
output = model(image)
output = output['out']
confmat.update(target.flatten(), output.argmax(1).flatten())
confmat.reduce_from_all_processes()
return confmat
【代码解析】对 evaluate 函数代码的具体解析(结合下图):
【注意】?output.argmax(1) 中的 1 是指在 channel 维度,而 argmax 方法用于?将每个像素预测值最大的类别作为其预测类别 。
在 distributed_utils.py 文件中的?ConfusionMatrix 类代码如下:
class ConfusionMatrix(object):
def __init__(self, num_classes):
self.num_classes = num_classes
self.mat = None
def update(self, a, b):
n = self.num_classes
if self.mat is None:
# 创建混淆矩阵
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
with torch.no_grad():
# 寻找GT中为目标的像素索引
k = (a >= 0) & (a < n)
# 统计像素真实类别a[k]被预测成类别b[k]的个数(这里的做法很巧妙)
inds = n * a[k].to(torch.int64) + b[k]
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
def reset(self):
if self.mat is not None:
self.mat.zero_()
def compute(self):
h = self.mat.float()
# 计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)
acc_global = torch.diag(h).sum() / h.sum()
# 计算每个类别的准确率
acc = torch.diag(h) / h.sum(1)
# 计算每个类别预测与真实目标的iou
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
return acc_global, acc, iu
def reduce_from_all_processes(self):
if not torch.distributed.is_available():
return
if not torch.distributed.is_initialized():
return
torch.distributed.barrier()
torch.distributed.all_reduce(self.mat)
def __str__(self):
acc_global, acc, iu = self.compute()
return (
'global correct: {:.1f}\n'
'average row correct: {}\n'
'IoU: {}\n'
'mean IoU: {:.1f}').format(
acc_global.item() * 100,
['{:.1f}'.format(i) for i in (acc * 100).tolist()],
['{:.1f}'.format(i) for i in (iu * 100).tolist()],
iu.mean().item() * 100)
【代码解析】ConfusionMatrix 类中的 update 函数传入了真实标签 a 和预测标签 b 等参数,代码的具体解析(结合上图):
【注意】关于?FCN 源码中的混淆矩阵,其横坐标是预测标签,纵坐标是真实标签,与【回顾】中的混淆矩阵恰好相反。
【代码解析】具体的计算过程可以参考【回顾】中的截图,注意代码中混淆矩阵的横纵坐标与【回顾】示例中的相反:
【说明】用 * 100 表示百分数,使用 iu.mean() 计算平均数,输出格式如下图所示:
常见的语义分割评价指标主要包括 Pixel Accuracy ( Global Accuracy )、mean Accuracy、mean IoU 等:
关于语义分割评价指标,微臣在本系列的第一篇文章中已经作了详细的讲解,王子公主们请移驾我的这篇博文: