在各种深度学习框架中,我们最常用的损失函数就是交叉熵(torch.nn.CrossEntropyLoss),熵是用来描述一个系统的混乱程度,通过交叉熵我们就能够确定预测数据与真是数据之间的相近程度。交叉熵越小,表示数据越接近真实样本。

交叉熵计算公式:

就是我们预测的概率的对数与标签的乘积,当qk->1的时候,它的损失接近零。

nn.NLLLoss
官方文档中介绍称: nn.NLLLoss输入是一个对数概率向量和一个目标标签,它与nn.CrossEntropyLoss的关系可以描述为:softmax(x)+log(x)+nn.NLLLoss====>nn.CrossEntropyLoss

CrossEntropyLoss()=log_softmax() + NLLLoss()

其中softmax函数又称为归一化指数函数,它可以把一个多维向量压缩在(0,1)之间,并且它们的和为1.

计算公式:

示例代码

import math
z = [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]
z_exp = [math.exp(i) for i in z]
print(z_exp)  # Result: [2.72, 7.39, 20.09, 54.6, 2.72, 7.39, 20.09]
sum_z_exp = sum(z_exp)
print(sum_z_exp)  # Result: 114.98
softmax = [round(i / sum_z_exp, 3) for i in z_exp]
print(softmax)  # Result: [0.024, 0.064, 0.175, 0.475, 0.024, 0.064, 0.175]

log_softmax
log_softmax是指在softmax函数的基础上,再进行一次log运算,此时结果有正有负,log函数的值域是负无穷到正无穷,当x在0—1之间的时候,log(x)值在负无穷到0之间。

nn.NLLLoss
此时,nn.NLLLoss的结果就是把上面的输出与Label对应的那个值拿出来,再去掉负号,再求均值。
代码示例:

import torch
input=torch.randn(3,3)
soft_input = torch.nn.Softmax(dim=0)
soft_input(input)
Out[20]:
tensor([[0.7284, 0.7364, 0.3343],[0.1565, 0.0365, 0.0408],[0.1150, 0.2270, 0.6250]])#对softmax结果取log
torch.log(soft_input(input))
Out[21]:
tensor([[-0.3168, -0.3059, -1.0958],[-1.8546, -3.3093, -3.1995],[-2.1625, -1.4827, -0.4701]])

假设标签是[0,1,2],第一行取第0个元素,第二行取第1个,第三行取第2个,去掉负号,即[0.3168,3.3093,0.4701],求平均值,就可以得到损失值。

(0.3168+3.3093+0.4701)/3
Out[22]: 1.3654000000000002#验证一下loss=torch.nn.NLLLoss()
target=torch.tensor([0,1,2])
loss(input,target)
Out[26]: tensor(0.1365)

nn.CrossEntropyLoss

loss=torch.nn.NLLLoss()
target=torch.tensor([0,1,2])
loss(input,target)
Out[26]: tensor(-0.1399)
loss =torch.nn.CrossEntropyLoss()
input = torch.tensor([[ 1.1879,  1.0780,  0.5312],[-0.3499, -1.9253, -1.5725],[-0.6578, -0.0987,  1.1570]])
target = torch.tensor([0,1,2])
loss(input,target)
Out[30]: tensor(0.1365)以上为全部实验验证两个loss函数之间的关系!!!

Pytorch损失函数torch.nn.NLLLoss()详解相关推荐

  1. PyTorch中的torch.nn.Parameter() 详解

    PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实 ...

  2. 【Pytorch】torch.argmax 函数详解

    文章目录 一.一个参数时的 torch.argmax 函数 1. 介绍 2. 实例 二.多个参数时的 torch.argmax 函数 1. 介绍 2. 实例 实例1:二维矩阵 实例2:三维矩阵 实例3 ...

  3. torch.nn.Linear详解

    在学习transformer时,遇到过非常频繁的nn.Linear()函数,这里对nn.Linear进行一个详解. 参考:https://pytorch.org/docs/stable/_module ...

  4. torch.nn.MaxPool2d详解

    注意:这里展示的是本篇博文写时的版本最新的实现,但是后续会代码可能会迭代更新,建议对照官方文档进行学习. 先来看源码: # 这个类是是许多池化类的基类,这里有必要了解一下 class _MaxPool ...

  5. torch.nn.parameter详解

    :-- 目录: 参考: 1.parameter基本解释: 2.参数requires_grad的深入理解: 2.1 Parameter级别的requires_grad 2.2Module级别的requi ...

  6. PyTorch中torch.norm函数详解

    torch.norm() 是 PyTorch 中的一个函数,用于计算输入张量沿指定维度的范数.具体而言,当给定一个输入张量 x 和一个整数 p 时,torch.norm(x, p) 将返回输入张量 x ...

  7. 深入理解Pytorch负对数似然函数(torch.nn.NLLLoss)和交叉熵损失函数(torch.nn.CrossEntropyLoss)

    在看Pytorch的交叉熵损失函数torch.nn.CrossEntropyLoss官方文档介绍中,给出的表达式如下.不免有点疑惑为何交叉熵损失的表达式是这个样子的 loss ⁡ ( y , clas ...

  8. python如何画损失函数图_Pytorch 的损失函数Loss function使用详解

    1.损失函数 损失函数,又叫目标函数,是编译一个神经网络模型必须的两个要素之一.另一个必不可少的要素是优化器. 损失函数是指用于计算标签值和预测值之间差异的函数,在机器学习过程中,有多种损失函数可供选 ...

  9. Pytorch|YOWO原理及代码详解(二)

    Pytorch|YOWO原理及代码详解(二) 本博客上接,Pytorch|YOWO原理及代码详解(一),阅前可看. 1.正式训练 if opt.evaluate:logging('evaluating ...

最新文章

  1. mysql db.opt+ (frm,MYD,MYI)备份与还原数据库
  2. SpringBoot shedlock MongoDb锁配置
  3. Cactoos中的面向对象的声明式输入/输出
  4. 深度学习(三十二)半监督阶梯网络学习笔记
  5. 使用DbVisualizer导出DB2创建序列SQL
  6. 英语secuerity证券
  7. N-Queen Problem
  8. but only one is allowed. 重复处理跨域请求
  9. A Beginner‘s Guide To Understanding Convolutional Neural Networks(part 1)
  10. live2d看板娘一览图
  11. 一个疯子的DK马历程(易中天说:悲剧啊)
  12. 苏州企业拿到商标注册证后,需要注意哪些事项?
  13. 记一个转行程序员的工作经历与感想(一)
  14. CS224N WINTER 2022(一)词向量(附Assignment1答案)
  15. adb 静默安装_Android ROOT下静默安装并打开APP
  16. 统计学“诺贝尔”奖——考普斯总统奖(COPSS Presidents' Award)
  17. 全相位FFT算法matlab的实
  18. 降噪效果好的蓝牙耳机有哪些?降噪耳机降噪效果排名
  19. MT7688学习笔记(17)——OpenWRT与电脑之间SCP文件传输
  20. 【QT 5 学习笔记-学习绘图相关+画图形图片等+绘图设备+基础学习(2)】

热门文章

  1. Java DTO(data transfer object)的理解,为什么要用DTO
  2. MIPI-CPHY、DPHY和MPHY基本介绍
  3. 微信小程序云开发——有数据却拿不到数据
  4. 从开发小白到入职抖音音视频开发岗位技术总结
  5. win10/win11安装qt4.8
  6. 腾讯T1~T9工程师技术剖析以及评定标准、能力要求
  7. Android LayoutInflater原理分析,带你一步步深入了解View
  8. java爬虫实时采集小说+springboot推荐算法+实现在线小说免费阅读推荐系统
  9. 【VOLTE】【ESRVCC】【4】eSRVCC
  10. 治理概念兴起 可望带动内容管理软件市场