softmax,log_softmx,nll_loss和CELoss之间的关系
目录
0. 数据准备
1. 按公式
2. softmax + log + nll(one-hot)
3. 直接cross_entropy
4. 计算四维预测值损失
0. 数据准备
计算二维预测值的损失。
先生成pred.shape==(4,5). label.shape==(4,)
import torch.nn.functional as F
import torch# softmax + log: softmax输出范围是(0,1),log输出范围是(负无穷,0)
pred = torch.Tensor(4,5) # (4,5)
softmax_log_pred= F.log_softmax(pred, dim=1) # (4,5). softmax + log# one-hot
target = torch.Tensor([1, 4, 3, 2]) # (4,)
one_hot_target = F.one_hot(target.long()) # (4,5). 不同的数值被编码成不同的01序列
pred = torch.Tensor(4,5)
pred
tensor([[0., 0., 0., 0., 0.],[0., 0., 1., 1., 0.],[1., 3., 2., 3., 0.],[0., 0., 0., 0., 0.]])softmax_log_pred= F.log_softmax(pred, dim=1) # 先softmax,再log
softmax_log_pred
tensor([[-1.6094, -1.6094, -1.6094, -1.6094, -1.6094],[-2.1326, -2.1326, -1.1326, -1.1326, -2.1326],[-2.9373, -0.9373, -1.9373, -0.9373, -3.9373],[-1.6094, -1.6094, -1.6094, -1.6094, -1.6094]])target = torch.Tensor([1, 4, 3, 2])
target
tensor([1., 4., 3., 2.]) # (4,)
one_hot_target = F.one_hot(target.long()) # (4,5) 增加一个维度
one_hot_target
tensor([[0, 1, 0, 0, 0], # 对应lable 1[0, 0, 0, 0, 1], # 对应label 4[0, 0, 0, 1, 0], # label 3[0, 0, 1, 0, 0]])
1. 按公式
CELoss等价于先x=log_softmax(pred),再y=one_hot(target),再利用公式avg_loss = -sum(x*y)/n.
# batch_size=4, 有4个数据,除以4,求所有数据的平均分数,
# softmax_log_pred (4,5); one_hot_target (4,5)
res=-torch.sum(softmax_log_pred*one_hot_target)/target.shape[0] # tensor(1.6094)
2. softmax + log + nll(one-hot)
CELoss等价于log_softmax + nll. 相比较前面,nll内置了热编码操作。
# 第二种方式,softmax + log + nll(one-hot)
# softmax_log_pred (4,5); target (4,)
F.nll_loss(softmax_log_pred, target.long()) # tensor(1.6094) # 负的log likelihood损失函数
3. 直接cross_entropy
# pred (4,5); target (4,)
import torch.nn.functional as F
F.cross_entropy(pred, target.long()) # tensor(1.6094)
4. 计算四维预测值损失
pred = torch.Tensor(4,3,256,256) # shape: [4, 3, 256, 256]
softmax_log_pred= F.log_softmax(pred, dim=1) # shape: [4, 3, 256, 256]target = torch.randint(0, 3, (4, 256, 256))
target.shape
torch.Size([4, 256, 256])# 热编码
one_hot_target = F.one_hot(target.long()) # [4, 256, 256, 3]
one_hot_target = torch.transpose(one_hot_target, 1, 3) # [4, 3, 256, 256]# 1,公式
# 每个数据有3个分数,有4*256*256个数据,求平均分数。
# softmax_log_pred:[4, 3, 256, 256];one_hot_target:[4, 3, 256, 256]
res=-torch.sum(softmax_log_pred*one_hot_target)/
(target.shape[0]*target.shape[2].target.shape[3]) # tensor(1.0986)# 2, log_softmax + nll
# softmax_log_pred: [4, 3, 256, 256], target: [4, 256, 256]
F.nll_loss(softmax_log_pred, target)
tensor(1.0965)# 3,直接CELoss
# pred.shape: [4, 3, 256, 256]; target.shape: [4, 256, 256]
F.cross_entropy(pred, target.long())
tensor(1.0965)
softmax,log_softmx,nll_loss和CELoss之间的关系相关推荐
- 利用pytorch来深入理解CELoss、BCELoss和NLLLoss之间的关系
利用pytorch来深入理解CELoss.BCELoss和NLLLoss之间的关系 损失函数为为计算预测值与真实值之间差异的函数,损失函数越小,预测值与真实值间的差异越小,证明网络效果越好.对于神经网 ...
- 论文阅读课3-GraphRel: Modeling Text as Relational Graphs for(实体关系联合抽取,重叠关系,关系之间的关系,自动提取特征)
文章目录 abstract 1.Introduction 2.相关工作 3.回顾GCN 4.方法 4.1第一阶段 4.1.1 Bi-LSTM 4.1.2 Bi_GCN 4.1.3 实体关系抽取 4.2 ...
- 手搓GPT系列之 - Logistic Regression模型,Softmax模型的损失函数与CrossEntropyLoss的关系
笔者在学习各种分类模型和损失函数的时候发现了一个问题,类似于Logistic Regression模型和Softmax模型,目标函数都是根据最大似然公式推出来的,但是在使用pytorch进行编码的时候 ...
- Day-16 面向对象03 类与类之间的关系
一.类与类之间的依赖关系 我用着你,但是你不属于我,这种关系是最弱的,比如,公司和雇员之间,对于正式员工,肯定要签订劳动合同,还得小心伺候着,但是如果是兼职,那无所谓,需要了你就来,不需要你就可以拜拜 ...
- 【linux】图形界面基础知识(X、X11、GNOME、Xorg、KDE的概念和它们之间的关系)
转载自:https://blog.csdn.net/zhangxinrun/article/details/7332049 简介 LINUX初学者经常分不清楚linux和X之间,X和Xfree86之间 ...
- 嵌入式开发之信号采集同步---VSYNC和HSYNC的作用以及它们两者之间的关系
VSYNC和HSYNC的作用以及它们两者之间的关系 VSYNC和HSYNC的作用以及它们两者之间的关系 VSYNC和HSYNC是什么 VSYNC: vertical synchronization,指 ...
- 通过构建城市来解释HTML,CSS和JavaScript之间的关系
by Kevin Kononenko 凯文·科诺年科(Kevin Kononenko) 通过构建城市来解释HTML,CSS和JavaScript之间的关系 (The relationship betw ...
- Python中怎样改变集合之间的关系?
Python中怎样改变集合之间的关系?数学中,两个集合关系的常见操作包括:交集.并集.差集.补集.设A,B是两个集合,集合关系的操作介绍如下: 交集是指属于集合A且属于集合B的元素所组成的集合, 并集 ...
- 当支持向量机遇上神经网络:这项研究揭示了SVM、GAN、Wasserstein距离之间的关系...
选自arXiv 作者:Alexia Jolicoeur-Martineau 编辑:小舟.蛋酱 转载自公众号:机器之心 SVM 是机器学习领域的经典算法之一.如果将 SVM 推广到神经网络,会发生什么呢 ...
最新文章
- Android开发学习笔记:数据存取之SQLite浅析
- VC对话框禁止关闭按钮和禁止任务管理中关闭进程
- Tool之ADB:ADB工具的简介、安装、使用方法之详细攻略
- 运维更简单、更智能,让运维人不再 “拼命”
- 亲测!这本 Python 书销量超过13W+原来是这样
- Android+Jquery Mobile学习系列(3)-创建Android项目
- 判断操作系统多久没有任何操作.e
- 找到指针的奇数位置 c语言,(ppt)【C语言程序设计】上机作业2010.ppt
- 六个好用的程序员开发在线工具
- Pix4D生成正射影像记录
- win10右键一直转圈_win10 系统 桌面点右键经常转圈圈卡住。
- matlab 优化 小于,科学网—matlab全局优化与局部优化 - 张凌的博文
- 唐朝一体机屏幕显示变红
- ubuntu为软件设定图标
- php使用redis在windows下配置方法
- 组建无线网络的六条思路
- 在UE4中实现锥体下雨效果
- Mura缺陷检测【1】:svd
- 第16课:Spring Cloud 实例详解——基础框架搭建(三)
- 思科无边界ip电话配置方案
热门文章
- python猪肉价格预测_如果现在生猪期货上市,猪肉价格会下降吗?
- 谱估计(三)DFT与正交分解
- php识别脸型代码,PHP人脸识别为你的颜值打分
- 源码阅读神器Sourcetrail
- 软件开发环境、生产环境、测试环境的基本理解和区别
- 阿里云IP地址AS37963 CNNIC-ALIBABA-CN-NET-AP
- 附件上传到文件服务器,文件服务器 上传附件
- Python项目实战 3.2:验证码.短信验证码
- Scrapy执行crawl命令报错:ModuleNotFoundError: No module named 'win32api'
- 【转】理想主义者--理查德.马修.斯托曼(GNU的传奇)