直接计算CrossEntropy

import torch
import torch.nn.functional as F

先按照流程手动计算CrossEntropy

class_dim = 3
z = torch.Tensor([[3, 1, -3]])
z

tensor([[ 3., 1., -3.]])


softmax过程,图片来自这里

y = torch.nn.Softmax(dim=1)(z)
y

tensor([[0.8789, 0.1189, 0.0022]])

注意:交叉熵在信息论中log是以2为底,在pytorch中log是以e为底

计算CrossEntropy过程,图片来自这里

y_log = torch.log(y)
y_log

tensor([[-0.1291, -2.1291, -6.1291]])

y_hat = torch.tensor([1], dtype=int)
y_hat

tensor([1])

y_1_hot = torch.nn.functional.one_hot(y_hat, class_dim)
y_1_hot

tensor([[0, 1, 0]])

l = torch.tensor(0, dtype=torch.float32)
for y_log_, y_1_hot_ in zip(y_log, y_1_hot):l += torch.dot(-y_log_.to(torch.float32), y_1_hot_.to(torch.float32))
l = torch.div(l, len(y_log))
l

tensor(2.1291)

最终对多笔数据求的是平均交叉熵

NLLLoss

The negative log likelihood loss.

def nll_loss(y_log, y_hat):y_1_hot = torch.nn.functional.one_hot(y_hat, class_dim)l = torch.tensor(0, dtype=torch.float32)for y_log_, y_1_hot_ in zip(y_log, y_1_hot):l += torch.dot(-y_log_.to(torch.float32), y_1_hot_.to(torch.float32))l = torch.div(l, len(y_log))return l
F.nll_loss(y_log, y_hat)

tensor(2.1291)

nll_loss(y_log, y_hat)

tensor(2.1291)

可以发现,nll_loss输入为经过了softmax和log后值,nll_loss所做的操作就是对y_log取负号,然后对y_index进行one hot编码,最后取真实类标位置上的-y_log值,即体现在点乘上

CrossEntropyLoss

def cross_entropy(z, y_hat):y = torch.nn.Softmax(dim=1)(z)y_log = torch.log(y)return nll_loss(y_log, y_hat)
F.cross_entropy(z, y_hat)

tensor(2.1291)

cross_entropy(z, y_hat)

tensor(2.1291)

可以发现,cross_entropy输入是未经过softmax和log后值,cross_entropy所做的操作就是对y进行softmax和取log,最后对y_log进行nll_loss

往往数据不会只有1笔,由于计算过程已经对多笔数据支持,所以只需在输入数据上增加

class_dim = 3
data_num = 5
z = torch.randn(data_num, class_dim)
y_hat = torch.ones(data_num, dtype=int).random_(class_dim)
z, y_hat

(tensor([[ 0.2319, 0.2875, -0.2994],
[ 0.7351, -1.3286, -0.4470],
[ 0.9836, -0.5633, -0.3552],
[ 0.8043, -0.3892, 1.2848],
[ 0.9196, 2.2589, -1.3184]]),
tensor([2, 0, 2, 1, 1]))

cross_entropy(z, y_hat)

tensor(1.2223)

F.cross_entropy(z, y_hat)

tensor(1.2223)

CrossEntropy

上面计算过程有个疑问,就是交叉熵为什么要这么求?

首先需要理解交叉熵含义,可以查看这个知乎回答加以理解,简单来说,交叉熵损失核心是熵的计算。所以有取log运算和取负号,而−lnx-lnx−lnx函数长这样:

可以发现经过softmax函数后的值域为[0,1],很好的满足了−lnx-lnx−lnx取值范围。上图也很好体现了信息量关系:概率越小,信息量越大。而信息熵是同分布下信息量在其概率下的期望。

交叉熵求的是非真实分布的信息量在真实分布概率下期望:
∑k=1Npklog⁡21qk\sum_{k=1}^{N} p_{k} \log _{2} \frac{1}{q_{k}}k=1∑N​pk​log2​qk​1​
或者是机器学习中常用表示(e为低,取负号)
∑k=1N−pkln⁡qk\sum_{k=1}^{N} -p_{k} \ln {q_{k}}k=1∑N​−pk​lnqk​
其中pkp_{k}pk​ 表示真实分布, qk\quad q_{k}qk​ 表示非真实分布(预测分布)。
交叉熵在非真实分布与真实分布一样时取得最小。怎么样更好解释这个结论暂时没找到。

那么为什么要求交叉熵最小呢?
实践上我们求的相对熵(KL散度)要最小,即真实分布和非真实分布差异,只是其中的真实分布的信息熵我们已经知道,求相对熵最小即求交叉熵最小。

pytorch中的NLLLoss和CrossEntropy相关推荐

  1. Pytorch中的NLLLoss代码解释

    在分类以及语义分割任务中,CrossEntropy是十分常用的一个损失函数,pytorch也对其进行了实现用于直接使用. 但本人在阅读其源码时,发现nn.CrossEropyLoss并不是直接按照交叉 ...

  2. pytorch中CrossEntropyLoss和NLLLoss的区别与联系

    pytorch中CrossEntropyLoss和NLLLoss的区别与联系 CrossEntropyLoss和NLLLoss主要是用在多分类问题的损失函数,他们两个既有不同,也有不浅的联系.先分别看 ...

  3. pytorch中实现Balanced Cross-Entropy

    当你明白了pytorch中F.cross_entropy以及F.binary_cross_entropy是如何实现的之后,你再基于它们做改进重新实现一个损失函数就很容易了. 1.背景 变化检测中,往往 ...

  4. Pytorch中的分类损失函数比较NLLLoss与CrossEntropyLoss

    参考来源,仅作为学习笔记 二分类 对于一个二分类问题,比如我们有一个样本,有两个不同的模型对他进行分类,那么它们的输出都应该是一个二维向量,比如: 模型一的输出为:pred_y1=[0.8,0.2] ...

  5. 机器学习花朵图像分类_在PyTorch中使用转移学习进行图像分类

    想了解更多好玩的人工智能应用,请关注公众号"机器AI学习 数据AI挖掘","智能应用"菜单中包括:颜值检测.植物花卉识别.文字识别.人脸美妆等有趣的智能应用.. ...

  6. Pytorch中的梯度知识总结

    文章目录 1.叶节点.中间节点.梯度计算 2.叶子张量 leaf tensor (叶子节点) (detach) 2.1 为什么需要叶子节点? 2.2 detach()将节点剥离成叶子节点 2.3 什么 ...

  7. 损失函数-负对数似然和交叉熵(Pytorch中的应用)

    文章目录 1.负对数似然损失函数 1.1.似然 1.2.似然函数 1.3.极大似然估计 1.4.对数似然 1.5.负对数似然 1.6.pytorch中的应用 2.交叉熵损失函数 2.1.信息量 2.2 ...

  8. torch中的NLLLoss与CrossEntropyLoss

    0.先搞清楚几种分类问题 图片分类问题中通常一张图片中同时有多个目标,所以会有多个标签,通常分类问题可以如下划分 表格1 样本单标签 样本多标签 类别数量=2 简单二分类 当作多个二分类 类别数量&g ...

  9. pytorch中的二分类及多分类交叉熵损失函数

    本文主要记录一下pytorch里面的二分类及多分类交叉熵损失函数的使用. import torch import torch.nn as nn import torch.nn.functional a ...

最新文章

  1. 数学图形(1.40)T_parameter
  2. 顺序查找计时函数C语言,用C语言编二分查找
  3. FreeBSD portupgrade升级你的FreeBSD软件[zt]
  4. scala学习-Linux命令行运行jar包传入main方法参数
  5. 输电线路巡检机器人PPT_国网泰安供电公司开展输电线路无人机精细化巡检
  6. FreeSwitch之拨号计划~简单例子(二)
  7. H5唤起APP客户端
  8. Typora免费版(Typora最后一个版本下载)
  9. 手动设置ip 访问内网地址
  10. Unity鼠标控制相机上下左右环视360度旋转(Quaternion.AngleAxis)
  11. 树莓派正式开售CM4以及CM4 Lite,32个不同配置,最低25美元起售!
  12. 设置chrome浏览器在一个标签页中打开链接自动跳转到新标签页
  13. 太过伤心,小王被这 10 道 Java 面试题虐哭了
  14. Ardupilot 绕圈模式分析
  15. 2023元旦倒计时代码
  16. C语言练习,利用求阶乘函数Fact(),编程计算并输出从1到n之间所有数的阶乘值。
  17. 树莓派4B连接KY008激光头
  18. Elasticsearch实战(四)---中英文分词及拼音搜索
  19. g4600黑苹果efi_黑苹果硬盘引导的两种方式
  20. CodeVS2495 水叮当的舞步

热门文章

  1. 麒麟处理器是基于arm的吗_为何华为海思麒麟处理器、高通、联发科等都要用到ARM的构架?...
  2. java毕业设计开题报告基于SSM考试在线报名管理系统
  3. mysql too many_Mysql错误:Too many connections的解决方法
  4. 单机服务器模型,reactor的5种实现方式,单线程、多线程、多核、多进程的实现
  5. Adobe终于放大招联网玩Cloud技术了么?
  6. 4、Macbook2015 A1502 笔记本的换屏过程
  7. 如何将音频转换成mp3?
  8. java 两个点球面距离_计算球面两点间距离实现Vincenty+Haversine
  9. 连发数枪,海尔离“世界中心”仅一步之遥
  10. 下载大文件报SocketTimeoutException