目录

0. 数据准备

1. 按公式

2. softmax + log + nll(one-hot)

3. 直接cross_entropy

4. 计算四维预测值损失


0. 数据准备

计算二维预测值的损失。

先生成pred.shape==(4,5). label.shape==(4,)

import torch.nn.functional as F
import torch# softmax + log: softmax输出范围是(0,1),log输出范围是(负无穷,0)
pred = torch.Tensor(4,5)  # (4,5)
softmax_log_pred= F.log_softmax(pred, dim=1)  # (4,5). softmax + log# one-hot
target = torch.Tensor([1, 4, 3, 2])  # (4,)
one_hot_target = F.one_hot(target.long())  # (4,5). 不同的数值被编码成不同的01序列
pred = torch.Tensor(4,5)
pred
tensor([[0., 0., 0., 0., 0.],[0., 0., 1., 1., 0.],[1., 3., 2., 3., 0.],[0., 0., 0., 0., 0.]])softmax_log_pred= F.log_softmax(pred, dim=1)  # 先softmax,再log
softmax_log_pred
tensor([[-1.6094, -1.6094, -1.6094, -1.6094, -1.6094],[-2.1326, -2.1326, -1.1326, -1.1326, -2.1326],[-2.9373, -0.9373, -1.9373, -0.9373, -3.9373],[-1.6094, -1.6094, -1.6094, -1.6094, -1.6094]])target = torch.Tensor([1, 4, 3, 2])
target
tensor([1., 4., 3., 2.])  # (4,)
one_hot_target = F.one_hot(target.long())  # (4,5) 增加一个维度
one_hot_target
tensor([[0, 1, 0, 0, 0],  # 对应lable 1[0, 0, 0, 0, 1],  # 对应label 4[0, 0, 0, 1, 0],  # label 3[0, 0, 1, 0, 0]])

1. 按公式

CELoss等价于先x=log_softmax(pred),再y=one_hot(target),再利用公式avg_loss = -sum(x*y)/n.

# batch_size=4, 有4个数据,除以4,求所有数据的平均分数,
# softmax_log_pred (4,5); one_hot_target (4,5)
res=-torch.sum(softmax_log_pred*one_hot_target)/target.shape[0]  # tensor(1.6094)

2. softmax + log + nll(one-hot)

CELoss等价于log_softmax + nll. 相比较前面,nll内置了热编码操作。

# 第二种方式,softmax + log + nll(one-hot)
# softmax_log_pred (4,5); target (4,)
F.nll_loss(softmax_log_pred, target.long())  # tensor(1.6094)  # 负的log likelihood损失函数

3. 直接cross_entropy

# pred (4,5); target (4,)
import torch.nn.functional as F
F.cross_entropy(pred, target.long())  # tensor(1.6094)

4. 计算四维预测值损失

pred = torch.Tensor(4,3,256,256)  # shape: [4, 3, 256, 256]
softmax_log_pred= F.log_softmax(pred, dim=1)  # shape: [4, 3, 256, 256]target = torch.randint(0, 3, (4, 256, 256))
target.shape
torch.Size([4, 256, 256])# 热编码
one_hot_target = F.one_hot(target.long())  # [4, 256, 256, 3]
one_hot_target = torch.transpose(one_hot_target, 1, 3)  # [4, 3, 256, 256]# 1,公式
# 每个数据有3个分数,有4*256*256个数据,求平均分数。
# softmax_log_pred:[4, 3, 256, 256];one_hot_target:[4, 3, 256, 256]
res=-torch.sum(softmax_log_pred*one_hot_target)/
(target.shape[0]*target.shape[2].target.shape[3])  # tensor(1.0986)# 2, log_softmax + nll
# softmax_log_pred: [4, 3, 256, 256], target: [4, 256, 256]
F.nll_loss(softmax_log_pred, target)
tensor(1.0965)# 3,直接CELoss
# pred.shape: [4, 3, 256, 256]; target.shape: [4, 256, 256]
F.cross_entropy(pred, target.long())
tensor(1.0965)

softmax,log_softmx,nll_loss和CELoss之间的关系相关推荐

  1. 利用pytorch来深入理解CELoss、BCELoss和NLLLoss之间的关系

    利用pytorch来深入理解CELoss.BCELoss和NLLLoss之间的关系 损失函数为为计算预测值与真实值之间差异的函数,损失函数越小,预测值与真实值间的差异越小,证明网络效果越好.对于神经网 ...

  2. 论文阅读课3-GraphRel: Modeling Text as Relational Graphs for(实体关系联合抽取,重叠关系,关系之间的关系,自动提取特征)

    文章目录 abstract 1.Introduction 2.相关工作 3.回顾GCN 4.方法 4.1第一阶段 4.1.1 Bi-LSTM 4.1.2 Bi_GCN 4.1.3 实体关系抽取 4.2 ...

  3. 手搓GPT系列之 - Logistic Regression模型,Softmax模型的损失函数与CrossEntropyLoss的关系

    笔者在学习各种分类模型和损失函数的时候发现了一个问题,类似于Logistic Regression模型和Softmax模型,目标函数都是根据最大似然公式推出来的,但是在使用pytorch进行编码的时候 ...

  4. Day-16 面向对象03 类与类之间的关系

    一.类与类之间的依赖关系 我用着你,但是你不属于我,这种关系是最弱的,比如,公司和雇员之间,对于正式员工,肯定要签订劳动合同,还得小心伺候着,但是如果是兼职,那无所谓,需要了你就来,不需要你就可以拜拜 ...

  5. 【linux】图形界面基础知识(X、X11、GNOME、Xorg、KDE的概念和它们之间的关系)

    转载自:https://blog.csdn.net/zhangxinrun/article/details/7332049 简介 LINUX初学者经常分不清楚linux和X之间,X和Xfree86之间 ...

  6. 嵌入式开发之信号采集同步---VSYNC和HSYNC的作用以及它们两者之间的关系

    VSYNC和HSYNC的作用以及它们两者之间的关系 VSYNC和HSYNC的作用以及它们两者之间的关系 VSYNC和HSYNC是什么 VSYNC: vertical synchronization,指 ...

  7. 通过构建城市来解释HTML,CSS和JavaScript之间的关系

    by Kevin Kononenko 凯文·科诺年科(Kevin Kononenko) 通过构建城市来解释HTML,CSS和JavaScript之间的关系 (The relationship betw ...

  8. Python中怎样改变集合之间的关系?

    Python中怎样改变集合之间的关系?数学中,两个集合关系的常见操作包括:交集.并集.差集.补集.设A,B是两个集合,集合关系的操作介绍如下: 交集是指属于集合A且属于集合B的元素所组成的集合, 并集 ...

  9. 当支持向量机遇上神经网络:这项研究揭示了SVM、GAN、Wasserstein距离之间的关系...

    选自arXiv 作者:Alexia Jolicoeur-Martineau 编辑:小舟.蛋酱 转载自公众号:机器之心 SVM 是机器学习领域的经典算法之一.如果将 SVM 推广到神经网络,会发生什么呢 ...

最新文章

  1. Android开发学习笔记:数据存取之SQLite浅析
  2. VC对话框禁止关闭按钮和禁止任务管理中关闭进程
  3. Tool之ADB:ADB工具的简介、安装、使用方法之详细攻略
  4. 运维更简单、更智能,让运维人不再 “拼命”
  5. 亲测!这本 Python 书销量超过13W+原来是这样
  6. Android+Jquery Mobile学习系列(3)-创建Android项目
  7. 判断操作系统多久没有任何操作.e
  8. 找到指针的奇数位置 c语言,(ppt)【C语言程序设计】上机作业2010.ppt
  9. 六个好用的程序员开发在线工具
  10. Pix4D生成正射影像记录
  11. win10右键一直转圈_win10 系统 桌面点右键经常转圈圈卡住。
  12. matlab 优化 小于,科学网—matlab全局优化与局部优化 - 张凌的博文
  13. 唐朝一体机屏幕显示变红
  14. ubuntu为软件设定图标
  15. php使用redis在windows下配置方法
  16. 组建无线网络的六条思路
  17. 在UE4中实现锥体下雨效果
  18. Mura缺陷检测【1】:svd
  19. 第16课:Spring Cloud 实例详解——基础框架搭建(三)
  20. 思科无边界ip电话配置方案

热门文章

  1. python猪肉价格预测_如果现在生猪期货上市,猪肉价格会下降吗?
  2. 谱估计(三)DFT与正交分解
  3. php识别脸型代码,PHP人脸识别为你的颜值打分
  4. 源码阅读神器Sourcetrail
  5. 软件开发环境、生产环境、测试环境的基本理解和区别
  6. 阿里云IP地址AS37963 CNNIC-ALIBABA-CN-NET-AP
  7. 附件上传到文件服务器,文件服务器 上传附件
  8. Python项目实战 3.2:验证码.短信验证码
  9. Scrapy执行crawl命令报错:ModuleNotFoundError: No module named 'win32api'
  10. 【转】理想主义者--理查德.马修.斯托曼(GNU的传奇)