知识蒸馏源自Hinton et al.于2014年发表在NIPS的一篇文章:Distilling the Knowledge in a Neural Network。

1. 背景

一般情况下,我们在训练模型的时候使用了大量训练数据和计算资源来提取知识,但这不方便在工业中部署,原因有二:
(1)大模型推理速度慢
(2)对设备的资源要求高(大内存)
因此我们希望对训练好的模型进行压缩,在保证推理效果的前提下减小模型的体量,知识蒸馏(Knownledge Distillation)属于模型压缩的一种方法 [1]。

2. 知识蒸馏

名词解释:
cumbersome model:原始模型或者说大模型,但在后续的论文中一般称它为teacher model;
distilled model:蒸馏后的小模型,在后续的论文中一般称它为stududent model;
hard targets:像[1, 0, 0]这样的标签,也叫做ground-truth label;
soft targets:像[0.7, 0.2, 0.1]这样的标签;
transfer set:训练student model的数据

好模型的目标不是拟合训练数据,而是学习如何泛化到新的数据。所以蒸馏的目标是让student学习到teacher的泛化能力,理论上得到的结果会比单纯拟合训练数据的student要好 [3]。显然,soft target可以提供更大的信息熵,所以studetn model可以学习到更多的信息。

通俗的来讲,粗暴的使用one-hot编码把原本有帮助的类内variance和类间distance都忽略了,比如猫和狗的相似性要比猫与摩托车的相似性要多,狗的某些特征可能对识别猫也会有帮助(比如毛发),因此使用soft target可以恢复被one-hot编码丢弃的信息 [2]。

在Hinton et al. 发表的这篇论文中,作者提出了"softmax temperature"的概念,其公式为:
qi=exp⁡(zi/T)∑jexp⁡(zj/T)q_{i}=\frac{\exp (z_{i}/T)}{\sum_{j}^{}\exp (z_{j}/T)} qi​=∑j​exp(zj​/T)exp(zi​/T)​
Python代码:

import numpy as np
def softmax_t(x,t):x_exp = np.exp(x / t)return x_exp / np.sum(x_exp)

qiq_{i}qi​代表第iii类的输出概率,ziz_{i}zi​和zjz_{j}zj​为softmax的输入,即上一层神经元的输出(logits),T表示temperature参数。通常情况下,我们使用的softmax函数T为1,但TTT可以控制输出soft的程度。比如对于z=[0.3,0.5,0.8,0.1,0.2]z=[0.3, 0.5, 0.8, 0.1, 0.2]z=[0.3,0.5,0.8,0.1,0.2],我们分别取T=[0.5,1,5,20]T=[0.5, 1, 5, 20]T=[0.5,1,5,20],然后画出softmax函数的输出可以看到,TTT越小,输出的预测结果越“硬”(曲线更加曲折),T越大输出的结果越“软”(曲线更加平和)。

插一句题外话,为什么这里的参数是叫温度(temperature)呢?这和蒸馏(distillation)这一热力学工艺有关。在蒸馏工艺中,温度越高提取到的物质越纯越浓缩。而在知识蒸馏中,参数T越大(温度越高),teacher model产生的label越"soft",信息熵就越高,提炼的知识更具有一般性(generalization)。所以说作者将这一参数取名temperature十分有趣。

知识蒸馏的实现过程可以概括为:

  1. 训练teacher model;
  2. 使用高温T将teacher model中的知识蒸馏到student model(在测试时温度T设为1)。

student modeld的目标函数由一下两项的加权平均组成:

  1. distillation loss:soft targets(由teacher model产生) 和student model的soft predictions的交叉熵,这里的T使用的是和训练teacher model相同的值。(保证student model和teacher model的结果尽可能一致)
  2. student loss:hard targets 和student model的输出数据的交叉熵,但T设置为1。(保证student model的结果和实际类别标签尽可能一致)

总体的损失函数可以写作:
L(x,W)=α∗CE(y,σ(zs;T=1))+β∗CE(σ(zt;T=τ),σ(zs,T=τ))\mathcal{L}(x,W)=\alpha \ast \text{CE}(y,\sigma(z_{s};T=1))+\beta \ast \text{CE}(\sigma (z_{t};T=\tau ),\sigma(z_{s},T=\tau)) L(x,W)=α∗CE(y,σ(zs​;T=1))+β∗CE(σ(zt​;T=τ),σ(zs​,T=τ))
其中,xxx表示输入,WWW表示student model的参数,yyy是ground-truth label,CE\text{CE}CE是交叉熵损失函数,σ\sigmaσ是刚刚提到的softmax temperature激活函数,zsz_{s}zs​和ztz_{t}zt​分别表示student和teacher model神经元的输出(logits), α\alphaα和β\betaβ表示两个权重参数 [4].

原论文指出,α\alphaα要比β\betaβ相对小一些可以取得更好的结果,因为在求梯度时soft targets被缩放了1/T21/T^{2}1/T2,所以第2项要乘以一个更小的权值来平衡二者在优化时的比重 [1].

换一个角度来想,这里的知识蒸馏其实是相对于对于原始交叉熵添加了一个正则项:
L(x,W)=CE(y,y^)+λsoft_loss(y′,y^)\mathcal {L}(x,W)=\text{CE}(y,\hat{y})+\lambda \text{soft\_loss}(y', \hat{y}) L(x,W)=CE(y,y^​)+λsoft_loss(y′,y^​)
利用teacher model的先验知识对student model进行正则化 [5]。

本文原载于简书,未经授权,不得转载。


References:

[1] Distilling the Knowledge in a Neural Network.
[2] # Distilling the Knowledge in a Neural Network 论文笔记
[3] 深度神经网络模型蒸馏Distillation
[4] Knowledge Distillation
[5] 神经网络知识蒸馏 Knowledge Distillation

知识蒸馏Knownledge Distillation相关推荐

  1. 知识蒸馏(Knowledge Distillation)详细深入透彻理解重点

    知识蒸馏是一种模型压缩方法,是一种基于"教师-学生网络思想"的训练方法,由于其简单,有效,在工业界被广泛应用.这一技术的理论来自于2015年Hinton发表的一篇神作: 论文链接 ...

  2. 知识蒸馏(Distillation)

    Hinton的文章<Distilling the Knowledge in a Neural Network>首次提出了知识蒸馏的概念,通过引入教师网络用以诱导学生网络的训练,实现知识迁移 ...

  3. Knowledge Distillation | 知识蒸馏经典解读

    作者 | 小小 整理 | NewBeeNLP 写在前面 知识蒸馏是一种模型压缩方法,是一种基于"教师-学生网络思想"的训练方法,由于其简单,有效,在工业界被广泛应用.这一技术的理论 ...

  4. 【深度学习】深度学习中的知识蒸馏技术(上)简介

    本文概览: 1. 知识蒸馏介绍 1.1 什么是知识蒸馏? 在化学中,蒸馏是一种有效的分离不同沸点组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的.化学蒸馏条件:( ...

  5. [深度学习]知识蒸馏技术

    一 知识蒸馏(Knowledge Distillation)介绍 名词解释 teacher - 原始模型或模型ensemble student - 新模型 transfer set - 用来迁移tea ...

  6. 给Bert加速吧!NLP中的知识蒸馏论文 Distilled BiLSTM解读

    论文题目:Distilling Task-Specific Knowledge from BERT into Simple Neural Networks 论文链接:https://arxiv.org ...

  7. 目标检测中的知识蒸馏方法

    目标检测中的知识蒸馏方法 知识蒸馏 (Knowledge Distillation KD) 是模型压缩(轻量化)的一种有效的解决方案,这种方法可以使轻量级的学生模型获得繁琐的教师模型中的知识.知识蒸馏 ...

  8. 知识蒸馏是什么?一份入门随笔

    点击上方,选择星标或置顶,每天给你送干货! 作者丨LinT@知乎 来源丨https://zhuanlan.zhihu.com/p/90049906 编辑丨极市平台 知识蒸馏的核心思想是通过迁移知识,从 ...

  9. 深度学习中的知识蒸馏技术(上)

    本文概览: 1. 知识蒸馏介绍 1.1 什么是知识蒸馏? 在化学中,蒸馏是一种有效的分离不同沸点组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的.化学蒸馏条件:( ...

最新文章

  1. 【c语言】输入输出格式练习
  2. linux序列比对程序,序列比对软件简单使用教程
  3. mysql备份与还原-mysqldump备份、mysql与source还原
  4. Linux读写执行(RWX)权限
  5. JavaFX官方教程(六)之带有JavaFX CSS的花式表单
  6. UVALive 6511 Term Project
  7. 网络篇:朋友面试之TCP/IP,回去等通知吧
  8. L1-041 寻找250 (10 分)—团体程序设计天梯赛
  9. h5 打包后效果失效
  10. community_louvain社群划分方法
  11. python查火车票_Python查询火车票(三)
  12. java调用opencc进行中文简体繁体转换
  13. 一个小透明作者到出版书籍,我的心路历程分享给各位作者
  14. 思维导图(自我介绍)
  15. datastage dsjob命令
  16. 四足机器人技术及进展
  17. vscode 学习(四)如何设置右键使用vscode打开
  18. MSP430+LCD1602显示实验
  19. 编译安装 Python
  20. 数字化时代-11:从马斯洛需求层次看未来选择做什么样的产品

热门文章

  1. 程序员到底有多可爱?笑死我了!
  2. linux下初始化磁盘
  3. 程序人生 - 996(一)马云谈996:只有付出巨大的代价才可能有回报
  4. 李宏毅ML lecture-10 CNN
  5. java 1.3 下载_我的世界Java版1.16.3
  6. 竞拍H5,源码,玉石、字画等转拍程序
  7. 数据库语言(DDL和DML)
  8. JAVAWEB复习知识六:根据Bootstrap框架做网站首页
  9. 从功能到测开,阿里巴巴软件测试面经大揭秘,看看大厂的技术栈
  10. 一、VS2015update2环境下DirectX11编程说明(2016.5.5更新)