网络裁剪是一种模型压方法,对深度神经网络的稠密连接引入稀疏性,通过将“不重要”的权值直接置零来减少非零权值数量。网络裁剪的目的是在保证网络精度的情况下,保留重要权重,去掉不重要权重。网络裁剪分为边训练边裁剪和训练后裁剪两大类。在嵌入式智能平台上,可以有以下三种网络裁剪思路。
如下图所示,网络裁剪包括剪枝和神经元裁剪。
基于训练与裁剪同步的算法是一种经典可复用的网络裁剪算法,本文在以下实验中采用的是将BN层的缩放因子作为优化约束策略,即将其作为判断通道重要性的依据。某个通道的缩放因子越小,代表该通道越不重要。针对该算法的实现原理如下所示。
实验的源码来自(https://github.com/Eric-mingjie/network-slimming),本次实验仅限于跑通源码,总结剪枝流程,熟悉一些通用的剪枝程序。参考资料来源于(https://mp.weixin.qq.com/s?__biz=MzA4MjY4NTk0NQ==&mid=2247484538&idx=1&sn=3c3c5ec4296b0fce745bdf756b065847&scene=21#wechat_redirect)
源码目录如下:
STEP1:压缩前模型训练
这一步指定相关参数,运行main.py,可以选择较小的数据集cifar10,vgg网络,网络深度可选:16或19,迭代次数程序默认160;这一部分包含训练,运行时间大概一天,觉得没必要尝试;主要是了解相关参数的含义就行。
STEP2:稀疏训练
论文这一部分的思想是对每一个通道都引入一个缩放因子,然后和通道的输出相乘。接着联合训练网络权重和这些缩放因子,最后将小缩放因子的通道直接减掉,并且微调训练后的网络。
STEP3:模型剪枝
#计算需要剪枝的变量个数total
total = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
total += m.weight.data.shape[0]
#确定剪枝的全局阈值
bn = torch.zeros(total)
index = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
size = m.weight.data.shape[0]
bn[index:(index+size)] = m.weight.data.abs().clone()
index += size
#按照权值大小排序
y, i = torch.sort(bn)
thre_index = int(total * args.percent)
#确定要剪枝的阈值
thre = y[thre_index]
#***********************************预剪枝**********************************#
pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
if isinstance(m, nn.BatchNorm2d):
weight_copy = m.weight.data.abs().clone()
mask = weight_copy.gt(thre).float().cuda()
pruned = pruned + mask.shape[0] - torch.sum(mask)
m.weight.data.mul_(mask)
m.bias.data.mul_(mask)
cfg.append(int(torch.sum(mask)))
cfg_mask.append(mask.clone())
print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
format(k, mask.shape[0], int(torch.sum(mask))))
elif isinstance(m, nn.MaxPool2d):
cfg.append('M')
pruned_ratio = pruned/total
STEP4:剪枝
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
#遍历原始模型和新模型的每一层
for [m0, m1] in zip(model.modules(), newmodel.modules()):
#当遇到BatchNorm2d层时,根据end_mask的索引来选择保留的通道,并将对应的权重、偏置、running_mean和running_var复制到新模型中。
if isinstance(m0, nn.BatchNorm2d):
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
if idx1.size == 1:
idx1 = np.resize(idx1,(1,))
m1.weight.data = m0.weight.data[idx1.tolist()].clone()
m1.bias.data = m0.bias.data[idx1.tolist()].clone()
m1.running_mean = m0.running_mean[idx1.tolist()].clone()
m1.running_var = m0.running_var[idx1.tolist()].clone()
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
end_mask = cfg_mask[layer_id_in_cfg]
#当遇到Conv2d层时,根据start_mask和end_mask的索引来选择保留的输入通道和输出通道,并将对应的权重复制到新模型中。
elif isinstance(m0, nn.Conv2d):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
w1 = w1[idx1.tolist(), :, :, :].clone()
m1.weight.data = w1.clone()
#当遇到Linear层时,根据start_mask的索引来选择保留的输入通道,并将对应的权重和偏置复制到新模型中。
elif isinstance(m0, nn.Linear):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
m1.weight.data = m0.weight.data[:, idx0].clone()
m1.bias.data = m0.bias.data.clone()
#保存新模型
torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth.tar'))
STEP5 :对新模型进行retrain
剪枝完成后,需要对新模型进行retrain ,使用main.py 修改相应的参数即可完成:
python main.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 16 --epochs 160
以上这些环节只是针对vgg模型的剪枝,对不同网络结构的剪枝过程略有不同,以后有时间会再看别的。
后面两个裁剪算法将在后续的学习过程中总结。从上一次发文到现在好像隔了好久,忙完实验室的工作后,今天终于有时间学习了,接下来也要坚持哦!