减少模型参数---模型剪枝(Pruning Deep Neural Networks)
简介
模型剪枝就是根据神经元的贡献程度对网络中的神经元进行排名,可以从网络中移除排名较低的神经元,从而形成一个更小、更快的网络模型。
基本思想示意图:
模型剪枝根据神经元权重的L1/L2范数来进行排序。剪枝后,准确率会下降,网络通常是训练-剪枝-训练-剪枝(trained-pruned-trained-pruned)迭代恢复的。如果我们一次剪枝太多,网络可能会被破坏得无法恢复。所以在实践中,这是一个迭代过程——通常被称为迭代剪枝(iterative pruning):剪枝/训练/重复。
训练时使用L1正则化能对参数进行稀疏作用
L1:稀疏与特征选择;L2:平滑特征
代码实现
预训练:
原始网络模型需要满足 Conv2d+BatchNorm2d+ReLU 作为一个整体
训练时在BatchNorm层增加L1正则进行稀疏训练,得到每个特征图对应的gamma值,即γ越小,其对应的特征图越不重要,为了使得γ 能有特征选择的作用,引入L1正则来控制γ
def updateBN(model, s):for m in model.modules():if isinstance(m, nn.BatchNorm2d):# L1 大于0为1 小于0为-1 0还是0m.weight.grad.data.add_(s*torch.sign(m.weight.data))'''
'''
#在训练函数中调用
'''
'''
loss.backward()
#剪枝优化
sr = 0.0001
if sr:updateBN(self.model,sr)
self.optimizer.step()
剪枝:
加载预训练模型,进行剪枝,然后保存剪枝后的模型
需要指定--percent 剪枝比例、--model 预训练的模型、--save 保存剪枝后的模型名称
import os
import argparse
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms#from vgg import vgg
from model.model import ASPNET
import numpy as np# Prune settings
parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')
parser.add_argument('--dataset', type=str, default='cifar10',help='training dataset (default: cifar10)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',help='input batch size for testing (default: 1000)')
parser.add_argument('--no-cuda', action='store_true', default=False,help='disables CUDA training')
parser.add_argument('--percent', type=float, default=0.5,help='scale sparse rate (default: 0.5)')
parser.add_argument('--model', default='', type=str, metavar='PATH',help='path to raw trained model (default: none)')
parser.add_argument('--save', default='', type=str, metavar='PATH',help='path to save prune model (default: none)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()#model = vgg()
model = ASPNET()
if args.cuda:model.cuda()
if args.model:if os.path.isfile(args.model):print("=> loading checkpoint '{}'".format(args.model))checkpoint = torch.load(args.model)args.start_epoch = checkpoint['epoch']best_prec1 = checkpoint['monitor_best']model.load_state_dict(checkpoint['state_dict'])print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}".format(args.model, checkpoint['epoch'], best_prec1))else:print("=> no checkpoint found at '{}'".format(args.resume))print(model)
total = 0 # 每层特征图个数 总和
for m in model.modules():if isinstance(m, nn.BatchNorm2d):total += m.weight.data.shape[0]bn = torch.zeros(total) # 拿到每一个gamma值 每个特征图都会对应一个γ、β
index = 0
for m in model.modules():if isinstance(m, nn.BatchNorm2d):size = m.weight.data.shape[0]bn[index:(index+size)] = m.weight.data.abs().clone()index += sizey, i = torch.sort(bn)
thre_index = int(total * args.percent)
thre = y[thre_index]pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):if isinstance(m, nn.BatchNorm2d):weight_copy = m.weight.data.clone()mask = weight_copy.abs().gt(thre).float().cuda() #.gt 比较前者是否大于后者pruned = pruned + mask.shape[0] - torch.sum(mask)m.weight.data.mul_(mask) # BN层gamma置0m.bias.data.mul_(mask) #cfg.append(int(torch.sum(mask)))cfg_mask.append(mask.clone())print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))elif isinstance(m, nn.MaxPool2d):cfg.append('M')pruned_ratio = pruned/totalprint('Pre-processing Successful!')# 执行剪枝
print(cfg)
#newmodel = vgg(cfg=cfg) # 剪枝后的模型
newmodel = ASPNET(net_name=cfg) # 剪枝后的模型
newmodel.cuda()
# 为剪枝后的模型赋值权重
layer_id_in_cfg = 0
start_mask = torch.ones(1) #输入
end_mask = cfg_mask[layer_id_in_cfg] #输出
for [m0, m1] in zip(model.modules(), newmodel.modules()):if isinstance(m0, nn.BatchNorm2d): idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) # 赋值m1.weight.data = m0.weight.data[idx1].clone()m1.bias.data = m0.bias.data[idx1].clone()m1.running_mean = m0.running_mean[idx1].clone()m1.running_var = m0.running_var[idx1].clone()layer_id_in_cfg += 1start_mask = end_mask.clone() #下一层的if layer_id_in_cfg < len(cfg_mask): # do not change in Final FCend_mask = cfg_mask[layer_id_in_cfg] #输出elif isinstance(m0, nn.Conv2d):idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))print(idx0)print(idx1)if idx0.size == 1:idx0 = np.resize(idx0, (1,))if idx1.size == 1:idx1 = np.resize(idx1, (1,))#print('In shape: {:d} Out shape:{:d}'.format(idx0.shape[0], idx1.shape[0]))w = m0.weight.data[:, idx0, :, :].clone() #拿到原始训练好权重w = w[idx1, :, :, :].clone()m1.weight.data = w.clone() # 将所需权重赋值到剪枝后的模型# m1.bias.data = m0.bias.data[idx1].clone()elif isinstance(m0, nn.Linear):#idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))#m1.weight.data = m0.weight.data[:, idx0].clone()m1.weight.data = m0.weight.data.clone()torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, args.save)print(newmodel)
剪枝前参数大小 27102,剪枝后参数大小 7185
使用剪枝后的模型再训练:
使用剪枝后的网络架构,同时加载剪枝后的模型参数进行初始化
refine = 剪枝后的模型
if refine:checkpoint = torch.load(refine)print(checkpoint['cfg'])model = ASPNET(net_name=checkpoint['cfg'])#使用剪枝后的网络架构model.cuda()model.load_state_dict(checkpoint['state_dict'])
减少模型参数---模型剪枝(Pruning Deep Neural Networks)相关推荐
- 剪枝综述论文阅读:Methods for Pruning Deep Neural Networks
文章目录 一.概述 1.分类 2.评估 二.Magnitude based pruning 1.权重剪枝 2.彩票定理 3.特征图和过滤器剪枝 (1)基于通道方差的剪枝 Inbound pruning ...
- 【模型压缩】Channel Pruning for Accelerating Very Deep Neural Networks算法笔记
转:https://blog.csdn.net/u014380165/article/details/79811779 论文:Channel Pruning for Accelerating Very ...
- 《AI系统周刊》第4期:DNN模型压缩之剪枝(Pruning)
No.04 智源社区 AI系统组 A I 系 统 研究 观点 资源 活动 关于周刊 AI系统是当前人工智能领域极具现实意义与前瞻性的研究热点之一,为了帮助研究与工程人员了解这一领域的进展和资讯,我们 ...
- DEEP COMPRESSION: COMPRESSING DEEP NEURAL NETWORKS WITH PRUNING, TRAINED QUANTIZATION AND HUFFMAN
深入理解DEEP COMPRESSION: COMPRESSING DEEP NEURAL NETWORKS WITH PRUNING, TRAINED QUANTIZATION AND HUFFMA ...
- deep compression:compressing deep neural networks with pruning,trained quantization and huffman codi
deep compression:compressing deep neural networks with pruning,trained quantization and huffman codi ...
- 论文解读《Structured Pruning for Deep Neural Networks with Adaptive Pruning Rate Derivation Based on Con》
论文:Structured Pruning for Deep Neural Networks with Adaptive Pruning Rate Derivation Based on Connec ...
- Channel Pruning for Accelerating Very Deep Neural Networks
Channel Pruning for Accelerating Very Deep Neural Networks ######################################### ...
- 深度学习(六十九)darknet 实现实验 Compressing Deep Neural Networks with Pruning, Trained Quantization and Huffma
本文主要实验文献文献<Deep Compression: Compressing Deep Neural Networks with Pruning, Trained Quantization ...
- 3.Deep Neural Networks for YouTube Recommendations论文精细解读
一.总述 今天分享的是Deep Neural Networks for Y ouTube Recommendations这篇论文的一些核心,这篇论文被称为推荐系统工程实践领域的一篇神文,每一个细节都值 ...
最新文章
- DOS下处理含特殊字符[如:]的字符串
- EJBCA 6 通过调用WebService接口增加用户并获取证书
- android netty导入_Android Netty框架的使用
- iOS开发基础知识--碎片27
- java统计文件字符数量_Java统计文件注释个数和注释字符数
- LeetCode——二叉树的最近公共祖先
- stl:string:将str中的oldstr替换为newstr
- windows和wsl设置代理
- 华硕笔记本 X550JD4710HQ
- Atitit nodejs js 获取图像分辨率 尺寸 大小 宽度 高度
- vbs代码弹计算机,如何恶搞朋友的电脑?超简单的vbs代码
- Linux操作系统原理— 进程与线程管理
- xlsx 解析excel 后渲染到表格里(前端实现 解析excel渲染到表格)
- 自动避障算法c语言,基于单片机的自动避障小车设计与实现报告.doc
- matlab中加载数据方式,【转帖】Matlab数据导入方法
- linux 拒绝访问文件夹,文件夹拒绝访问的原因与解决办法
- 消逝的光芒 Dying Light for Mac 跑酷僵尸游戏 动作生存游戏
- python爬b站视频_python代码福利:用requests爬取B站视频封面
- 串之Ukkonen、Rabin_karp算法
- ESP-Hosted 入门介绍 使用指南
热门文章
- 找不到office.zh\officelr.cab
- srm采购管理系统有那些功能
- 内存泄漏的8种情况(附代码示例)
- fzu-神龙的难题 舞蹈链之可重复覆盖
- 华为手机的计算机删除了怎么恢复,技巧 | 如何让手机秒变电脑?误删文件怎么恢复?...
- Codeforces Round#525(Div.2)Ehab and a component choosing problem CodeForces - 1088E
- 《延禧攻略》网播50亿+,脑洞广告背后有家智能营销扫地僧
- 解决 Deepin V20、Ubuntu 20.04、Linux Mint 20 声卡无法识别导致的没有声音
- 洛谷P6054:开门大吉
- Top2:CNN 卷积神经网络实现猫狗图片识别二分类