背景知识

模型剪枝(Model Pruning)是一种模型压缩方法,对深度神经网络的稠密连接引入稀疏性,通过将“不重要”的权值直接置零来减少非零权值数量,其历史可追溯到上世纪 90 年代初。

在 Optimal Brain Damage【2】中,使用对角 Hessian 逼近计算每个权值的重要性,重要性低的权值被置零,然后重新训练网络。

在 Optimal Brain Surgeon【3】中,使用逆 Hessian 矩阵计算每个权值的重要性,重要性低的权值被置零,剩下的权值使用二阶泰勒逼近的 loss 增量更新。

最近比较流行基于幅度的权值剪枝方法【4】,该方法将权值取绝对值,与设定的 threshhold 值进行比较,低于门限的权值被置零。基于幅度的权值剪枝算法计算高效,可以应用到大部分模型和数据集。TensorFlow 也使用了基于幅度的权值剪枝算法。


TF 代码实现

TensorFlow 代码目录 tensorflow/contrib/model_pruning/ 提供了对 TensorFlow  框架的扩展,可在模型训练时实现剪枝。

对每个被选中做剪枝的层增加一个二进制掩模(mask)变量,形状和该层的权值张量形状完全相同。该掩模决定了哪些权值参与前向计算。掩模更新算法则需要为 TensorFlow 训练计算图注入特殊运算符,对当前层权值按绝对值大小排序,对幅度小于一定门限的权值将其对应掩模值设为 0。反向传播梯度也经过掩模,被屏蔽的权值(mask 为 0)在反向传播步骤中无法获得更新量。

研究发现稀疏度不宜从一开始就设置最大,这样容易将重要的权值剪掉造成无法挽回的准确率损失,更好的方法是渐进稀疏度,从初始稀疏度 (一般为 0 )开始,逐步增大到最终稀疏度 ,这期间二进制掩模变量 mask 经历了 n 次更新,每次更新时的门限由当时的稀疏度决定,稀疏度由如下公式计算得到:

随着训练过程,逐步提高稀疏度,直到达到期望的稀疏度 为止。

下图很直观地反映了渐进提高稀疏度的过程。

初始时刻,稀疏度提升较快,而越到后面,稀疏度提升速度会逐渐放缓,这个比较符合直觉,因为初始时有大量冗余的权值,而越到后面保留的权值数量越少,不能再“大刀阔斧”地修剪,而需要更谨慎些,避免“误伤无辜”。

下面 TensorFlow 代码创建了带有 mask 变量的 graph:

from tensorflow.contrib.model_pruning.python import pruningwith tf.variable_scope('conv1') as scope:# 创建权值 variablekernel = _variable_with_weight_decay('weights', shape=[5, 5, 3, 64], stddev=5e-2, wd=0.0)# 创建 conv2d op,权值 variable 增加 maskconv = tf.nn.conv2d(images, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')

下面代码给出了带剪枝的模型训练代码结构:

from tensorflow.contrib.model_pruning.python import pruning# 命令行参数解析pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)# 创建剪枝对象pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)# 使用剪枝对象向训练图增加更新 mask 的运算符# 当且仅当训练步骤位于 [begin_pruning_step, end_pruning_step] 之间时,# conditional_mask_update_op 才会更新 maskmask_update_op = pruning_obj.conditional_mask_update_op()# 使用剪枝对象写入 summaries,用于跟踪每层权值 sparsity 变化pruning_obj.add_pruning_summaries()with tf.train.MonitoredTrainingSession() as mon_sess:while not mon_sess.should_stop():mon_sess.run(train_op)# 更新 maskmon_sess.run(mask_update_op)  

其中 FLAGS.pruning_hparams 为一组逗号分隔的键值对,取值如下表所示:

超参名 类型 默认值 说明
begin_pruning_step integer 0 开始剪枝的全局 step
end_pruning_step integer -1 结束剪枝的全局 step,默认为 -1 标识剪枝一直持续到训练结束
do_not_prune list of strings [""] 一组层名,标记哪些层不做剪枝
threshold_decay float 0.9 衰减因子,用于门限衰减
pruning_frequency integer 10 mask 更新的频率,计数单位为全局 step 数
initial_sparsity float 0.0 初始稀疏度值
target_sparsity float 0.5 目标稀疏度值
sparsity_function_begin_step integer 0 渐进稀疏度函数开始时刻
sparsity_function_end_step integer 100 渐进稀疏度函数结束时刻
sparsity_function_exponent float 3.0

指数项,=1 则为线性增长,>1 则初始快后续慢

转者注: 详细代码可参照https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10    (例子)

           和 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/model_pruning/python/pruning.py  (pruning源码)                                                                                                                                                                                                                

实践

TensorFlow model pruning 自带 CIFAR10 例程,实现了一个稀疏 CNN 模型,其中卷积层和 local 层的权值均做了稀疏化。

(1) 准备 TensorFlow r1.7 环境

硬件环境:GTX 1080

软件环境:CUDA 9.0 + cuDNN 7, Bazel 0.11.1

git clone https://github.com/tensorflow/tensorflow.git
cd tensorflow/
git checkout r1.7

(2) 编译、运行 tensorflow/contrib/model_pruning/

cd tensorflow/contrib/model_pruning/
bazel build -c opt examples/cifar10:cifar10_{train,val}
cd ../../
bazel-bin/contrib/model_pruning/examples/cifar10/cifar10_train -prune_hparams=name=cifar10_pruning,begin_pruning_step=10000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000

(3) 查看训练过程

运行 TensorBoard:

tensorboard --logdir /tmp/cifar10_train/

打开浏览器,输入 localhost:6006

可以看到随着训练步骤增加,conv1 和 conv2 的 sparsity 在不断增长。总的 loss 变化如下图所示:

(4) 查看计算图

切换到 GRAPHS 页面,双击 conv2 节点,可以看到在原有计算图基础上新增了 mask 和 threshold 节点用来做 model pruning。

(5) 模型评估

利用以下命令对训练模型进行评估:

bazel-bin/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_eval

训练 15 万次迭代的结果(仅供参考)

Sparsity Accuracy after 150K steps
0% 86%
50% 86%
75%
90%
95% 77%

论文【1】中一些结论

随着稀疏度提高,模型质量逐渐下降,其表现为分类准确率降低。下表为 InceptionV3 模型不同稀疏度的情况【1】:

从表中看到,50% 系数模型和基准模型(0% 稀疏度)表现一致,而 87.5% 稀疏度模型的 top-5 准确率相比基准模型只有 2% 降低,但模型非零权值数量减少为原来 1/8。

我们前面文章《用于移动和嵌入式视觉应用的 MobileNets》介绍过轻量 CNN 模型 MobileNet,是一类特别为移动视觉应用设计的高效卷积神经网络。MobileNet 基于 depthwise separable 卷积,将通道内滤波通道间线性组合分解为两个独立步骤,显著减少了参数数量。MobileNet 网络架构包括一个标准卷积层用于处理输入图片,一大堆 depthwise separable conv,最后为 average pooling 和全连接层。 width multiplier 是 MobileNet 的一个调节参数,能实现模型准确率和模型权值数量、计算量的 trade-off。

我们既可以通过设置更小的 width multiplier 实现尺寸更小的模型(准确率会降低),也可以通过对原始 MobileNet 做稀疏化得到尺寸更小的模型(准确率同样会降低),那么这两种方法哪种更有效呢?论文【1】 给出了结果:

基本结论为:大而稀疏的模型(large-sparse)表现优于小而稠密的模型(small-dense)。

例如,75% 稀疏度模型( 1.09 M 权值,top-1 accuracy 为 67.7%)优于稠密 0.5 MobileNet( 1.32 M 权值,top-1 accuracy 为 63.7%)。

类似地, 90% 稀疏度模型(0.46 M 权值,top-1 accuracy 61.8%)优于稠密 0.25 MobileNet(0.46 M 权值,top-1 accuracy 50.6%)。

通过这些结论,为轻量级模型设计提供了新的思路,同时也为专用硬件加速器设计提供了参考。

参考文献

【1】 To prune, or not to prune: exploring the efficacy of pruning for model compression, arXiv:1710.01878

【2】Yann LeCun et.al. Optimal brain damage. NIPS, 1990

【3】B.Hassibi et.al. Optimal brain surgeeon and general network pruning. ICNN, 1993

【4】Song Han et.al. Learning both weights and connections for efficient neural network. NIPS, 2015

tensorflow 网络修剪 剪枝操作相关推荐

  1. 【模型压缩】深度卷积网络的剪枝和加速(含完整代码)

    作者 | 贝壳er 研究 | 数据挖掘与异常检测 出品 | AI蜗牛车 " 记录一下去年12月份实验室的一个工作:模型的剪枝压缩,虽然模型是基于yolov3的魔改,但是剪枝的对象还是CBL层 ...

  2. 计算机硬件Word,[计算机硬件及网络]word的操作.doc

    [计算机硬件及网络]word的操作.doc Office Word 2010高级应用技术长文档排版一.考查知识点内置样式修改与使用.新建样式.题注.交叉引用.脚注与尾注.目录.图表目录.分节符的使用. ...

  3. 工程之星android版使用,安卓版工程之星软件网络1+1模式及网络cors连接操作详解...

    原标题:安卓版工程之星软件网络1+1模式及网络cors连接操作详解 现在,越来越多用户开始使用安卓版工程之星进行作业,科力达技术工程师总结了安卓版工程之星网络1+1模式及网络CORS连接方式操作步骤, ...

  4. 网路游侠:用网络运维操作管理平台进行网络安全管理

    本来前几天在 [ 使用WEB应用防火墙保护网站安全 ] 文章末尾曾经提到最近想写数据库安全的文章的,但是Cisco 3560被征用了,所以,还是写另一个热点产品:网络运维管理操作平台. 网络运维管理操 ...

  5. 如何在CAD中进行修剪命令操作?

    如何在CAD中进行修剪命令操作? 我们在进行CAD制图时,面对多余的部分,我们通常会需要用到修剪命令,那么修剪命令该如何使用呢?下面来教你具体的操作方法. 1.首先我们需要运行迅捷CAD编辑器绘制任意 ...

  6. MobileNetV3基于NNI剪枝操作

    NNI剪枝入门可参考:nni模型剪枝_benben044的博客-CSDN博客_nni 模型剪枝 1.背景 本文的剪枝操作针对CenterNet算法的BackBone,即MobileNetV3算法. 该 ...

  7. 脑网络分析软件Gretna操作--Network Analysis

    脑网络分析软件Gretna操作--Network Analysis 2018-10-25 15:03:31 云端浅蓝 阅读数 2923 版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版 ...

  8. 酷狗软件测试自学,酷狗音乐检测网络的详细操作

    想必当前不少伙伴们还不熟悉酷狗音乐检测网络的详细操作.下面就来看看酷狗音乐检测网络的操作方法吧.希望可以帮助到大家. 酷狗音乐检测网络的详细操作 1.进入到酷狗音乐的主界面,如果出现网络异常的情况,可 ...

  9. 海蜘蛛系统日志怎么保存到服务器,海蜘蛛软路由网络设置的操作步骤

    海蜘蛛软路由系统默认LAN(局域网接口)IP地址是:192.168.0.253,默认子网掩码是255.255.255.0,默认控制端口为880.这些值可以根据您的需要而改变.下面网吧路由栏目小编具体说 ...

最新文章

  1. SAP MM 采购申请中的物料组字段改成Optional
  2. 算法笔记_188:历届试题 危险系数(Java)
  3. Nginx的负载均衡 - 保持会话 (ip_hash)
  4. 前端如何正确使用中间件?
  5. 恭喜你!在25岁前看到了这篇最最靠谱的深度学习入门指南
  6. How GPUs Work
  7. ElasticSearch核心基础之索引管理
  8. 微课|玩转Python轻松过二级(2.4节):常用内置函数用法精要3
  9. 【CSS】text-align:justify 的使用
  10. 问题三十四:怎么用ray tracing画任意长方体(generalized box)
  11. 线程池如何确定线程数量
  12. 泛函密度 matlab,【讨论】密度泛函理论中“密度”究竟指什么 - 计算模拟 - 小木虫 - 学术 科研 互动社区...
  13. CUBA Platform 7.0.3 发布,企业级应用开发平台
  14. Java写txt—读txt—清空txt文件
  15. CHD+CM-2 初始化集群和安装软件
  16. 云洲无人船:驶向水上智能时代
  17. win7下登录中国银行网银,叫你四步搞定!
  18. Qt -设计嵌入式设备用户界面的利器
  19. Linux从头学09:x86 处理器如何进行-层层的内存保护?
  20. java中间件技术有哪些?

热门文章

  1. 70分钟,干货十足!百度CTO王海峰在新华社带来一场人工智能课
  2. 人工智能(AI)在金融行业的应用
  3. php订单系统 帝国cms,帝国cms有订单管理系统吗
  4. ET框架-15 Actor消息的编写 以及 ET框架实战之前的准备工作
  5. 设计了一个支撑 数亿 用户的系统
  6. android 实现图片旋转,移动,缩放,并且记录变化值,用另外一张图片显示出来
  7. 模拟拖拽-小火柴博客
  8. shiro分布式session共享
  9. Manjaro 安装wps
  10. 第五章、首次登陆与在线求助 man page