文章目录

  • 摘要
  • 创新点
  • UDA介绍
  • 一些训练技巧
    • 1.训练信号退火(Training Signal Annealing, TSA)
    • 2.增强预测(Sharpening Predictions)
    • 3.领域相关数据过滤
  • 实验
  • 结论

论文链接:https://arxiv.org/pdf/1904.12848v2.pdf
代码链接:https://github.com/google-research/uda

摘要

在标记数据稀缺的情况下,半监督学习在改进深度学习模型方面表现出了很大的潜力。在最近的方法中,常见的是对大量未标记数据进行一致性训练,以约束模型预测对输入噪声保持不变。在本文中,作者建议在半监督学习环境中将数据增强应用于未标记的数据,使用的方法称为“无监督数据增强”(Unsupervised Data Augmentation, UDA),使模型预测在未标记样本和扩充的未标记样本之间保持一致。
与以前使用诸如高斯噪声或压降噪声之类的随机噪声的方法不同,UDA利用了由最新数据生成的更难、更逼真的噪声增强方法(如RandAugment和back translation)代替简单的去噪操作。即使标记数据集非常小,这个不同也使得它在六个语言任务和三个视觉任务有了实质性改进。 例如,在仅带有20个标记样本的IMDb文本分类数据集上,UDA的错误率达到4.20,胜过了在25,000个标记样本上训练的最新模型。 在标准的半监督学习基准CIFAR-10和SVHN上,UDA的性能优于所有以前的方法,在只有4,000个样本的CIFAR-10上实现了2.7%的错误率;在仅有250个示例的SVHN上实现了2.85%的错误率,几乎与大一到两个数量级的全带标签训练的模型性能相同。UDA在大型数据集(如ImageNet)上也能很好地作用。 当使用10%的标记数据集进行训练时,UDA将top-1 / top-5的准确性从55.1 / 77.3%提高到68.7 / 88.5%。 对于具有130万额外未标记数据的完整ImageNet数据集,UDA进一步将性能从78.3 / 94.4%提升至79.0 / 94.5%


创新点

  1. 我们显示在监督学习中优秀的数据增强方法也适用于半监督学习的一致性训练中。
  2. UDA可以媲美甚至超越监督学习的效果。而这些监督学习却使用了比UDA多很多的标注数据。无论是在视觉任务还是语言任务上。UDA只需使用很少的标注数据。
  3. 提出了一种称为TSA的训练技术,当未标记数据远多于标记数据时,该技术可以有效地防止过拟合
  4. 开发了一种方法,使得UDA甚至可以应用于标记和未标记数据不匹配的分布。

UDA介绍

符号说明:

半监督学习的最新工作是利用未标记的样本来增强模型的平滑性。 这些工作的一般形式可以概括如下:

  • 给定输入x,通过注入一个小的噪声来计算给定x的输出分布pθ(y∣x)p_{\theta}(y \mid x)pθ​(y∣x)和一个有噪声的版本pθ(y∣x,ϵ)p_{\theta}(y \mid x, \epsilon)pθ​(y∣x,ϵ)。 噪声可以应用于x或隐藏层,也可以用于更改计算过程。
  • 最小化两个预测分布D(pθ(y∣x)∥pθ(y∣x,ϵ))\mathcal{D}\left(p_{\theta}(y \mid x) \| p_{\theta}(y \mid x, \epsilon)\right)D(pθ​(y∣x)∥pθ​(y∣x,ϵ))之间的散度度量。

此过程使模型对噪声不敏感,因此就输入(或隐藏)空间的变化而言更平滑。在本文中,作者对现有的平滑性/一致性执行工作进行了简单的改动,并扩展了使用数据增强作为扰动。 作者建议使用针对不同任务的最新数据增强作为一种特殊形式的扰动,并针对未标记的样本优化相同的平滑度或一致性。 具体来说,建议使用一组在各种监督设置下验证的丰富的最新数据增强,在未标记的示例上注入噪声并优化相同的一致性训练目标。当与标记的示例联合训练时,我们使用权重因子λ平衡有监督交叉熵和无监督一致性训练损失,UDA训练流程如下图所示:

具体的公式为:

其中:
q(x^∣x)q(\hat{x} \mid x)q(x^∣x)数据增强转换, θ~\tilde{\theta}θ~是当前参数θ{\theta}θ的固定副本,指梯度并不通过 θ{\theta}θ传播;加号前表示正常的交叉熵
通过最小化一致性损失,UDA允许标签信息从有标签样本传播到无标签样本。 对于大多数实验,作者将λ设置为1,并对受监督的数据和无监督的数据使用不同的batch。 结果发现,对于某些数据集,在无监督的数据上使用较大的batch会带来更好的性能。

一些训练技巧

1.训练信号退火(Training Signal Annealing, TSA)

由于获取未标记的数据比标记的数据容易得多,因此在实践中,经常遇到这样一种情况,即未标记的数据量和标记的数据量之间存在很大的差距。为了使UDA能够利用尽可能多的未标记数据,通常需要足够大的模型,但是较大的模型可能会很容易导致数量有限的标记数据过拟合。为此,作者引入了一种称为训练信号退火(TSA)的新训练技术。
TSA的主要思想是,随着模型在越来越多的未标记数据上进行训练,逐渐释放标记样本的训练信号。具体来说,对于每个训练步骤t,将阈值设置为1/K≤ηt≤1,其中K为类别数。 当标记样本的正确类别pθ(y∗∣x)p_{\theta}\left(y^{*} \mid x\right)pθ​(y∗∣x)的概率高于阈值ηt时,就从损失函数中删除该样本,仅对小批量中的其他标记样本进行训练。 给定一小批带有标记的数据B,模型的目标如下:

其中I( )是指标函数,Z 是归一化因子。 阈值ηt充当上限,以防止模型对模型已经确定的样本进行过度训练。 当在训练过程中将ηt从1 / K逐渐退火到1时,该模型只是缓慢接受标记样本的监督,从而大大缓解了过拟合问题。

2.增强预测(Sharpening Predictions)

在问题很困难且标记样本数量非常少的情况下,基于未标记样本和增强的未标记样本的预测分布在各个类别上往往过于平坦。 因此,来自KL散度的无监督训练信号相对较弱,因此受到受监督部分的支配。作者发现增强未标记样本上产生的预测分布很有帮助,并可采用三种技术: Confidence-based masking、Entropy minimization、Softmax temperature controlling。实际上,作者发现结合使用Confidence-based masking和Softmax temperature controlling在标记数据量非常少时最有效,而熵最小化在标记数据量相对较大的情况下效果很好。

3.领域相关数据过滤

理想情况下,希望使用域外未标记的数据,因为通常它们更容易收集,但是域外数据的类分布通常与域内数据的类分布不匹配。因此,使用域外的未标记数据会损害性能。 为了获得与手头任务领域相关的数据,作者采用了一种常见的技术来检测域外数据——使用在域内数据上训练的基线模型来推断大型域外数据集中的数据标签,并挑选出该模型最有信心的样本(在类之间平均分布)。 具体来说,就是对于每个类别,根据属于该类别的分类概率对所有域外样本进行排序,然后选择概率最高的样本。


实验

实验设置: 按照标准的半监督学习设置,将UDA与先前在CIFAR-10和SVHN进行了比较,采用WideResNet-28-2作为我们的基准模型。 我们将UDA与Pseudo-Label(基于自训练的算法),虚拟对抗训练(VAT,对输入生成对抗性高斯扰动的算法)在Π-Model、Mean Teach、MixMatch上进行了比较。

与现有的半监督学习方法的比较。对比如图5所示。UDA以明显的优势胜过现有所有方法,包括MixMatch并行模型。例如,对于250个样本,UDA在CIFAR-10上的错误率是8.41,在SVHN上是2.85,而MixMatch在CIFAR-10上的错误率是11.08,在SVHN上是3.78。有趣的是,当未使用AutoAugment时,UDA可以匹配在完全监督的数据上训练的模型的性能。

首先,如表6所示,如果将通过AutoAugment在SVHN上找到的增强策略应用于CIFAR-10(由Switched Augment表示),错误率将从5.10增加到5.59,这证明了目标性数据增强的有效性。此外,如果我们去除Augment并仅使用Cutout,那么错误率将增加到6.42。最后,如果仅使用简单的裁剪和翻转作为增强,则错误率将增加到16.17。在SVHN上,不同增强方法的效果相似。

结论

本文证明了数据增强和半监督学习是紧密相连的。UDA使用高度针对性的数据增强来生成各种逼真的扰动,并使模型相对于这些扰动保持平稳。 本文还提出了TSA技术,当有很多未标记的数据时,该技术可以有效地防止UDA过度拟合监督的数据。 对于文本任务,UDA可以与表示学习(例如BERT)很好地结合在一起,并且在小数据集上非常有效。对于图像任务,UDA在性能指标较高的半监督学习设置中将错误率降低了30%以上。

主要对图像任务进行了说明

UDA:Unsupervised Data Augmentation for Consistency Training相关推荐

  1. [UDA]Unsupervised Data Augmentation for Consistency Training

    目录 Abstract 1 Introduction 2 Unsupervised Data Augmentation (UDA) Discussion Augmentation Strategies ...

  2. UDA(Unsupervised Data Augmentation)-半监督学习与数据增强

    1 简介 当标注好的数据很少时,半监督学习在深度学习模型中有非常好的表现.目前常用的方法是一致性训练,基于大量的非标注数据进行训练来使模型可以应对各种输入噪声(或者隐状态的噪声). 有些方法是来设计各 ...

  3. 论文阅读 Jointly Optimize Data Augmentation and Network Training

    平常的过,再过几分钟就25岁了,不知道怎么捕捉这个时刻,越来越喜欢孤独的感觉,常哭,常生气,希望未来的人会出现吧,真的要长大了 –槛外人– Abstract 随机的数据扩增对于网络训练很重要,以前的方 ...

  4. 数据增强 data augmentation

    有人称为数据扩充,不过更多按原意翻译为:数据增强(补充:数据增广更准确) 数据增强的方法种类 一些常见方法,如裁剪/缩放/彩色变换/翻转等,可参考:https://www.cnblogs.com/zh ...

  5. WS-DAN:Weakly Supervised Data Augmentation Netowrk for Fine-Grained Visual Classification

    See Better Before Looking Closer: Weakly Supervised Data Augmentation Netowrk for Fine-Grained Visua ...

  6. AutoML论文笔记(十四)Automatic Data Augmentation via Deep Reinforcement Learning for Effective Kidney Tumor

    文章题目:Automatic Data Augmentation via Deep Reinforcement Learning for Effective Kidney Tumor Segmenta ...

  7. 语音识别(ASR)论文优选:A comparison of streaming models and data augmentation methods for robust speech recog

    声明:平时看些文章做些笔记分享出来,文章中难免存在错误的地方,还望大家海涵.搜集一些资料,方便查阅学习:http://yqli.tech/page/speech.html.语音合成领域论文列表请访问h ...

  8. 论文阅读:AutoAugment: Learning Augmentation Strategies from Data

    文章目录 1.论文总述 2.MNIST 与 ImageNet 数据集上有效数据增强的不同 3.The key difference between our method and GAN 4.A sea ...

  9. Dataset之DA:数据增强(Data Augmentation)的简介、方法、案例应用之详细攻略

    Dataset之DA:数据增强(Data Augmentation)的简介.方法.案例应用之详细攻略 目录 DA的简介 DA的方法 DA的案例应用 DA的简介 数据集增强主要是为了减少网络的过拟合现象 ...

最新文章

  1. java怎么复制动态数组_Java 数组排序复制等操作(Java Arraycopy)
  2. 物体分割--Deep Watershed Transform for Instance Segmentation
  3. STL初探——构造和析构的基本工具: construct()和destroy()
  4. Libnids库-网络入侵检测的基础框架
  5. C语言嵌入式系统编程修炼之道——屏幕操作篇
  6. 如何让电脑成为看图说话的高手?计算机视觉顶会ICCV论文解读
  7. Android 自定义Application
  8. 13-容器的端口映射
  9. CentOS 安装OciLib 4.2.1 (Linux)
  10. python requests form data_Python requests模块 multipart/form-data类型文件上传
  11. PHP 缓存 内存,php - 一个大型数组变量的APC内存缓存(22MB)
  12. 剑指offer——面试题33:把数组排成最小数
  13. 常见概率分布图表总结
  14. JWT结合Springboot+shiro,session、token同时存在来应对不同的业务场景(物联网设备管理及开放api)...
  15. 使用AJAX时出现“Microsoft JScript 运行时错误: 'Sys' 未定义”提示的解决方法
  16. SQL Server 2016 SP1 标准版等同企业版?!
  17. NGFF(M.2) m.2中Bkey接口Mkey接口有什么不同
  18. 程序员内功:八大排序算法
  19. 中国牛奶市场竞争态势分析及未来发展前景预测报告2022-2028年版
  20. 水安ABC考试多选练习题库(6)

热门文章

  1. Quidem repellendus similique reiciendis quas.ExTable blond sorte bcepturi voluptatibus ipsa aliquid.
  2. 对话框AlertDialog的使用
  3. 利用Fierce2查询子域名
  4. Microsoft .NET Framework 3.5 sp1离线安装(DotNetFX35)
  5. Sublime Text 2搭建Go开发环境(Windows)
  6. java毕业设计迅腾游戏交流网站Mybatis+系统+数据库+调试部署
  7. View的这些基础知识你必须要知道,聪明人已经收藏了!
  8. Swift 头像上传(2)http://blog.csdn.net/wei_chong_chong/article/details/52611110
  9. 华为鸿蒙手机曝光,华为鸿蒙手机新特性曝光:充电期间系统将进行深度优化
  10. 《解忧杂货店》-东野圭吾