【模型压缩】神经网络裁剪1

发布时间:2024年01月24日

神经网络裁剪基础知识

网络裁剪是一种模型压方法,对深度神经网络的稠密连接引入稀疏性,通过将“不重要”的权值直接置零来减少非零权值数量。网络裁剪的目的是在保证网络精度的情况下,保留重要权重,去掉不重要权重。网络裁剪分为边训练边裁剪和训练后裁剪两大类。在嵌入式智能平台上,可以有以下三种网络裁剪思路。

1基于训练与裁剪同步的裁剪思想

1.1原理

如下图所示,网络裁剪包括剪枝和神经元裁剪。神经网络裁剪前后对比
基于训练与裁剪同步的算法是一种经典可复用的网络裁剪算法,本文在以下实验中采用的是将BN层的缩放因子作为优化约束策略,即将其作为判断通道重要性的依据。某个通道的缩放因子越小,代表该通道越不重要。针对该算法的实现原理如下所示。
基于训练与裁剪同步的网络裁剪原理

1.2实验

实验的源码来自(https://github.com/Eric-mingjie/network-slimming),本次实验仅限于跑通源码,总结剪枝流程,熟悉一些通用的剪枝程序。参考资料来源于(https://mp.weixin.qq.com/s?__biz=MzA4MjY4NTk0NQ==&mid=2247484538&idx=1&sn=3c3c5ec4296b0fce745bdf756b065847&scene=21#wechat_redirect)

源码目录如下:
源码程序目录
STEP1:压缩前模型训练
这一步指定相关参数,运行main.py,可以选择较小的数据集cifar10,vgg网络,网络深度可选:16或19,迭代次数程序默认160;这一部分包含训练,运行时间大概一天,觉得没必要尝试;主要是了解相关参数的含义就行。
STEP2:稀疏训练
论文这一部分的思想是对每一个通道都引入一个缩放因子,然后和通道的输出相乘。接着联合训练网络权重和这些缩放因子,最后将小缩放因子的通道直接减掉,并且微调训练后的网络。

STEP3:模型剪枝

#计算需要剪枝的变量个数total
total = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        total += m.weight.data.shape[0]
#确定剪枝的全局阈值
bn = torch.zeros(total)
index = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone()
        index += size
#按照权值大小排序
y, i = torch.sort(bn)
thre_index = int(total * args.percent)

#确定要剪枝的阈值
thre = y[thre_index]
#***********************************预剪枝**********************************#
pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm2d):
        weight_copy = m.weight.data.abs().clone()
        mask = weight_copy.gt(thre).float().cuda()
        pruned = pruned + mask.shape[0] - torch.sum(mask)
        m.weight.data.mul_(mask)
        m.bias.data.mul_(mask)
        cfg.append(int(torch.sum(mask)))
        cfg_mask.append(mask.clone())
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))
    elif isinstance(m, nn.MaxPool2d):
        cfg.append('M')

pruned_ratio = pruned/total

STEP4:剪枝

start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
#遍历原始模型和新模型的每一层
for [m0, m1] in zip(model.modules(), newmodel.modules()):
#当遇到BatchNorm2d层时,根据end_mask的索引来选择保留的通道,并将对应的权重、偏置、running_mean和running_var复制到新模型中。
    if isinstance(m0, nn.BatchNorm2d):
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        if idx1.size == 1:
            idx1 = np.resize(idx1,(1,))
        m1.weight.data = m0.weight.data[idx1.tolist()].clone()
        m1.bias.data = m0.bias.data[idx1.tolist()].clone()
        m1.running_mean = m0.running_mean[idx1.tolist()].clone()
        m1.running_var = m0.running_var[idx1.tolist()].clone()
        layer_id_in_cfg += 1
        start_mask = end_mask.clone()
        if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
            end_mask = cfg_mask[layer_id_in_cfg]
    #当遇到Conv2d层时,根据start_mask和end_mask的索引来选择保留的输入通道和输出通道,并将对应的权重复制到新模型中。
    elif isinstance(m0, nn.Conv2d):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        if idx1.size == 1:
            idx1 = np.resize(idx1, (1,))
        w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
        w1 = w1[idx1.tolist(), :, :, :].clone()
        m1.weight.data = w1.clone()
    #当遇到Linear层时,根据start_mask的索引来选择保留的输入通道,并将对应的权重和偏置复制到新模型中。
    elif isinstance(m0, nn.Linear):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        m1.weight.data = m0.weight.data[:, idx0].clone()
        m1.bias.data = m0.bias.data.clone()
#保存新模型
torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth.tar'))

STEP5 :对新模型进行retrain
剪枝完成后,需要对新模型进行retrain ,使用main.py 修改相应的参数即可完成:

python main.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 16 --epochs 160

以上这些环节只是针对vgg模型的剪枝,对不同网络结构的剪枝过程略有不同,以后有时间会再看别的。

2基于动态裁剪与剪接、训练同步的裁剪思想

3基于裁剪滤波器的训练后裁剪算法

后面两个裁剪算法将在后续的学习过程中总结。从上一次发文到现在好像隔了好久,忙完实验室的工作后,今天终于有时间学习了,接下来也要坚持哦!

文章来源:https://blog.csdn.net/weixin_44793127/article/details/135520978
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。