sampled softmax原论文:On Using Very Large Target Vocabulary for Neural Machine Translation
以及tensorflow关于candidate sampling的文档:candidate sampling


1. 问题背景

在神经机器翻译中,训练的复杂度以及解码的复杂度和词汇表的大小成正比。当输出的词汇表巨大时,传统的softmax由于要计算每一个类的logits就会有问题。在论文Neural Machine Translation by Jointly Learning to Align and Translate 中,带有attention的decoder中权重的公式如下:

其中的aaa为一个单层的前馈神经网络,根据αt\alpha_tαt​和输入的因状态,我们就可以得到一个context vector ctc_tct​。在decoder的t时刻,输出的目标词汇的概率可以使用如下公式计算:

其中,yt−1y_{t-1}yt−1​是上一个次的输出,ztz_tzt​为当前decoder的隐状态,ctc_tct​为context vector。
因为我们输出的是一个概率值,所以(6)式的归一化银子ZZZ的计算就需要将词汇表当中的logits都计算一遍,这个代价是很大的。
基于此,作者提出了一种采样的方法,使得我们在训练的时候,输出为原来输出的一个子集。(关于其它的解决方法,作者也有提,感兴趣的可以看原文,本篇博客只关注Sampled Softmax)

2. 解决方法

上面已经说过,计算归一化的因子ZZZ,因为所用的词太多造成复杂度的上升,那么原文的方法就是使用一个子集V′V'V′来近似的计算出ZZZ, 假设我们现在已经知道的这个子集,那么之前计算输出的概率公式就为:

好了,那么V′V'V′怎么取?

我们看看tensorflow中的文档吧: https://www.tensorflow.org/extras/candidate_sampling.pdf
对于Sampled Softmax的每一个训练样例(xi,{ti})(x_i,\{t_i\})(xi​,{ti​}),我们根据采样函数Q(y∣x)Q(y|x)Q(y∣x),从所有的输出集合中挑选一个小的子集SiS_{i}Si​。要求选择子集的函数和具体的训练样本无关。假设full softmax的输出全集为LLL, 那么在给定xix_ixi​的情况下,根据分布QQQ从LLL中抽取的子集似然函数为:

然后我们生成一个包含SiS_iSi​和训练目标类的候选集合VVV:
V′=Si∪tiV'=S_i \cup{t_i}V′=Si​∪ti​
之后我们的训练目标就是找出样本为V′V'V′的哪一个类别了。
(感觉还是tensorflow文档说的清楚一点,最初看论文的时候还以为是相当于把一个单词划分到最近的一个类,那样的话,应该会有不同类别的关系啊不然也不make sense啊,但是看tensorflow源码就只有采样的过程啊,笑cry)

3. tensorflow的实现

def sampled_softmax_loss(weights,biases,labels,inputs,num_sampled, # 每一个batch随机选择的类别num_classes, # 所有可能的类别num_true=1, #每一个sample的类别数量sampled_values=None,remove_accidental_hits=True,partition_strategy="mod",name="sampled_softmax_loss"):

tensorflow对于使用的建议:仅仅在训练阶段使用,在inference或者evaluation的时候还是需要使用full softmax。

原文:
This operation is for training only. It is generally an underestimate of
the full softmax loss.
A common use case is to use this method for training, and calculate the full softmax loss for evaluation or inference.

这个函数的主体主要调用了另外一个函数:

logits, labels = _compute_sampled_logits(weights=weights,biases=biases,labels=labels,inputs=inputs,num_sampled=num_sampled,num_classes=num_classes,num_true=num_true,sampled_values=sampled_values,subtract_log_q=True,remove_accidental_hits=remove_accidental_hits,partition_strategy=partition_strategy,name=name)

上述函数的返回值shape为:[batch_size, num_true + num_sampled]即可能的class为: Si∪tiS_i \cup{t_i}Si​∪ti​
而这个函数采样集合的代码如下:

sampled_values=candidate_sampling_ops.log_uniform_candidate_sampler(true_classes=labels,# 真实的labelnum_true=num_true,num_sampled=num_sampled, # 需要采样的子集大小unique=True,range_max=num_classes)

而这个函数主要是按照log-uniform distribution(Zipfian distribution)来采样出一个子集,Zipfian distribution
即Zipf法则,以下为Wikipedia关于Zipf’s law的解释:

Zipf’s law states that given some corpus of natural language utterances, the frequency of any word is inversely proportional to its rank in the frequency table.

Sampled Softmax相关推荐

  1. Sampled Softmax,你真的会用了吗?

    作者 | 夜小白 整理 | NewBeeNLP 前面两篇关于文本匹配的博客中,都用到了Sampled-softmax训练方法来加速训练. 基于表征(Representation)的文本匹配.信息检索. ...

  2. 一文讲懂召回中的 NCE NEG sampled softmax loss

    深度学习中与分类相关的问题都会涉及到softmax的计算.当目标类别较少时,直接用标准的softmax公式进行计算没问题,当目标类别特别多时,则需采用估算近似的方法简化softmax中归一化的计算. ...

  3. 文本匹配开山之作-DSSM论文笔记及源码阅读(类似于sampled softmax训练方式思考)

    文章目录 前言 DSSM框架简要介绍 模型结构 输入 Encoder层 相似度Score计算 训练方式解读 训练数据 训练目标 训练方式总结 DSSM源码阅读 训练数据中输入有负样本的情况 输入数据 ...

  4. 深度模型(七):Sampled Softmax

    Softmax 给定softmax的输入(z1,z2,...,zn)(z_1,z_2,...,z_n)(z1​,z2​,...,zn​),则输出为f(z1,f(z2),...,f(zn))f(z_1, ...

  5. 【机器学习】sampled softmax loss

    目录 1.前置知识softmax loss 2.sampled softmax 1.1.问题引入 1.2.如何通俗理解sampled softmax机制? 3.sampled softmax loss ...

  6. Tensorflow的负采样函数Sampled softmax loss学习笔记

    最近阅读了YouTube的推荐系统论文,在代码实现中用到的负采样方法我比较疑惑,于是查了大量资料,总算能够读懂关于负采样的一些皮毛. 本文主要针对tf.nn.sampled_softmax_loss这 ...

  7. Tensorflow的负采样函数Sampled softmax loss踩坑之旅

    谷歌16年出的论文<Deep Neural Networks for Youtube Recommendation>中提到文章采用了负采样的思想来进行extreme multiclass分 ...

  8. Tensorflow之负采样函数Sampled softmax loss

    Tensorflow之负采样函数Sampled softmax loss 谷歌16年出的论文<Deep Neural Networks for Youtube Recommendation> ...

  9. yolo-mask的损失函数l包含三部分_损失函数总结-应用和trick

    常见的损失函数,如交叉熵损失.平方误差损失.Hinge损失等并不是本文的重点,关于这些损失函数的介绍网上很多,可以参考如下几篇文章 机器学习中的 7 大损失函数实战总结 常见的损失函数(loss fu ...

最新文章

  1. SLAM基础:相机与图像
  2. php 数组键值分离,array_keys array_values::PHP数组键名于键值分离
  3. 用c语言设计一个统计字符个数的程序,「第6篇」「C程序上机题」「统计输入的字符个数思路与实现」...
  4. Mac下Idea安装Git报错Xcrun问题的解决
  5. zipsys驱动签名工具_全球首发 300系列主板USB WIN7 64位驱动 SMXDIY
  6. CMakeList.txt中设置一个可变的变量的值(bool)
  7. c语言bellman算法,求 最短路径中BELLMAN FORD算法实现的C程序
  8. javascript 对象属性
  9. CSS你可以不写,但这些规范必须要知道!
  10. 0603封装 1%贴片电阻代码表示的阻值
  11. PS常用快捷键操作记录
  12. [重要]宝塔面板Linux7.4.3/Windows6.8紧急更新
  13. python结果不能全部显示_numpy矩阵数值太多不能全部显示的解决
  14. sis防屏蔽程序_屏蔽机房设计方案知识
  15. Jordan标准形(番外篇)——线性变换可对角化和最小多项式的关系
  16. 脉冲响应与频率响应的关系
  17. 阿里云服务器|centos查看并发数调优
  18. 《Arduino实验》实验二:DHT11温湿度传感器检测当前环境温湿度
  19. 【转发】微博feed系统的推(push)模式和拉(pull)模式和时间分区拉模式架构探讨
  20. SignalR 2.0 系列: SignalR 自托管主机

热门文章

  1. C语言笔记--代码学习笔记--C语言语法--基本操作运算-basic-logorithm
  2. React引入ant-design实现正在加载效果
  3. 2023自动化毕业设计选题
  4. PG10 Vacuum监控进度
  5. 调度算法的定义与使用价值
  6. 使用windows10内置的OpenSSH密钥登录Linux服务器
  7. Android 实现答题、做题功能包含(多选、单选、材料、填空 、判断 、问答 )以及题卡交卷查看解析功能
  8. 调用百度地图显示周围方圆100米、500米、1000米附近的餐馆宾馆酒店及公交站点API接口
  9. 用Qt写一个简单的音乐播放器(四):歌曲浏览、上一曲、下一曲
  10. MVC的理解和优缺点的总结