loss函数之KLDivLoss
KL散度
KL散度,又叫相对熵,用于衡量两个分布(离散分布和连续分布)之间的距离。
设p(x)p(x)p(x) 、q(x)q(x)q(x) 是离散随机变量XXX的两个概率分布,则ppp 对qqq 的KL散度是:
DKL(p∥q)=Ep(x)logp(x)q(x)=∑i=1Np(xi)⋅(logp(xi)−logq(xi))D_{K L}(p \| q)=E_{p(x)} \log \frac{p(x)}{q(x)}=\sum_{i=1}^{N} p\left(x_{i}\right) \cdot\left(\log p\left(x_{i}\right)-\log q\left(x_{i}\right)\right)DKL(p∥q)=Ep(x)logq(x)p(x)=i=1∑Np(xi)⋅(logp(xi)−logq(xi))
KLDivLoss
对于包含NNN个样本的batch数据 D(x,y)D(x, y)D(x,y),xxx是神经网络的输出,并且进行了归一化和对数化;yyy是真实的标签(默认为概率),xxx与yyy同维度。
第nnn个样本的损失值lnl_{n}ln计算如下:
ln=yn⋅(logyn−xn)l_{n}=y_{n} \cdot\left(\log y_{n}-x_{n}\right)ln=yn⋅(logyn−xn)
class KLDivLoss(_Loss):__constants__ = ['reduction']def __init__(self, size_average=None, reduce=None, reduction='mean'):super(KLDivLoss, self).__init__(size_average, reduce, reduction)def forward(self, input, target):return F.kl_div(input, target, reduction=self.reduction)
pytorch中通过torch.nn.KLDivLoss
类实现,也可以直接调用F.kl_div
函数,代码中的size_average
与reduce
已经弃用。reduction有四种取值mean
,batchmean
, sum
, none
,对应不同的返回ℓ(x,y)\ell(x, y)ℓ(x,y)。 默认为mean
L={l1,…,lN}L=\left\{l_{1}, \ldots, l_{N}\right\}L={l1,…,lN}
ℓ(x,y)={L,if reduction =’none’ mean(L),if reduction =’mean’ N∗mean(L),if reduction =’batchmean’ sum(L),if reduction =’sum’ \ell(x, y)=\left\{\begin{array}{ll}\operatorname L, & \text { if reduction }=\text { 'none' } \\ \operatorname{mean}(L), & \text { if reduction }=\text { 'mean' } \\ N*\operatorname {mean}(L), & \text { if reduction }=\text { 'batchmean' } \\ \operatorname{sum}(L), & \text { if reduction }=\text { 'sum' }\end{array} \right.ℓ(x,y)=⎩⎪⎪⎨⎪⎪⎧L,mean(L),N∗mean(L),sum(L), if reduction = ’none’ if reduction = ’mean’ if reduction = ’batchmean’ if reduction = ’sum’
例子:
import torch
import torch.nn as nn
import mathdef validate_loss(output, target):val = 0for li_x, li_y in zip(output, target):for i, xy in enumerate(zip(li_x, li_y)):x, y = xyloss_val = y * (math.log(y, math.e) - x)val += loss_valreturn val / output.nelement()torch.manual_seed(20)
loss = nn.KLDivLoss()
input = torch.Tensor([[-2, -6, -8], [-7, -1, -2], [-1, -9, -2.3], [-1.9, -2.8, -5.4]])
target = torch.Tensor([[0.8, 0.1, 0.1], [0.1, 0.7, 0.2], [0.5, 0.2, 0.3], [0.4, 0.3, 0.3]])
output = loss(input, target)
print("default loss:", output)output = validate_loss(input, target)
print("validate loss:", output)loss = nn.KLDivLoss(reduction="batchmean")
output = loss(input, target)
print("batchmean loss:", output)loss = nn.KLDivLoss(reduction="mean")
output = loss(input, target)
print("mean loss:", output)loss = nn.KLDivLoss(reduction="none")
output = loss(input, target)
print("none loss:", output)
输出:
default loss: tensor(0.6209)
validate loss: tensor(0.6209)
batchmean loss: tensor(1.8626)
mean loss: tensor(0.6209)
none loss: tensor([[1.4215, 0.3697, 0.5697],[0.4697, 0.4503, 0.0781],[0.1534, 1.4781, 0.3288],[0.3935, 0.4788, 1.2588]])
loss函数之KLDivLoss相关推荐
- 《深度学习笔记》——loss函数的学习笔记
1 loss的作用 在南溪看来,loss函数是对目标target和预测prediction之间的一种距离度量的公式: 2 loss函数的设计原则 此设计原则参考了距离的定义,(注意:距离跟范数是两个概 ...
- 【Dual-Path-RNN-Pytorch源码分析】loss函数:SI-SNR
DPRNN使用的loss函数是 SI-SNR SI-SNR 是scale-invariant source-to-noise ratio的缩写,中文翻译为尺度不变的信噪比,意思是不受信号变化影响的信噪 ...
- tensorflow学习(4.loss函数以及正则化的使用 )
本文还是以MNIST的CNN分析为例 loss函数一般有MSE均方差函数.交叉熵损失函数,说明见 https://blog.csdn.net/John_xyz/article/details/6121 ...
- 多分类loss函数本质理解
一.面对一个多分类问题,如何设计合理的损失函数呢? 1.损失函数的本质在数学上称为目标函数:这个目标函数的目标值符合最完美的需求:损失函数的目标值肯定是0,完美分类的损失必然为0 : 2.损失函数分为 ...
- 深度学习基础(三)loss函数
loss函数,即损失函数,是决定网络学习质量的关键.若网络结构不变的前提下,损失函数选择不当会导致模型精度差等后果.若有错误,敬请指正,Thank you! 目录 一.loss函数定义 二.常见的lo ...
- Keras自定义Loss函数
Keras作为一个深度学习库,非常适合新手.在做神经网络时,它自带了许多常用的目标函数,优化方法等等,基本能满足新手学习时的一些需求.具体包含目标函数和优化方法.但它也支持用户自定义目标函数,下边介绍 ...
- 商汤使用AutoML设计Loss函数,全面超越人工设计
点击我爱计算机视觉标星,更快获取CVML新技术 深度学习领域,神经架构搜索得到的算法如雨后春笋般出现. 今天一篇arXiv论文<AM-LFS: AutoML for Loss Function ...
- 深度学习中的损失函数总结以及Center Loss函数笔记
北京 上海巡回站 | NVIDIA DLI深度学习培训 2018年1月26/1月12日 NVIDIA 深度学习学院 带你快速进入火热的DL领域 阅读全文 ...
- 'int' object has no attribute 'backward'报错 使用Pytorch编写 Hinge loss函数
在编写SVM中的Hinge loss函数的时候报错"'int' object has no attribute 'backward'" for epoch in range(50) ...
最新文章
- 《研磨设计模式》chap22 装饰模式Decorator(1)模式简介
- 学长毕业日记 :本科毕业论文写成博士论文的神操作20170401
- [coco2d]pageView:addPage时,page无法对齐
- python随机生成数字列表_详解Python利用random生成一个列表内的随机数
- JavaScript实现接口的三种经典方式
- 十家全国学会就IEEE“审稿门”事件发表联合声明
- 蓝桥杯 ADV-169 算法提高 士兵排队问题
- 注册事件的两种方式(传统注册事件、方法监听注册事件)
- ArcGIS API 离线字体库加载及跨域问题解决
- java 视频合并_java 实现分段视频合并
- 融合正弦余弦和变异选择的蝗虫优化算法
- VMware.exe应用程序错误--应用程序无法正常启动(0xc000007b)错误解决方法
- 爬虫小练习:堆糖图片抓取--爬虫正式学习day1
- 第四届中国软件开源创新大赛通知
- “舒淇半停工原因”上热搜:人生下半场,拼的是健康
- 初中女生数学不好能学计算机,初中女生必看:学好数学的方法及窍门
- linux下 不显示光驱,Windows7电脑下不显示光驱盘符的解决方法
- 未明学院:量化金融训练营开始报名,成为兼具数据分析技能+项目实战经验的复合型人才!
- 重带电粒子的能量歧离(energy straggling)
- 6月22日!苹果WWDC大会,全球免费参加-首次在线举行!
热门文章
- 如何使用Navicat MySQL导入.sql文件
- aix创建oracle表空间,Oracle for AIX基于裸设备的表空间扩充步聚
- html5 PHP 分片上传,H5分片上传含前端JS和后端处理(thinkphp)
- 计算机应用技术专业考试试题,全国专业技术人员计算机应用能力考试模拟试题笔试题.docx...
- 新手快速入门自动化测试第一步
- oracle巡检项,Oracle数据库巡检参考项
- 学生用计算机记录表,计算机教室学生上机记录表第14周
- Python自动化办公 | 如何实现报表自动化?
- Jmeter数据库及接口测试
- linux 自动挂载usb设备,Raspberry Pi 自动挂载USB存储设备