Sampled Softmax
sampled softmax原论文:On Using Very Large Target Vocabulary for Neural Machine Translation
以及tensorflow关于candidate sampling的文档:candidate sampling
1. 问题背景
在神经机器翻译中,训练的复杂度以及解码的复杂度和词汇表的大小成正比。当输出的词汇表巨大时,传统的softmax由于要计算每一个类的logits就会有问题。在论文Neural Machine Translation by Jointly Learning to Align and Translate 中,带有attention的decoder中权重的公式如下:
其中的aaa为一个单层的前馈神经网络,根据αt\alpha_tαt和输入的因状态,我们就可以得到一个context vector ctc_tct。在decoder的t时刻,输出的目标词汇的概率可以使用如下公式计算:
其中,yt−1y_{t-1}yt−1是上一个次的输出,ztz_tzt为当前decoder的隐状态,ctc_tct为context vector。
因为我们输出的是一个概率值,所以(6)式的归一化银子ZZZ的计算就需要将词汇表当中的logits都计算一遍,这个代价是很大的。
基于此,作者提出了一种采样的方法,使得我们在训练的时候,输出为原来输出的一个子集。(关于其它的解决方法,作者也有提,感兴趣的可以看原文,本篇博客只关注Sampled Softmax)
2. 解决方法
上面已经说过,计算归一化的因子ZZZ,因为所用的词太多造成复杂度的上升,那么原文的方法就是使用一个子集V′V'V′来近似的计算出ZZZ, 假设我们现在已经知道的这个子集,那么之前计算输出的概率公式就为:
好了,那么V′V'V′怎么取?
我们看看tensorflow中的文档吧: https://www.tensorflow.org/extras/candidate_sampling.pdf
对于Sampled Softmax的每一个训练样例(xi,{ti})(x_i,\{t_i\})(xi,{ti}),我们根据采样函数Q(y∣x)Q(y|x)Q(y∣x),从所有的输出集合中挑选一个小的子集SiS_{i}Si。要求选择子集的函数和具体的训练样本无关。假设full softmax的输出全集为LLL, 那么在给定xix_ixi的情况下,根据分布QQQ从LLL中抽取的子集似然函数为:
然后我们生成一个包含SiS_iSi和训练目标类的候选集合VVV:
V′=Si∪tiV'=S_i \cup{t_i}V′=Si∪ti
之后我们的训练目标就是找出样本为V′V'V′的哪一个类别了。
(感觉还是tensorflow文档说的清楚一点,最初看论文的时候还以为是相当于把一个单词划分到最近的一个类,那样的话,应该会有不同类别的关系啊不然也不make sense啊,但是看tensorflow源码就只有采样的过程啊,笑cry)
3. tensorflow的实现
def sampled_softmax_loss(weights,biases,labels,inputs,num_sampled, # 每一个batch随机选择的类别num_classes, # 所有可能的类别num_true=1, #每一个sample的类别数量sampled_values=None,remove_accidental_hits=True,partition_strategy="mod",name="sampled_softmax_loss"):
tensorflow对于使用的建议:仅仅在训练阶段使用,在inference或者evaluation的时候还是需要使用full softmax。
原文:
This operation is for training only. It is generally an underestimate of
the full softmax loss.
A common use case is to use this method for training, and calculate the full softmax loss for evaluation or inference.
这个函数的主体主要调用了另外一个函数:
logits, labels = _compute_sampled_logits(weights=weights,biases=biases,labels=labels,inputs=inputs,num_sampled=num_sampled,num_classes=num_classes,num_true=num_true,sampled_values=sampled_values,subtract_log_q=True,remove_accidental_hits=remove_accidental_hits,partition_strategy=partition_strategy,name=name)
上述函数的返回值shape为:[batch_size, num_true + num_sampled]即可能的class为: Si∪tiS_i \cup{t_i}Si∪ti
而这个函数采样集合的代码如下:
sampled_values=candidate_sampling_ops.log_uniform_candidate_sampler(true_classes=labels,# 真实的labelnum_true=num_true,num_sampled=num_sampled, # 需要采样的子集大小unique=True,range_max=num_classes)
而这个函数主要是按照log-uniform distribution(Zipfian distribution)来采样出一个子集,Zipfian distribution
即Zipf法则,以下为Wikipedia关于Zipf’s law的解释:
Zipf’s law states that given some corpus of natural language utterances, the frequency of any word is inversely proportional to its rank in the frequency table.
Sampled Softmax相关推荐
- Sampled Softmax,你真的会用了吗?
作者 | 夜小白 整理 | NewBeeNLP 前面两篇关于文本匹配的博客中,都用到了Sampled-softmax训练方法来加速训练. 基于表征(Representation)的文本匹配.信息检索. ...
- 一文讲懂召回中的 NCE NEG sampled softmax loss
深度学习中与分类相关的问题都会涉及到softmax的计算.当目标类别较少时,直接用标准的softmax公式进行计算没问题,当目标类别特别多时,则需采用估算近似的方法简化softmax中归一化的计算. ...
- 文本匹配开山之作-DSSM论文笔记及源码阅读(类似于sampled softmax训练方式思考)
文章目录 前言 DSSM框架简要介绍 模型结构 输入 Encoder层 相似度Score计算 训练方式解读 训练数据 训练目标 训练方式总结 DSSM源码阅读 训练数据中输入有负样本的情况 输入数据 ...
- 深度模型(七):Sampled Softmax
Softmax 给定softmax的输入(z1,z2,...,zn)(z_1,z_2,...,z_n)(z1,z2,...,zn),则输出为f(z1,f(z2),...,f(zn))f(z_1, ...
- 【机器学习】sampled softmax loss
目录 1.前置知识softmax loss 2.sampled softmax 1.1.问题引入 1.2.如何通俗理解sampled softmax机制? 3.sampled softmax loss ...
- Tensorflow的负采样函数Sampled softmax loss学习笔记
最近阅读了YouTube的推荐系统论文,在代码实现中用到的负采样方法我比较疑惑,于是查了大量资料,总算能够读懂关于负采样的一些皮毛. 本文主要针对tf.nn.sampled_softmax_loss这 ...
- Tensorflow的负采样函数Sampled softmax loss踩坑之旅
谷歌16年出的论文<Deep Neural Networks for Youtube Recommendation>中提到文章采用了负采样的思想来进行extreme multiclass分 ...
- Tensorflow之负采样函数Sampled softmax loss
Tensorflow之负采样函数Sampled softmax loss 谷歌16年出的论文<Deep Neural Networks for Youtube Recommendation> ...
- yolo-mask的损失函数l包含三部分_损失函数总结-应用和trick
常见的损失函数,如交叉熵损失.平方误差损失.Hinge损失等并不是本文的重点,关于这些损失函数的介绍网上很多,可以参考如下几篇文章 机器学习中的 7 大损失函数实战总结 常见的损失函数(loss fu ...
最新文章
- SLAM基础:相机与图像
- php 数组键值分离,array_keys array_values::PHP数组键名于键值分离
- 用c语言设计一个统计字符个数的程序,「第6篇」「C程序上机题」「统计输入的字符个数思路与实现」...
- Mac下Idea安装Git报错Xcrun问题的解决
- zipsys驱动签名工具_全球首发 300系列主板USB WIN7 64位驱动 SMXDIY
- CMakeList.txt中设置一个可变的变量的值(bool)
- c语言bellman算法,求 最短路径中BELLMAN FORD算法实现的C程序
- javascript 对象属性
- CSS你可以不写,但这些规范必须要知道!
- 0603封装 1%贴片电阻代码表示的阻值
- PS常用快捷键操作记录
- [重要]宝塔面板Linux7.4.3/Windows6.8紧急更新
- python结果不能全部显示_numpy矩阵数值太多不能全部显示的解决
- sis防屏蔽程序_屏蔽机房设计方案知识
- Jordan标准形(番外篇)——线性变换可对角化和最小多项式的关系
- 脉冲响应与频率响应的关系
- 阿里云服务器|centos查看并发数调优
- 《Arduino实验》实验二:DHT11温湿度传感器检测当前环境温湿度
- 【转发】微博feed系统的推(push)模式和拉(pull)模式和时间分区拉模式架构探讨
- SignalR 2.0 系列: SignalR 自托管主机