多分类问题的交叉熵

  在多分类问题中,损失函数(loss function)为交叉熵(cross entropy)损失函数。对于样本点(x,y)来说,y是真实的标签,在多分类问题中,其取值只可能为标签集合labels. 我们假设有K个标签值,且第i个样本预测为第k个标签值的概率为pi,kpi,kp_{i,k}, 即pi,k=Pr(ti,k=1)pi,k=Pr⁡(ti,k=1)p_{i,k} = \operatorname{Pr}(t_{i,k} = 1), 一共有N个样本,则该数据集的损失函数为

Llog(Y,P)=−logPr(Y|P)=−1N∑i=0N−1∑k=0K−1yi,klogpi,kLlog(Y,P)=−log⁡Pr⁡(Y|P)=−1N∑i=0N−1∑k=0K−1yi,klog⁡pi,k

L_{\log}(Y, P) = -\log \operatorname{Pr}(Y|P) = - \frac{1}{N} \sum_{i=0}^{N-1} \sum_{k=0}^{K-1} y_{i,k} \log p_{i,k}

一个例子

  在Python的sklearn模块中,提供了一个函数log_loss()来计算多分类问题的交叉熵。再根据我们在博客Sklearn中二分类问题的交叉熵计算对log_loss()函数的源代码的分析,我们不难利用上面的计算公式用自己的方法来实现交叉熵的求值。
  我们给出的例子如下:

y_true = ['1', '4', '5'] # 样本的真实标签
y_pred = [[0.1, 0.6, 0.3, 0, 0, 0, 0, 0, 0, 0],[0, 0.3, 0.2, 0, 0.5, 0, 0, 0, 0, 0],[0.6, 0.3, 0, 0, 0, 0.1, 0, 0, 0, 0]]               # 样本的预测概率
labels = ['0','1','2','3','4','5','6','7','8','9'] # 所有标签

在这个例子中,一个有3个样本,标签为1,4,5,一共是10个标签,y_pred是对每个样本的所有标签的预测值。
  接下来我们将会用log_loss()函数和自己的方法分别来实现这个例子的交叉熵的计算,完整的Python代码如下:

from sklearn.metrics import log_loss
from sklearn.preprocessing import LabelBinarizer
from math import logy_true = ['1', '4', '5'] # 样本的真实标签
y_pred = [[0.1, 0.6, 0.3, 0, 0, 0, 0, 0, 0, 0],[0, 0.3, 0.2, 0, 0.5, 0, 0, 0, 0, 0],[0.6, 0.3, 0, 0, 0, 0.1, 0, 0, 0, 0]]               # 样本的预测概率
labels = ['0','1','2','3','4','5','6','7','8','9'] # 所有标签# 利用sklearn中的log_loss()函数计算交叉熵
sk_log_loss = log_loss(y_true, y_pred, labels=labels)
print("Loss by sklearn is:%s." %sk_log_loss)# 利用公式实现交叉熵
# 交叉熵的计算公式网址为:
# http://scikit-learn.org/stable/modules/model_evaluation.html#log-loss# 对样本的真实标签进行标签二值化
lb = LabelBinarizer()
lb.fit(labels)
transformed_labels = lb.transform(y_true)
# print(transformed_labels)N = len(y_true)  # 样本个数
K = len(labels)  # 标签个数eps = 1e-15      # 预测概率的控制值
Loss = 0         # 损失值初始化for i in range(N):for k in range(K):# 控制预测概率在[eps, 1-eps]内,避免求对数时出现问题if y_pred[i][k] < eps:y_pred[i][k] = epsif y_pred[i][k] > 1-eps:y_pred[i][k] = 1-eps# 多分类问题的交叉熵计算公式Loss -= transformed_labels[i][k]*log(y_pred[i][k])Loss /= N
print("Loss by equation is:%s." % Loss)

输出的结果如下:

Loss by sklearn is:1.16885263244.
Loss by equation is:1.16885263244.

这说明我们能够用公式来自己实现交叉熵的计算了,是不是很神奇呢?
  多分类问题的交叉熵计算是建立在二分类问题的交叉熵计算的基础上,有了我们对log_loss()函数的源代码的研究,那就用自己的方法来实现多(二)分类问题的交叉熵计算就不是问题了~~
  本次分享到此结束,欢迎大家交流~~

注意:本人现已开通两个微信公众号: 因为Python(微信号为:python_math)以及轻松学会Python爬虫(微信号为:easy_web_scrape), 欢迎大家关注哦~~

多分类问题的交叉熵计算相关推荐

  1. 交叉熵损失函数分类_逻辑回归(Logistic Regression)二分类原理,交叉熵损失函数及python numpy实现...

    本文目录: 1. sigmoid function (logistic function) 2. 逻辑回归二分类模型 3. 神经网络做二分类问题 4. python实现神经网络做二分类问题 ----- ...

  2. 二分类交叉熵损失函数python_二分类问题的交叉熵损失函数多分类的问题的函数交叉熵损失函数求解...

    二分类问题的交叉熵损失函数; 在二分类问题中,损失函数为交叉熵损失函数.对于样本(x,y)来讲,x为样本 y为对应的标签.在二分类问题中,其取值的集合可能为{0,1},我们假设某个样本的真实标签为yt ...

  3. 均方误差越大越好_直观理解为什么分类问题用交叉熵损失而不用均方误差损失?...

    交叉熵损失与均方误差损失 常规分类网络最后的softmax层如下图所示,传统机器学习方法以此类比, 一共有\(K\)类,令网络的输出为\([\hat{y}_1,\dots, \hat{y}_K]\), ...

  4. 机器学习中交叉熵cross entropy是什么,怎么计算?

    项目背景:人体动作识别(分类),CNN或者RNN网络,softmax分类输出,输出为one-hot型标签. loss可以理解为预测输出pred与实际输出Y之间的差距,其中pred和Y均为one-hot ...

  5. 分类交叉熵Cross-Entropy

    一.简介 在二分类问题中,你可以根据神经网络节点的输出,通过一个激活函数如Sigmoid,将其转换为属于某一类的概率,为了给出具体的分类结果,你可以取0.5作为阈值,凡是大于0.5的样本被认为是正类, ...

  6. 交叉熵损失函数分类_交叉熵损失函数

    我们先从逻辑回归的角度推导一下交叉熵(cross entropy)损失函数. 从逻辑回归到交叉熵损失函数 这部分参考自 cs229-note1 part2. 为了根据给定的 预测 (0或1),令假设函 ...

  7. 一文了解机器学习中的交叉熵

    https://www.toutiao.com/a6654435108105224712/ 2019-02-05 16:46:49 熵 在介绍交叉熵之前首先介绍熵(entropy)的概念.熵是信息论中 ...

  8. 神经网络学习中的SoftMax与交叉熵

    简 介: 对于在深度学习中的两个常见的函数SoftMax,交叉熵进行的探讨.在利用paddle平台中的反向求微分进行验证的过程中,发现结果 与数学定义有差别.具体原因还需要之后进行查找. 关键词: 交 ...

  9. 深度学习中softmax交叉熵损失函数的理解

    1. softmax层的作用 通过神经网络解决多分类问题时,最常用的一种方式就是在最后一层设置n个输出节点,无论在浅层神经网络还是在CNN中都是如此,比如,在AlexNet中最后的输出层有1000个节 ...

最新文章

  1. 一位海外华人的质问:谁在误导中国人艳羡美国?
  2. ORACLE导入TXT文件数据的解决思路
  3. C.【转】C语言字符串与数字相互转换
  4. 对着爬虫网页HTML学习Python正则表达式re
  5. 面试题之Java内存区域
  6. 一加9 Pro渲染图曝光:6.55英寸曲面屏 左上角打孔
  7. 苹果传出放弃研发自动驾驶,因iPhone销量不佳收紧支出
  8. 高级着色语言HLSL入门(7)
  9. 谈薪资被 HR 怼了:估计你一辈子就是个程序员!气不过啊。。。
  10. 永磁同步电机dq坐标系中转矩公式中系数3/2的由来
  11. Go语言 —— 前景
  12. Obsidian关系图谱如何让节点可以手动拖动
  13. opencv半透明填充不规则区域
  14. 信息系统安全 总结提纲
  15. python 跳过_python怎么跳过异常继续执行
  16. 【Captain America Sentinel of Liberty HD】美国队长:自由哨兵 v1.0.2
  17. 千亿级宠物赛道,卖蚊香的朝云能“掘金”多少?
  18. YII2 数据库常用操作案例
  19. HackPwn:TCL智能洗衣机破解细节分析
  20. 一个人的思想:漫谈技术社区

热门文章

  1. STlink、Jlink驱动一直安装失败的解决办法
  2. 语言特征与模式- λ演算
  3. pytorch使用visdom可视化loss
  4. 学术英语 | (8) WordList7
  5. 人工智能学习(一)newff函数介绍
  6. 脑疾病患者福音,又一家脑机接口公司完成首次人体试验
  7. 两个主机之间如何通信
  8. 很多TAP-Windows adapter V9#
  9. 我国超级计算机历代,神威计算机图片_神威太湖之光_神威计祘机
  10. 新塘系列linux_【一点资讯】重磅!增城这6条村又被广州点名!涉及中新、正果、小楼、新塘… www.yidianzixun.com...