【CVPR 2021】基于Wasserstein Distance对比表示蒸馏方法:Wasserstein Contrastive Representation Distillation

  • 论文地址:
  • 主要问题:
  • 主要思路:
  • Wasserstein Distance:
    • 基本内容:
    • 定义:
  • 具体实现:
    • Global Contrastive Knowledge Transfer:
    • Local Contrastive Knowledge Transfer
    • Unifying Global and Local Knowledge Transfer:
  • 实验结果:
  • 关注我的公众号:
  • 联系作者:

论文地址:

https://arxiv.org/abs/2012.08674

主要问题:

目前大部分知识蒸馏(例如使用KL散度的知识蒸馏方法)可能无法在教师网络中捕获重要的结构性知识,并且往往缺乏特征泛化的能力,特别是在教师和学生被用来解决不同分类任务的情况下。

主要思路:

作者提出了一个基于Wasserstein Distance对比表示蒸馏方法,称之为Wasserstein Contrastive Representation Distillation。在该算法中,作者同时使用基础形式和对偶形式的Wasserstein Distance。

其中:

对偶形式用于度量global的知识迁移,即产生一个相反的学习目标,最大化教师和学生网络之间相互信息的下界(跟第一篇思路很类似,只不过第一篇里面距离仍用的是KL距离);

原始形式用于mini-batch内的局部对比知识迁移,有效地匹配了教师和学生网络之间的特征分布。最终结果能够达到state-of-the-art的效果。

Wasserstein Distance:

基本内容:

Wasserstein Distance是最近在基于对比的知识蒸馏方法中提出的一种距离度量,使用它的目标往往是将相似的样本移动得更近,同时将特征空间中不同的样本分开。

定义:

考虑两个概率分布:x1∼p1x_1\sim p_1x1​∼p1​和x2∼p2x_2\sim p_2x2​∼p2​,那么p1,p2p_1,p_2p1​,p2​的Wasserstein-1距离就可以写作:

W(p1,p2)=infπ∈∏(p1,p2)∫M×Mc(x1,x2)dπ(x1,x2)W(p_1,p_2)=inf_{\pi\in\prod(p_1,p_2)}\int_{M\times M}c(x_1,x_2)d\pi(x_1,x_2)W(p1​,p2​)=infπ∈∏(p1​,p2​)​∫M×M​c(x1​,x2​)dπ(x1​,x2​)

其中c(⋅)c(\cdot)c(⋅)是一个点对点的用来评估距离的损失函数,∏\prod∏是p1(x1)p_1(x_1)p1​(x1​)和p2(x2)p_2(x_2)p2​(x2​)所有可能的联合概率分布,MMM是x1,x2x_1,x_2x1​,x2​所在的特征空间,π(x1,x2)\pi(x_1,x_2)π(x1​,x2​)是满足∫Mπ(x1,x2)dx2=p1(x1)\int_{M}\pi(x_1,x_2)dx_2=p_1(x_1)∫M​π(x1​,x2​)dx2​=p1​(x1​)和∫Mπ(x1,x2)dx1=p2(x2)\int_{M}\pi(x_1,x_2)dx_1=p_2(x_2)∫M​π(x1​,x2​)dx1​=p2​(x2​)的联合概率分布

基于Kantorovich-Rubenstein二元性,WD可以写作对偶的形式:

W(p1,p2)=sup∣∣g∣∣L≤1Ex1∼p1[g(x1)]−Ex2∼p2[g(x2)]W(p_1,p_2)=sup_{||g||_L\leq1}\mathbb{E_{x_1\sim p_1}[g(x_1)]}-\mathbb{E_{x_2\sim p_2}[g(x_2)]}W(p1​,p2​)=sup∣∣g∣∣L​≤1​Ex1​∼p1​​[g(x1​)]−Ex2​∼p2​​[g(x2​)]

其中ggg是一个满足1-Lipschitz约束的函数(往往是个神经网络)

具体实现:

Global Contrastive Knowledge Transfer:

对于全局对比知识迁移,作者考虑在logits之前的层最大化两个特征表示hS,hTh^S,h^ThS,hT的相关信息(MIMIMI),即试图通过KL散度将联合分布p(hT,hS)p(h^T,h^S)p(hT,hS)与边缘分布µ(hT)µ(h^T)µ(hT)和ν(hS)ν(h^S)ν(hS)的乘积相匹配:

I(hS,hT)=KL(p(hS,hT)∣∣µ(hT)ν(hS))I(h^S,h^T)=KL(p(h^S,h^T)||µ(h^T)ν(h^S))I(hS,hT)=KL(p(hS,hT)∣∣µ(hT)ν(hS))

由于联合分布和边缘分布都是隐式的(即我们没法直接计算),因此我们可以用NCE(Noise Contrastive Estimation)的方法来近似估计MIMIMI。

具体地说,我们将来自联合分布的对表示为同余对(congruent pair),独立的边缘分布的乘积的对表示为不余对。换句话说就是同余对是指将相同的数据输入提供给教师和学生网络,而不同余对由不同的数据输入组成。

跟 Complementary Relation Contrastive Distillation 论文中的做法类似,作者也引入了带有隐变量 η\etaη 的分布 qqq:

q(hT,hS∣η=1)=p(hT,hS)q(h^T,h^{S}|\eta=1)=p(h^T,h^{S})q(hT,hS∣η=1)=p(hT,hS)

q(hT,hS∣η=0)=μ(hT)ν(hS)q(h^T,h^{S}|\eta=0)=\mu(h^T)\nu(h^{S})q(hT,hS∣η=0)=μ(hT)ν(hS)

这里我们假设1个相关关系对带有111个不相关关系对,那么q(η=1)=q(η=0)=1/2q(\eta=1)=q(\eta=0)=1/2q(η=1)=q(η=0)=1/2

基于Complementary Relation Contrastive Distillation我们同样可以推导出:

I(hT,hS)≥Eq(hT,hS∣η=1)log⁡q(η=1∣hT,hS)I(h^T,h^{S})\geq\mathbb{E}_{q(h^T,h^{S}|\eta=1)}\log q(\eta=1|h^T,h^{S})I(hT,hS)≥Eq(hT,hS∣η=1)​logq(η=1∣hT,hS)

同样使用一个函数ggg来评估一个关系对是否来自联合分布,并且可以通过NCE loss来学习:

LNCE=Eq(hT,hS∣η=1))log⁡g(hT,hS)+Eq(hT,hS∣η=0)log⁡[1−g(hT,hS)]\mathcal{L}_{NCE}=\mathbb{E}_{q(h^T,h^S|\eta=1))}\log g(h^T,h^S)+\mathbb{E}_{q(h^T,h^S|\eta=0)}\log [1-g(h^T,h^S)]LNCE​=Eq(hT,hS∣η=1))​logg(hT,hS)+Eq(hT,hS∣η=0)​log[1−g(hT,hS)]

并且LNCE\mathcal{L}_{NCE}LNCE​可同时优化函数ggg的参数和网络SSS的参数

不同于Complementary Relation Contrastive Distillation使用神经网络作为函数ggg,因为这样有两个缺陷:

  • g可能对输入中的小数值变化很敏感,从而产生比较差的性能,尤其是当网络架构或学生和教师网络的训练数据集不同的时候
  • g可能会出现模式坍塌的问题(参考Wasserstein GAN)

为了解决这个问题,作者使用了 spectral normalization。即对于一个任意矩阵AAA,它的spectral normalization定义为:

σ(A)=max∣∣β∣∣2≤1∣∣Aβ∣∣2\sigma(A)=max_{||\beta||_2\leq1}||A\beta||_2σ(A)=max∣∣β∣∣2​≤1​∣∣Aβ∣∣2​

它相当于A的最大奇异值

通过将此正则化器应用于g^\hat{g}g^​中每个层的权重,就可以满足1-Lipschitz约束了,因此最终将其损失函数改写为:

LGCKT=Eq(hT,hS∣η=1))log⁡g^(hT,hS)−MEg^(hT,hS∣η=0)log⁡g^(hT,hS)\mathcal{L}_{GCKT}=\mathbb{E}_{q(h^T,h^S|\eta=1))}\log \hat{g}(h^T,h^S)-M\mathbb{E}_{\hat{g}(h^T,h^S|\eta=0)}\log \hat{g}(h^T,h^S)LGCKT​=Eq(hT,hS∣η=1))​logg^​(hT,hS)−MEg^​(hT,hS∣η=0)​logg^​(hT,hS)

训练中采样方法则与Complementary Relation Contrastive Distillation类似

Local Contrastive Knowledge Transfer

对比学习也可以应用于一个mini-batch,以进一步提高性能

具体来说,在小批量处理中,在训练学生网络时,从教师网络中提取的特征{hiT}i=1n\{h^T_i\}^n_{i=1}{hiT​}i=1n​可以被视为一个固定的集合。理想情况下,分类信息被封装在特征空间中,因此来自学生网络中的每个元素{hjS}j=1n\{h^S_j\}^n_{j=1}{hjS​}j=1n​都应该能够在这一个固定的集合中找到nerighbor

而且我们可以推测,在特征空间临近的样本可能共享相同的类。因此,我们可以尝试鼓励模型将hjSh^S_jhjS​同时推送到几个邻居{hiT}i=1n\{h^T_i\}^n_{i=1}{hiT​}i=1n​,而不是仅仅从教师网络中的一个以更好地泛化模型性能

这可以用WD的原始形式有效地实现。即当只使用有限的训练样本时,原始形式可以被解释为将概率质量从µ(hT)µ(h^T)µ(hT)转移到ν(hS)ν(h^S)ν(hS)的一种简单有效的方法

具体解释为我当我们有µ(hT)=∑i=1nuiδhiTµ(h^T)=\sum^n_{i=1}u_i\delta_{h_i^T}µ(hT)=∑i=1n​ui​δhiT​​和ν(hS)=∑j=1nvjδhjSν(h^S)=\sum^n_{j=1}v_j\delta_{h_j^S}ν(hS)=∑j=1n​vj​δhjS​​,其中δx\delta_xδx​是以x为中心的狄拉克函数

那么我们可以把WD的一般形式写作:
W(µ,ν)=min⁡π∑i=1n∑j=1nπijc(hiT,hjS)=min⁡π〈π,C〉W(µ,ν)=\min_{\pi}\sum^n_{i=1} \sum^n_{j=1}\pi_{ij}c(h^T_i,h^S_j)=\min_{\pi}〈\pi,C〉W(µ,ν)=minπ​∑i=1n​∑j=1n​πij​c(hiT​,hjS​)=minπ​〈π,C〉

其中∑j=1nπij=ui\sum^n_{j=1}\pi_{ij}=u_i∑j=1n​πij​=ui​,∑i=1nπij=νj\sum^n_{i=1}\pi_{ij}=ν_j∑i=1n​πij​=νj​,πππ是hTh^ThT和hSh^ShS中的离散联合概率,CCC是由Cij=c(hiT,hjS)C_{ij}=c(h^T_i,h^S_j)Cij​=c(hiT​,hjS​)给出的损失矩阵,〈π,C〉=Tr〈πTC〉〈\pi,C〉=Tr〈\pi^TC〉〈π,C〉=Tr〈πTC〉表示Frobenius 点积,C(⋅)C(\cdot)C(⋅)
表示一个用来度量两个特征向量不相关性的损失函数(例如cosine距离)

理想情况下,可以使用线性规划获得上式的全局最优值。但是该方法是不可微的,使它与现有的深度学习框架不兼容。作为一种替代方案,作者应用了Sinkhorn算法,通过添加一个凸正则化项来求解上式,即:

LLCKT=min⁡π∑i,jπijc(hiT,hjS)+ϵH(π)\mathcal{L}_{LCKT}=\min_{\pi}\sum_{i,j} \pi_{ij}c(h^T_i,h^S_j)+\epsilon H(\pi)LLCKT​=minπ​∑i,j​πij​c(hiT​,hjS​)+ϵH(π)

其中H(π)=∑i,jπijlog⁡πijH(\pi)=\sum_{i,j}\pi_{ij}\log \pi_{ij}H(π)=∑i,j​πij​logπij​,ϵ\epsilonϵ是一个超参数

更具体的算法可以看论文给出的伪代码:

Unifying Global and Local Knowledge Transfer:

虽然GCKT和LCKT是为不同的目标而设计的,但它们是互补的。通过优化LCKT,我们的目标是最小化边缘分布之间的差异,这相当于减少两个特征空间之间的差异,以便LCKT可以为GCKT提供一个更受约束的特征空间。另一方面,通过优化GCKT,学习到的表示也可以形成一个更好的特征空间,这反过来又帮助LCKT匹配边缘分布。

因此最终的Loss就可以写作:

LWCoRD(θS,ϕ)=LCE(θS)−λ1LGCKT(θS,ϕ)+λ2LLCKT(θS)L_{WCoRD}(\theta_S,\phi)=L_{CE}(\theta_S)-\lambda_1L_{GCKT}(\theta_S,\phi)+\lambda_2L_{LCKT}(\theta_S)LWCoRD​(θS​,ϕ)=LCE​(θS​)−λ1​LGCKT​(θS​,ϕ)+λ2​LLCKT​(θS​)

实验结果:

关注我的公众号:

感兴趣的同学关注我的公众号——可达鸭的深度学习教程:

联系作者:

B站:https://space.bilibili.com/470550823

CSDN:https://blog.csdn.net/weixin_44936889

AI Studio:https://aistudio.baidu.com/aistudio/personalcenter/thirdview/67156

Github:https://github.com/Sharpiless

【CVPR 2021】基于Wasserstein Distance对比表示蒸馏方法:Wasserstein Contrastive Representation Distillation相关推荐

  1. 蚂蚁金服AAAI论文:基于长短期老师的样本蒸馏方法和自动车险定损系统的最新突破...

    来源 | 蚂蚁金服 出品 | AI科技大本营(ID:rgznai100) 一年一度在人工智能方向的顶级会议之一AAAI 2020于2月7日至12日在美国纽约举行,旨在汇集世界各地的人工智能理论和领域应 ...

  2. 【CVPR 2021】Knowledge Review:知识蒸馏新解法

    [CVPR 2021]Knowledge Review:知识蒸馏新解法 论文地址: 主要问题: 主要思路: 符号假设: 具体实现: 实验结果: 关注我的公众号: 联系作者: 论文地址: https:/ ...

  3. 【CVPR 2021】通用的实例级蒸馏:General Instance Distillation for Object Detection

    [CVPR 2021]通用的实例级蒸馏:General Instance Distillation for Object Detection 论文地址: 主要问题: 主要思路: 主要贡献: 具体实现: ...

  4. 【CVPR 2021】树状决策知识蒸馏:Tree-like Decision Distillation

    [CVPR 2021]树状决策知识蒸馏:Tree-like Decision Distillation 论文地址: 主要问题: 主要思路: 具体实现: 基本符号: Tree-like Decision ...

  5. CVPR 2021 | 基于随机标签的神经架构搜索

    本文转自旷视研究院. 今日分享一篇来自旷视被收录为 CVPR2021 的论文『Neural Architecture Search with Random Labels』.详情如下: 论文名称:Neu ...

  6. CVPR 2021 | 基于语义聚合与自适应2D-1D配准的手部三维重建(快手)

    来源丨arXiv每日学术速递 今天,我们介绍的是快手Y-tech入选CVPR 2021的工作之一,Camera-Space Hand Mesh Recovery via Semantic Aggreg ...

  7. CVPR 2021 | 基于跨任务场景结构知识迁移的单张深度图像超分辨率方法

    ©PaperWeekly 原创 · 作者|孙宝利 学校|大连理工大学硕士 研究方向|计算机视觉 项目主页: http://faculty.dlut.edu.cn/yexinchen/zh_CN/zdy ...

  8. CVPR 2021 | 基于Transformer的端到端视频实例分割方法

    实例分割是计算机视觉中的基础问题之一.目前,静态图像中的实例分割业界已经进行了很多的研究,但是对视频的实例分割(Video Instance Segmentation,简称VIS)的研究却相对较少.而 ...

  9. CVPR 2021 | 基于帧场学习的多边形建筑提取

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 参考论文: Polygonal Building Extraction by Frame Field ...

最新文章

  1. css3宽度变大动画_不会仪表?太尴尬了。14种动画让你轻松掌握各种流量计工作原理...
  2. Spark 安装配置简单测试
  3. Boost.MultiIndex 使用随机访问索引的示例
  4. P1083 借教室(差分+二分)
  5. Impala-shell 启动异常 - Python版本为3.x 启动脚本为2.x
  6. vmware虚拟机安装win7_图文分享虚拟机怎么安装win7系统
  7. 让8只数码管初始显示零,每隔大约1s加一显示,到数码管显示9后,再从一开始显示
  8. Flinksql读取Kafka写入Iceberg 实践亲测
  9. MySQL8单表记录多少_mysql学习笔记之8(单表数据记录查询)_mysql
  10. 一步一步写算法(之排序二叉树插入)
  11. 独家 | 微软披露拓扑量子计算机计划!
  12. C#高仿腾讯QQ截图程序
  13. Centos磁盘管理和文件系统管理
  14. PB与各种数据库连接
  15. axure如何实现跳转_Axure 9 教程:如何做跑马灯广告、弹幕
  16. 微软图表控件MsChart使用初探
  17. 如何下载太原市卫星地图高清版大图
  18. 触发器的三种触发方式:电平触发、边沿触发、脉冲触发区别
  19. 不要因为错爱而寂寞一生
  20. 人工智能真正值得担心的是缺德,而不是聪明

热门文章

  1. 练习之彩票三 添加号码相关代码
  2. 免费给你的QQ个人信息面板加上彩色背景(转)
  3. 计算机一会儿黑屏,电脑一会黑屏一会亮怎么处理?
  4. pdf编辑软件哪个好 如何在pdf上修改
  5. 小伙开私人影院,裁掉员工玩套路,你见过哪个老板敢这么玩?
  6. 2007年web开发技术预言
  7. PS怎么抠出圆形图(可调整边缘)
  8. 翔谈设计模式——观察者模式
  9. Mac通过aapt获取apk文件的基本信息
  10. FPGA和CPLD芯片选型介绍(一)