本文源自于SPACES:“抽取-生成”式长文本摘要(法研杯总结),原文其实是对一个比赛的总结,里面提到了很多Trick,其中有一个叫做稀疏Softmax(Sparse Softmax)的东西吸引了我的注意,查阅了很多资料以后,汇总在此

Sparse Softmax的思想源于《From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification》、《Sparse Sequence-to-Sequence Models》等文章。里边作者提出了将Softmax稀疏化的做法来增强其解释性乃至提升效果

不够稀疏的Softmax

前面提到Sparse Softmax本质上是将Softmax的结果稀疏化,那么为什么稀疏化之后会有效呢?我们认稀疏化可以避免Softmax过度学习的问题。假设已经成功分类,那么我们有smax=sts_{\text{max}}=s_tsmax​=st​(目标类别的分数最大),此时我们可以推导原始交叉熵的一个不等式:

log⁡(∑i=1nesi)−smax=log⁡(est+∑i≠tesi)−smax=log⁡(esmax+∑i≠tesi)−log⁡(esmax)=log⁡(esmax+∑i≠tesiesmax)=log⁡(1+∑i≠tesi−smax)≥log⁡(1+(n−1)esmin−smax)(1)\begin{aligned} \log (\sum_{i=1}^n e^{s_i})-s_{\text{max}} &= \log (e^{s_t}+\sum_{i\neq t}e^{s_i})-s_{\text{max}}\\ &= \log (e^{s_{\text{max}}} + \sum_{i\neq t}e^{s_i})-\log (e^{s_{\text{max}}})\\ &= \log (\frac{e^{s_{\text{max}}} + \sum_{i\neq t}e^{s_i}}{e^{s_{\text{max}}}})\\ &= \log (1+ \sum_{i \neq t}e^{s_i - s_{\text{max}}})\\ & \ge \log (1+ (n - 1)e^{s_{\text{min}}-s_{\text{max}}}) \end{aligned}\tag{1} log(i=1∑n​esi​)−smax​​=log(est​+i​=t∑​esi​)−smax​=log(esmax​+i​=t∑​esi​)−log(esmax​)=log(esmax​esmax​+∑i​=t​esi​​)=log(1+i​=t∑​esi​−smax​)≥log(1+(n−1)esmin​−smax​)​(1)

假设当前交叉熵值为ε\varepsilonε,那么有

ε≥log⁡(1+(n−1)esmin−smax)(2)\varepsilon \ge \log (1+ (n - 1)e^{s_{\text{min}}-s_{\text{max}}})\tag{2} ε≥log(1+(n−1)esmin​−smax​)(2)

解得

smax−smin≥log⁡(n−1)−log⁡(eε−1)(3)s_{\text{max}} - s_{\text{min}} \ge \log (n - 1) - \log (e^{\varepsilon} - 1)\tag{3} smax​−smin​≥log(n−1)−log(eε−1)(3)

我们以ε=ln⁡2=0.69...\varepsilon = \ln2 = 0.69...ε=ln2=0.69...为例,这时候log⁡(eε−1)=0\log (e^{\varepsilon} - 1)=0log(eε−1)=0,那么smax−smin≥log⁡(n−1)s_{\text{max}} - s_{\text{min}}\ge \log (n-1)smax​−smin​≥log(n−1)。也就是说,为了要loss降到0.69,那么最大的logit与最小的logit的差就必须大于log⁡(n−1)\log (n-1)log(n−1),当nnn比较大时,对于分类问题来说这是一个没有必要的过大的间隔,因为我们只希望目标类的logit比所有非目标类都要大一点就行,但是并不一定需要大log⁡(n−1)\log (n-1)log(n−1)那么多,因此常规的交叉熵容易过度学习从而导致过拟合

稀疏的Sparsemax

前面说了这么多关于Softmax的内容,那么Sparse Softmax或者说Sparsemax是如何做到稀疏化分布的呢?原文内容大家可以直接去看论文,写的非常复杂,这里我给出苏剑林大佬设计的一个更简单的版本

OriginSparseSoftmaxpi=esi∑j=1nesjpi={esi∑j∈Ωkesj,i∈Ωk0,i∉ΩkCrossEntropylog⁡(∑i=1nesi)−stlog⁡(∑i∈Ωkesi)−st\begin{array}{c|c|c} \hline & \text{Origin} & \text{Sparse} \\ \hline \text{Softmax} & p_i = \frac{e^{s_i}}{\sum\limits_{j=1}^{n} e^{s_j}} & p_i=\left\{\begin{aligned}&\frac{e^{s_i}}{\sum\limits_{j\in\Omega_k} e^{s_j}},\,i\in\Omega_k\\ &\quad 0,\,i\not\in\Omega_k\end{aligned}\right.\\ \hline \text{CrossEntropy} & \log\left(\sum\limits_{i=1}^n e^{s_i}\right) - s_t & \log\left(\sum\limits_{i\in\Omega_k} e^{s_i}\right) - s_t\\ \hline \end{array} SoftmaxCrossEntropy​Originpi​=j=1∑n​esj​esi​​log(i=1∑n​esi​)−st​​Sparsepi​=⎩⎪⎪⎨⎪⎪⎧​​j∈Ωk​∑​esj​esi​​,i∈Ωk​0,i​∈Ωk​​log(i∈Ωk​∑​esi​)−st​​​

其中Ωk\Omega_kΩk​是将s1,s2,...,sns_1,s_2,...,s_ns1​,s2​,...,sn​从大到小排列后前kkk个元素的下标集合。说白了,苏剑林大佬提出的Sparse Softmax就是在计算概率的时候,只保留前kkk个,后面的直接置零,kkk是人为选择的超参数

代码

首先我根据苏剑林大佬的思路,给出一个简单版本的PyTorch代码

import torch
import torch.nn as nnclass Sparsemax(nn.Module):"""Sparsemax loss"""def __init__(self, k_sparse=1):super(Sparsemax, self).__init__()self.k_sparse = k_sparsedef forward(self, preds, labels):"""Args:preds (torch.Tensor):  [batch_size, number_of_logits]labels (torch.Tensor): [batch_size] index, not ont-hotReturns:torch.Tensor"""preds = preds.reshape(preds.size(0), -1) # [batch_size, -1]topk = preds.topk(self.k_sparse, dim=1)[0] # [batch_size, k_sparse]# log(sum(exp(topk)))pos_loss = torch.logsumexp(topk, dim=1)# s_tneg_loss = torch.gather(preds, 1, labels[:, None].expand(-1, preds.size(1)))[:, 0]return (pos_loss - neg_loss).sum()

再给出一个Github上找到的一个PyTorch原版代码

"""Sparsemax activation function.
Pytorch implementation of Sparsemax function from:
-- "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification"
-- André F. T. Martins, Ramón Fernandez Astudillo (http://arxiv.org/abs/1602.02068)
"""import torch
import torch.nn as nndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class Sparsemax(nn.Module):"""Sparsemax function."""def __init__(self, dim=None):"""Initialize sparsemax activationArgs:dim (int, optional): The dimension over which to apply the sparsemax function."""super(Sparsemax, self).__init__()self.dim = -1 if dim is None else dimdef forward(self, input):"""Forward function.Args:input (torch.Tensor): Input tensor. First dimension should be the batch sizeReturns:torch.Tensor: [batch_size x number_of_logits] Output tensor"""# Sparsemax currently only handles 2-dim tensors,# so we reshape to a convenient shape and reshape back after sparsemaxinput = input.transpose(0, self.dim)original_size = input.size()input = input.reshape(input.size(0), -1)input = input.transpose(0, 1)dim = 1number_of_logits = input.size(dim)# Translate input by max for numerical stabilityinput = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)# Sort input in descending order.# (NOTE: Can be replaced with linear time selection method described here:# http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html)zs = torch.sort(input=input, dim=dim, descending=True)[0]range = torch.arange(start=1, end=number_of_logits + 1, step=1, device=device, dtype=input.dtype).view(1, -1)range = range.expand_as(zs)# Determine sparsity of projectionbound = 1 + range * zscumulative_sum_zs = torch.cumsum(zs, dim)is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())k = torch.max(is_gt * range, dim, keepdim=True)[0]# Compute threshold functionzs_sparse = is_gt * zs# Compute taustaus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / ktaus = taus.expand_as(input)# Sparsemaxself.output = torch.max(torch.zeros_like(input), input - taus)# Reshape back to original shapeoutput = self.outputoutput = output.transpose(0, 1)output = output.reshape(original_size)output = output.transpose(0, self.dim)return outputdef backward(self, grad_output):"""Backward function."""dim = 1nonzeros = torch.ne(self.output, 0)sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))return self.grad_input

*补充

经过苏剑林大佬的许多实验发现,Sparse Softmax只适用于有预训练的场景,因为预训练模型已经训练得很充分了,因此finetune阶段要防止过拟合;但是如果从零训练一个模型,那么Sparse Softmax会造成性能下降,因为每次只有kkk个类别被学习到,反而会存在学习不充分的情况(欠拟合)

References

  • SPACES:“抽取-生成”式长文本摘要(法研杯总结)
  • 稀疏序列到序列模型
  • 深度学习激活函数从Softmax到Sparsemax
  • GLU, sparsemax, GELU激活函数

稀疏Softmax(Sparse Softmax)相关推荐

  1. softmax、softmax损失函数、cross-entropy损失函数

    softmax softmax ,顾名思义,就是 soft 版本的 max. 在了解 softmax 之前,先看看什么是 hardmax. hardmax 就是直接选出一个最大值,例如 [1,2,3] ...

  2. ML之SR:Softmax回归(Softmax Regression)的简介、使用方法、案例应用之详细攻略

    ML之SR:Softmax回归(Softmax Regression)的简介.使用方法.案例应用之详细攻略 目录 Softmax回归的简介 Softmax回归的使用方法 Softmax回归的案例应用 ...

  3. c++稀疏表sparse table的实现算法(附完整源码)

    C++稀疏表sparse table的实现算法 C++稀疏表sparse table的实现算法完整源码(定义,实现,main函数测试) C++稀疏表sparse table的实现算法完整源码(定义,实 ...

  4. 层次softmax (hierarchical softmax)理解

    目录 1 前言 2 CBOW(Continuous Bag-of-Word) 2.1 One-word context 2.2 Multi-word context 3 Skip-gram 4 hie ...

  5. 卷积神经网络系列之softmax,softmax loss和cross entropy的讲解

    我们知道卷积神经网络(CNN)在图像领域的应用已经非常广泛了,一般一个CNN网络主要包含卷积层,池化层(pooling),全连接层,损失层等.虽然现在已经开源了很多深度学习框架(比如MxNet,Caf ...

  6. softmax,softmax loss和cross entropy

    我们知道卷积神经网络(CNN)在图像领域的应用已经非常广泛了,一般一个CNN网络主要包含卷积层,池化层(pooling),全连接层,损失层等.虽然现在已经开源了很多深度学习框架(比如MxNet,Caf ...

  7. 稀疏性(sparse)知识点

    稀疏性(sparse) 定义:Sparse表示为模型内的参数中,只用很少的几个非零元素或只有很少的几个远大于零的元素. WHY: 为什么模型中要包含稀疏性的特征呢? 例子:考研学霸有10000的词汇量 ...

  8. softmax和softmax loss详细解析

    本文转载于以下博文地址:https://blog.csdn.net/u014380165/article/details/77284921 如有冒犯,还望谅解! 我们知道卷积神经网络(CNN)在图像领 ...

  9. Softmax和softmax loss的理解

    转载博客链接:https://blog.csdn.net/u014380165/article/details/77284921 下图展示的是全连接层的计算: 这张图的等号左边部分就是全连接层做的事, ...

最新文章

  1. python主要运用于-python主要应用领域有哪些?看这一篇就够了
  2. 一个非常棒的jQuery 评分插件--好东西要分享
  3. 蓝桥杯练习系统习题-算法训练3
  4. 思科路由器全局、接口、协议调试(下)
  5. php 正则特殊字符转义,php 正则特殊字符转义的方法
  6. Android之Fragment(二)
  7. python运维开发培训_运维架构师-Python 自动化运维开发-014
  8. FlashPaper安装及使用方法
  9. Spring Boot 中文乱码问题解决方案汇总
  10. Google 向平板电脑彻底说再见!
  11. Open3d之点云顶点法线估计
  12. Gradient Boosting and GBDT
  13. Laravel 在哪些地方使用了 trait ?
  14. vb 获取设备音量_自制 Windows 10X 启动盘,提前体验微软折叠设备新系统
  15. JAVA获取汉字拼音首字母
  16. 视频分辨率过高,导致部分手机播放失败
  17. linux配置mac地址命令是什么,Linux环境下如何配置IP地址、MAC地址
  18. 为小米4与小米3 Mi3 Mi4编译Cyanogenmod 12.1与13.0 (CM12与CM13) 的步骤以及错误解决
  19. 百度飞桨小白逆袭大神之鲤鱼跃龙门
  20. 【强化记忆】生物选修三填空题考点强化记忆2-胚胎工程、安全伦理问题、生态工程——2017年2月25日...

热门文章

  1. Python语言是解释性语言还是编译性语言?
  2. centos5 双网卡重启network服务提升IP占用 ping不通
  3. 一只小蜜蜂,一万天纪念日,杨辉三角,洗牌
  4. 年会的槽,吐着吐着就习惯了……
  5. ElasticSearch immense term错误
  6. 自古成功在尝试 jzoj 2017.8.21 B组
  7. 防蓝光护眼灯哪个牌子比较好?目前比较好用的护眼灯推荐
  8. pandas错位计算
  9. IIS 设置文件传输大小限制
  10. deepin20系统选择手动安装盘_手把手教你安装Mac双系统