0.持续学习
1.LwF算法相关链接
2.基本想法
- 针对问题:在无法获得原始任务训练数据的情况下,适合使视觉系统适应新任务,并且保证其在旧任务上的性能
- 问题建模:学习对新任务具有判别能力的参数,同时保留训练数据上原始任务的输出
- 将网络分为所有任务共享部分和特定任务独享部分,网络架构如下:
图片
3.损失函数
- 待学习参数有三种:共享部分参数、旧任务们的独享参数、新任务独享参数
- 由三部分组成:旧任务损失、新任务损失、正则化项
- 旧任务损失:增长后的网络的输出与增长前的输出尽可能相同,采用知识蒸馏损失,类似交叉熵损失,只不过加大了较小概率的惩罚权重(其中关键参数T,要大于1来加大小概率的权重,文中通过网格搜索将其定位2)
- 新任务损失:对于新任务的预测与真实值尽可能相同,使用交叉熵损失或者NLL损失
- 正则化项:限制网络中所有参数,权重0.0005
- 新旧任务权衡:在新任务损失前面有一个系数来表示对新旧任务性能的权衡,文中取1,参数越大,在新任务上的性能越好,在旧任务上的性能越差。通过改变该参数可以获得新旧任务性能曲线。
4.训练流程
- 热身阶段(warm-up step):冻结共享部分参数、旧任务们的独享参数,单独训练新任务独享参数
- 联合优化阶段(joint-optimize step):优化所有参数
5.特点
- 与传统联合调优方法相比:无需存储旧任务的数据,新任务只需要通过一次共享层便可以用来进行旧任务和新任务的更新,却具有了联合调优的优点。但因为不同任务的分布会不相同,所以文中的方法效果会不如传统联合调优,传统联合调优的效果可以视为本文方法的上限。
- 效率分析:
- 最慢:共享参数的正反向传播
- 最快:特征提取层,因为只需要训练新任务的参数
- 与传统微调相比:多了一步旧任务的独享参数更新,效率稍微低一点
- 与传统联合调优相比:新旧任务共享的参数只需要进行一次前后向传播,效率更高
6.具体细节
- 使用动量0.9的随机梯度下降
- 在全连接层使用了dropout
- 用旧任务的信息对新任务进行归一化
- 数据增强:
- 5X5的网格上对调整过大小的图像进行随机的固定尺寸裁剪
- 随机镜像裁剪
- RGB值上添加方差
- 使用Xavier初始化新任务独享参数
- 学习率是原网络学习率的0.1-0.02倍
- 由于任务独享的特征提取部分参数量少,所以使用5倍学习率
- 对于学习速度相似的方法,使用相同的训练epoch来进行公平比较
- 有时为了防止过拟合、提升学习速度,会接近平稳在的时候将学习率变为0.1倍
- 为了公平比较,将热身阶段后的共享网络作为联合训练和微调训练的起始点
7.实验
- 添加单个新任务
- 添加多个新任务
- 数据集大小的影响
- 网络设计的影响
- 不同损失
- 扩展网络结构的效用
- 小学习率微调来保证旧任务的影响
- 改变任务专属部分的网络层数
8.结论
- 对于增长节点式的任务专属网络,其性能与原本的LwF性能相近,但是计算开销却大很多
- 仅仅降低共享网络的学习率对保留旧任务性能的帮助并不大,但却会很大程度影响新任务
- 用网络输出的变化来现在旧任务的变化要优于用网络参数的变化来衡量,因为网络参数一点小小的改变就可能引起输出巨大的改变
- 知识蒸馏损失略优于L1、L2、交叉熵损失,但优势很小
- 训练速度优于联合优化,对新任务的性能优于微调
- 本文针对旧任务的损失对旧任务性能上的表现更可解释
9.未来工作
- 应用到图像分类、跟踪等更多领域:分割、检测、视觉外的任务
- 探索根据任务分布针对性地保留一些过去的任务数据和输出(由于是面对重尾分布)
10.代码解读
import torch
torch.backends.cudnn.benchmark=True
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
from PIL import Image
from tqdm import tqdm
import time
import copy
import torchvision.models as models
import torchvision.transforms as transforms
def MultiClassCrossEntropy(logits, labels, T):
labels = Variable(labels.data, requires_grad=False).cuda()
outputs = torch.log_softmax(logits/T, dim=1)
labels = torch.softmax(labels/T, dim=1)
outputs = torch.sum(outputs * labels, dim=1, keepdim=False)
outputs = -torch.mean(outputs, dim=0, keepdim=False)
return Variable(outputs.data, requires_grad=True).cuda()
def kaiming_normal_init(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')
class Model(nn.Module):
'''
分为超参数、网络架构、类增加三个部分
前向传播里没有softmax
'''
def __init__(self, classes, classes_map, args):
self.init_lr = args.init_lr
self.num_epochs = args.num_epochs
self.batch_size = args.batch_size
self.lower_rate_epoch = [int(0.7 * self.num_epochs), int(0.9 * self.num_epochs)]
self.lr_dec_factor = 10
self.pretrained = False
self.momentum = 0.9
self.weight_decay = 0.0001
self.epsilon = 1e-16
super(Model, self).__init__()
self.model = models.resnet34(pretrained=self.pretrained)
self.model.apply(kaiming_normal_init)
"""独享层:一层全连接层,与classes数量有关,且没有偏置"""
num_features = self.model.fc.in_features
self.model.fc = nn.Linear(num_features, classes, bias=False)
self.fc = self.model.fc
'''共享层:resnet34除去最后一层'''
self.feature_extractor = nn.Sequential(*list(self.model.children())[:-1])
self.feature_extractor = nn.DataParallel(self.feature_extractor)
self.n_classes = 0
self.n_known = 0
self.classes_map = classes_map
def forward(self, x):
x = self.feature_extractor(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def increment_classes(self, new_classes):
"""Add n classes in the final fc layer"""
n = len(new_classes)
print('new classes: ', n)
in_features = self.fc.in_features
out_features = self.fc.out_features
weight = self.fc.weight.data
if self.n_known == 0:
new_out_features = n
else:
new_out_features = out_features + n
print('new out features: ', new_out_features)
self.model.fc = nn.Linear(in_features, new_out_features, bias=False)
self.fc = self.model.fc
kaiming_normal_init(self.fc.weight)
self.fc.weight.data[:out_features] = weight
self.n_classes += n
def classify(self, images):
"""Classify images by softmax
Args:
x: input image batch
Returns:
preds: Tensor of size (batch_size,)
"""
_, preds = torch.max(torch.softmax(self.forward(images), dim=1), dim=1, keepdim=False)
return preds
def update(self, dataset, class_map, args):
self.compute_means = True
prev_model = copy.deepcopy(self)
prev_model.cuda()
classes = list(set(dataset.train_labels))
print('Known: ', self.n_known)
if self.n_classes == 1 and self.n_known == 0:
new_classes = [classes[i] for i in range(1,len(classes))]
else:
new_classes = [cl for cl in classes if class_map[cl] >= self.n_known]
if len(new_classes) > 0:
self.increment_classes(new_classes)
self.cuda()
loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size,
shuffle=True, num_workers=12)
print("Batch Size (for n_classes classes) : ", len(dataset))
optimizer = optim.SGD(self.parameters(), lr=self.init_lr, momentum = self.momentum, weight_decay=self.weight_decay)
with tqdm(total=self.num_epochs) as pbar:
for epoch in range(self.num_epochs):
for i, (indices, images, labels) in enumerate(loader):
seen_labels = []
images = Variable(torch.FloatTensor(images)).cuda()
seen_labels = torch.LongTensor([class_map[label] for label in labels.numpy()])
labels = Variable(seen_labels).cuda()
optimizer.zero_grad()
logits = self.forward(images)
cls_loss = nn.CrossEntropyLoss()(logits, labels)
if self.n_classes//len(new_classes) > 1:
dist_target = prev_model.forward(images)
logits_dist = logits[:,:-(self.n_classes-self.n_known)]
dist_loss = MultiClassCrossEntropy(logits_dist, dist_target, 2)
loss = dist_loss+cls_loss
else:
loss = cls_loss
loss.backward()
optimizer.step()
if (i+1) % 1 == 0:
tqdm.write('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'
%(epoch+1, self.num_epochs, i+1, np.ceil(len(dataset)/self.batch_size), loss.data))
pbar.update(1)