简介

模型剪枝就是根据神经元的贡献程度对网络中的神经元进行排名,可以从网络中移除排名较低的神经元,从而形成一个更小、更快的网络模型。

基本思想示意图:

模型剪枝根据神经元权重的L1/L2范数来进行排序。剪枝后,准确率会下降,网络通常是训练-剪枝-训练-剪枝(trained-pruned-trained-pruned)迭代恢复的。如果我们一次剪枝太多,网络可能会被破坏得无法恢复。所以在实践中,这是一个迭代过程——通常被称为迭代剪枝(iterative pruning):剪枝/训练/重复。

训练时使用L1正则化能对参数进行稀疏作用

L1:稀疏与特征选择;L2:平滑特征

代码实现

预训练:

原始网络模型需要满足 Conv2d+BatchNorm2d+ReLU 作为一个整体

训练时在BatchNorm层增加L1正则进行稀疏训练,得到每个特征图对应的gamma值,即γ越小,其对应的特征图越不重要,为了使得γ 能有特征选择的作用,引入L1正则来控制γ

def updateBN(model, s):for m in model.modules():if isinstance(m, nn.BatchNorm2d):# L1 大于0为1 小于0为-1 0还是0m.weight.grad.data.add_(s*torch.sign(m.weight.data))'''
'''
#在训练函数中调用
'''
'''
loss.backward()
#剪枝优化
sr = 0.0001
if sr:updateBN(self.model,sr)
self.optimizer.step()

 剪枝:

加载预训练模型,进行剪枝,然后保存剪枝后的模型

需要指定--percent 剪枝比例、--model 预训练的模型、--save 保存剪枝后的模型名称

import os
import argparse
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms#from vgg import vgg
from model.model import ASPNET
import numpy as np# Prune settings
parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')
parser.add_argument('--dataset', type=str, default='cifar10',help='training dataset (default: cifar10)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',help='input batch size for testing (default: 1000)')
parser.add_argument('--no-cuda', action='store_true', default=False,help='disables CUDA training')
parser.add_argument('--percent', type=float, default=0.5,help='scale sparse rate (default: 0.5)')
parser.add_argument('--model', default='', type=str, metavar='PATH',help='path to raw trained model (default: none)')
parser.add_argument('--save', default='', type=str, metavar='PATH',help='path to save prune model (default: none)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()#model = vgg()
model = ASPNET()
if args.cuda:model.cuda()
if args.model:if os.path.isfile(args.model):print("=> loading checkpoint '{}'".format(args.model))checkpoint = torch.load(args.model)args.start_epoch = checkpoint['epoch']best_prec1 = checkpoint['monitor_best']model.load_state_dict(checkpoint['state_dict'])print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}".format(args.model, checkpoint['epoch'], best_prec1))else:print("=> no checkpoint found at '{}'".format(args.resume))print(model)
total = 0 # 每层特征图个数 总和
for m in model.modules():if isinstance(m, nn.BatchNorm2d):total += m.weight.data.shape[0]bn = torch.zeros(total) # 拿到每一个gamma值 每个特征图都会对应一个γ、β
index = 0
for m in model.modules():if isinstance(m, nn.BatchNorm2d):size = m.weight.data.shape[0]bn[index:(index+size)] = m.weight.data.abs().clone()index += sizey, i = torch.sort(bn)
thre_index = int(total * args.percent)
thre = y[thre_index]pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):if isinstance(m, nn.BatchNorm2d):weight_copy = m.weight.data.clone()mask = weight_copy.abs().gt(thre).float().cuda() #.gt 比较前者是否大于后者pruned = pruned + mask.shape[0] - torch.sum(mask)m.weight.data.mul_(mask) # BN层gamma置0m.bias.data.mul_(mask) #cfg.append(int(torch.sum(mask)))cfg_mask.append(mask.clone())print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))elif isinstance(m, nn.MaxPool2d):cfg.append('M')pruned_ratio = pruned/totalprint('Pre-processing Successful!')# 执行剪枝
print(cfg)
#newmodel = vgg(cfg=cfg) # 剪枝后的模型
newmodel = ASPNET(net_name=cfg) # 剪枝后的模型
newmodel.cuda()
# 为剪枝后的模型赋值权重
layer_id_in_cfg = 0
start_mask = torch.ones(1) #输入
end_mask = cfg_mask[layer_id_in_cfg] #输出
for [m0, m1] in zip(model.modules(), newmodel.modules()):if isinstance(m0, nn.BatchNorm2d): idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) # 赋值m1.weight.data = m0.weight.data[idx1].clone()m1.bias.data = m0.bias.data[idx1].clone()m1.running_mean = m0.running_mean[idx1].clone()m1.running_var = m0.running_var[idx1].clone()layer_id_in_cfg += 1start_mask = end_mask.clone() #下一层的if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FCend_mask = cfg_mask[layer_id_in_cfg] #输出elif isinstance(m0, nn.Conv2d):idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))print(idx0)print(idx1)if idx0.size == 1:idx0 = np.resize(idx0, (1,))if idx1.size == 1:idx1 = np.resize(idx1, (1,))#print('In shape: {:d} Out shape:{:d}'.format(idx0.shape[0], idx1.shape[0]))w = m0.weight.data[:, idx0, :, :].clone() #拿到原始训练好权重w = w[idx1, :, :, :].clone()m1.weight.data = w.clone() # 将所需权重赋值到剪枝后的模型# m1.bias.data = m0.bias.data[idx1].clone()elif isinstance(m0, nn.Linear):#idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))#m1.weight.data = m0.weight.data[:, idx0].clone()m1.weight.data = m0.weight.data.clone()torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, args.save)print(newmodel)

剪枝前参数大小  27102,剪枝后参数大小 7185

使用剪枝后的模型再训练:

使用剪枝后的网络架构,同时加载剪枝后的模型参数进行初始化

refine = 剪枝后的模型
if refine:checkpoint = torch.load(refine)print(checkpoint['cfg'])model = ASPNET(net_name=checkpoint['cfg'])#使用剪枝后的网络架构model.cuda()model.load_state_dict(checkpoint['state_dict'])

减少模型参数---模型剪枝(Pruning Deep Neural Networks)相关推荐

  1. 剪枝综述论文阅读:Methods for Pruning Deep Neural Networks

    文章目录 一.概述 1.分类 2.评估 二.Magnitude based pruning 1.权重剪枝 2.彩票定理 3.特征图和过滤器剪枝 (1)基于通道方差的剪枝 Inbound pruning ...

  2. 【模型压缩】Channel Pruning for Accelerating Very Deep Neural Networks算法笔记

    转:https://blog.csdn.net/u014380165/article/details/79811779 论文:Channel Pruning for Accelerating Very ...

  3. 《AI系统周刊》第4期:DNN模型压缩之剪枝(Pruning)

    No.04 智源社区 AI系统组 A I 系  统 研究 观点 资源 活动 关于周刊 AI系统是当前人工智能领域极具现实意义与前瞻性的研究热点之一,为了帮助研究与工程人员了解这一领域的进展和资讯,我们 ...

  4. DEEP COMPRESSION: COMPRESSING DEEP NEURAL NETWORKS WITH PRUNING, TRAINED QUANTIZATION AND HUFFMAN

    深入理解DEEP COMPRESSION: COMPRESSING DEEP NEURAL NETWORKS WITH PRUNING, TRAINED QUANTIZATION AND HUFFMA ...

  5. deep compression:compressing deep neural networks with pruning,trained quantization and huffman codi

    deep compression:compressing deep neural networks with pruning,trained quantization and huffman codi ...

  6. 论文解读《Structured Pruning for Deep Neural Networks with Adaptive Pruning Rate Derivation Based on Con》

    论文:Structured Pruning for Deep Neural Networks with Adaptive Pruning Rate Derivation Based on Connec ...

  7. Channel Pruning for Accelerating Very Deep Neural Networks

    Channel Pruning for Accelerating Very Deep Neural Networks ######################################### ...

  8. 深度学习(六十九)darknet 实现实验 Compressing Deep Neural Networks with Pruning, Trained Quantization and Huffma

    本文主要实验文献文献<Deep Compression: Compressing Deep Neural Networks with Pruning, Trained Quantization ...

  9. 3.Deep Neural Networks for YouTube Recommendations论文精细解读

    一.总述 今天分享的是Deep Neural Networks for Y ouTube Recommendations这篇论文的一些核心,这篇论文被称为推荐系统工程实践领域的一篇神文,每一个细节都值 ...

最新文章

  1. DOS下处理含特殊字符[如:]的字符串
  2. EJBCA 6 通过调用WebService接口增加用户并获取证书
  3. android netty导入_Android Netty框架的使用
  4. iOS开发基础知识--碎片27
  5. java统计文件字符数量_Java统计文件注释个数和注释字符数
  6. LeetCode——二叉树的最近公共祖先
  7. stl:string:将str中的oldstr替换为newstr
  8. windows和wsl设置代理
  9. 华硕笔记本 X550JD4710HQ
  10. Atitit nodejs js 获取图像分辨率 尺寸 大小 宽度 高度
  11. vbs代码弹计算机,如何恶搞朋友的电脑?超简单的vbs代码
  12. Linux操作系统原理— 进程与线程管理
  13. xlsx 解析excel 后渲染到表格里(前端实现 解析excel渲染到表格)
  14. 自动避障算法c语言,基于单片机的自动避障小车设计与实现报告.doc
  15. matlab中加载数据方式,【转帖】Matlab数据导入方法
  16. linux 拒绝访问文件夹,文件夹拒绝访问的原因与解决办法
  17. 消逝的光芒 Dying Light for Mac 跑酷僵尸游戏 动作生存游戏
  18. python爬b站视频_python代码福利:用requests爬取B站视频封面
  19. 串之Ukkonen、Rabin_karp算法
  20. ESP-Hosted 入门介绍 使用指南

热门文章

  1. 找不到office.zh\officelr.cab
  2. srm采购管理系统有那些功能
  3. 内存泄漏的8种情况(附代码示例)
  4. fzu-神龙的难题 舞蹈链之可重复覆盖
  5. 华为手机的计算机删除了怎么恢复,技巧 | 如何让手机秒变电脑?误删文件怎么恢复?...
  6. Codeforces Round#525(Div.2)Ehab and a component choosing problem CodeForces - 1088E
  7. 《延禧攻略》网播50亿+,脑洞广告背后有家智能营销扫地僧
  8. 解决 Deepin V20、Ubuntu 20.04、Linux Mint 20 声卡无法识别导致的没有声音
  9. 洛谷P6054:开门大吉
  10. Top2:CNN 卷积神经网络实现猫狗图片识别二分类