1. motivation

目前的方法在源域和目标域存在较大域间偏差时实用性较差。本文认为:

1) 无监督学习可以缓解监督崩溃问题,并且训练得到的模型可以更好地推广到目标域中。

2) 因为源数据集和目标数据集之间存在很大差异,因此对源任务有用的特征可能对目标任务没有帮助,甚至有害。所以本文期望在小样本的情况下,通过提取更少的特征来提升泛化性能。

2. contribution

本文提出了一个“对比学习和特征选择系统”(Comparative learning and Feature Selection System)的小样本学习框架Confess,解决了基类和新类之间存在较大域偏移的问题。包含三个部分:

1) 在源域上基于对比损失无监督训练backbone;

2) 引入了mask module在目标域上训练来选择更适合目标域分类的相关特征;

3) 在目标域上微调分类器和backbone 。

实验部分在ECCV2020 challenge benchmark 上取得了很好的效果。

3. 核心内容

3.1 overall framework

3.2 无监督训练backbone

在预训练阶段,使用各种变换从训练批中的现有样本中扩充样本,并使用这些扩充样本和原始样本来计算对比损失。

具体来说,在每个批次中有 个样本,表示为 。对于每个样本 ,都用 个变化得到对应的扩充样本 。让扩充样本 的特征接近原样本 并远离其他样本 ,使用如下交叉熵损失:

具体的变化方式:color distortion (A Simple Framework for Contrastive Learning of Visual Representations, ICML2020)。

3.3 训练mask generator

从源域上训练得到的特征提取器为 。给定目标域上的样本,可以得到每个样本的特征为:

将该特征输入到mask生成模块M中得到对应的mask:

根据得到的mask,可以为每个特征得到对应的positive和negative的特征:

这里希望确保positive feature是有类别区分性的,而negative feature是没有类别区分性的:

其中, 是交叉熵损失, 是两个线性分类器。。

除此之外,最大化正特征集合 和负特征 集合之间的统计距离:

上述损失被联合起来用于训练mask generator:

3.4 微调过程

定义在目标域上的特征提取器为 ,其初始化为 ,对于每个目标域上的样本可以得到其特征为 。将其输入到线性分类器C中,计算交叉熵损失:

除此之外,Positive feature为:

提取到的特征应该和positive feature更接近:

那么微调损失就包含两个部分,分类损失和回归损失:

4. 实验部分

在跨域上得到的效果非常好:

5. 总结

存在一些有疑问的地方,mask generator是基于 在目标域上训练获得的,他是否适用于 提取得到的特征,这里的训练逻辑上有点绕。在迁移学习中,会在源域上预训练得到一个模型,然后再目标域上训练时固定大部分参数,微调剩下的部分参数,这里的mask起到的作用似乎也是类似的,就是把模型迁移到目标域上,但是又在目标域上继续微调了特征提取器。

小样本学习论文阅读 | Confess: A framework for single source cross-domain few-shot learning, ICLR 2022 poster相关推荐

  1. 深度学习论文阅读图像分类篇(五):ResNet《Deep Residual Learning for Image Recognition》

    深度学习论文阅读图像分类篇(五):ResNet<Deep Residual Learning for Image Recognition> Abstract 摘要 1. Introduct ...

  2. 论文阅读:Piggyback: Adapting a Single Network to Multiple Tasks by Learning to Mask Weights

    ECCV2018 , 在网络上训练一个mask,以适应新任务. 1.Introduction Packnet通过迭代地剪枝再训练扩展网络学习新任务,然而真的有必要调整网络的全部参数吗? 基于这个ide ...

  3. 论文中文翻译——Automated Vulnerability Detection in Source Code Using Deep Representation Learning

    本论文相关内容 论文下载地址--Web Of Science 论文中文翻译--Automated Vulnerability Detection in Source Code Using Deep R ...

  4. 深度学习论文阅读目标检测篇(四)中英文对照版:YOLOv1《 You Only Look Once: Unified, Real-Time Object Detection》

    深度学习论文阅读目标检测篇(四)中英文对照版:YOLOv1< You Only Look Once: Unified, Real-Time Object Detection> Abstra ...

  5. 深度学习论文阅读目标检测篇(一):R-CNN《Rich feature hierarchies for accurate object detection and semantic...》

    深度学习论文阅读目标检测篇(一):R-CNN<Rich feature hierarchies for accurate object detection and semantic segmen ...

  6. 深度学习论文阅读图像分类篇(三):VGGNet《Very Deep Convolutional Networks for Large-Scale Image Recognition》

    深度学习论文阅读图像分类篇(三):VGGNet<Very Deep Convolutional Networks for Large-Scale Image Recognition> Ab ...

  7. 深度学习论文阅读目标检测篇(三):Faster R-CNN《 Towards Real-Time Object Detection with Region Proposal Networks》

    深度学习论文阅读目标检测篇(三):Faster R-CNN< Towards Real-Time Object Detection with Region Proposal Networks&g ...

  8. 并行多任务学习论文阅读(二)同步和异步优化算法

    1.并行与分布式多任务学习(Multi-task Learning, MTL)简介 我们在上一篇文章<并行多任务学习论文阅读(一)多任务学习速览>(链接:https://www.cnblo ...

  9. 深度学习论文阅读目标检测篇(五)中英对照版:YOLOv2《 YOLO9000: Better, Faster, Stronger》

    深度学习论文阅读目标检测篇(五)中文版:YOLOv2< YOLO9000: Better, Faster, Stronger> Abstract 摘要 1. Introduction 1. ...

最新文章

  1. 全面理解java内存模型_深入理解Java内存模型(八)——总结
  2. 《人月神话》阅读笔记2
  3. 自定义webpart显示Lync状态球
  4. Matlab回显语句
  5. Spark SQL 之SparkSession
  6. 深入理解javascript原型和闭包(1)——一切都是对象
  7. 从零开始——电子商务平台01
  8. Actuator提供的endpoint
  9. 猿题库 iOS 客户端架构设计-唐巧
  10. Tiny模板语言(VelocityPlus)初步入门
  11. Java 调用 Impala - JDBC 调用Impala
  12. oracle tb级别数据量,备份TB级别Oracle数据库的一些技巧
  13. 计算火车运行时间(pta)
  14. java ==陷阱_Java小陷阱
  15. Linux基础之vim文本编辑器
  16. Django项目 BBS论坛
  17. 缓存应用(一)Ehcache使用介绍
  18. QT所有版本和VS插件下载
  19. 【ASP.NET】ASP.NET入门
  20. iText - OCR 截图识字 - 新版小幅更新

热门文章

  1. el-select使用filterable右侧箭头消失
  2. (高通平台)pdaf log打印不出来的检查步骤
  3. a 标签的 href 属性的获取与拼接
  4. 并联型APF/有源电力滤波器/Matlab/Simulink仿真 *dq/FBD谐波/无功检测
  5. Revit2017 外部工具添加
  6. 怎么删除PDF文件不要的页面?
  7. 文字打怪小游戏 源码(c++)
  8. Project软件安装包及教程|横道图|进度计划|项目管理
  9. 搜集素材“搜”出的产品设计灵感
  10. 应收账款与存货-坏账确认核算