背景

mse均方误差、mae绝对值平均误差用于拟合回归,公式已经熟悉了,但交叉熵的每次都只是应用,没有了解公式,这对于自己写交叉熵损失函数以及分析损失函数不利。

公式详解

C是损失值;
n是求平均用的,所以是样本数量,也就是batchsize;
x是预测向量维度,因为需要在输出的特征向量维度上一个个计算并求和;
y是onehot编码后的真实值 对应x维度上的标签,是1或0;
a是onehot格式输出的预测标签,是0~1的值,a经过了softmax激活,所以a的和值为1

对于某个维度xix_ixi​,y=1y=1y=1时a越大越好,相反a越小越好,C值为两者和的负值,所以越好→C↓ 所以可以最优化C(神经网络需要最小化损失函数)

公式计算举例

C((0.8,0.1,0.1),(1,0,0))=−1/1∗(1ln(0.8)+0+0+1ln(0.9)+0+1ln(0.9))C((0.8,0.1,0.1),(1,0,0))=-1/1*(1ln(0.8)+0+0+1ln(0.9)+0+1ln(0.9))C((0.8,0.1,0.1),(1,0,0))=−1/1∗(1ln(0.8)+0+0+1ln(0.9)+0+1ln(0.9))

公式编程实现

计算举例是为了理解计算过程,最终还是要落到编程实现上:

def cross_entropy(y_true,y_pred):C=0# one-hot encodingfor col in range(y_true.shape[-1]):y_pred[col] = y_pred[col] if y_pred[col] < 1 else 0.99999y_pred[col] = y_pred[col] if y_pred[col] > 0 else 0.00001C+=y_true[col]*np.log(y_pred[col])+(1-y_true[col])*np.log(1-y_pred[col])return -C# 没有考虑样本个数 默认=1
num_classes = 3
label=1#设定是哪个类别 真实值y_true = np.zeros((num_classes))
# y_pred = np.zeros((num_classes))
# preset
y_true[label]=1
y_pred = np.array([0.0,1.0,0.0])
C = cross_entropy(y_true,y_pred)
print(y_true,y_pred,"loss:",C)
y_pred = np.array([0.1,0.8,0.1])
C = cross_entropy(y_true,y_pred)
print(y_true,y_pred,"loss:",C)
y_pred = np.array([0.2,0.6,0.2])
C = cross_entropy(y_true,y_pred)
print(y_true,y_pred,"loss:",C)
y_pred = np.array([0.3,0.4,0.3])
C = cross_entropy(y_true,y_pred)
print(y_true,y_pred,"loss:",C)

执行结果:

[0. 1. 0.] [1.0000e-05 9.9999e-01 1.0000e-05] loss: 3.0000150000863473e-05
[0. 1. 0.] [0.1 0.8 0.1] loss: 0.43386458262986227
[0. 1. 0.] [0.2 0.6 0.2] loss: 0.9571127263944101
[0. 1. 0.] [0.3 0.4 0.3] loss: 1.62964061975162Process finished with exit code 0

结论

  1. 分类任务神经网络的输出层往往是经过了softmax激活,所以最后一层输出的预测向量各个维度的值均为0~1范围内的数。
  2. python的引用类型,如array,在函数传参后是传的引用,所以函数内的修改会影响到实际的值,通过控制台输出的第一条信息即可知。
  3. 计算过程可能出现溢出NaN的报错,所以需要进行近似处理。

【神经网络笔记】——多分类交叉熵损失函数公式及代码实现相关推荐

  1. 交叉熵损失函数分类_PyTorch学习笔记——多分类交叉熵损失函数

    理解交叉熵 关于样本集的两个概率分布p和q,设p为真实的分布,比如[1, 0, 0]表示当前样本属于第一类,q为拟合的分布,比如[0.7, 0.2, 0.1]. 按照真实分布p来衡量识别一个样本所需的 ...

  2. LESSON 10.110.210.3 SSE与二分类交叉熵损失函数二分类交叉熵损失函数的pytorch实现多分类交叉熵损失函数

    在之前的课程中,我们已经完成了从0建立深层神经网络,并完成正向传播的全过程.本节课开始,我们将以分类深层神经网络为例,为大家展示神经网络的学习和训练过程.在介绍PyTorch的基本工具AutoGrad ...

  3. pytorch_lesson10 二分类交叉熵损失函数及调用+多分类交叉熵损失函数及调用

    注:仅仅是学习记录笔记,搬运了学习课程的ppt内容,本意不是抄袭!望大家不要误解!纯属学习记录笔记!!!!!! 文章目录 一.机器学习中的优化思想 二.回归:误差平方和SSE 三.二分类交叉熵损失函数 ...

  4. pytorch中的二分类及多分类交叉熵损失函数

    本文主要记录一下pytorch里面的二分类及多分类交叉熵损失函数的使用. import torch import torch.nn as nn import torch.nn.functional a ...

  5. 交叉熵损失函数公式_交叉熵损失函数对其参数求导

    1.Sigmoid 二分类交叉熵 交叉熵公式: 其中y是laebl:0 或1. hθ(xi)是经过sigmoid得到的预测概率.θ为网络的参数, m为样本数. hθ()函数如下所示, J(θ) 对参数 ...

  6. 二分类交叉熵损失函数python_【深度学习基础】第二课:softmax分类器和交叉熵损失函数...

    [深度学习基础]系列博客为学习Coursera上吴恩达深度学习课程所做的课程笔记. 本文为原创文章,未经本人允许,禁止转载.转载请注明出处. 1.线性分类 如果我们使用一个线性分类器去进行图像分类该怎 ...

  7. 二分类交叉熵损失函数python_二分类问题的交叉熵损失函数多分类的问题的函数交叉熵损失函数求解...

    二分类问题的交叉熵损失函数; 在二分类问题中,损失函数为交叉熵损失函数.对于样本(x,y)来讲,x为样本 y为对应的标签.在二分类问题中,其取值的集合可能为{0,1},我们假设某个样本的真实标签为yt ...

  8. AI学习[随堂笔记1109]_交叉熵损失函数_方差损失函数_基础

    交叉熵损失函数 一种用于分类问题1的损失函数2,原理为:将模型输出的概率,与标准答案3的值对比. 和正确答案越接近,则计算结果:交叉熵越低,模型质量越好 和错误答案越接近,则交叉熵越大,模型质量越差 ...

  9. 多分类交叉熵损失函数的梯度计算过程推导

    Softmax函数公式: Si 代表的是第i个神经元的输出 其中wij 是第i个神经元的第 j 个权重,b是偏移值.zi 表示该网络的第i个输出 隐藏层输出经过softmax: 具体过程如下图所示: ...

最新文章

  1. C++:名字空间的使用
  2. C++和python先学哪个
  3. 数据结构与算法分析资源总结
  4. 算法模板-对称性递归
  5. ES6-2 块级作用域与嵌套、let、暂行性死区
  6. 【c# 学习笔记】所有类的父类:System.object
  7. python坐标定位_如何利用Python识别并定位图片中某一个色块的坐标?
  8. 实战 | WebMagic 爬取某保险经纪人网站经纪人列表之网站列表爬取
  9. shell中source的作用
  10. linux实时信号的优势,c/c++开发分享Linux和实时信号
  11. 麻省理工学生令计算机系统升级不需重启
  12. entity framework 调用 oracle 序列_Weblogic T3 反序列化漏洞(CVE20192890 )分析
  13. 活动推荐 | 首届云原生编程挑战赛开始报名啦~
  14. SLAM学习笔记-------------(九)后端1
  15. 计算机组成原理微指令课程设计,计算机组成原理课程设计(微程序设计) New.doc...
  16. SACD ISO提取DSF文件及添加封面
  17. 来自和府捞面的信任,一起见证「客户的成功就是璞华的成功」
  18. VCRedist.exe静默安装方法(转)
  19. dfs 访问拒绝_DFS文件夹无法访问
  20. 电容笔和Apple pencil的区别?适合ipad画画的电容笔推荐

热门文章

  1. MasteringOpenCV实战源码学习笔记 章节一
  2. 数云原力大会 李扬:数据资产的首要问题是确定权益
  3. c语言分号怎么打,问什么C程序里总是提示缺少分号;,而明明有分号?
  4. 分支覆盖率 代码覆盖率_100%代码覆盖率神话
  5. 丝雨学姐小灶班——Week 3
  6. 多任务学习模型MTL: MMoE、PLE
  7. 【无人机路径规划】基于非线性最小二乘和DOP实现无人机的路径规划问题附matlab代码
  8. 本地html在线打包apk,HTML一键打包APK工具使用及配置方法
  9. 清理系统垃圾(windows批处理脚本)
  10. JavaScript 赋值运算符 、运算符 优先级