【CVPR2019】SpotTune: Transfer Learning through Adaptive Fine-tuning

论文链接:SpotTune: Transfer Learning through Adaptive Fine-tuning

一. Introduction

使用深度学习模型时,微调(fine-tune)是应用最普遍的迁移学习方法。它具体指先在源任务上获得预训练模型,然后在目标任务上进一步训练它,从而,可以减少对目标标签数据需求的同时,提升模型的性能。

常用的微调方式有以下两种:第一个是使用目标数据集优化预训练模型中的所有参数,它的一大缺陷是,当目标数据集小且预训练网络的参数过大时,可能会产生过拟合;第二个是依据目标任务中训练集有限以及初始层学到的低级特征可以在多个任务间共享这一经验,选择微调深度网络的最后几层的参数,冻结前面其他层的参数,但是由于需要手动选择初始冻结层数,这不利于提升优化效率。并且,像ResNet这种由多个浅层网络集成的模型,初始层学到的低级特征可以共享这一前提不再适用,所以仅是微调模型的最后几层并不一定是最优的选择。

目前的方法也均是采用全局微调的策略,即,对目标任务中的所有样本采取(在某些网络层)freeze参数或者是fine-tune参数的决定。这就相当于假设该决定对整个目标数据分布是最优的,但是,现实往往并非如此。

例如,目标任务中的某些类与源任务之间的相似性较高,这些类的样本可能倾向于finetune较少的预训练参数,与之相反的样本则希望能finetune更多的预训练参数,以达最好的准确率。

所以,理想的情况是,为目标任务中的每一样本,在每一层,都制定一个该finetune还是该freeze参数的决策。

就如图1所示,上面的是在源任务上得到预训练模型,下面,在目标任务中,有两个猫的训练样本,第一个猫样本在前两块选择冻结参数,也就是保留预训练模型原有的参数,后两块做了微调,而第二个猫样本在第一三块选择了微调,在二四块选择冻结参数。而这样的选择对他们来说是达到了最优的微调策略。

本文提出了一种方法SpotTune。它可以学习依赖输入(input-dependent)的微调策略,大体指从一个轻量级神经网络的输出所构成的离散分布中采样,来为每一个样本决定在哪一层该fine-tune,哪一层该freeze。

由于策略函数是离散的,不可微,所以,采用了Gumbel Softmax 采样方法来训练策略网络。

在测试期间,策略网络就可以决定来自上一层的特征(feature)该进入原预训练的网络层,还是需做微调的网络层。

本文贡献:

•提出了一个依赖输入的微调方法,能为每个目标样本自动决定在哪些层fine-tune。

•还提出了上述方法的一个global变体,即,约束所有样本在相同的k层做fine-tune,其中,这k个层可以分布于网络中的任意部分。该变体可以使最终模型有较少的参数。

•通过大量实验证明,本文提出的方法在14个数据集中有12个超过了标准fine-tune方法,并且,在Visual Decathlon Challenge(10个用于多域学习算法性能测试的基准数据集),相比其他先进的方法,取得了最高的score。

二. Proposed Approach

本文提出的方法能应用于不同的神经网络架构,但是由于ResNet相当于是多个较浅的层组合而成,使它对残差块之间的交换具有弹性,也就是,交换残差块对网络性能影响不大。该性质更合乎本文提出的方法。所以,接下来的实验,均是基于ResNet网络架构。下边这个图是resnet的一个基本残差块形式:

假定ResNet预训练模型的第l块的输出表示为:

为了在训练期间,决定某一residual块,是否被fine-tune,先freeze了该原始块,再创建一个与它并排的初始参数相同的trainable的块。此时,第l层的输出可以表示为:

其中,是一个二进制随机变量,可以为输入图片指示该residual块是被微调还是frozen。 它是从一个轻量级的策略网络输出所构成的离散分布中抽样所得,取值若为0,表示重用第个frozen块,若为1,表示通过优化来微调第个块。

图2是文中提出的SpotTune方法架构的图例说明,上面的黄色块表示策略网络,下面的两排表示预训练模型,浅棕色块的这一排表示不做微调,对应于式2中的F,深棕色表示做微调,对应于式2中的F尖,通过策略网络得到微调策略I(x),I(x)的取值就可以决定每个残差块前面的开关的开合,从而决定上一层的输出接下来该选择走微调的残差块还是冻结的残差块。

对于策略网络,它是一个轻量级的resnet网络,由于预训练ResNet模型有L个残差块,所以,它的输出logits就是一个L*2的二维矩阵,然后通过Gumbel-max采样得到微调策略I(x),它是一个L维的向量,取值不是0就是1。Gumbel-max采样过程可以分为4步,分别是:

Gumbel-max的采样过程可简述如下(参考自【一文学会】Gumbel-Softmax的采样技巧)

  • 对于网络输出的一个zz维向量,生成zz个服从均匀分布U(0,1)的独立样本,z表示类别数,由于策略网络只有两类(fine-tune or freeze参数),所以z=2;
  • 通过计算得到
  • 上述两步结果对应相加,得到新的向量
  • 最后通过argmax取上述向量最大值的索引。

这4步可以简化为式3的表示方式:

由于Gumbel-Max的采样结果(不是0就是1),是离散的,不可微,也就不能用于反向传播优化网络参数,所以,在反向传播时,作者采用了Gumbel-softmax采样方式,也就是将式3中的argmax用softmax替换,得到式(4),其中,?τ是控制输出向量Y离散程度的参数,当它逼近于0时,生成的分布就逼近于离散分布,当它越大时,可以使生成的分布越平滑。所以,当?τ>0时,Gumbel softmax分布是平滑的,这样就可以解决了之前的无法反向传播优化网络的问题。

除了特定图像的微调策略,作者还对其做了扩展,提出了它的一个全局变体,也就是,限制所有的图像在ResNet的相同的k个块作微调,这k个块可以分布在resnet的任何部位。为了实现这一变体,作者引入了两个损失函数,分别是式5和式6,

式5中的是指对第个residual块,目标数据集中选择了fine-tune的图像比例,取值为0到1。它可以使所有训练样本趋向于选择k个微调块。式6可以迫使精确到0或1,这样,就可以使所有图像在第l块,要么全部微调,要么都不微调,从而保证了所有图像在相同的k个块做了微调,并且这k个块可以分布于预训练模型的任何部位。最后,将这两个损失函数与分类损失函数结合,就可以得到最终的损失函数,式7。该变体相比于手动选择k个块,能实现最好的准确率。并且,由于它在测试阶段不需要策略网络,且k被设为一个较小的数时,它可以减少内存占用和计算成本。

三. Experiments

数据集:

作者通过实验比较了SpotTune方法与其他微调方法和正则化方法的效果,数据集包含两部分,第一部分就是表1中列出的5个数据集,其中,前3个是细粒度分类基准,后面两个数据集较大,并且与ImageNet不匹配。第二部分用于评估来自多个域的图像的视觉识别算法的数据集,包含10个。为了减少计算负担,这10个数据集中的图片的长宽均调成72pixels。

度量方式:

Baselines:将SpotTune与下面的几种微调和正则化技术做了比较实验

Pre-trained model(预训练模型):

Policy network architecture(策略网络架构):

SpotTune vs. Fine-tuning Baselines:

表2中列出了spottune方法和其他几种微调方式在第一部分数据集上的测试结果,很明显,文中提出的SpotTune方法在几个数据集上基本上都超过了其他方法的性能,只有,在WiKiArt数据集上,比微调resnet-101略低,作者推测这是因为这个数据集的训练样本比较多,所以,作者只选取其中的25%训练样本和10%训练样本,再次比较了两种方式,结果就是上面的绿色区域,可以看出,减少了训练样本,spottune的性能胜过了微调resnet101,并且随着样本数减少,差距越来越大。其次,可以看出,只是微调后面1个或2个或3个残差块的效果均不如标准微调方式。

结果中的第一行是将预训练网络当作特征提取器,当应用于目标数据集时,它能减少参数量,但是由于域转换,致使网络性能下降。

这个正则化方法的结果非常接近于spottune的结果,但是,作者在文中提到,可以将它用于补充spottune方法,两者结合,应该能得到一个更好的结果。

Visualization of Policies:to

为了能更好地理解,策略网络学到的微调策略,作者将第一部分数据集上对应的每一个残差块的策略做了可视化,如图3,从下往上,每一横排代表第多少个残差块,每一列代表一种数据集,每一方块的颜色深浅代表对应数据集中,在该残差块选择了做微调的图像所占比例,占比越大,颜色越深。从图3中可以看出,不同的数据集有不同的微调策略,而SpotTune能为每个数据集,甚至每个样本自动地确认恰当的微调策略。

Visualization of Block Usage:

此外,作者还对测试时, 每个数据集使用的微调块的数量的分布做了研究,上面图4就是其结果,纵轴表示测试样本数量,横轴表示做了微调的残差块数量,通过不同的颜色表示几种数据集,比如,图中红色椭圆圈出的部分表示,Flowers测试集中有大概1500个样本在6个残差块中做了微调。可以从中看出,对于每一种数据集,不同的图像倾向于使用不同数量的微调残差块。再次佐证了特定图像的微调策略能比所有图像的全局微调策略的准确率更高。

上面的图5,展示了几张CUBS和flowers数据集中,使用较少微调块的图像样本和使用较多微调块的图像样本,第一排的图都是使用微调块较少的图,可以看出,它们的背景比较干净,下面这排是使用微调块数目多的图,它们的背景相对比较复杂些。

Visual Decathlon Challenge:

表3列出了spotTune和其他方法在第二部分数据集上的实验结果对比。可以看出,spottune方法基本超过了其他所有方式。与黄色标出的标准微调方式相比,spottune的参数量与它相近,但是,最终的得分3612远超过标准微调方式的得分3096,再加上第一部分数据集的实验结果,spottune在14个数据中的12个上,超过了标准微调方式,只有红色线划出的两个数据集不如标准微调方式。倒数第二行的全局变体方法,在这里设定k=3,它相对于spottune方法,参数量有大幅度减少,并且分数为3401,仅次于spottune。

四. 总结

提出了一种自适应微调算法,SpotTune。它是针对于目标数据集中的每一个样本的微调策略。并且在大量的数据集上验证,SpotTune的性能基本上超过了常用的几种微调方式。

【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning相关推荐

  1. 论文笔记----Selective Transfer Learning for EEG-Based Drowsiness Detection

    对session进行评估,判断其是否适合使用迁移学习来提升性能.阅读重点,如何cross-subject. 全文核心:文中提出了一种新的可被迁移性的度量指标LSG,可以衡量一组数据是否适合接受来自其他 ...

  2. 论文笔记:CVPR 2022 Cross-Domain Adaptive Teacher for Object Detection

    摘要 我们解决了对象检测中的域适应任务,其中有注释的源域和没有注释的感兴趣的目标域之间存在域间隙(注:在一个数据集上训练模型,再另外一个数据集上进行预测性能下降很大,在一个数据集上训练好的模型无法应用 ...

  3. 图像隐写术分析论文笔记:Deep learning for steganalysis via convolutional neural networks

    好久没有写论文笔记了,这里开始一个新任务,即图像的steganalysis任务的深度网络模型.现在是论文阅读阶段,会陆续分享一些相关论文,以及基础知识,以及传统方法的思路,以资借鉴. 这一篇是Medi ...

  4. ECCV2016【论文笔记】Unsupervised Learning of Visual Representations by Solving Jigsaw Puzzles

    1.INTRO 本文作者旨在通过解决拼图问题来进行self-supervised learning,这样可以训练一个网络去识别目标的组成部分. 2.Solving Jigsaw Puzzles 当前一 ...

  5. 论文笔记 Traffic Data Reconstruction via Adaptive Spatial-Temporal Correlations

    IEEE TRANSACTIONS ON INTELLIGENT TRANSPORTATION SYSTEMS 2019 0 摘要 数据缺失仍然是交通信息系统中的一个难点和重要问题,严重制约了智能交通 ...

  6. 论文笔记:Hashtag2Vec: Learning Hashtag Representation with Relational Hierarchical Embedding Model

    感想 这是一片IJCAI 2018的论文,一开始看到这个东西的时候,我感觉还是比较新的,把社交网络的hashtag和tweet的网络结构融入到embedding中,做了一个network embedd ...

  7. 【论文笔记 · RL】Learning Phase Competition for Traffic Signal Control

    Learning Phase Competition for Traffic Signal Control 摘要 FRAP模型基于交通信号控制中相位竞争的思想:当两个交通信号出现竞争时,应该给交通流动 ...

  8. [深度学习论文笔记]Modality-aware Mutual Learning for Multi-modal Medical Image Segmentation

    Modality-aware Mutual Learning for Multi-modal Medical Image Segmentation 多模态医学图像分割中的模态感知互学习 Publish ...

  9. 【论文笔记之 APA】an adaptive filtering algorithm using an orthogonal projection to an affine subspace ...

    本文对 Kazuhiko Ozeki 等人于 1984 年在 Electronics and Communications in Japan (Part I: Communications) 上发表的 ...

最新文章

  1. python代码获取今天、昨天、明天的日期
  2. java 代码 设置环境变量_Java 配置环境变量教程
  3. 为什么分布式一定要有redis,redis的一些优缺点
  4. Djanog结合jquery实现ajax
  5. 【大数据部落】 17年房贷市场数据调研报告
  6. 直线度误差 matlab,基于MATLAB的直线度误差数据处理
  7. 什么是Http无状态协议?
  8. Maven报错Please ensure you are using JDK 1.4 or above and not a JRE解决方法!
  9. [HZAU]华中农业大学第四届程序设计大赛网络同步赛
  10. RTI_DDS线程模型
  11. 华为状态栏图标替换_【新手教程】状态栏图标替换教程
  12. Android混淆注意事项
  13. 数学模型:传染病模型
  14. [edu #63][div2 #554][div3 #555]
  15. 芒种时节,某地为何无人收割小麦?
  16. win vista/win 7/win 2008 超级激活工具
  17. PS轻松打造低多边形风格图像
  18. 关于pr文件导入的问题
  19. Ubuntu 安装 TP_LINK驱动 TL-WDN5200H 2.0无线网卡
  20. 当产品经理和程序员成为情侣后...

热门文章

  1. Oracle中的 DML, DDL,DCL
  2. 可爱的头像【 InsCode Stable Diffusion 美图活动一期】
  3. kitten学多久才能学python_KittenCode (Python学习软件)官方版v1.0.8
  4. 电视台搞的知识竞赛原来都用了这些软件硬件设备
  5. android梦网物联卡信息,摘自TechWeb:梦网物联云:实现智能穿戴只需一张物联网卡...
  6. Apple 计划在 2022 年推出五款新 Mac,包括入门级 MacBook Pro Refresh
  7. 计算机应用基础是文管二级吗,计算机应用基础
  8. 《Keras深度学习:入门、实战与进阶》之印第安人糖尿病诊断
  9. 随着移动支付普及那么如何进行支付测试呢?
  10. 切水果游戏中的刀的实现