博客已迁移至知乎 本文地址:https://zhuanlan.zhihu.com/p/70804197

前言

在处理分类问题的神经网络模型中,很多都使用交叉熵 (cross entropy) 做损失函数。
这篇文章详细地介绍了交叉熵的由来、为什么使用交叉熵,以及它解决了什么问题,最后介绍了交叉熵损失函数的应用场景。


要讲交叉熵就要从最基本的信息熵说起。

1.信息熵

信息熵是消除不确定性所需信息量的度量。(多看几遍这句话)

信息熵就是信息的不确定程度,信息熵越小,信息越确定。

信 息 熵 = ∑ x = 1 n ( 信 息 x 发 生 的 概 率 × 验 证 信 息 x 需 要 的 信 息 量 ) 信息熵 = \sum_{x=1}^{n}(信息x发生的概率 × 验证信息x需要的信息量) 信息熵=x=1∑n​(信息x发生的概率×验证信息x需要的信息量)

(因为事件都有个概率分布,这里我们只考虑离散分布)

举个列子,比如说:今年中国取消高考了,这句话我们很不确定(甚至心里还觉得这TM是扯淡),那我们就要去查证了,这样就需要很多信息量(去查证);反之如果说今年正常高考,大家回想:这很正常啊,不怎么需要查证,这样需要的信息量就很小。从这里我们可以学到:根据信息的真实分布,我们能够找到一个最优策略,以最小的代价消除系统的不确定性,即最小信息熵

简而言之,概率越低,需要越多的信息去验证,所以验证真假需要的信息量和概率成反比。我们需要用数学表达式把它描述出来,推导:

考虑一个离散的随机变量 x x x,已知信息的量度依赖于概率分布 p ( x ) p(x) p(x),因此我们想要寻找一个函数 I ( x ) I(x) I(x),它是概率 p ( x ) p(x) p(x)的单调函数,表示信息量
怎么寻找呢?如果我们有两个不相关的事件 x x x 和 y y y,那么观察两个事件同时发生时获得的信息量应该等于观察到事件各自发生时获得的信息之和,即:
I ( x , y ) = I ( x ) + I ( y ) I(x,y)=I(x)+I(y) I(x,y)=I(x)+I(y)

因为两个事件是独立不相关的,因此
p ( x , y ) = p ( x ) p ( y ) p(x,y)=p(x)p(y) p(x,y)=p(x)p(y)

根据这两个关系,很容易看出 I ( x ) I(x) I(x)一定与 P ( x ) P(x) P(x) 的对数有关。
由对数的运算法则可知:
l o g a ( p ( x ) p ( y ) ) = l o g a p ( x ) + l o g a p ( y ) log_a(p(x)p(y))=log_ap(x)+log_ap(y) loga​(p(x)p(y))=loga​p(x)+loga​p(y)

因此,我们有
I ( x ) = − l o g ( p ( x ) ) I(x)=−log(p(x)) I(x)=−log(p(x))

其中负号是用来保证信息量是正数或者零。而 l o g log log 函数基的选择是任意的(信息论中基常常选择为2,因此信息的单位为比特bits;而机器学习中基常常选择为自然常数,因此单位常常被称为奈特nats)。 I ( x ) I(x) I(x) 也被称为随机变量 x 的自信息 (self-information),描述的是随机变量的某个事件发生所带来的信息量

以上推导借鉴了这篇博客。

信息熵即所有信息量的期望:
H ( X ) = − ∑ x p ( x ) l o g ( p ( x ) ) = − ∑ i = 1 n p ( x i ) l o g ( p ( x i ) ) H(X)=−∑_xp(x)log(p(x))=−∑_{i=1}^np(x_i)log(p(x_i)) H(X)=−x∑​p(x)log(p(x))=−i=1∑n​p(xi​)log(p(xi​))

其中n为事件的所有可能性。


2.相对熵(KL散度)

相对熵又称KL散度,如果对于同一个随机变量 x x x有两个单独的概率分布 p ( x ) p(x) p(x)和 q ( x ) q(x) q(x),可以使用相对熵来衡量这两个分布的差异。
D K L ( p ∣ ∣ q ) = ∑ i = 1 n p ( x i ) l o g ( p ( x i ) q ( x i ) ) D_{KL}(p||q)=\sum_{i=1}^np(x_i)log(\frac{p(x_i)}{q(x_i)}) DKL​(p∣∣q)=i=1∑n​p(xi​)log(q(xi​)p(xi​)​)

注: D K L D_{KL} DKL​越小,表示p(x)和q(x)的分布越近。


3.交叉熵

交叉熵公式:
H ( p , q ) = − ∑ i = 1 n p ( x i ) l o g ( q ( x i ) ) H(p,q)=-\sum_{i=1}^np(x_i)log(q(x_i)) H(p,q)=−i=1∑n​p(xi​)log(q(xi​))

相对熵的推导:
D K L ( p ∣ ∣ q ) = ∑ i = 1 n p ( x i ) l o g ( p ( x i ) ) − ∑ i = 1 n p ( x i ) l o g ( q ( x i ) ) = − H ( X ) + [ − ∑ i = 1 n p ( x i ) l o g ( q ( x i ) ) ] = [ − ∑ i = 1 n p ( x i ) l o g ( q ( x i ) ) ] − H ( X ) \begin{array}{l} \quad D_{KL}(p||q) \\\\ = \sum_{i=1}^np(x_i)log(p(x_i))-\sum_{i=1}^np(x_i)log(q(x_i)) \\\\ = -H(X)+[-\sum_{i=1}^np(x_i)log(q(x_i))]\\\\ = [-\sum_{i=1}^np(x_i)log(q(x_i))]-H(X)\\ \end{array} DKL​(p∣∣q)=∑i=1n​p(xi​)log(p(xi​))−∑i=1n​p(xi​)log(q(xi​))=−H(X)+[−∑i=1n​p(xi​)log(q(xi​))]=[−∑i=1n​p(xi​)log(q(xi​))]−H(X)​

在机器学习中,往往用 p ( x ) p(x) p(x)用来描述真实分布, q ( x ) q(x) q(x)用来描述模型预测的分布

计算损失,理应使用相对熵来计算概率分布的差异,然而由相对熵推导出的结果看:

相 对 熵 = 交 叉 熵 − 信 息 熵 相对熵=交叉熵-信息熵 相对熵=交叉熵−信息熵

由于信息熵描述的是消除 p p p (即真实分布) 的不确定性所需信息量的度量,所以其值应该是最小的、固定的。那么:优化减小相对熵也就是优化交叉熵,所以在机器学习中使用交叉熵就可以了。


4.为什么使用交叉熵

在机器学习中,我们希望模型在训练数据上学到的预测数据分布真实数据分布越相近越好,上面讲过了,用相对熵,但是为了简便计算使用交叉熵就可以了。

注意:此处真实数据分布指的就是训练数据的分布(标注)。

交叉熵损失函数:

L = − [ y l o g y ^ + ( 1 − y ) l o g ( 1 − y ^ ) ] L=-[ylog\ \hat y+(1-y)log\ (1-\hat y)] L=−[ylog y^​+(1−y)log (1−y^​)]

交叉熵损失函数一般用来代替均方差损失函数与sigmoid激活函数组合。
sigmoid激活函数表达式:
σ ( z ) = 1 1 + e − z \sigma(z) = \frac{1}{1+e^{-z}} σ(z)=1+e−z1​

下面是sigmoid函数及其导数的图像:
[外链图片转存失败(img-zGeLUZMx-1565158338849)(https://ws1.sinaimg.cn/large/e3bfcf62ly1fy1heq9aroj20sh0gymy7.jpg =800x)]

从图中可以看出,对于sigmoid函数,当 x x x的取值越大或越小,函数曲线变得越平缓,意味着导数 σ ′ ( x ) σ′(x) σ′(x)越趋近于0。

以单个样本的一次梯度下降为例:

z = w x + b z= wx+b z=wx+b

y ^ = a = σ ( z ) \hat{y}= a =\sigma(z) y^​=a=σ(z)

L 1 ( y , a ) = 1 2 ( y − a ) 2 L_1(y,a)=\frac{1}{2}(y-a)^2 L1​(y,a)=21​(y−a)2

L 2 ( y , a ) = − ( y l o g ( a ) + ( 1 − y ) l o g ( 1 − a ) ) L_2(y,a)=-(ylog(a)+(1-y)log(1-a)) L2​(y,a)=−(ylog(a)+(1−y)log(1−a))

前两个公式公式分别是前向传播的线性和非线性部分,第三个公式公式是均方差损失函数,第四个公式是交叉熵损失函数。梯度下降的目的,直白地说:是减小真实值和预测值的距离,而损失函数用来度量真实值和预测值之间距离,所以梯度下降目的也就是减小损失函数的值。怎么减小损失函数的值呢?变量只有 w w w和 b b b,所以我们要做的就是不断修改 w w w和 b b b的值以使损失函数越来越小。(这里例子只有一步,只修改一次)

w w w和 b b b的更新: 参 数 = 参 数 − 学 习 率 × 损 失 函 数 对 参 数 的 偏 导 参数=参数-学习率×损失函数对参数的偏导 参数=参数−学习率×损失函数对参数的偏导:

w = w − α ∂ L ( y , a ) ∂ w w = w - \alpha \frac{\partial L(y,a)}{\partial w} w=w−α∂w∂L(y,a)​

b = b − α ∂ L ( y , a ) ∂ w b = b - \alpha \frac{\partial L(y,a)}{\partial w} b=b−α∂w∂L(y,a)​

其中 α \alpha α 表示学习率,用来控制步长,即向下走一步的长度

为什么要这样更新参数呢,讲完下面的关键点我们会解释一下。

关键点来了,为什么用交叉熵而不是均方差呢?

均方差对参数的偏导:

∂ L 1 ( y , a ) ∂ w = − ∣ y − σ ( z ) ∣ σ ′ ( z ) x \frac{\partial L_1(y,a)}{\partial w}=-|y-\sigma(z)|\sigma'(z)x ∂w∂L1​(y,a)​=−∣y−σ(z)∣σ′(z)x

∂ L 1 ( y , a ) ∂ b = − ∣ y − σ ( z ) ∣ σ ′ ( z ) \frac{\partial L_1(y,a)}{\partial b}=-|y-\sigma(z)|\sigma'(z) ∂b∂L1​(y,a)​=−∣y−σ(z)∣σ′(z)

交叉熵对参数的偏导:

∂ L 2 ( y , a ) ∂ w = x [ σ ( z ) − y ] \frac{\partial L_2(y,a)}{\partial w}=x[\sigma(z)-y] ∂w∂L2​(y,a)​=x[σ(z)−y]

∂ L 2 ( y , a ) ∂ w = σ ( z ) − y \frac{\partial L_2(y,a)}{\partial w}=\sigma(z)-y ∂w∂L2​(y,a)​=σ(z)−y

注:为了简洁,以上公式中用 z z z 代替了 w x + b wx+b wx+b

从以上公式可以看出:均方差对参数的偏导的结果都乘了sigmoid的导数 σ ′ ( z ) x \sigma'(z)x σ′(z)x,而之前看图发现sigmoid导数在其变量值很大或很小时趋近于0,所以偏导数很有可能接近于0。
由参数更新公式: 参 数 = 参 数 − 学 习 率 × 损 失 函 数 对 参 数 的 偏 导 参数=参数-学习率×损失函数对参数的偏导 参数=参数−学习率×损失函数对参数的偏导
可知,偏导很小时,参数更新速度会变得很慢,而当偏导接近于0时,参数几乎就不更新了。

反观交叉熵对参数的偏导就没有sigmoid导数,所以不存在这个问题。这就是选择交叉熵而不选择均方差的原因。


梯度下降的原理,为什么要这样更新参数

借用吴恩达深度学习课上的图:

在这个图中,横轴表示参数w和b,在实践中,w可以是更高的维度,但是为了更好地绘图,我们定义w和b都是单一实数,损失函数 J ( w , b ) J(w,b) J(w,b)是在水平轴和上的曲面,因此曲面的高度就是 J ( w , b ) J(w,b) J(w,b)在某一点的函数值。我们所做的就是找到使得损失函数 J ( w , b ) J(w,b) J(w,b)函数值为最小值时,对应的参数w和b。
两个参数不太好说明,我们把它简化成一个参数来讲,假设损失函数只有 w w w一个参数:

图画的丑,能说明意思就行,曲线是损失函数,参数w为横坐标,红色的点记录参数 w w w的每次更新(这里例子只有一步,只更新一次)。
损失函数对 w w w的偏导 ∂ L 1 ( y , a ) ∂ w \frac{\partial L_1(y,a)}{\partial w} ∂w∂L1​(y,a)​相当于曲线的斜率, w = w − α ∂ L ( y , a ) ∂ w w = w - \alpha \frac{\partial L(y,a)}{\partial w} w=w−α∂w∂L(y,a)​,会使红点像曲线下端移动,这样就减小了损失函数。多个参数也是同样的道理。


5.使用场景

下面是知乎上看到的一张图,图中写得很清楚了。

CSDN图片显示不正常,博客已迁移至知乎:https://zhuanlan.zhihu.com/p/70804197


References:

[1] 详解机器学习中的熵、条件熵、相对熵和交叉熵
[2] 吴恩达深度学习课程
[3] 知乎:为什么交叉熵(cross-entropy)可以用于计算代价?
[4] 使用ReLU作为激活函数还有必要用交叉熵计算损失函数吗?

为什么用交叉熵做损失函数相关推荐

  1. 【李沐】softmax回归-1.使用交叉熵作为损失函数而不是MSE均方误差-(意思就是为什么经过了softmax之后不用MSE) 2.softmax与sigmoid

    2. 交叉熵的来源 2.1 信息量 一条信息的信息量大小和它的不确定性有很大的关系.一句话如果需要很多外部信息才能确定,我们就称这句话的信息量比较大.比如你听到"云南西双版纳下雪了" ...

  2. 交叉熵代价函数(损失函数)及其求导推导

    转自:http://blog.csdn.net/jasonzzj/article/details/52017438 前言 交叉熵损失函数 交叉熵损失函数的求导 前言 说明:本文只讨论Logistic回 ...

  3. 交叉熵代价函数(损失函数)及其求导推导 (Logistic Regression)

    目录 1. 前言 2. 交叉熵损失函数 3. 交叉熵损失函数的求导 前言 说明:本文只讨论Logistic回归的交叉熵,对Softmax回归的交叉熵类似(Logistic回归和Softmax回归两者本 ...

  4. 交叉熵损失函数和softmax笔记

    文章目录 1. 交叉熵定义 2.交叉熵损失思考 3.交叉熵损失代码 4. softmax 回归的简洁实现 4.1 代码 4.2 结果 5. torch.nn.Softmax代码测试 5.1 说明 注: ...

  5. 信息量、熵、交叉熵、KL散度、JS散度、Wasserstein距离

    信息量.熵.交叉熵.KL散度.JS散度 文章目录 信息量.熵.交叉熵.KL散度.JS散度 前言 一.信息量 二.熵 三.交叉熵 四.KL散度 五.JS散度 六. Wasserstein距离 1.解决的 ...

  6. softmax回归与交叉熵损失

    前言 回归与分类是机器学习中的两个主要问题,二者有着紧密的联系,但又有所不同.在一个预测任务中,回归问题解决的是多少的问题,如房价预测问题,而分类问题用来解决是什么的问题,如猫狗分类问题.分类问题又以 ...

  7. 交叉熵三连(3)——交叉熵及其使用

    相关文章: 交叉熵三连(1)--信息熵 交叉熵三连(2)--KL散度(相对熵) 交叉熵三连(3)--交叉熵及其使用   在神经网络中,我们经常使用交叉熵做多分类问题和二分类的损失函数,在通过前面的两篇 ...

  8. 交叉熵以及通过Python实现softmax_交叉熵(tensorflow验证)

    文章目录 交叉熵(Cross Entropy) 信息论 相对熵 交叉熵 机器学习中的交叉熵 为什么要用交叉熵做损失函数? 分类问题中的交叉熵 softmax softmax_cross_entropy ...

  9. 熵,交叉熵与softmax

    经常说交叉熵+softmax(或者交叉熵损失函数),应该是对多分类下,softmax作为最后一层输出,交叉熵再判断这softmax的概率和真实值之间的差异(感觉不对,应该是用softmax的公式去求交 ...

  10. dice系数 交叉熵_语义分割中的损失函数

    1 交叉熵 信息量:当一个事件发生的概率为 ,那么该事件对应的概率的信息量是 . 信息量的熵:信息量的期望,假设 事件 共有n种可能,发生 的概率为 ,那么该事件的熵 为: 相对熵,又称KL散度,如果 ...

最新文章

  1. C语言优势大揭露,你还在等什么呢?
  2. 百越杯 Reverse (crazy write up)
  3. 美国三院院士「迈克尔•乔丹」长文论述:为什么说「人工智能革命」尚未发生...
  4. 训练softmax分类器实例_CS224N NLP with Deep Learning(四):Window分类器与神经网络
  5. java面试题7 牛客:关于AWT和Swing说法正确的是?
  6. tp5易支付完整版源码
  7. Swing超市收银系统附图
  8. VS启动调试速度异常的缓慢问题
  9. FUSE—用户空间文件系统
  10. asp Eval()函数的一些使用总结
  11. JavaScript DOM高级程序设计 5动态修改样式和层叠样式表1(源代码)--我要坚持到底!...
  12. 如何看待软件测试培训?
  13. t分布 u分布 卡方分布_三大抽样分布:卡方分布,t分布和F分布的简单理解
  14. treeTable树结构表格的使用
  15. 【unprofessional use Blog003】基因数据库NCBI相关介绍
  16. Flask Web——Jinjia2模板的使用
  17. 利用Python进行博客图片压缩
  18. 第十一章 卡米洛特的黑暗时代
  19. ECCV 2022 Oral | 无需微调即可泛化!RegAD:少样本异常检测新框架
  20. 基于facenet人脸识别设计文档

热门文章

  1. SpringMVC--入门
  2. 最近又重温了一下像素的刀剑,整理了一下宝石炼造系统
  3. 上网记录监控软件有哪些(三款好用的上网行为监控软件)
  4. OSChina 周六乱弹 ——周末和女友在家看猫片吧
  5. 毕设——图像视觉显著性目标检测(第五周到第七周工作总结)
  6. 反应扩散方程与图灵图(世间万物神秘的斑图)
  7. #DTOJ 5134 小h的几何
  8. MongoDB命令行操作
  9. 职场高效的秘诀,竟是一个“懒”字
  10. python鼠标移动事件是真的吗_JS mousemove事件:鼠标移动事件