Transformer详解(附代码)
引言
Transformer\mathrm{Transformer}Transformer模型是Google\mathrm{Google}Google团队在201720172017年666月由AshishVaswani\mathrm{Ashish\text{ }Vaswani}AshishVaswani等人在论文《AttentionIsAllYouNeed\mathrm{Attention\text{ }Is\text{ }All \text{ }You \text{ } Need}AttentionIsAllYouNeed》所提出,当前它已经成为NLP\mathrm{NLP}NLP领域中的首选模型。Transformer\mathrm{Transformer}Transformer抛弃了RNN\mathrm{RNN}RNN的顺序结构,采用了Self\mathrm{Self}Self-Attention\mathrm{Attention}Attention机制,使得模型可以并行化训练,而且能够充分利用训练资料的全局信息,加入Transformer\mathrm{Transformer}Transformer的Seq2seq\mathrm{Seq2seq}Seq2seq模型在NLP\mathrm{NLP}NLP的各个任务上都有了显著的提升。本文做了大量的图示目的是能够更加清晰地讲解Transformer\mathrm{Transformer}Transformer的运行原理,以及相关组件的操作细节,文末还有完整可运行的代码示例。
注意力机制
Transformer\mathrm{Transformer}Transformer中的核心机制就是Self\mathrm{Self}Self-Attention\mathrm{Attention}Attention。Self\mathrm{Self}Self-Attention\mathrm{Attention}Attention机制的本质来自于人类视觉注意力机制。当人视觉在感知东西时候往往会更加关注某个场景中显著性的物体,为了合理利用有限的视觉信息处理资源,人需要选择视觉区域中的特定部分,然后集中关注它。注意力机制主要目的就是对输入进行注意力权重的分配,即决定需要关注输入的哪部分,并对其分配有限的信息处理资源给重要的部分。
Self-Attention
Self\mathrm{Self}Self-Attention\mathrm{Attention}Attention工作原理如上图所示,给定输入wordembedding\mathrm{word\text{ }embedding}wordembedding向量a1,a2,a3∈Rdl×1a^1,a^2,a^3 \in \mathbb{R}^{d_l \times 1}a1,a2,a3∈Rdl×1,然后对于输入向量ai,i∈{1,2,3}a^i,i\in \{1,2,3\}ai,i∈{1,2,3}通过矩阵Wq∈Rdk×dl,Wk∈Rdk×dl,Wv∈Rdl×dlW^q\in \mathbb{R}^{d_k \times d_l},W^k\in \mathbb{R}^{d_k \times d_l},W^v\in \mathbb{R}^{d_l\times d_l}Wq∈Rdk×dl,Wk∈Rdk×dl,Wv∈Rdl×dl进行线性变换得到Query\mathrm{Query}Query向量qi∈Rdk×1q^i\in\mathbb{R}^{d_k \times 1}qi∈Rdk×1,Key\mathrm{Key}Key向量ki∈Rdk×1k^i\in \mathbb{R}^{d_k \times 1}ki∈Rdk×1,以及Value\mathrm{Value}Value向量vi∈Rdl×1v^i\in \mathbb{R}^{d_l \times 1}vi∈Rdl×1,即{qi=Wq⋅aiki=Wk⋅ai,i∈{1,2,3}vi=Wv⋅ai\left\{\begin{aligned}q^i&=W^q \cdot a^i\\k^i&=W^k \cdot a^i,\quad i\in\{1,2,3\}\\v^i&=W^v \cdot a^i\end{aligned}\right.⎩
Multi-Head Attention
Multi\mathrm{Multi}Multi-HeadAttention\mathrm{Head\text{ }Attention}HeadAttention的工作原理与Self\mathrm{Self}Self-Attention\mathrm{Attention}Attention的工作原理非常类似。为了方便图解可视化将Multi\mathrm{Multi}Multi-Head\mathrm{Head}Head设置为222-Head\mathrm{Head}Head,如果Multi\mathrm{Multi}Multi-Head\mathrm{Head}Head设置为888-Head\mathrm{Head}Head,则上图的qi,ki,vi,i∈{1,2,3}q^i,k^i,v^i,i\in\{1,2,3\}qi,ki,vi,i∈{1,2,3}的下一步的分支数为888。给定输入wordembedding\mathrm{word\text{ }embedding}wordembedding向量a1,a2,a3∈Rdl×1a^1,a^2,a^3 \in \mathbb{R}^{d_l \times 1}a1,a2,a3∈Rdl×1,然后对于输入向量ai,i∈{1,2,3}a^i,i\in \{1,2,3\}ai,i∈{1,2,3}通过矩阵Wq∈Rdk×dl,Wk∈Rdk×dl,Wv∈Rdl×dlW^q\in \mathbb{R}^{d_k \times d_l},W^k\in \mathbb{R}^{d_k \times d_l},W^v\in \mathbb{R}^{d_l\times d_l}Wq∈Rdk×dl,Wk∈Rdk×dl,Wv∈Rdl×dl进行第一次线性变换得到Query\mathrm{Query}Query向量qi∈Rdk×1q^i\in\mathbb{R}^{d_k \times 1}qi∈Rdk×1,Key\mathrm{Key}Key向量ki∈Rdk×1k^i \in\mathbb{R}^{d_k \times 1}ki∈Rdk×1,以及Value\mathrm{Value}Value向量vi∈Rdl×1v^i \in\mathbb{R}^{d_l \times 1}vi∈Rdl×1。然后再对Query\mathrm{Query}Query向量qiq^iqi通过矩阵Wq1∈Rdm×dkW^{q1}\in \mathbb{R}^{d_m \times d_k}Wq1∈Rdm×dk和Wq2∈Rdm×dkW^{q2}\in \mathbb{R}^{d_m\times d_k}Wq2∈Rdm×dk进行第二次线性变换得到qi1∈Rdm×1q^{i1}\in \mathbb{R}^{d_m \times 1}qi1∈Rdm×1和qi2∈Rdm×1q^{i2}\in \mathbb{R}^{d_m\times 1}qi2∈Rdm×1,同理对Key\mathrm{Key}Key向量kik^iki通过矩阵Wk1∈Rdm×dkW^{k1}\in \mathbb{R}^{d_m \times d_k}Wk1∈Rdm×dk和Wk2∈Rdm×dkW^{k2}\in \mathbb{R}^{d_m\times d_k}Wk2∈Rdm×dk进行第二次线性变换得到ki1∈Rdm×1k^{i1}\in \mathbb{R}^{d_m\times 1}ki1∈Rdm×1和ki2∈Rdm×1k^{i2}\in \mathbb{R}^{d_m\times 1}ki2∈Rdm×1,对Value\mathrm{Value}Value向量viv^ivi通过矩阵Wv1∈Rdl2×dlW^{v1}\in \mathbb{R}^{\frac{d_l}{2}\times d_l}Wv1∈R2dl×dl和Wv2∈Rdl2×dlW^{v2}\in \mathbb{R}^{\frac{d_l}{2}\times d_l}Wv2∈R2dl×dl进行第二次线性变换得到vi1∈Rdl2×1v^{i1}\in \mathbb{R}^{\frac{d_l}{2}\times 1}vi1∈R2dl×1和vi2∈Rdl2×1v^{i2}\in \mathbb{R}^{\frac{d_l}{2}\times 1}vi2∈R2dl×1,具体的计算公式如下所示:{qih=Wqh⋅Wq⋅aikih=Wkh⋅Wk⋅ai,i={1,2,3},h={1,2}vih=Wvh⋅Wv⋅ai\left\{\begin{aligned}q^{ih}&=W^{qh}\cdot W^{q} \cdot a^i\\ k^{ih}&=W^{kh}\cdot W^{k} \cdot a^i,\quad i=\{1,2,3\},\quad h=\{1,2\}\\v^{ih}&=W^{vh}\cdot W^{v} \cdot a^i\end{aligned}\right.⎩
Mask Self-Attention
如下图左半部分所示,Self\mathrm{Self}Self-Attention\mathrm{Attention}Attention的输出向量bi,i∈{1,2,3,4}b^i, i \in \{1,2,3,4\}bi,i∈{1,2,3,4}综合了输入向量ai,i∈{1,2,3,4}a^i, i \in \{1,2,3,4\}ai,i∈{1,2,3,4}的全部信息,由此可见,Self\mathrm{Self}Self-Attention\mathrm{Attention}Attention在实际编程中支持并行运算。如下图右半部分所示,MaskSelf\mathrm{Mask \text{ } Self}MaskSelf-Attention\mathrm{Attention}Attention的输出向量bib^ibi只利用了已知部分输入的向量aia^iai的信息。例如,b1b1b1只是与a1a^1a1有关;b2b^2b2与a1a^1a1和a2a^2a2有关;b3b^3b3与a1a^1a1,a2a^2a2和a3a^3a3有关;b4b^4b4与a1a^1a1,a2a^2a2,a3a^3a3和a4a^4a4有关。MaskSelf\mathrm{Mask \text{ } Self}MaskSelf-Attention\mathrm{Attention}Attention在Transformer\mathrm{Transformer}Transformer中被用到过两次。
- Transformer\mathrm{Transformer}Transformer的Encoder\mathrm{Encoder}Encoder中如果输入一句话的word\mathrm{word}word长度小于指定的长度,为了能够让长度一致往往会用000进行填充,此时则需要用MaskSelf\mathrm{Mask \text{ } Self}MaskSelf-Attention\mathrm{Attention}Attention来计算注意力分布。
- Transformer\mathrm{Transformer}Transformer的Decoder\mathrm{Decoder}Decoder的输出是有时序关系的,当前的输出只与之前的输入有关,所以此时算注意力分布时需要用到MaskSelf\mathrm{Mask \text{ } Self}MaskSelf-Attention\mathrm{Attention}Attention。
Transformer模型
以上对Transformer\mathrm{Transformer}Transformer中的核心内容即自注意力机制进行了详细解剖,接下来会对Transformer\mathrm{Transformer}Transformer模型架构进行介绍。Transformer\mathrm{Transformer}Transformer模型是由Encoder\mathrm{Encoder}Encoder和Decoder\mathrm{Decoder}Decoder两个模块组成,具体的示意图如下所示,为了能够对Transformer\mathrm{Transformer}Transformer内部的操作细节进行更清晰的展示,下图以矩阵运算的视角对Transformer\mathrm{Transformer}Transformer的原理进行讲解。
Encoder\mathrm{Encoder}Encoder模块操作的具体流程如下所示:
- Encoder\mathrm{Encoder}Encoder的输入由两部分组成分别是词编码矩阵I∈Rn×l×dI \in \mathbb{R}^{n \times l \times d}I∈Rn×l×d和位置编码矩阵P∈Rn×l×dP \in \mathbb{R}^{n \times l \times d}P∈Rn×l×d,其中nnn表示句子数目,lll表示一句话单词的最大数目,ddd表示的是词向量的维度。位置编码矩阵PPP表示的是每个单词在一句里的所有位置信息,因为Self\mathrm{Self}Self-Attention\mathrm{Attention}Attention计算注意力分布的时候只能给出输出向量和输入向量之间的权重关系,但是不能给出词在一句话里的位置信息,所以需要在输入里引入位置编码矩阵PPP。位置编码向量生成方法有很多。一种比较简单粗暴的方式就是根据单词在句子中的位置生成一个one\mathrm{one}one-hot\mathrm{hot}hot的位置编码;还有的方法是将位置编码当成参数进行训练学习;在该论文里是利用三角函数对位置进行编码,具体的公式如下所示PE(pos,2i)=sin(pos10002i/d),PE(pos,2i+1)=cos(pos10002i/d)\mathrm{PE}(pos,2i)=\sin(\frac{pos}{1000^{2i/d}}),\quad \mathrm{PE}(pos,2i+1)=\cos(\frac{pos}{1000^{2i/d}})PE(pos,2i)=sin(10002i/dpos),PE(pos,2i+1)=cos(10002i/dpos)其中PE\mathrm{PE}PE表示的是位置编码向量,pospospos表示词在句子中的位置,iii表示编码向量的位置索引。
- 输入矩阵I+PI+PI+P通过线性变换生成矩阵QQQ,KKK,VVV。在实际编程中是将输入I+PI+PI+P直接赋值给QQQ,KKK,VVV。如果输入单词长度小于最大长度并000来填充的时候,还要相应引入Mask\mathrm{Mask}Mask矩阵。
- 将矩阵QQQ,KKK,VVV输入到Multi\mathrm{Multi}Multi-HeadAttention\mathrm{Head\text{ }Attention}HeadAttention模块中进行注意分布的计算得到矩阵I′∈Rn×l×dI^{\prime}\in \mathbb{R}^{n \times l \times d}I′∈Rn×l×d,计算公式为I′=MultiHead(Q,K,V)I^{\prime}=\mathrm{MultiHead}(Q,K,V)I′=MultiHead(Q,K,V)具体的计算细节参考上文关于Multi\mathrm{Multi}Multi-HeadAttention\mathrm{Head\text{ }Attention}HeadAttention原理的讲解不在这里赘述。然后将原始输入I+PI+PI+P与注意力分布I′I^{\prime}I′进行残差计算得到输出矩阵I+P+I′∈Rn×l×dI+P+I^{\prime}\in \mathbb{R}^{n \times l \times d}I+P+I′∈Rn×l×d。
- 对矩阵I+P+I′={xijk}nldI+P+I^{\prime}=\{x_{ijk}\}^{nld}I+P+I′={xijk}nld进行层归一化操作得到I′′∈Rn×l×dI^{\prime\prime}\in\mathbb{R}^{n \times l \times d}I′′∈Rn×l×d,具体的计算公式为{μij=∑k=1dxijkσij=∑k=1d(xijk−μij)2⟹x^ijk=xijk−uijσij,i∈{1,⋯,n},j∈{1,⋯,l},k∈{1,⋯,d}\left\{\begin{aligned}\mu^{ij}&=\sum\limits_{k=1}^d x_{ijk}\\\sigma^{ij}&=\sqrt{\sum\limits_{k=1}^d\left(x_{ijk}-\mu^{ij}\right)^2}\end{aligned}\right. \Longrightarrow \hat{x}_{ijk}=\frac{x_{ijk}-u^{ij}}{\sigma^{ij}},\quad i\in\{1,\cdots,n\},j\in\{1,\cdots,l\},k\in\{1,\cdots,d\}⎩⎨⎧μijσij=k=1∑dxijk=k=1∑d(xijk−μij)2⟹x^ijk=σijxijk−uij,i∈{1,⋯,n},j∈{1,⋯,l},k∈{1,⋯,d}
- 将I′′I^{\prime\prime}I′′输入到全连接神经网络中得到I′′′∈Rn×l×dI^{\prime\prime\prime}\in \mathbb{R}^{n \times l \times d}I′′′∈Rn×l×d ,然后再让全连接神经网络的输入I′′I^{\prime\prime}I′′与输出I′′′I^{\prime\prime\prime}I′′′进行残差计算得到I′′+I′′′I^{\prime\prime}+I^{\prime\prime\prime}I′′+I′′′,接着对I′′+I′′′I^{\prime\prime}+I^{\prime\prime\prime}I′′+I′′′进行层归一化操作。
- 以上是一个Block\mathrm{Block}Block的操作原理,将NNN个Block\mathrm{Block}Block进行堆叠就组成了Encoder\mathrm{Encoder}Encoder的模块,得到的最后输出为IN∈Rn×l×dI^N \in \mathbb{R}^{n \times l \times d}IN∈Rn×l×d。这里需要注意的是Encoder\mathrm{Encoder}Encoder模块中的各个组件的操作顺序并不是固定的,也可以先进行归一化操作,然后再计算注意力分布,再归一化,再预测等。
Decoder\mathrm{Decoder}Decoder模块操作的具体流程如下所示:
- Decoder\mathrm{Decoder}Decoder的输入也由两部分组成分别是词编码矩阵O∈Rn1×l1×dO \in \mathbb{R}^{n_1 \times l_1 \times d}O∈Rn1×l1×d和位置编码矩阵PO∈Rn1×l1×dP^O \in \mathbb{R}^{n_1 \times l_1 \times d}PO∈Rn1×l1×d。因为Decoder\mathrm{Decoder}Decoder的输入是具有时顺序关系的(即上一步的输出为当前步输入)所以还需要输入Mask\mathrm{Mask}Mask矩阵MMM以便计算注意力分布。
- 输入矩阵O+POO+P^OO+PO通过线性变换生成矩阵Q^\hat{Q}Q^,K^\hat{K}K^,V^\hat{V}V^。在实际编程中是将输入O+POO+P^OO+PO直接赋值给Q^\hat{Q}Q^,K^\hat{K}K^,V^\hat{V}V^。如果输入单词长度小于最大长度并000来填充的时候,还要相应引入Mask\mathrm{Mask}Mask矩阵。
- 将矩阵Q^\hat{Q}Q^,K^\hat{K}K^,V^\hat{V}V^以及Mask\mathrm{Mask}Mask矩阵MMM输入到MaskMulti\mathrm{Mask\text{ }Multi}MaskMulti-HeadAttention\mathrm{Head\text{ }Attention}HeadAttention模块中进行注意分布的计算得到矩阵O′∈Rn1×l1×dO^{\prime}\in \mathbb{R}^{n_1 \times l_1 \times d}O′∈Rn1×l1×d,计算公式为O′=MaskMultiHead(Q^,K^,V^,M)O^{\prime}=\mathrm{MaskMultiHead}(\hat{Q},\hat{K},\hat{V},M)O′=MaskMultiHead(Q^,K^,V^,M)具体的计算细节参考上文关于MaskSelf\mathrm{Mask \text{ }Self}MaskSelf-Attention\mathrm{Attention}Attention的讲解不在这里赘述。然后将原始输入O+POO+P^OO+PO与注意力分布O′O^{\prime}O′进行残差计算得到输出矩阵O+PO+O′∈Rn1×l1×dO+P^O+O^{\prime}\in \mathbb{R}^{n_1 \times l_1 \times d}O+PO+O′∈Rn1×l1×d。接着再对矩阵O+PO+O′O+P^O+O^{\prime}O+PO+O′进行层归一化操作得到O′′∈Rn1×l1×dO^{\prime\prime}\in\mathbb{R}^{n_1 \times l_1 \times d}O′′∈Rn1×l1×d。
- Encoder\mathrm{Encoder}Encoder的输出INI^NIN通过线性变换得到QNQ^NQN和KNK^NKN,O′O^{\prime}O′进行线性变换得到V^′\hat{V}^{\prime}V^′,利用矩阵QNQ^NQN和KNK^NKN和V^′\hat{V}^{\prime}V^′进行交叉注意力分布的计算得到O′′′O^{\prime\prime\prime}O′′′,计算公式为O′′′=MultiHead(QN,KN,V^′)O^{\prime\prime\prime}=\mathrm{MultiHead}(Q^N,K^N,\hat{V}^{\prime})O′′′=MultiHead(QN,KN,V^′)这里的交叉注意力分布综合Encoder\mathrm{Encoder}Encoder输出结果和Decoder\mathrm{Decoder}Decoder中间结果的信息。实际编程编程中将INI^NIN直接赋值给Q^\hat{Q}Q^和K^\hat{K}K^,O′O^{\prime}O′直接赋值给V^′\hat{V}^{\prime}V^′。然后将O′′O^{\prime\prime}O′′与注意力分布O′′′O^{\prime\prime\prime}O′′′进行残差计算得到输出矩阵O′′+O′′′O^{\prime\prime}+O^{\prime\prime\prime}O′′+O′′′。
- 接着对O′′+O′′′O^{\prime\prime}+O^{\prime\prime\prime}O′′+O′′′进行层归一操作得到O′′′′O^{\prime\prime\prime\prime}O′′′′,再将O′′′′O^{\prime\prime\prime\prime}O′′′′输入到全连接神经网络中得到O′′′′′O^{\prime\prime\prime\prime\prime}O′′′′′,接着再做一步残差操作得到O′′′′+O′′′′′O^{\prime\prime\prime\prime}+O^{\prime\prime\prime\prime\prime}O′′′′+O′′′′′,最后再进行一层归一化操作。
- 以上是一个Block\mathrm{Block}Block的操作原理,将NNN个Block\mathrm{Block}Block进行堆叠就组成了Decoder\mathrm{Decoder}Decoder的模块,得到的输出为ON∈Rn1×l1×dO^N \in \mathbb{R}^{n_1 \times l_1 \times d}ON∈Rn1×l1×d。然后在词汇字典中找到当前预测最大概率的单词,并将该单词词向量作为下一阶段的输入,重复以上步骤,直到输出“end\mathrm{end}end”字符为止。
代码示例
Transformer\mathrm{Transformer}Transformer具体的代码示例如下所示为一个国外博主视频里的代码,并根据上文对代码的一些细节进行了探讨。根据上文中Multi\mathrm{Multi}Multi-HeadAttention\mathrm{Head\text{ }Attention}HeadAttention原理示例图可知,严格来看Multi\mathrm{Multi}Multi-HeadAttention\mathrm{Head\text{ }Attention}HeadAttention在求注意分布的时候中间其实是有两步线性变换。给定输入向量x∈R256×1x\in \mathbb{R}^{256\times 1}x∈R256×1 第一步线性变换直接让向量xxx赋值给qqq,kkk,vvv,这一过程以下程序中有所体现,在这里并不会产生歧义。第二步线性变换产生多Head\mathrm{Head}Head,假设Head=8\mathrm{Head}=8Head=8的时候,按理说qqq要与888个矩阵Wq1,⋯,Wq8W^{q1},\cdots,W^{q8}Wq1,⋯,Wq8进行线性变换得到888个q1,⋯,q8q^{1},\cdots,q^{8}q1,⋯,q8,同理kkk要与888个矩阵Wk1,⋯,Wk8W^{k1},\cdots,W^{k8}Wk1,⋯,Wk8进行线性变换得到888个k1,⋯,k8k^{1},\cdots,k^{8}k1,⋯,k8,vvv要与888个矩阵Wv1,⋯,Wv8W^{v1},\cdots,W^{v8}Wv1,⋯,Wv8进行线性变换得到888个v1,⋯,v8v^{1},\cdots,v^{8}v1,⋯,v8,如果按照这个方式在程序实现则需要定义24个权重矩阵,非常的麻烦。以下程序中有一个简单的权重定义方法,通过该方法也可以实现以上多Head\mathrm{Head}Head的线性变换,以向量q=(q1,⋯,q256)⊤∈R256×1q = (q_1,\cdots, q_{256})^{\top}\in \mathbb{R}^{256 \times 1}q=(q1,⋯,q256)⊤∈R256×1为例:
- 首先将向量qqq进行截断分成Head=8\mathrm{Head}=8Head=8个向量,即为{q(1)=(E,0,0,0,0,0,0,0)⋅qq(2)=(0,E,0,0,0,0,0,0)⋅qq(3)=(0,0,E,0,0,0,0,0)⋅qq(4)=(0,0,0,E,0,0,0,0)⋅qq(5)=(0,0,0,0,E,0,0,0)⋅qq(6)=(0,0,0,0,0,E,0,0)⋅qq(7)=(0,0,0,0,0,0,E,0)⋅qq(8)=(0,0,0,0,0,0,0,E)⋅q\left\{\begin{aligned}q^{(1)}&=({\bf{E},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})\cdot q\\q^{(2)}&=({\bf{0},\bf{E},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})\cdot q\\q^{(3)}&=({\bf{0},\bf{0},\bf{E},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})\cdot q\\q^{(4)}&=({\bf{0},\bf{0},\bf{0},\bf{E},\bf{0},\bf{0},\bf{0},\bf{0}})\cdot q\\q^{(5)}&=({\bf{0},\bf{0},\bf{0},\bf{0},\bf{E},\bf{0},\bf{0},\bf{0}})\cdot q\\q^{(6)}&=({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{E},\bf{0},\bf{0}})\cdot q\\q^{(7)}&=({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{E},\bf{0}})\cdot q\\q^{(8)}&=({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{E}})\cdot q \end{aligned}\right.⎩⎨⎧q(1)q(2)q(3)q(4)q(5)q(6)q(7)q(8)=(E,0,0,0,0,0,0,0)⋅q=(0,E,0,0,0,0,0,0)⋅q=(0,0,E,0,0,0,0,0)⋅q=(0,0,0,E,0,0,0,0)⋅q=(0,0,0,0,E,0,0,0)⋅q=(0,0,0,0,0,E,0,0)⋅q=(0,0,0,0,0,0,E,0)⋅q=(0,0,0,0,0,0,0,E)⋅q其中q(i)∈R32×1q^{(i)}\in \mathbb{R}^{32\times 1}q(i)∈R32×1是qqq的第iii个截断向量,E∈R32×32{\bf{E}}\in \mathbb{R}^{32 \times 32}E∈R32×32是单位矩阵,0∈R32×32{\bf{0}}\in \mathbb{R}^{32 \times 32}0∈R32×32是零矩阵。
- 然后对q(i),i∈{1,⋯,8}q^{(i)},i\in \{1,\cdots,8\}q(i),i∈{1,⋯,8}用相同的权重矩阵W∈R32×32W \in \mathbb{R}^{32 \times 32}W∈R32×32进行线性变换,此时可以发现,训练过程的时候只需要更新权重矩阵WWW即可,而且可以进行多Head\mathrm{Head}Head线性变换,888个权重矩阵可以表示为:{Wq1=W⋅(E,0,0,0,0,0,0,0)=(W,0,0,0,0,0,0,0)Wq2=W⋅(0,E,0,0,0,0,0,0)=(0,W,0,0,0,0,0,0)Wq3=W⋅(0,0,E,0,0,0,0,0)=(0,0,W,0,0,0,0,0)Wq4=W⋅(0,0,0,E,0,0,0,0)=(0,0,0,W,0,0,0,0)Wq5=W⋅(0,0,0,0,E,0,0,0)=(0,0,0,0,W,0,0,0)Wq6=W⋅(0,0,0,0,0,E,0,0)=(0,0,0,0,0,W,0,0)Wq7=W⋅(0,0,0,0,0,0,E,0)=(0,0,0,0,0,0,W,0)Wq8=W⋅(0,0,0,0,0,0,0,E)=(0,0,0,0,0,0,0,W)\left\{\begin{aligned}W^{q1}&=W\cdot ({\bf{E},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})=(W,{\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})\\W^{q2}&=W\cdot ({\bf{0},\bf{E},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})=({\bf{0},}W{,\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})\\W^{q3}&=W\cdot ({\bf{0},\bf{0},\bf{E},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})=({\bf{0},\bf{0},}W{,\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})\\W^{q4}&=W\cdot ({\bf{0},\bf{0},\bf{0},\bf{E},\bf{0},\bf{0},\bf{0},\bf{0}})=({\bf{0},\bf{0},\bf{0},}W{,\bf{0},\bf{0},\bf{0},\bf{0}})\\W^{q5}&=W\cdot ({\bf{0},\bf{0},\bf{0},\bf{0},\bf{E},\bf{0},\bf{0},\bf{0}})=({\bf{0},\bf{0},\bf{0},\bf{0},}W{,\bf{0},\bf{0},\bf{0}})\\W^{q6}&=W\cdot ({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{E},\bf{0},\bf{0}})=({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},}W{,\bf{0},\bf{0}})\\W^{q7}&=W\cdot ({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{E},\bf{0}})=({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},}W{,\bf{0}})\\W^{q8}&=W\cdot ({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{E}})=({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},}W{})\end{aligned}\right.⎩⎨⎧Wq1Wq2Wq3Wq4Wq5Wq6Wq7Wq8=W⋅(E,0,0,0,0,0,0,0)=(W,0,0,0,0,0,0,0)=W⋅(0,E,0,0,0,0,0,0)=(0,W,0,0,0,0,0,0)=W⋅(0,0,E,0,0,0,0,0)=(0,0,W,0,0,0,0,0)=W⋅(0,0,0,E,0,0,0,0)=(0,0,0,W,0,0,0,0)=W⋅(0,0,0,0,E,0,0,0)=(0,0,0,0,W,0,0,0)=W⋅(0,0,0,0,0,E,0,0)=(0,0,0,0,0,W,0,0)=W⋅(0,0,0,0,0,0,E,0)=(0,0,0,0,0,0,W,0)=W⋅(0,0,0,0,0,0,0,E)=(0,0,0,0,0,0,0,W)其中权重矩阵Wqi∈R32×256,i∈{1,⋯,8}W^{qi}\in\mathbb{R}^{32 \times 256},i\in\{1,\cdots,8\}Wqi∈R32×256,i∈{1,⋯,8}。
import torch
import torch.nn as nn
import osclass SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)def forward(self, values, keys, query, mask):N =query.shape[0]value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]# split embedding into self.heads piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)# queries shape: (N, query_len, heads, heads_dim)# keys shape : (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, heads_dim)# (N, query_len, heads, head_dim)out = self.fc_out(out)return outclass TransformerBlock(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion*embed_size),nn.ReLU(),nn.Linear(forward_expansion*embed_size, embed_size))self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, mask):attention = self.attention(value, key, query, mask)x = self.dropout(self.norm1(attention + query))forward = self.feed_forward(x)out = self.dropout(self.norm2(forward + x))return outclass Encoder(nn.Module):def __init__(self,src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length,):super(Encoder, self).__init__()self.embed_size = embed_sizeself.device = deviceself.word_embedding = nn.Embedding(src_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion,)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x, mask):N, seq_length = x.shapepositions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))for layer in self.layers:out = layer(out, out, out, mask)return outclass DecoderBlock(nn.Module):def __init__(self, embed_size, heads, forward_expansion, dropout, device):super(DecoderBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm = nn.LayerNorm(embed_size)self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)self.dropout = nn.Dropout(dropout)def forward(self, x, value, key, src_mask, trg_mask):attention = self.attention(x, x, x, trg_mask)query = self.dropout(self.norm(attention + x))out = self.transformer_block(value, key, query, src_mask)return outclass Decoder(nn.Module):def __init__(self,trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length,):super(Decoder, self).__init__()self.device = deviceself.word_embedding = nn.Embedding(trg_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([DecoderBlock(embed_size, heads, forward_expansion, dropout, device)for _ in range(num_layers)])self.fc_out = nn.Linear(embed_size, trg_vocab_size)self.dropout = nn.Dropout(dropout)def forward(self, x ,enc_out , src_mask, trg_mask):N, seq_length = x.shapepositions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))for layer in self.layers:x = layer(x, enc_out, enc_out, src_mask, trg_mask)out =self.fc_out(x)return outclass Transformer(nn.Module):def __init__(self,src_vocab_size,trg_vocab_size,src_pad_idx,trg_pad_idx,embed_size = 256,num_layers = 6,forward_expansion = 4,heads = 8,dropout = 0,device="cuda",max_length=100):super(Transformer, self).__init__()self.encoder = Encoder(src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length)self.decoder = Decoder(trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length)self.src_pad_idx = src_pad_idxself.trg_pad_idx = trg_pad_idxself.device = devicedef make_src_mask(self, src):src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)# (N, 1, 1, src_len)return src_mask.to(self.device)def make_trg_mask(self, trg):N, trg_len = trg.shapetrg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)return trg_mask.to(self.device)def forward(self, src, trg):src_mask = self.make_src_mask(src)trg_mask = self.make_trg_mask(trg)enc_src = self.encoder(src, src_mask)out = self.decoder(trg, enc_src, src_mask, trg_mask)return outif __name__ == '__main__':device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(device)x = torch.tensor([[1,5,6,4,3,9,5,2,0],[1,8,7,3,4,5,6,7,2]]).to(device)trg = torch.tensor([[1,7,4,3,5,9,2,0],[1,5,6,2,4,7,6,2]]).to(device)src_pad_idx = 0trg_pad_idx = 0src_vocab_size = 10trg_vocab_size = 10model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)out = model(x, trg[:, : -1])print(out.shape)
Transformer详解(附代码)相关推荐
- 动态规划---01背包问题--Dp(详解附代码)
一.动态规划 代表一类问题(最优子结构或子问题最优性)的一般解法,是设计方法或者策略,不是具体算法 本质:递推,核心是找到状态转移的方式,写出dp方程. 解决问题:交叉,重叠子问题(最优子问题) 形式 ...
- 【排序】堆排序详解 附代码
按照国际惯例,开篇前先简单介绍(吹一波)堆排序(Heapsort).Heapsort是一种优秀的排序算法(个人感觉基本排序算法中仅次于快速排序),时间复杂度为O(nlgn),同时,Heapsort具有 ...
- 各种进制转换(二,八,十,十六进制间转换)详解附代码
进制转换 原理 进制转换是人们利用符号来计数的方法.进制转换由一组数码符号和两个基本因素"基数"与"位权"构成. 基数是指,进位计数制中所采用的数码(数制中用来 ...
- 前序遍历、中序遍历、后序遍历层序遍历详解附代码(数据结构C语言)
目录 (1)前序遍历 (DLR) 递归算法 (2)中序遍历 (LDR) 递归算法 (3)后序遍历 (LRD) 递归算法 (4)层序遍历 队列实现方法 层序遍历的定义: 实现方法: 代码实现 结果截图 ...
- Numpy学习笔记(二):argmax参数中axis=0,axis=1,axis=-1详解附代码
文章目录 1.argmax和max函数区别 2.axis=0/axis=1/axis=-1的区别 3.具体代码分析 ---3.1一维数组 ---3.2二维数组 ---3.3三维数组 1.argmax和 ...
- 随机分布嵌入(RDE)框架详解附代码
介绍 研究了好一阵子马欢飞老师在PNAS上发的文章,下面附上个人的研究心得与代码与大家讨论. 在基于非线性系统的理论基础上,延迟嵌入理论以及广义嵌入理论等相空间重构的理论基础上,观察者便有可能从一个观 ...
- c++实现贪吃蛇详解(附代码)
文章目录 前言 一.运行界面 二.类的大致抽象 三.关于一些问题的思考 四.最后一些想说的 五.代码 前言 经过一个多月的学习,又加深了对c++的理解,所以接下来,就和大家分享一下,一个月学习c++的 ...
- 【大道至简】机器学习算法之EM算法(Expectation Maximization Algorithm)详解(附代码)---通俗理解EM算法。
☕️ 本文来自专栏:大道至简之机器学习系列专栏
- 目标检测模型的评估指标mAP详解(附代码)
https://zhuanlan.zhihu.com/p/37910324 对于使用机器学习解决的大多数常见问题,通常有多种可用的模型.每个模型都有自己的独特之处,并随因素变化而表现不同. 每个模型在 ...
- Transformer 详解(上) — 编码器【附pytorch代码实现】
Transformer 详解(上)编码器 Transformer结构 文本嵌入层 位置编码 注意力机制 编码器之多头注意力机制层 编码器之前馈全连接层 规范化层和残差连接 代码实现Transforme ...
最新文章
- swift笔记——环境搭建及Hello,Swift!
- 背景图片等比缩放的写法background-size简写法
- (20)PDE_PTE属性(U/S PS A D 有效位)
- Java学习小程序(10)三个等级的才字母游戏
- c++面向对象高级编程 学习七 转换函数
- python后台框架_我的第一个python web开发框架(14)——后台管理系统登录功能
- Dubbo面试 - Dubbo通信协议
- display: inline-block;水平居中
- ElasticSearch核心基础之索引管理
- php置顶文章,zblogphp不同情况置顶文章调用方法
- [转载] AUML——FIPA Modeling Technical Committee
- MATLAB读取图片时报错:“错误使用 fopen 找不到文件,确保文件存在且路径” 的原因及解决方法
- Jenkins Pipeline 手记(1)—— 什么是CPS编程
- virtualenvs error: deactivate must be sourced. Run 'source deactivate' instead of 'deactivate'
- 块存储、文件存储、对象存储三者的区别
- 《UNIX/LINUX系统管理I》课程学习总结
- Python视频制作 MoviePy框架afx音频效果示例
- 【Python打卡2019】20190406之货币兑换
- 阿里云免费SSL证书配置(图文详解)
- 动画效果html5,HTML5动画效果
热门文章
- 王文彬(淘宝网首席架构师)等关注探讨的问题
- 二十九、进阶之项目数据请求
- 计算机信息安全攻防大赛,2018年度信息安全攻防大赛圆满收官
- 杨致远:雅虎的华裔酋长(附图)
- PL/SQL编程基础(五):异常处理(EXCEPTION)
- PHP+mysql 入门级通讯录(一)
- JavaScript shells
- java并发编程源码世界大师_求咕泡学院Java架构师第三期的完整版资料源码+视频,注(完整无解压密码)...
- 03 | 论文中的「文献综述」应该怎么写?
- Android开发以来所记载最全的有关项目的网址