1. nn.MSELoss()

模型的预测值与标签的L2距离。一般用于回归问题。之所以不用于分类问题,可能原因为:使用sigmoid之后,函数形式不是凸函数 ,不容易求解 ,容易进入局部最优。

loss = nn.MSELoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output = loss(input, target)
output.backward()

2. nn.BCELoss()

交叉熵损失函数,衡量两个分布之间的差异,一般用于分类问题。输入的x值需先经过sigmoid压缩到(0,1)之间。标签形式为[0, 0, 1], [0, 1, 1]等,各个类别预测概率独立,类与类之间不互斥,可见不仅能用于二分类问题,也能用于多标签分类问题。

m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)
output.backward()

3. nn.BCEWithLogitsLoss()

交叉熵损失函数, 与nn.BCELoss()不同的是网络的输出无需用sigmoid压缩,函数内部整合了nn.sigmoid()和nn.BCELoss(),并且使用log-sum-exp trick提高了数值稳定性。同样可用于二分类及多标签分类。
这里简单介绍一下log-sum-exp trick:
原始的log-sum-exp公式为:

y = l o g ∑ i e x i y = log\sum_{i}^{}e^{x_{i}} y=logi∑​exi​
在 x i x_{i} xi​都很小时, ∑ i e x i \sum_{i}^{}e^{x_{i}} ∑i​exi​趋近于0,导致数值计算问题。
而如果 x i x_{i} xi​很大, ∑ i e x i \sum_{i}^{}e^{x_{i}} ∑i​exi​也会很大,同样会导致数值计算问题。
而log-sum-exp trick的计算公式为:
a = m a x ( x i ) a = max(x_{i}) a=max(xi​)

y = a + l o g ∑ i e x i − a y = a + log\sum_{i}^{}e^{x_{i}-a} y=a+logi∑​exi​−a
其中 e x i − a ≤ 1 e^{x_{i}-a}\leq 1 exi​−a≤1,使 ∑ i e x i \sum_{i}^{}e^{x_{i}} ∑i​exi​既不会趋于0也不会很大,避免数值溢出。

loss = nn.BCEWithLogitsLoss()
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(input, target)
output.backward()

4. nn.LogSoftmax()

将输入softmax后取对数

m = nn.LogSoftmax()
input = torch.randn(2, 3)
output = m(input)

5. nn.NLLLoss()

NLLLoss需要配合nn.LogSoftmax()使用,网络输出的值经过nn.LogSoftmax()后传入NLLLoss()。其标签实际上就是网络输出里对应数值的索引,通过这个索引取得对应的值并去掉负号。如果是一批次的话,还需要求均值。由于网络输出需先经过softmax,各个类别预测概率之和为1,所以类与类之间互斥,用于多类别问题。

loss = nn.NLLLoss()
m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
# input is of size N x C = 3 x 5
input = torch.randn(3, 5, requires_grad=True)
# each element in target has to have 0 <= value < C
target = torch.tensor([1, 0, 4])
output = loss(m(input), target)
output.backward()
# 2D loss example (used, for example, with image inputs)
N, C = 5, 4
loss = nn.NLLLoss()
# input is of size N x C x height x width
data = torch.randn(N, 16, 10, 10)
conv = nn.Conv2d(16, C, (3, 3))
m = nn.LogSoftmax(dim=1)
# each element in target has to have 0 <= value < C
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
output = loss(m(conv(data)), target)
output.backward()

6. nn.CrossEntropyLoss()

nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合,适用于多类别问题。

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()

分类问题思维导图:

torch.nn里的损失函数:MSE、BCE、BCEWithLogits、NLLLoss、CrossEntropyLoss的用法相关推荐

  1. pytorch torch.nn.MSELoss(size_average=True)(均方误差【损失函数】)Mean Squared Error(MSE)、SSE(和方差)

    class torch.nn.MSELoss(size_average=True)[source] 创建一个衡量输入x(模型预测输出)和目标y之间均方误差标准. x 和 y 可以是任意形状,每个包含n ...

  2. 【MSE/BCE/CE】均方差、交叉熵损失函数理解

    文章目录 1 均方误差(Mean Squared Error, MSE) 1.1 MSE介绍 1.2 MSE为何不常用于分类 1.3 那什么常用于分类呢? 2 二值交叉熵损失(Binary Cross ...

  3. PyTorch里面的torch.nn.Parameter()

    在刷官方Tutorial的时候发现了一个用法self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)),看了官方教程里面的解释也是云里雾里, ...

  4. 深入理解Pytorch负对数似然函数(torch.nn.NLLLoss)和交叉熵损失函数(torch.nn.CrossEntropyLoss)

    在看Pytorch的交叉熵损失函数torch.nn.CrossEntropyLoss官方文档介绍中,给出的表达式如下.不免有点疑惑为何交叉熵损失的表达式是这个样子的 loss ⁡ ( y , clas ...

  5. Pytorch损失函数torch.nn.NLLLoss()详解

    在各种深度学习框架中,我们最常用的损失函数就是交叉熵(torch.nn.CrossEntropyLoss),熵是用来描述一个系统的混乱程度,通过交叉熵我们就能够确定预测数据与真是数据之间的相近程度.交 ...

  6. PyTorch中的torch.nn.Parameter() 详解

    PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实 ...

  7. pytorch分布式训练(二):torch.nn.parallel.DistributedDataParallel

      之前介绍了Pytorch的DataParallel方法来构建分布式训练模型,这种方法最简单但是并行加速效果很有限,并且只适用于单节点多gpu的硬件拓扑结构.除此之外Pytorch还提供了Distr ...

  8. pytorch使用torch.nn.Sequential构建网络

    以一个线性回归的例子为例: 全部代码 import torch import numpy as npdef get_x_y():x = np.random.randint(0, 50, 300)y_v ...

  9. Pytorch的自定义拓展:torch.nn.Module和torch.autograd.Function

    参考链接:pytorch的自定义拓展之(一)--torch.nn.Module和torch.autograd.Function_LoveMIss-Y的博客-CSDN博客_pytorch自定义backw ...

最新文章

  1. TypeScript 泛型
  2. 阿里云 centos 远程可视化桌面部署
  3. opencv视频模糊处理
  4. [JavaWeb-MySQL]约束(非空约束,唯一约束,主键约束,外键约束_级联操作)
  5. STM32之串口例程
  6. 1039 字符消除 java_Hihocoder 1039 字符消除
  7. Thrift之代码生成器Compiler原理及源码详细解析2
  8. 微软正式提供免费杀毒软件下载 仅限7.5万份
  9. 树算法系列之四:XGBoost
  10. 特种作业人员题库及答案
  11. linux下mysql命令大全_linux下mysql命令大全
  12. linux如何设置显示器亮度调节软件,为 Linux 启用色温和亮度调节工具
  13. python 图像检索系统_python-计算机视觉 - 图像检索
  14. VR球类游戏填坑总结
  15. Linux嵌入式开发入门(一)——初探嵌入式开发板的基本使用
  16. 计算机网络中tdm是什么,8.1 计算机网络FDM TDM计算机网络自学笔记.pdf
  17. 架构师之路---面向过程和面向对象 王泽宾
  18. 微软打印机驱动服务器,无法通过 Windows Server 中的 Windows 更新来安装打印机驱动程序 - Windows Server | Microsoft Docs...
  19. 新能源汽车补贴监管升级,“信息/网络安全+OTA”勒紧“紧箍咒”
  20. 数据库SQL语句 快速入门(一)

热门文章

  1. 【转载】鼻炎的中医治疗
  2. 在商场,自助收银比人工收银更欢迎?是为什么呢?
  3. “算法有偏见,比人强就行?”其实影响很广泛!
  4. 实例讲解用.NET技术将Excel表格中的数据导入到特定的SQL Server数据库中
  5. 留住员工的七个“秘诀”(zt)
  6. 计算机主机的认识500字,电脑的说明文500字
  7. Mac iOS Mac Watch 应用和游戏编程开发工具推荐
  8. Linux 应用编程之strerror函数
  9. 百度地图实现只展示某一个省份地图,点击市以后高亮
  10. NLPCamp-SpellCorrection