《SELF-ADAPTIVE NETWORK PRUNING》论文笔记
参考代码:无
1. 概述
导读:这篇文章提出了一种channel剪枝的算法,在网络中通过嵌入SPM(Saliency-and-Pruning Module )模块得到卷积过程中重要的channel,之后通过一个阈值得到一个二值标志序列,之后通过将其中为0的位置“置0”从而达到网络剪枝的目的。CNN中重要的channel是通过计算特征图自身的特性(文章中为均值)之后连接一个fc得到的,之后给定一个期待的计算量开销目标,之后在训练的过程中将网络现有的开销与期望的开销计算损失,从而约束CNN网络中的channel数量。
文章的作者在一些基于分类的任务中发现了如下的亮点规律:
- 1)对于CNN网络中的每一层卷积其适用的剪裁比例是不一致的,因而使用固定比例的方式进行剪枝是次优的,应该以数据驱动;
- 2)在CNN分类网络中其实卷积中只有很少的一部分channel对某一类别有较强的反应(特征图的统计意义上),那么这就说明其中是存在较大的冗余的,是存在剪枝的空间的;
上述的两点观察可以从下图看出:
2. 方法设计
2.1 网络结构
文章提出的剪枝整体pipline见下图所示:
在上图中文章通过在每个卷积层上添加SPM模块提取出显著性(重要)的channel:
sl(xl−1)=SaliencyPrediction(xl−1,W)s^l(x^{l-1})=SaliencyPrediction(x^{l-1},W)sl(xl−1)=SaliencyPrediction(xl−1,W)
其中,xl−1x^{l-1}xl−1是上一层卷积输出的特征图。之后将这些显著性channel(经过阈值)得到需要剪除的部分:
bl(xl−1)=Binarize(xl−1)b^l(x^{l-1})=Binarize(x^{l-1})bl(xl−1)=Binarize(xl−1)
在得到上述的二值序列掩膜之后,便是与之前的重要性置信度组合起来,从而这一层的卷积输出描述为:
xl=sl(xl−1)⋅bl(xl−1)⋅BatchNorm(fl∗xl−1)x^l=s^l(x^{l-1})\cdot b^l(x^{l-1})\cdot BatchNorm(f^l*x^{l-1})xl=sl(xl−1)⋅bl(xl−1)⋅BatchNorm(fl∗xl−1)
其中,flf^lfl是当前层的卷积参数。之后通过二值化的结果计算一个开销损失,从而与原本的损失函数进行联合训练。
2.2 channel重要性度量函数
在文章中对于channel重要性的度量是通过计算特征图在channel上的均值,之后经过一个FC层得到的,首先计算其均值:
d=1Hl−1∗Wl−1∑i=1Hl−1∑j=1Wl−1xl−1(i,j)d=\frac{1}{H_{l-1}*W_{l-1}}\sum_{i=1}^{H_{l-1}}\sum_{j=1}^{W_{l-1}}x^{l-1}(i,j)d=Hl−1∗Wl−11i=1∑Hl−1j=1∑Wl−1xl−1(i,j)
之后再将其与一个FC连接得到预测结果:
sl(xl−1)=SaliencyPrediction(xl−1,W)=W2δ(W1d)s^l(x^{l-1})=SaliencyPrediction(x^{l-1},W)=W_2\delta(W_1d)sl(xl−1)=SaliencyPrediction(xl−1,W)=W2δ(W1d)
其中,δ\deltaδ是ReLU。
2.3 重要性二值函数
通过上面的内容得到重要性置信度之后,文章引入了一个二值函数用以区分那些channel是需要保留的,反之就需要被剪枝。在训练的过程中文章引入了高斯噪声ξ∼N(0,1)Cl\xi\sim N(0,1)^{C_l}ξ∼N(0,1)Cl,从而得到:
s1=max(0,min(1,a⋅σ(sl(xl−1)+ξ)−b))s_1=max(0,min(1,a\cdot\sigma(s^l(x^{l-1})+\xi)-b))s1=max(0,min(1,a⋅σ(sl(xl−1)+ξ)−b))
其中,σ\sigmaσ是sigmoid函数,a,ba,ba,b是超参数。之后通过一个设定的阈值得到二值化的掩膜序列:
s2=1(s1>0.5)s_2=\mathcal{1}(s_1\gt0.5)s2=1(s1>0.5)
2.4 网络损失函数
除了分类网络自身的分类损失之外,文章还对网络的开销进行损失监督(这部分监督可以看作是在网络channel上去做L1正则化,使其稀疏化),其损失函数描述为:
Lmulti=Lcls+λ1Nc∑l=1L∣∣sl∣∣1L_{multi}=L_{cls}+\lambda\frac{1}{N_c}\sum_{l=1}^L||s^l||_1Lmulti=Lcls+λNc1l=1∑L∣∣sl∣∣1
其中,λ\lambdaλ是通过ptp_tpt(网络估计出来剪枝之后的开销)p0p_0p0(网络的总开销)ppp(目标开销)参数组合得到的,其是一个变化的比例,其表示为:
λ=λ0⋅(pt−p)p0\lambda=\lambda_0\cdot\frac{(p_t-p)}{p_0}λ=λ0⋅p0(pt−p)
3. 实验结果
CIFAR-10:
CIFAR-100:
《SELF-ADAPTIVE NETWORK PRUNING》论文笔记相关推荐
- 论文笔记之Understanding and Diagnosing Visual Tracking Systems
Understanding and Diagnosing Visual Tracking Systems 论文链接:http://dwz.cn/6qPeIb 本文的主要思想是为了剖析出一个跟踪算法中到 ...
- 《Understanding and Diagnosing Visual Tracking Systems》论文笔记
本人为目标追踪初入小白,在博客下第一次记录一下自己的论文笔记,如有差错,恳请批评指正!! 论文相关信息:<Understanding and Diagnosing Visual Tracking ...
- 论文笔记Understanding and Diagnosing Visual Tracking Systems
最近在看目标跟踪方面的论文,看到王乃岩博士发的一篇分析跟踪系统的文章,将目标跟踪系统拆分为多个独立的部分进行分析,比较各个部分的效果.本文主要对该论文的重点的一个大致翻译,刚入门,水平有限,如有理解错 ...
- 目标跟踪笔记Understanding and Diagnosing Visual Tracking Systems
Understanding and Diagnosing Visual Tracking Systems 原文链接:https://blog.csdn.net/u010515206/article/d ...
- 追踪系统分模块解析(Understanding and Diagnosing Visual Tracking Systems)
追踪系统分模块解析(Understanding and Diagnosing Visual Tracking Systems) PROJECT http://winsty.net/tracker_di ...
- ICCV 2015 《Understanding and Diagnosing Visual Tracking Systems》论文笔记
目录 写在前面 文章大意 一些benchmark 实验 实验设置 基本模型 数据集 实验1 Featrue Extractor 实验2 Observation Model 实验3 Motion Mod ...
- Understanding and Diagnosing Visual Tracking Systems
文章把一个跟踪器分为几个模块,分别为motion model, feature extractor, observation model, model updater, and ensemble po ...
- CVPR 2017 SANet:《SANet: Structure-Aware Network for Visual Tracking》论文笔记
理解出错之处望不吝指正. 本文模型叫做SANet.作者在论文中提到,CNN模型主要适用于类间判别,对于相似物体的判别能力不强.作者提出使用RNN对目标物体的self-structure进行建模,用于提 ...
- ICCV 2017 UCT:《UCT: Learning Unified Convolutional Networks forReal-time Visual Tracking》论文笔记
理解出错之处望不吝指正. 本文模型叫做UCT.就像论文题目一样,作者提出了一个基于卷积神经网络的end2end的tracking模型.模型的整体结构如下图所示(图中实线代表online trackin ...
- CVPR 2018 STRCF:《Learning Spatial-Temporal Regularized Correlation Filters for Visual Tracking》论文笔记
理解出错之处望不吝指正. 本文提出的模型叫做STRCF. 在DCF中存在边界效应,SRDCF在DCF的基础上中通过加入spatial惩罚项解决了边界效应,但是SRDCF在tracking的过程中要使用 ...
最新文章
- Ajax请求Session超时解决
- 如何重新打开Windows防火墙提示?
- openoffice将html转成pdf,通过openOffice将office文件转成pdf
- 基于强化学习的倒立摆控制策略Matlab实现(附代码)
- 001-JavaScript简介
- java实现数据库回滚,java 数据库操作,事宜回滚
- JAVA反射机制Reflection详解
- 互联网大数据时代下亚马逊是如何解决数据存储的
- C语言分支/顺序作业总结
- 递归,举几个简单的例子
- 【Vue 4 笔记 】(一)
- ODBC连接数据库使用动态密码
- 【待办】三国杀单挑测试脚本
- 【ICPR 2021】遥感图中的密集小目标检测:Tiny Object Detection in Aerial Images
- 关于 ajax Content-Type 的问题 贼拉有用的!!!
- 蓝桥杯-第九届决赛——采油
- doubb超时_dubbo源码分析(二):超时原理以及应用场景
- Linux图形子系统之GEM内存管理
- [翻译] 求生之路AI系统讲稿
- 用vue写一套销客多电商分销后台系统