KL散度

KL散度,又叫相对熵,用于衡量两个分布(离散分布和连续分布)之间的距离。

设p(x)p(x)p(x) 、q(x)q(x)q(x) 是离散随机变量XXX的两个概率分布,则ppp 对qqq 的KL散度是:

DKL(p∥q)=Ep(x)log⁡p(x)q(x)=∑i=1Np(xi)⋅(log⁡p(xi)−log⁡q(xi))D_{K L}(p \| q)=E_{p(x)} \log \frac{p(x)}{q(x)}=\sum_{i=1}^{N} p\left(x_{i}\right) \cdot\left(\log p\left(x_{i}\right)-\log q\left(x_{i}\right)\right)DKL​(p∥q)=Ep(x)​logq(x)p(x)​=i=1∑N​p(xi​)⋅(logp(xi​)−logq(xi​))

KLDivLoss

对于包含NNN个样本的batch数据 D(x,y)D(x, y)D(x,y),xxx是神经网络的输出,并且进行了归一化和对数化;yyy是真实的标签(默认为概率),xxx与yyy同维度。

第nnn个样本的损失值lnl_{n}ln​计算如下:

ln=yn⋅(log⁡yn−xn)l_{n}=y_{n} \cdot\left(\log y_{n}-x_{n}\right)ln​=yn​⋅(logyn​−xn​)

class KLDivLoss(_Loss):__constants__ = ['reduction']def __init__(self, size_average=None, reduce=None, reduction='mean'):super(KLDivLoss, self).__init__(size_average, reduce, reduction)def forward(self, input, target):return F.kl_div(input, target, reduction=self.reduction)

pytorch中通过torch.nn.KLDivLoss类实现,也可以直接调用F.kl_div 函数,代码中的size_averagereduce已经弃用。reduction有四种取值mean,batchmean, sum, none,对应不同的返回ℓ(x,y)\ell(x, y)ℓ(x,y)。 默认为mean

L={l1,…,lN}L=\left\{l_{1}, \ldots, l_{N}\right\}L={l1​,…,lN​}

ℓ(x,y)={L⁡,if reduction =’none’ mean⁡(L),if reduction =’mean’ N∗mean⁡(L),if reduction =’batchmean’ sum⁡(L),if reduction =’sum’ \ell(x, y)=\left\{\begin{array}{ll}\operatorname L, & \text { if reduction }=\text { 'none' } \\ \operatorname{mean}(L), & \text { if reduction }=\text { 'mean' } \\ N*\operatorname {mean}(L), & \text { if reduction }=\text { 'batchmean' } \\ \operatorname{sum}(L), & \text { if reduction }=\text { 'sum' }\end{array} \right.ℓ(x,y)=⎩⎪⎪⎨⎪⎪⎧​L,mean(L),N∗mean(L),sum(L),​ if reduction = ’none’  if reduction = ’mean’  if reduction = ’batchmean’  if reduction = ’sum’ ​

例子:

import torch
import torch.nn as nn
import mathdef validate_loss(output, target):val = 0for li_x, li_y in zip(output, target):for i, xy in enumerate(zip(li_x, li_y)):x, y = xyloss_val = y * (math.log(y, math.e) - x)val += loss_valreturn val / output.nelement()torch.manual_seed(20)
loss = nn.KLDivLoss()
input = torch.Tensor([[-2, -6, -8], [-7, -1, -2], [-1, -9, -2.3], [-1.9, -2.8, -5.4]])
target = torch.Tensor([[0.8, 0.1, 0.1], [0.1, 0.7, 0.2], [0.5, 0.2, 0.3], [0.4, 0.3, 0.3]])
output = loss(input, target)
print("default loss:", output)output = validate_loss(input, target)
print("validate loss:", output)loss = nn.KLDivLoss(reduction="batchmean")
output = loss(input, target)
print("batchmean loss:", output)loss = nn.KLDivLoss(reduction="mean")
output = loss(input, target)
print("mean loss:", output)loss = nn.KLDivLoss(reduction="none")
output = loss(input, target)
print("none loss:", output)

输出:

default loss: tensor(0.6209)
validate loss: tensor(0.6209)
batchmean loss: tensor(1.8626)
mean loss: tensor(0.6209)
none loss: tensor([[1.4215, 0.3697, 0.5697],[0.4697, 0.4503, 0.0781],[0.1534, 1.4781, 0.3288],[0.3935, 0.4788, 1.2588]])

loss函数之KLDivLoss相关推荐

  1. 《深度学习笔记》——loss函数的学习笔记

    1 loss的作用 在南溪看来,loss函数是对目标target和预测prediction之间的一种距离度量的公式: 2 loss函数的设计原则 此设计原则参考了距离的定义,(注意:距离跟范数是两个概 ...

  2. 【Dual-Path-RNN-Pytorch源码分析】loss函数:SI-SNR

    DPRNN使用的loss函数是 SI-SNR SI-SNR 是scale-invariant source-to-noise ratio的缩写,中文翻译为尺度不变的信噪比,意思是不受信号变化影响的信噪 ...

  3. tensorflow学习(4.loss函数以及正则化的使用 )

    本文还是以MNIST的CNN分析为例 loss函数一般有MSE均方差函数.交叉熵损失函数,说明见 https://blog.csdn.net/John_xyz/article/details/6121 ...

  4. 多分类loss函数本质理解

    一.面对一个多分类问题,如何设计合理的损失函数呢? 1.损失函数的本质在数学上称为目标函数:这个目标函数的目标值符合最完美的需求:损失函数的目标值肯定是0,完美分类的损失必然为0 : 2.损失函数分为 ...

  5. 深度学习基础(三)loss函数

    loss函数,即损失函数,是决定网络学习质量的关键.若网络结构不变的前提下,损失函数选择不当会导致模型精度差等后果.若有错误,敬请指正,Thank you! 目录 一.loss函数定义 二.常见的lo ...

  6. Keras自定义Loss函数

    Keras作为一个深度学习库,非常适合新手.在做神经网络时,它自带了许多常用的目标函数,优化方法等等,基本能满足新手学习时的一些需求.具体包含目标函数和优化方法.但它也支持用户自定义目标函数,下边介绍 ...

  7. 商汤使用AutoML设计Loss函数,全面超越人工设计

    点击我爱计算机视觉标星,更快获取CVML新技术 深度学习领域,神经架构搜索得到的算法如雨后春笋般出现. 今天一篇arXiv论文<AM-LFS: AutoML for Loss Function ...

  8. 深度学习中的损失函数总结以及Center Loss函数笔记

    北京 上海巡回站 | NVIDIA DLI深度学习培训 2018年1月26/1月12日 NVIDIA 深度学习学院 带你快速进入火热的DL领域 阅读全文                        ...

  9. 'int' object has no attribute 'backward'报错 使用Pytorch编写 Hinge loss函数

    在编写SVM中的Hinge loss函数的时候报错"'int' object has no attribute 'backward'" for epoch in range(50) ...

最新文章

  1. 《研磨设计模式》chap22 装饰模式Decorator(1)模式简介
  2. 学长毕业日记 :本科毕业论文写成博士论文的神操作20170401
  3. [coco2d]pageView:addPage时,page无法对齐
  4. python随机生成数字列表_详解Python利用random生成一个列表内的随机数
  5. JavaScript实现接口的三种经典方式
  6. 十家全国学会就IEEE“审稿门”事件发表联合声明
  7. 蓝桥杯 ADV-169 算法提高 士兵排队问题
  8. 注册事件的两种方式(传统注册事件、方法监听注册事件)
  9. ArcGIS API 离线字体库加载及跨域问题解决
  10. java 视频合并_java 实现分段视频合并
  11. 融合正弦余弦和变异选择的蝗虫优化算法
  12. VMware.exe应用程序错误--应用程序无法正常启动(0xc000007b)错误解决方法
  13. 爬虫小练习:堆糖图片抓取--爬虫正式学习day1
  14. 第四届中国软件开源创新大赛通知
  15. “舒淇半停工原因”上热搜:人生下半场,拼的是健康
  16. 初中女生数学不好能学计算机,初中女生必看:学好数学的方法及窍门
  17. linux下 不显示光驱,Windows7电脑下不显示光驱盘符的解决方法
  18. 未明学院:量化金融训练营开始报名,成为兼具数据分析技能+项目实战经验的复合型人才!
  19. 重带电粒子的能量歧离(energy straggling)
  20. 6月22日!苹果WWDC大会,全球免费参加-首次在线举行!

热门文章

  1. 如何使用Navicat MySQL导入.sql文件
  2. aix创建oracle表空间,Oracle for AIX基于裸设备的表空间扩充步聚
  3. html5 PHP 分片上传,H5分片上传含前端JS和后端处理(thinkphp)
  4. 计算机应用技术专业考试试题,全国专业技术人员计算机应用能力考试模拟试题笔试题.docx...
  5. 新手快速入门自动化测试第一步
  6. oracle巡检项,Oracle数据库巡检参考项
  7. 学生用计算机记录表,计算机教室学生上机记录表第14周
  8. Python自动化办公 | 如何实现报表自动化?
  9. Jmeter数据库及接口测试
  10. linux 自动挂载usb设备,Raspberry Pi 自动挂载USB存储设备