为什么 dot-product attention 需要被 scaled?
前言
注意力机制也有很多种类,不同的注意力机制对应着不同的对齐分数(alignment score)计算方式。有关注意力机制的总结,大家可以看看这篇博客:Attention? Attention!
在 Attention Is All You Need 这篇论文中,有提到两种较为常见的注意力机制:additive attention 和 dot-product attention。并讨论到,当 query 和 key 向量维度 dkd_kdk 较小时,这两种注意力机制效果相当,但当 dkd_kdk 较大时,additive attention 要优于 dot-product attention. 但是 dot-product attention 在计算方面更具有优势。为了利用 dot-product attention 的优势且消除当 dkd_kdk 较大时 dot-product attention 的不足,原文采用 scaled dot-product attention。
正文
那造成这种情况(但当 dkd_kdk 较大时,additive attention 要优于 dot-product attention)的原因是什么?下面是原论文中的解释(当 dkd_kdk 较大时,向量内积的值也会容易变得很大,这时 softmax 函数的梯度会非常的小)。
We suspect that for large values of dkd_kdk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely samll gradients.
我们知道,计算完各个 key 的对齐分数后需要将所有 key 的对齐分数输入到 softmaxsoftmaxsoftmax 激活函数中,得到规范化的注意力权重。
dot-product attention 中的对齐分数的计算公式为:
score(q,k)=qTkscore(q, k) = q^T k score(q,k)=qTk
先解释:为什么当 dkd_kdk 较大时,向量内积容易取很大的值(借用原论文的注释)
假设 query 和 key 向量中的元素都是相互独立的均值为 0,方差为 1 的随机变量,那么这两个向量的内积 qTk=∑i=1dkqikiq^T k = \sum_{i=1}^{d_k} q_ik_iqTk=∑i=1dkqiki 的均值为 0,而方差为 dkd_kdk.
证明:
已知 E[qi]=E[ki]=0,Var(qi)=Var(ki)=1\text{E}[q_i] = \text{E}[k_i] = 0,\ \text{Var}(q_i)=\text{Var}(k_i)=1E[qi]=E[ki]=0, Var(qi)=Var(ki)=1.
由于 qiq_iqi 与 kik_iki 相互独立,则两者的协方差为 0:
Cov(qi,ki)=E[(qi−E[qi])(ki−E[ki])]=E[qiki]−E[qi]E[ki]=0\begin{aligned} \text{Cov}(q_i,k_i) &= \text{E}\left[\left(q_i-\text{E}[q_i]\right)\left(k_i-\text{E}[k_i]\right)\right] \\ &= \text{E}[q_ik_i] - \text{E}[q_i] \text{E}[k_i] \\ &= 0 \end{aligned} Cov(qi,ki)=E[(qi−E[qi])(ki−E[ki])]=E[qiki]−E[qi]E[ki]=0
故得 E[qiki]=E[qi]E[ki]=0\text{E}[q_ik_i] = \text{E}[q_i] \text{E}[k_i] = 0E[qiki]=E[qi]E[ki]=0.
对于方差,有:
Var(qi)=E[qi2]−(E[qi])2=E[qi2]=1Var(ki)=E[ki2]=1\begin{aligned} \text{Var}(q_i) &= \text{E}[q_i^2] - (\text{E}[q_i])^2\\ &= \text{E}[q_i^2] \\ &= 1 \\ \text{Var}(k_i) &= \text{E}[k_i^2] = 1 \end{aligned} Var(qi)Var(ki)=E[qi2]−(E[qi])2=E[qi2]=1=E[ki2]=1
故得:
Var(qiki)=E[(qiki)2]−(E[qiki])2=E[qi2]E[ki2]−(E[qi]E[ki])2=Var(qi)Var(ki)=1\begin{aligned} \text{Var}(q_ik_i) &= \text{E}[(q_ik_i)^2] - (\text{E}[q_ik_i])^2 \\ &= \text{E}[q_i^2]\text{E}[k_i^2] - (\text{E}[q_i] \text{E}[k_i])^2 \\ & = \text{Var}(q_i)\text{Var}(k_i) \\ & = 1 \end{aligned} Var(qiki)=E[(qiki)2]−(E[qiki])2=E[qi2]E[ki2]−(E[qi]E[ki])2=Var(qi)Var(ki)=1
由于对于两个相互独立的随机变量有如下定义:
E[X+Y]=E[X]+E[Y]Var(X+Y)=Var(X)+Var(Y)+2Cov(X,Y)=Var(X)+Var(Y)\begin{aligned} &\text{E}[X+Y] = \text{E}[X] +\text{E}[Y]\\ &\text{Var(X+Y)} = \text{Var(X)} + \text{Var(Y)} + 2\text{Cov}(X,Y) \\ &\qquad \qquad \ \ \ =\text{Var(X)} + \text{Var(Y)} \end{aligned} E[X+Y]=E[X]+E[Y]Var(X+Y)=Var(X)+Var(Y)+2Cov(X,Y) =Var(X)+Var(Y)
综上,可得:
E[qTk]=∑i=1dkE[qiki]=0Var(qTk)=∑i=1dkVar(qiki)=dk\begin{aligned} &\text{E}[q^T k ] = \sum_{i=1}^{d_k} \text{E}[q_ik_i] = 0\\ &\text{Var}(q^T k) = \sum_{i=1}^{d_k} \text{Var}(q_ik_i) = d_k \end{aligned} E[qTk]=i=1∑dkE[qiki]=0Var(qTk)=i=1∑dkVar(qiki)=dk
所以,可以看出,当 dkd_kdk 较大时,qTkq^TkqTk 的方差较大,不同的 key 与同一个 query 算出的对齐分数可能会相差很大,有的远大于 0,有的则远小于 0.
再解释:向量内积的值(对齐分数)较大时,softmax 函数梯度很小
先介绍一下 softmax 函数:
softmaxsoftmaxsoftmax 函数是 logistic (或 sigmoid)函数在多类问题上的引申(有关于 sigmoid 函数的信息可查看我的另一篇博客),记为 SSS,其公式为:
S(xi)=exi∑j=0nexjS(x_i) = \frac{e^{x_i}}{\sum_{j=0}^n e^{x_j}} S(xi)=∑j=0nexjexi
对 S(xi)S(x_i)S(xi) 求偏导,可得:
∂∂xiS(xi)=S(xi)(1−S(xi))∂∂xjS(xi)=−S(xi)S(xj)\begin{aligned} \frac{\partial}{\partial x_i} S(x_i) &= S(x_i)(1-S(x_i)) \\ \frac{\partial}{\partial x_j} S(x_i) &= -S(x_i)S(x_j) \end{aligned} ∂xi∂S(xi)∂xj∂S(xi)=S(xi)(1−S(xi))=−S(xi)S(xj)
从上面的结果可以看出:
- 当 xix_ixi 相对于其他的 xj(j≠i)x_j(j \neq i)xj(j=i) 特别大时,S(xi)S(x_i)S(xi) 趋近于 1,则 ∂∂xiS(xi)\frac{\partial}{\partial x_i} S(x_i)∂xi∂S(xi) 和 ∂∂xiS(xj)\frac{\partial}{\partial x_i} S(x_j)∂xi∂S(xj) 都趋近于 0.
- 当 xix_ixi 相对较小时,S(xi)S(x_i)S(xi) 趋近于 0,则 ∂∂xiS(xi)\frac{\partial}{\partial x_i} S(x_i)∂xi∂S(xi) 和 ∂∂xiS(xj)\frac{\partial}{\partial x_i} S(x_j)∂xi∂S(xj) 也都趋近于 0.
也就是,当 xix_ixi 趋于 0 或 1 时,上述的两种偏导数都趋于零。
现在,我们就可以把这里的 xix_ixi 替换成前一部分讲到的 query 和 key 向量的内积 qTkq^T kqTk 了。
在前一部分我们有得出结论:当 dkd_kdk 较大时,qTkq^TkqTk 的方差较大,不同的 key 与同一个 query 算出的对齐分数可能会相差很大,有的远大于 0,有的则远小于 0.
所以,当 dkd_kdk 较大时,很有可能存在某个 key,其与 query 计算出来的对齐分数远大于其他的 key 与该 query 算出的对齐分数。这时, softmaxsoftmaxsoftmax 函数对各个 qTkq^TkqTk 的偏导数都趋于 0.
其结果就是, softmaxsoftmaxsoftmax 函数梯度过低(趋于零),使得模型误差反向传播(back-propagation)经过 softmaxsoftmaxsoftmax 函数后无法继续传播到模型前面部分的参数上,造成这些参数无法得到更新,最终影响模型的训练效率。
那么如何消除如上 dot-product attention 的问题呢?一种方法就是论文中的对 dot-product attention 进行缩放(除以 dk\sqrt{d_k}dk),获得 scaled dot-product attention。其对齐分数的计算公式为:
score(q,k)=qTkdkscore(q, k) = \frac{q^T k}{\sqrt{d_k}} score(q,k)=dkqTk
根据方差的计算法则:Var(kx)=k2Var(x)\text{Var}(kx) = k^2\text{Var}(x)Var(kx)=k2Var(x),可知缩放后,score(q,k)score(q,k)score(q,k) 的方差由原来的 dkd_kdk 缩小到了 1. 这就消除了 dot-product attention 在 dkd_kdk 较大时遇到的问题。这时,softmax 函数的梯度就不容易趋近于零了。
这就是为什么 dot-product attention 需要被 scaled.
总结
本博客基于随机变量的期望和方差以及 softmaxsoftmaxsoftmax 函数的性质,详细说明了——为什么 dot-product attention 需要被 scaled.
参考源
- Attention Is All You Need
- Attention? Attention!
推荐资源(Transformer 相关)
- The Illustrated Transformer(概念上)
- The Annotated Transformer(代码实现上)
为什么 dot-product attention 需要被 scaled?相关推荐
- 【源码解读】Transformer的Scaled dot product部分详解
def attention(query, key, value, mask=None, dropout=None):# shape:query=key=value---->[batch_size ...
- CUDA Samples: dot product(使用零拷贝内存)
以下CUDA sample是分别用C++和CUDA实现的点积运算code,CUDA包括普通实现和采用零拷贝内存实现两种,并对其中使用到的CUDA函数进行了解说,code参考了<GPU高性能编程C ...
- FB面经Prepare: Dot Product
Conduct Dot Product of two large Vectors 1. two pointers 2. hashmap 3. 如果没有额外空间,如果一个很大,一个很小,适合scan小的 ...
- 向量点积(Dot Product),向量叉积(Cross Product)
参考的是<游戏和图形学的3D数学入门教程>,非常不错的书,推荐阅读,老外很喜欢把一个东西解释的很详细. 1.向量点积(Dot Product) 向量点积的结果有什么意义?事实上,向量的点积 ...
- dot product【点积】
(1)概念 点积在数学中,又称数量积(dot product; scalar product),是指接受在实数R上的两个向量并返回一个实数值标量的二元运算. 两个向量a = [a1, a2,-, an ...
- 15.计算几何:点积(Dot product)与叉积(Cross product)
向量的内容只是前置知识,现在要讲的点积与叉积才是重点! 向量的基本运算是点积和叉积.计算几何的各种操作,几乎都基于这两种运算 文章目录 1.点积(Dot product) 1.1点积的定义[dot() ...
- CUDA Samples: Dot Product
以下CUDA sample是分别用C++和CUDA实现的两个非常大的向量实现点积操作,并对其中使用到的CUDA函数进行了解说,各个文件内容如下: common.hpp: #ifndef FBC_CUD ...
- 向量点积(Dot Product)
http://www.evernote.com/shard/s146/sh/e0d95bd1-68df-49d9-87c8-e21647d94e18/4d7af393bd986fd0c462ebd13 ...
- Scaled dot-product Attention、Self-Attention辨析
一.Scaled dot-product Attention 有两个序列X.YX.YX.Y:序列XXX提供查询信息QQQ,序列YYY提供键.值信息K.VK.VK.V.Q∈Rx_len×in_dimQ\ ...
最新文章
- python项目开发:ftp server开发
- C语言 解决4996警告
- sjms-3 结构型模式
- CSS:你真的会用 z-index 吗?
- Linux的system和popen的差异
- WIKIOI 1519 过路费
- Educational Codeforces Round 18
- Tomcat的startup.bat启动闪退解决办法
- python接口自动化(三十六)-封装与调用--流程类接口关联续集(详解)
- 测试方法-等价类划分法
- 计算机二级lookup函数,LOOKUP函数用法全解(下)——LOOKUP函数的二分法原理
- 万用表怎么测量电池容量_万用表怎么测量12v电瓶(用万用表测电瓶电量怎么测?)...
- BTA 常问的 Java基础39道常见面试题
- C++笔记:奇葩排序之猴子排序、珠排序、面条排序
- cad2017单段线_AutoCAD2017命令总结
- 请查收!顶会AAAI 2020录用论文之自然语言处理篇
- 重卡自动驾驶进入“正规战”
- 重装系统后安装的软件
- NLP系列(4)_朴素贝叶斯实战与进阶
- 使用Assimp库读取mtl文件数据
热门文章
- 视频教程-HTML5+CSS3项目实战详解-HTML5/CSS
- 大学计算机基础毕业论文答案,计算机本科论文范文
- 如何将MP4转换为MP3?四种简单易行的方法!
- linux系统调整屏幕亮的时间,Linux系统的电脑上调整屏幕亮度的方法
- 余弦知乎living
- 用VB.NET(Visual Basic 2010)封装EXCEL VBA为DLL_COM组件(一)
- 机器学习(聚类四)——K-Means的优化算法
- CleanMyMac X4.12.1苹果电脑系统优化软件更新功能介绍
- 数字信号处理——Python实现快速傅里叶变换FFT
- 国内首例视频聚合App侵权案仅赔1.4万元