【GiantPandaCV引言】 知识回顾(KR)发现学生网络深层可以通过利用教师网络浅层特征进行学习,基于此提出了回顾机制,包括ABF和HCL两个模块,可以在很多分类任务上得到一致性的提升。

摘要

知识蒸馏通过将知识从教师网络传递到学生网络,但是之前的方法主要关注提出特征变换和实施相同层的特征。

知识回顾Knowledge Review选择研究教师与学生网络之间不同层之间的路径链接。

简单来说就是研究教师网络向学生网络传递知识的链接方式。

代码在:https://github.com/Jia-Research-Lab/ReviewKD

KD简单回顾

KD最初的蒸馏对象是logits层,也即最经典的Hinton的那篇Knowledge Distillation,让学生网络和教师网络的logits KL散度尽可能小。

随后FitNets出现开始蒸馏中间层,一般通过使用MSE Loss让学生网络和教师网络特征图尽可能接近。

Attention Transfer进一步发展了FitNets,提出使用注意力图来作为引导知识的传递。

PKT(Probabilistic knowledge transfer for deep representation learning)将知识作为概率分布进行建模。

Contrastive representation Distillation(CRD)引入对比学习来进行知识迁移。

以上方法主要关注于知识迁移的形式以及选择不同的loss function,但KR关注于如何选择教师网络和学生网络的链接,一下图为例:

(a-c)都是传统的知识蒸馏方法,通常都是相同层的信息进行引导,(d)代表KR的蒸馏方式,可以使用教师网络浅层特征来作为学生网络深层特征的监督,并发现学生网络深层特征可以从教师网络的浅层学习到知识。

教师网络浅层到深层分别对应的知识抽象程度不断提高,学习难度也进行了提升,所以学生网络如果能在初期学习到教师网络浅层的知识会对整体有帮助。

KR认为浅层的知识可以作为旧知识,并进行不断回顾,温故知新。如何从教师网络中提取多尺度信息是本文待解决的关键:

  • 提出了Attention based fusion(ABF) 进行特征fusion

  • 提出了Hierarchical context loss(HCL) 增强模型的学习能力。

Knowledge Review

形式化描述

X是输入图像,S代表学生网络,其中(S1,S2,⋯,Sn,Sc)\left(\mathcal{S}_{1}, \mathcal{S}_{2}, \cdots, \mathcal{S}_{n}, \mathcal{S}_{c}\right)(S1​,S2​,⋯,Sn​,Sc​)代表学生网络各个层的组成。

Ys=Sc∘Sn∘⋯∘S1(X)\mathbf{Y}_{s}=\mathcal{S}_{c} \circ \mathcal{S}_{n} \circ \cdots \circ \mathcal{S}_{1}(\mathbf{X}) Ys​=Sc​∘Sn​∘⋯∘S1​(X)

Ys代表X经过整个网络以后的输出。(Fs1,⋯,Fsn)\left(\mathbf{F}_{s}^{1}, \cdots, \mathbf{F}_{s}^{n}\right)(Fs1​,⋯,Fsn​)代表各个层中间层输出。

那么单层知识蒸馏可以表示为:

LSKD=D(Msi(Fsi),Mti(Fti))\mathcal{L}_{S K D}=\mathcal{D}\left(\mathcal{M}_{s}^{i}\left(\mathbf{F}_{s}^{i}\right), \mathcal{M}_{t}^{i}\left(\mathbf{F}_{t}^{i}\right)\right) LSKD​=D(Msi​(Fsi​),Mti​(Fti​))

M代表一个转换,从而让Fs和Ft的特征图相匹配。D代表衡量两者分布的距离函数。

同理多层知识蒸馏表示为:

LMKD=∑i∈ID(Msi(Fsi),Mti(Fti))\mathcal{L}_{M K D}=\sum_{i \in \mathbf{I}} \mathcal{D}\left(\mathcal{M}_{s}^{i}\left(\mathbf{F}_{s}^{i}\right), \mathcal{M}_{t}^{i}\left(\mathbf{F}_{t}^{i}\right)\right) LMKD​=i∈I∑​D(Msi​(Fsi​),Mti​(Fti​))

以上公式是学生和教师网络层层对应,那么单层KR表示方式为:

具体具体 具体

与之前不同的是,这里计算的是从j=1 to i 代表第i层学生网络的学习需要用到从第1到i层所有知识。

同理,多层的KR表示为:

LMKD−R=∑i∈I(∑j=1iD(Msi,j(Fsi),Mtj,i(Ftj)))\mathcal{L}_{M K D_{-} R}=\sum_{i \in \mathbf{I}}\left(\sum_{j=1}^{i} \mathcal{D}\left(\mathcal{M}_{s}^{i, j}\left(\mathbf{F}_{s}^{i}\right), \mathcal{M}_{t}^{j, i}\left(\mathbf{F}_{t}^{j}\right)\right)\right) LMKD−​R​=i∈I∑​(j=1∑i​D(Msi,j​(Fsi​),Mtj,i​(Ftj​)))

Fusion方式设计

已经确定了KR的形式,即学生每一层回顾教师网络的所有靠前的层,那么最简单的方法是:

直接缩放学生网络最后一层feature,让其形状和教师网络进行匹配,这样Msi,j\mathcal{M}_s^{i,j}Msi,j​可以简单使用一个卷积层配合插值层完成形状的匹配过程。这种方式是让学生网络更接近教师网络。

这张图表示扩展了学生网络所有层对应的处理方式,也即按照第一张图的处理方式进行形状匹配。

这种处理方式可能并不是最优的,因为会导致stage之间出现巨大的差异性,同时处理过程也非常复杂,带来了额外的计算代价。

为了让整个过程更加可行,提出了Attention based fusion $\mathcal{U}
$, 这样整体蒸馏变为:

∑i=jnD(Fsi,Ftj)≈D(U(Fsj,⋯,Fsn),Ftj)\sum_{i=j}^{n} \mathcal{D}\left(\mathbf{F}_{s}^{i}, \mathbf{F}_{t}^{j}\right) \approx \mathcal{D}\left(\mathcal{U}\left(\mathbf{F}_{s}^{j}, \cdots, \mathbf{F}_{s}^{n}\right), \mathbf{F}_{t}^{j}\right) i=j∑n​D(Fsi​,Ftj​)≈D(U(Fsj​,⋯,Fsn​),Ftj​)

如果引入了fusion的模块,那整体流程就变为下图所示:

但是为了更高的效率,再对其进行改进:

可以发现,这个过程将fusion的中间结果进行了利用,即Fsjand U(Fsj+1,⋯,Fsn)\mathbf{F}_{s}^{j} \text { and } \mathcal{U}\left(\mathbf{F}_{s}^{j+1}, \cdots, \mathbf{F}_{s}^{n}\right)Fsj​ and U(Fsj+1​,⋯,Fsn​), 这样循环从后往前进行迭代,就可以得到最终的loss。

具体来说,ABF的设计如下(a)所示,采用了注意力机制融合特征,具体来说中间的1x1 conv对两个level的feature提取综合空间注意力特征图,然后再进行特征重标定,可以看做SKNet的空间注意力版本。

而HCL Hierarchical context loss 这里对分别来自于学生网络和教师网络的特征进行了空间池化金字塔的处理,L2 距离用于衡量两者之间的距离。

KR认为这种方式可以捕获不同level的语义信息,可以在不同的抽象等级提取信息。

实验

实验部分主要关注消融实验:

第一个是使用不同stage的结果:

蓝色的值代表比baseline 69.1更好,红色代表要比baseline更差。通过上述结果可以发现使用教师网络浅层知识来监督学生网络深层知识是有效的。

第二个是各个模块的作用:

源码

主要关注ABF, HCL的实现:

ABF实现:

class ABF(nn.Module):def __init__(self, in_channel, mid_channel, out_channel, fuse):super(ABF, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False),nn.BatchNorm2d(mid_channel),)self.conv2 = nn.Sequential(nn.Conv2d(mid_channel, out_channel,kernel_size=3,stride=1,padding=1,bias=False),nn.BatchNorm2d(out_channel),)if fuse:self.att_conv = nn.Sequential(nn.Conv2d(mid_channel*2, 2, kernel_size=1),nn.Sigmoid(),)else:self.att_conv = Nonenn.init.kaiming_uniform_(self.conv1[0].weight, a=1)  # pyre-ignorenn.init.kaiming_uniform_(self.conv2[0].weight, a=1)  # pyre-ignoredef forward(self, x, y=None, shape=None, out_shape=None):n,_,h,w = x.shape# transform student featuresx = self.conv1(x)if self.att_conv is not None:# upsample residual featuresy = F.interpolate(y, (shape,shape), mode="nearest")# fusionz = torch.cat([x, y], dim=1)z = self.att_conv(z)x = (x * z[:,0].view(n,1,h,w) + y * z[:,1].view(n,1,h,w))# output if x.shape[-1] != out_shape:x = F.interpolate(x, (out_shape, out_shape), mode="nearest")y = self.conv2(x)return y, x

HCL实现:

def hcl(fstudent, fteacher):
# 两个都是list,存各个stage对象loss_all = 0.0for fs, ft in zip(fstudent, fteacher):n,c,h,w = fs.shapeloss = F.mse_loss(fs, ft, reduction='mean')cnt = 1.0tot = 1.0for l in [4,2,1]:if l >=h:continuetmpfs = F.adaptive_avg_pool2d(fs, (l,l))tmpft = F.adaptive_avg_pool2d(ft, (l,l))cnt /= 2.0loss += F.mse_loss(tmpfs, tmpft, reduction='mean') * cnttot += cntloss = loss / totloss_all = loss_all + lossreturn loss_all

ReviewKD实现:

class ReviewKD(nn.Module):def __init__(self, student, in_channels, out_channels, shapes, out_shapes,):  super(ReviewKD, self).__init__()self.student = studentself.shapes = shapesself.out_shapes = shapes if out_shapes is None else out_shapesabfs = nn.ModuleList()mid_channel = min(512, in_channels[-1])for idx, in_channel in enumerate(in_channels):abfs.append(ABF(in_channel, mid_channel, out_channels[idx], idx < len(in_channels)-1))self.abfs = abfs[::-1]self.to('cuda')def forward(self, x):student_features = self.student(x,is_feat=True)logit = student_features[1]x = student_features[0][::-1]results = []out_features, res_features = self.abfs[0](x[0], out_shape=self.out_shapes[0])results.append(out_features)for features, abf, shape, out_shape in zip(x[1:], self.abfs[1:], self.shapes[1:], self.out_shapes[1:]):out_features, res_features = abf(features, res_features, shape, out_shape)results.insert(0, out_features)return results, logit

参考

https://zhuanlan.zhihu.com/p/363994781

https://arxiv.org/pdf/2104.09044.pdf

https://github.com/dvlab-research/ReviewKD

【知识蒸馏】Knowledge Review相关推荐

  1. 知识蒸馏(Knowledge Distillation)详细深入透彻理解重点

    知识蒸馏是一种模型压缩方法,是一种基于"教师-学生网络思想"的训练方法,由于其简单,有效,在工业界被广泛应用.这一技术的理论来自于2015年Hinton发表的一篇神作: 论文链接 ...

  2. 知识蒸馏 knowledge distill 相关论文理解

    Knowledge Distil 相关文章 1.FitNets : Hints For Thin Deep Nets (ICLR2015) 2.A Gift from Knowledge Distil ...

  3. Knowledge Distillation | 知识蒸馏经典解读

    作者 | 小小 整理 | NewBeeNLP 写在前面 知识蒸馏是一种模型压缩方法,是一种基于"教师-学生网络思想"的训练方法,由于其简单,有效,在工业界被广泛应用.这一技术的理论 ...

  4. 【深度学习】深度学习中的知识蒸馏技术(上)简介

    本文概览: 1. 知识蒸馏介绍 1.1 什么是知识蒸馏? 在化学中,蒸馏是一种有效的分离不同沸点组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的.化学蒸馏条件:( ...

  5. [深度学习]知识蒸馏技术

    一 知识蒸馏(Knowledge Distillation)介绍 名词解释 teacher - 原始模型或模型ensemble student - 新模型 transfer set - 用来迁移tea ...

  6. 给Bert加速吧!NLP中的知识蒸馏论文 Distilled BiLSTM解读

    论文题目:Distilling Task-Specific Knowledge from BERT into Simple Neural Networks 论文链接:https://arxiv.org ...

  7. 目标检测中的知识蒸馏方法

    目标检测中的知识蒸馏方法 知识蒸馏 (Knowledge Distillation KD) 是模型压缩(轻量化)的一种有效的解决方案,这种方法可以使轻量级的学生模型获得繁琐的教师模型中的知识.知识蒸馏 ...

  8. 知识蒸馏是什么?一份入门随笔

    点击上方,选择星标或置顶,每天给你送干货! 作者丨LinT@知乎 来源丨https://zhuanlan.zhihu.com/p/90049906 编辑丨极市平台 知识蒸馏的核心思想是通过迁移知识,从 ...

  9. 深度学习中的知识蒸馏技术(上)

    本文概览: 1. 知识蒸馏介绍 1.1 什么是知识蒸馏? 在化学中,蒸馏是一种有效的分离不同沸点组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的.化学蒸馏条件:( ...

最新文章

  1. 达观杯_构建模型(四)贝叶斯
  2. SQL2000 好书 《SQL Server 2000数据库管理与开发技术大全》----求是科技 人民邮电出版社
  3. 有趣的0-1背包问题:分割等和子集
  4. [BX] 和 loop指令
  5. wxWidgets:编写非英语应用程序
  6. 『ACM-算法-二分法』信息竞赛进阶指南--二分法
  7. centos 7 /etc/rc.local 开机不执行的问题
  8. 基于评论文本的深度推荐系统总结
  9. python三本经典书籍-Python入门经典书籍有哪些?有这三本就够了
  10. as没有add as library选项
  11. oracle的档案软件,思源档案管理系统(WEB版)
  12. narwal机器人_Narwal云鲸智能扫拖机器人,会自己洗拖布
  13. excel批量生成批处理语句另存为.bat文件批量改名
  14. 生产环境安装、配置、管理PostgreSQL14.5数据库集群。pgpool 4.3.3参数中文说明
  15. NTLite 1 2 0 4453授权注册版
  16. 公众号榜单 | 2020·5月公众号地区排行榜重磅发布
  17. Git仓库完整迁移 含历史记录
  18. CPU 使用率低高负载的原因,看看这篇!
  19. 迈拓网络硬盘软件全攻略(5)mldonkey
  20. Ipopt-3.12.7在ubuntu18.04安装

热门文章

  1. GE IC系列PLC IC687RCM711 GE IC687BEM744 IC697CFR782 IC697CFR28 IC687BEM713IC687BEM731 IC687BEM744-EB
  2. Kotlin学习——了解Kotlin
  3. python 无限循环小数_有关无限循环小数的一处漏洞
  4. 78系列79系列稳压芯片
  5. ECHARTS 水球图
  6. ta点读笔客户端_PIYO PEN点读笔=早教机+故事机+智能音箱+伴眠神器
  7. 零跑汽车迎难而上,坚持全域自研战略指引
  8. vscode 中 react 代码保存后,代码格式乱了怎么办?react代码点击保存格式化错误代码错乱处理
  9. 原来,PPTV是在筹划国内上市呐!
  10. 一篇文章让你了解大数据