参考文章:

Pruning Filters for Efficient Convnets

Compressing deep neural nets

压缩神经网络 实验记录(剪枝 + rebirth + mobilenet)

为了在手机上加速运行深度学习模型,目前实现的方式基本分为两类:一是深度学习框架层面的加速,另一个方向是深度学习模型层面的加速。

深度学习模型的加速又可以分为采用新的卷积算子来加速模型,另一个方向是通过对已有模型进行剪枝操作得到一个参数更少的模型来加速模型。

通过观察深度学习模型,可以发现其中很多kernel的权重很小,均在-1~1之间震荡,对于这些绝对值很小的参数,可以视其对整体模型贡献很小,将其删除,然后将剩余的权重构成新的模型,以达到模型压缩,加速,并保证精准度不变的目的。

基本步骤:

1.实现原始网络,并将其训练到收敛,保存权重

2.观察对每一层的权重,判断其对模型的贡献大小,删除贡献较小的kernel,评判标准可以是std,sum(abs),mean等

3.当删除部分kernel后,会导致输出层的channel数变化,需要删除输出层对应kernel的对应channel

4.构建剪枝后的网络,加载剪枝后的权重,与原模型对比精准度。

5.使用较小的学习率,rebirth剪枝后的模型

6.重复第1步

上图展示了conv的kernel剪枝后导致的输出维度变化

对于conv层后面接续全连接层的情况:

conv层在接续全连接层前,会先reshape为一个维度。假设conv层输出为 (h,w,c),其会reshape为 h*w*c, 假设删除的kernel下标为[ 2,5,7],对应的conv输出通道也会减少 [2,5,7] 。reshape后 会减少 [2,5,7,...,h*w*2,h*w*5,h*w*7]。

全连接层接全连接层的逻辑基本和conv接conv层的逻辑一样。

一个使用mnist的简单示例

# 读取保存的权重和所有训练的var
model_path = './checkpoints/net_2018-12-19-10-05-17.ckpt-99900'
reader = tf.train.NewCheckpointReader(model_path)
all_variables = reader.get_variable_to_shape_map()
{'conv1/biases': [16],'conv1/weights': [3, 3, 1, 16],'conv2/biases': [32],'conv2/weights': [3, 3, 16, 32],'conv3/biases': [32],'conv3/weights': [3, 3, 32, 32],'fc1/biases': [128],'fc1/weights': [512, 128],'fc2/biases': [256],'fc2/weights': [128, 256],'global_step': [],'logits/biases': [10],'logits/weights': [256, 10]}
# 分析 conv1 的权重
conv1_weight = reader.get_tensor("conv1/weights")
# 计算每个kernel权重的和 (也可以使用其他指标,如std,mean等)
conv1_weight_sum = np.sum(conv1_weight, (0,1,2))
sort_conv1_weights = np.sort(conv1_weight_sum)
# 绘制conv1的
x = np.arange(0,len(sort_conv1_weights),step=1)
plt.plot(x,sort_conv1_weights)

# 保留权重和最大的8个kernel
pure_conv1_weight_index = np.where(conv1_weight_sum >= sort_conv1_weights[8])
pure_conv1_weight = conv1_weight[:,:,:,pure_conv1_weight_index[0]]
# conv1对应的bias 也做相同处理
conv1_bias = reader.get_tensor("conv1/biases")
pure_conv1_bias = conv1_bias[pure_conv1_weight_index[0]]
# 对后面接续的 conv 层的kernel做相同处理
conv2_weight = reader.get_tensor("conv2/weights")
conv2_bias = reader.get_tensor("conv2/biases")
conv2_weight = conv2_weight[:,:,pure_conv1_weight_index[0],:]

后面层重复以上操作

剪枝后结果对比

原始模型精度

剪枝后模型精度

模型权重大小从400多kb减小到了100多kb

jupyter文件及代码

深度学习 模型 剪枝相关推荐

  1. PyTorch 深度学习模型压缩开源库(含量化、剪枝、轻量化结构、BN融合)

    点击我爱计算机视觉标星,更快获取CVML新技术 本文为52CV群友666dzy666投稿,介绍了他最近开源的PyTorch模型压缩库,该库开源不到20天已经收获 219 颗星,是最近值得关注的模型压缩 ...

  2. 深度学习模型压缩(量化、剪枝、轻量化结构、batch-normalization融合)

    "目前在深度学习领域分类两个派别,一派为学院派,研究强大.复杂的模型网络和实验方法,为了追求更高的性能:另一派为工程派,旨在将算法更稳定.高效的落地在硬件平台上,效率是其追求的目标.复杂的模 ...

  3. 深度学习模型压缩算法综述(二):模型剪枝算法

    深度学习模型压缩算法综述(二):模型剪枝算法 本文禁止转载 联系作者: 模型剪枝算法 : L1(L2)NormFilterPruner: 主要思想: 修剪策略: 微调策略: 残差网络的处理: 缺点: ...

  4. 深度学习模型压缩方法(3)-----模型剪枝(Pruning)

    link 前言 上一章,将基于核的稀疏化方法的模型压缩方法进行了介绍,提出了几篇值得大家去学习的论文,本章,将继续对深度学习模型压缩方法进行介绍,主要介绍的方向为基于模型裁剪的方法,由于本人主要研究的 ...

  5. 深度学习模型压缩与加速综述!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:Pikachu5808,编辑:极市平台 来源丨https://zh ...

  6. 深度学习模型压缩与加速综述

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 导读 本文详细介绍了4种主流的压缩与加速技术:结构优化.剪枝.量化 ...

  7. 一文看懂深度学习模型压缩和加速

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:opencv学堂 1 前言 近年来深度学习模型在计算机视 ...

  8. 不用GPU,稀疏化也能加速你的YOLOv3深度学习模型

    水木番 发自 凹非寺 来自|量子位 你还在为神经网络模型里的冗余信息烦恼吗? 或者手上只有CPU,对一些只能用昂贵的GPU建立的深度学习模型"望眼欲穿"吗? 最近,创业公司Neur ...

  9. 深度学习模型的中毒攻击与防御综述

    来源:专知本文约2000字,建议阅读5分钟本文首次综述了深度学习中的中毒攻击方法,回顾深度学习中的中毒攻击,分析了此类攻击存在的可能性,并研究了现有的针对这些攻击的防御措施.最后,对未来中毒攻击的研究 ...

最新文章

  1. http反向代理调度算法追朔
  2. android地址格式转换,Android(安卓)时间戳和日期之间的转化
  3. java 视图对象转换,使用spring boot开发时java对象和Json对象转换的问题_JavaScript_网络编程...
  4. 人工智能到来的时代,你曾经瞧不起的职业,可能会非常吃香!
  5. Hemberg-lab单细胞转录组数据分析(四)
  6. python脚本多少钱一个_一个python脚本
  7. browser.html – HTML 实现 Firefox UI
  8. Facebook又开两处AI实验室,在西雅图和匹兹堡招兵买马
  9. 富文本编辑器复制word文档中的图片
  10. 月薪过万的php面试题目
  11. DELL win10插入耳机后声音仍然外放(亲测有效)
  12. 《HTML CSS JavaScript 网页制作》第六章-创建框架结构网页
  13. rand和randc有什么区别
  14. 2.4.U-Boot配置和编译过程详解-U-Boot和系统移植第4部分视频课程笔记
  15. undefined reference to `__stack_chk_fail'
  16. Linux内核配置(9)
  17. Linux开机无网络连接解决方案
  18. 软件測试系列之入门篇(一)
  19. SQL语句在dos操作MySQL数据库
  20. 【转】小生我怕怕工具包[2010.06.17](转自52破解论坛)

热门文章

  1. 高级艺术二维码制作保姆级教程!
  2. strtol全面解析
  3. (C++)输入一行字符,分别统计出其中英文字母、空格、数字和其他字符的个数
  4. PHPExcel的常用功能
  5. 使用qrcode.vue生成二维码
  6. sqlite在H5中封装及使用(笔记)
  7. 如何选择适合自己企业的B2B2C商城系统?
  8. [Programe.Linux.Shell]
  9. 【金猿技术展】一种分布式 HTAP 数据库上基于索引的数据任意分布方法——为 HTAP 数据库实现 Collocation 优化...
  10. ionic实现微信,QQ,微博分享