CrossEntropyLoss(交叉熵损失)与NLLLoss(negative log likelihood,负对数似然损失)都是适用于分类问题,基于log似然损失,即交叉熵损失函数的实现方式,其体现了两个分布的近似程度:L(p,q)=−∑ipilog⁡qiL(p,q)=-\sum_ip_i\log q_iL(p,q)=−i∑​pi​logqi​对于分类问题,其只在真实所属类别kkk上的编码为1,其他均为0(即one-hot编码),所以该损失函数可进一步写为:L=−Cklog⁡qkL=-C_k\log q_kL=−Ck​logqk​

CrossEntropyLoss和NLLLoss的区别在于:
(1)CrossEntropyLoss对输出层结果自动进行softmax概率归一化运算以及对数处理,然后再基于L=−Cklog⁡qkL=-C_k\log q_kL=−Ck​logqk​进行计算;
(2)NLLLoss直接计算L=−Cklog⁡qkL=-C_k\log q_kL=−Ck​logqk​,因此为了保证计算的正确,必须接在logsoftmax函数后面。

也就是说,在pytorch中:CrossEntropyLoss()=logsoftmax()+NLLLoss()

下面以一个例子来验证该结论:

import torch
import torch.nn as nn
output = torch.rand(3, 5)   # 随机输出的3*5未归一化结果
target = torch.tensor([0, 1, 4])   # 对应3个样本的真实分类criterion1 = nn.NLLLoss()
criterion2 = nn.CrossEntropyLoss()# 方案一:直接采用CrossEntropyLoss()
print(criterion2(output, target))# 方案二:在output后接入logsoftmax(),再计算NLLLoss()
logsoftmax = nn.LogSoftmax(dim=1)
output_logsoftmax = logsoftmax(output)
print(criterion1(output_logsoftmax, target))# 方案三:手动计算交叉熵损失函数的平均值
target_onehot = torch.zeros_like(output)
for i,j in enumerate(target):target_onehot[i][j] = 1
crossentropy = - (target_onehot * output_logsoftmax).sum()/len(target)
print(crossentropy)

结果表明,上述三种方法计算的结果是完全相同的。

此外,还有一个注意值得的细节为:作为Target的分类编码要从0开始计起

【Pytorch】对比CrossEntropyLoss与NLLLoss相关推荐

  1. pytorch中CrossEntropyLoss和NLLLoss的区别与联系

    pytorch中CrossEntropyLoss和NLLLoss的区别与联系 CrossEntropyLoss和NLLLoss主要是用在多分类问题的损失函数,他们两个既有不同,也有不浅的联系.先分别看 ...

  2. BCELoss、crossentropyLoss、NLLLoss的使用(pytorch)

    文章目录 BCELoss 参考文档 理解 demo 应用 crossentropyLoss.NLLLoss 参考文档 crossEntropyLoss NLLLoss BCELoss 用于二分类问题, ...

  3. PyTorch学习笔记——softmax和log_softmax的区别、CrossEntropyLoss() 与 NLLLoss() 的区别、log似然代价函数...

    1.softmax 函数 Softmax(x) 也是一个 non-linearity, 但它的特殊之处在于它通常是网络中一次操作. 这是因为它接受了一个实数向量并返回一个概率分布.其定义如下. 定义 ...

  4. Pytorch损失函数torch.nn.NLLLoss()详解

    在各种深度学习框架中,我们最常用的损失函数就是交叉熵(torch.nn.CrossEntropyLoss),熵是用来描述一个系统的混乱程度,通过交叉熵我们就能够确定预测数据与真是数据之间的相近程度.交 ...

  5. pytorch几种损失函数CrossEntropyLoss、NLLLoss、BCELoss、BCEWithLogitsLoss、focal_loss、heatmap_loss

    分类问题常用的几种损失,记录下来备忘,后续不断完善. nn.CrossEntropyLoss()交叉熵损失 常用于多分类问题 CE = nn.CrossEntropyLoss() loss = CE( ...

  6. [深度学习] Pytorch nn.CrossEntropyLoss()和nn.NLLLoss() 区别

    nn.NLLLoss()的参数是经过logsoftmax加工的,而CrossEntropyLoss的是原始输出数据 target = torch.tensor([1, 2]) entropy_out ...

  7. Pytorch CrossEntropyLoss和NLLLoss

    NLLloss class torch.nn.NLLLoss(weight=None, size_average=True) 作用:训练一个n类的分类器 参数 weight:可选的,应该是一个tens ...

  8. pytorch nn.CrossEntropyLoss

    应用 概念讲解 1)假设有m张图片,经过神经网络后输出为m*n的矩阵(m是图片个数,n是图片类别),下例中: m=2,n=2既有两张图片,供区分两种类别比如猫狗.假设第0维为猫,第1维为狗 impor ...

  9. PyTorch nn.CrossEntropyLoss() dimension out of range (expected to be in range of [-1, 0], but got 1)

    import torch import torch.nn as nn loss_fn = nn.CrossEntropyLoss() # 方便理解,此处假设batch_size = 1 x_input ...

最新文章

  1. python菜鸟教程函数-Python 函数装饰器
  2. 云供应商安全评估:小心落入陷阱
  3. linux删除之前的文件在哪里,Linux下,如何将最后修改时间在某个时间之前的文件删除去?...
  4. VTK:颜色边缘用法实战
  5. 程序清单3-1 测试能否对标准输入设置偏移量
  6. python筛选csv数据_pandas数据筛选和csv操作的实现方法
  7. 文件磁盘相关函数[2]-建立新文件 FileCreate
  8. 读写分离MYSQL类
  9. 工业互联网是什么?发展有多厉害?
  10. draco3D轻量化技术在Unity3D中应用
  11. Java、JSP高速公路收费系统
  12. 计算机硬件调查和报价600字,600字调查报告.docx
  13. 多线程:synchronized关键字解析
  14. 多家汽车金融公司拿下融担牌照,“助贷+融担”模式成主流
  15. Nginx 安全漏洞
  16. 论文笔记:Mind the Gap An Experimental Evaluation of Imputation ofMissing Values Techniques in TimeSeries
  17. 前端踩坑(八)前端使用Moment 时间格式化错误
  18. 2020年第十一届蓝桥杯A组省赛
  19. python 实现描述性统计、频数分布图、正态分布检验、概率密度曲线拟合
  20. java文件太大 上传不了怎么办_上传文件(200M)过大失败,想提高成800M

热门文章

  1. 全球最好的外贸B2B平台有哪些
  2. Vue修炼系列教程 - 元婴篇2
  3. 我的2019年年终总结
  4. 身份证实名认证api
  5. HDFS 双缓冲技术核心源码剖析
  6. Kindle 特价书
  7. 小程序: 长按识别图中二维码
  8. echarts科技饼图
  9. Python可视化:matplotlib 绘制堆积柱状图绘制
  10. 【源码】渐开线齿轮的MATLAB几何计算程序gearsInMesh