0 前言

Softmax在机器学习中应用非常广泛,尤其在处理多分类问题,分类器最后的输出单元需要Softmax 函数进行数值处理。但是刚刚接触机器学习的同学可能对Softmax的特点及好处并不理解,当你了解以后会发现,Softmax计算简单,效果显著。

我们先来直观看一下,Softmax究竟是什么意思:我们知道max,假如说我有两个数,a和b,并且a>b,如果取max,那么就直接取a,没有第二种可能。然而,有的时候我们并不想这样,因为这样会造成分值小的那个饥饿。所以我希望分值大的那一项经常取到,分值小的那一项偶尔也可以取到,那么用softmax就可以实现这一操作。

现在还是a和b,a>b,如果我们取按照softmax来计算取a和b的概率,那a的softmax值大于b的,所以a会经常被取到,而b也会偶尔取到,概率跟它们本来的大小有关。所以说不是max,而是 Soft max 那各自的概率究竟是多少呢,这里我们引入softmax的概念及表达式。

1 Softmax 表达式

关于Softmax 函数的定义如下所示:

其中,Vi 是分类器线性输出单元的输出。i表示类别索引,总的类别个数为C。Si表示当前元素的指数与所有元素指数和的比值。Softmax 将多分类的输出数值转化为相对概率,更容易理解和比较。我们来看下面这个例子。

一个多分类问题,C = 4。线性分类器模型最后输出层包含了四个输出值,分别是:

经过Softmax处理后,数值转化为相对概率:

很明显,Softmax的输出表征了不同类别之间的相对概率。可以明显看出,S1 = 0.8390,对应的概率最大,因此,判断预测为第1类的可能性更大。Softmax 将连续数值转化成相对概率,更有利于理解。

实际应用中,使用 Softmax 需要注意数值溢出的问题。因为有指数运算,如果V值很大,经过指数运算后的数值有溢出的可能。所以,需要对V做一些数值处理:即V中的每个元素减去V中的最大值。

相应的python示例代码如下:

scores = np.array([123, 456, 789])    # example with 3 classes and each having large scoresscores -= np.max(scores)    # scores becomes [-666, -333, 0]p = np.exp(scores) / np.sum(np.exp(scores))

2 Softmax 损失函数

我们知道,线性分类器的输出是输入x与权重系数矩阵w的相乘,即s = Wx。对于多分类问题,使用 Softmax 对线性输出进行处理。于是,我们来探讨下 Softmax 的损失函数。

其中,Syi是正确类别对应的线性得分函数,Si 是正确类别对应的 Softmax输出。

由于 log 运算符不会影响函数的单调性,我们对 Si 进行 log 操作:

我们希望 Si 越大越好,即正确类别对应的相对概率越大越好,那么就可以对 Si 前面加个负号,来表示损失函数:

对上式进一步处理,把指数约去:

这样,Softmax 的损失函数就转换成了简单的形式。

举个简单的例子,上一小节中得到的线性输出为:

假设 i = 1 为真实样本,计算其损失函数为:

若令 i = 0 为真实样本,计算其损失函数为:

3 softmax 反向梯度

推导了 Softmax 的损失函数之后,接下来继续对权重参数进行反向求导。

Softmax 线性分类器中,线性输出为:

其中,下标 i 表示第 i 个样本。

求导过程的程序设计分为两种方法:一种是使用嵌套 for 循环,另一种是直接使用矩阵运算。

使用嵌套 for 循环,对权重 W 求导函数定义如下:

def softmax_loss_naive(W, X, y, reg):
 """
 Softmax loss function, naive implementation (with loops)
 Inputs have dimension D, there are C classes, and we operate on minibatches
 of N examples.
 Inputs:
 - W: A numpy array of shape (D, C) containing weights.
 - X: A numpy array of shape (N, D) containing a minibatch of data.
 - y: A numpy array of shape (N,) containing training labels; y[i] = c means
   that X[i] has label c, where 0 <= c < C.
 - reg: (float) regularization strength
 Returns a tuple of:
 - loss as single float
 - gradient with respect to weights W; an array of same shape as W
 """
 # Initialize the loss and gradient to zero.
 loss = 0.0
 dW = np.zeros_like(W)
 num_train = X.shape[0]
 num_classes = W.shape[1]
 for i in xrange(num_train):
   scores = X[i,:].dot(W)
   scores_shift = scores - np.max(scores)
   right_class = y[i]
   loss += -scores_shift[right_class] + np.log(np.sum(np.exp(scores_shift)))
   for j in xrange(num_classes):
     softmax_output = np.exp(scores_shift[j]) / np.sum(np.exp(scores_shift))
     if j == y[i]:
       dW[:,j] += (-1 + softmax_output) * X[i,:]
     else:
       dW[:,j] += softmax_output * X[i,:]
 loss /= num_train
 loss += 0.5 * reg * np.sum(W * W)
 dW /= num_train
 dW += reg * W
 return loss, dW

使用矩阵运算,对权重 W 求导函数定义如下:

def softmax_loss_vectorized(W, X, y, reg):
 """
 Softmax loss function, vectorized version.
 Inputs and outputs are the same as softmax_loss_naive.
 """
 # Initialize the loss and gradient to zero.
 loss = 0.0
 dW = np.zeros_like(W)
 num_train = X.shape[0]
 num_classes = W.shape[1]
 scores = X.dot(W)
 scores_shift = scores - np.max(scores, axis = 1).reshape(-1,1)
 softmax_output = np.exp(scores_shift) / np.sum(np.exp(scores_shift), axis=1).reshape(-1,1)
 loss = -np.sum(np.log(softmax_output[range(num_train), list(y)]))
 loss /= num_train
 loss += 0.5 * reg * np.sum(W * W)
 dS = softmax_output.copy()
 dS[range(num_train), list(y)] += -1
 dW = (X.T).dot(dS)
 dW = dW / num_train + reg * W  
 return loss, dW

实际验证表明,矩阵运算速度要比嵌套循环快很多,特别是在训练样本数量多的情况下。我们使用 CIFAR-10 数据集中约5000个样本对两种求导方式进行测试对比:

tic = time.time()
loss_naive, grad_naive = softmax_loss_naive(W, X_train, y_train, 0.000005)
toc = time.time()
print('naive loss: %e computed in %fs' % (loss_naive, toc - tic))
tic = time.time()
loss_vectorized, grad_vectorized = softmax_loss_vectorized(W, X_train, y_train, 0.000005)
toc = time.time()
print('vectorized loss: %e computed in %fs' % (loss_vectorized, toc - tic))
grad_difference = np.linalg.norm(grad_naive - grad_vectorized, ord='fro')
print('Loss difference: %f' % np.abs(loss_naive - loss_vectorized))
print('Gradient difference: %f' % grad_difference)

结果显示为:

>> naive loss: 2.362135e+00 computed in 14.680000s

>> vectorized loss: 2.362135e+00 computed in 0.242000s

>> Loss difference: 0.000000

>> Gradient difference: 0.000000

显然,此例中矩阵运算的速度要比嵌套循环快60倍。所以,当我们在编写机器学习算法模型时,尽量使用矩阵运算,少用嵌套循环,以提高运算速度。

4 Softmax 与 SVM

Softmax线性分类器的损失函数计算相对概率,又称交叉熵损失(Cross Entropy Loss)。线性 SVM 分类器和 Softmax 线性分类器的主要区别在于损失函数不同。SVM 使用的是 hinge loss,更关注分类正确样本和错误样本之间的距离「Δ = 1」,只要距离大于 Δ,就不在乎到底距离相差多少,忽略细节。而 Softmax 中每个类别的得分函数都会影响其损失函数的大小。例如,类别个数 C=3,两个样本的得分函数分别为[10, -10, -10],[10, 9, 9],真实标签为第0类。对于SVM来说,这两个 Li 都为0;对于Softmax来说,这两个 Li 分别为0和0.55,差别很大。

5 Softmax的正则化参数 λ

接下来,谈一下正则化参数 λ 对 Softmax 的影响。我们知道正则化的目的是限制权重参数 W 的大小,防止过拟合。正则化参数 λ 越大,对 W 的限制越大。例如,某3分类的线性输出为 [1, -2, 0],相应的 Softmax 输出为[0.7, 0.04, 0.26]。假设,第0类是正确类别,显然0.7远大于0.04和0.26。若使用正则化参数 λ,由于 λ 限制了 W 的大小,得到的线性输出也会等比例缩小,初始的线性输出[1, -2, 0] 就会变成类似于 [0.5, -1, 0],相应的 Softmax 输出为[0.55, 0.12, 0.33]。显然,正确类别和错误类别之间的相对概率差距变小了。

也就是说,正则化参数 λ 越大,Softmax 各类别输出越接近。大的 λ 实际上是“均匀化”正确样本与错误样本之间的相对概率。但是,概率大小的相对顺序并没有改变,因此也不会影响到对 Loss 的优化算法。

Softmax的通俗讲解相关推荐

  1. 人工智能算法通俗讲解系列(二):逻辑回归

    2019独角兽企业重金招聘Python工程师标准>>> 今天,我们介绍的机器学习算法叫逻辑回归.它英语名称是Logistic Regression,简称LR. 跟之前一样,介绍这个算 ...

  2. Hadoop平台K-Means聚类算法分布式实现+MapReduce通俗讲解

        Hadoop平台K-Means聚类算法分布式实现+MapReduce通俗讲解 在Hadoop分布式环境下实现K-Means聚类算法的伪代码如下: 输入:参数0--存储样本数据的文本文件inpu ...

  3. 冲突域、广播域的通俗讲解

    冲突域.广播域的通俗讲解 1.冲突域(物理分段)         连接在同一导线上的所有工作站的集合,或者说是同一物理网段上所有节点的集合或以太网上竞争同一带宽的节点集合.这个域代表了冲突在其中发生并 ...

  4. 通俗讲解:图像傅里叶变换

    转自某乎:通俗讲解:图像傅里叶变换 - 知乎 这里我们主要要讲的是二维图像傅里叶变换,但是我们首先来看一张很厉害的一维傅里叶变换动图. 妈耶~厉害哇!它把时域和频域解释的很清楚! 什么!你看不懂! 简 ...

  5. 关于CSS浮动(float,clear)的通俗讲解(经验分享)

    很早以前就接触过CSS,但对于浮动始终非常迷惑,可能是自身理解能力差,也可能是没能遇到一篇通俗的教程. 前些天小菜终于搞懂了浮动的基本原理,迫不及待的分享给大家. 写在前面的话: 由于CSS内容比较多 ...

  6. 二维小波变换_小波变换完美通俗讲解系列之 (一)

    声明:该篇文章转自csdn,原始博主已经找不到了,在这里给出转载博主地址,如有侵权,请私信我删除. https://blog.csdn.net/liusandian/article/details/5 ...

  7. vue标准时间改为时间戳_区块链科普005:什么是时间戳?白话通俗讲解时间戳是什么意思?...

    原标题:区块链科普005:什么是时间戳?白话通俗讲解时间戳是什么意思? 八宝饭区块链:什么是时间戳 时间戳是什么?可能很多人都没有听过,但是作为办公一族应该掌握这个技能.时间戳技术就是数字签名技术一种 ...

  8. Java多数据源最通俗讲解

    Java多数据源最通俗讲解 before after 理论 实操 编码 小总结 before 项目中可能会用到很多的数据源,例如目前这个项目中用到了五个数据源,那么数据源的 配置和数据源的切换就成为了 ...

  9. 通俗讲解 依概率收敛,大数定理和中心极限定理

    通俗讲解 依概率收敛,大数定理和中心极限定理 依概率收敛 首先说一下结论,依概率收敛是一种基础证明工具,可以类比到高数中的极限定义,将一种直觉上的 "逼近某个数" 用数学公式来定义 ...

最新文章

  1. 阿铭每日一题 day 6 20180116
  2. 研究之路的秘密花园-个人经验分享-台湾朝阳科技大学陈金铃教授
  3. C#.NET常见问题(FAQ)-Combobox如何设置不可以编辑
  4. 《图解CSS3:核心技术与案例实战》——1.3节渐进增强
  5. 工具杂记-notepad++正则表达式匹配替换
  6. Java包的命名规则
  7. SAP License:SAP FI/CO—Questions and Answers
  8. 未在本地计算机上注册“Microsoft.ACE.OLEDB.12.0”提供程序 解决方案
  9. java jsonobject date_如何将Json Passed Date Value分配给Java Date Object
  10. 电脑系统及软件安装日期查看
  11. python安卓手机编程入门自学_编程入门学习路线(附教程推荐)
  12. 郭天祥的10天学会51单片机_第十一节
  13. android手机如何拥有苹果表情包,怎样让安卓emoji显示iPhone的emoji样式
  14. 竞品分析 | 荔枝VS喜马拉雅FM:有声音频APP的发展与社交
  15. 形而上学 “形而上者谓之道,形而下者谓之器”
  16. 单元测试中Assert详解-xUnit
  17. canal安装最详细教程
  18. 关于树莓派4B安装桌面控件wbar和conky解决报错的一种方案
  19. 【MySQL】幻读是什么?如何避免幻读?
  20. 关于WirelessKey的一些说明

热门文章

  1. 火热的“互联网+医疗” 究竟是谁的菜?
  2. 如何利用excel和jupyter 编程,对身高体重的数据做线性回归
  3. 《卡车模拟器3D》用户隐私政策
  4. 单片机中的液晶显示器
  5. 【每日新闻】微软悄然删除世界上最大的公共人脸识别数据库
  6. 网店美工之装修基础篇
  7. 中国黑客元老倡议自律 拒绝网络犯罪
  8. 重磅回归丨2020云和恩墨大讲堂,线上线下同步开讲!
  9. 如何将png转换为其他格式?格式转换器工具怎么用?
  10. DenseNet: Densely Connected Convolutional Networks