一、前言

问:transformer模型的众多派生BERT,RoBERTa,ALBERT,SpanBERT,DistilBERT,SesameBERT,SemBERT,SciBERT,BioBERT,MobileBERT,TinyBERT和CamemBERT有什么共同点?

答:Self-attention//Transformer架构

使用Transformer架构对NLP任务建模,避免使用递归神经网络,完全依赖Self-Attention机制绘制输入和输出之间的全局依赖关系。

本文要:

  1. 探究Self-Attention机制背后的数学原理
  2. 引导完成Self-Attention模块中涉及的数学计算
  3. 从头带领编写Self-Attention模块代码(pytorch)

二、自注意力机制(Self-Attention)

一个self-attention模块输入为 n,输出也为 n.那么在这个模块内部发生了什么?用门外汉的术语来说,self-attention机制允许输入彼此之间进行交互(“self”)并找出它们应该更多关注的区域(“Attention”)。输出是这些交互作用和注意力得分的总和。

v2-058a9fc42c50efa1a62a2fe124fb8dca_b.jpg

在self-attention中,每个单词有3个不同的向量,它们分别是Query向量(

),Key向量(
)和Value向量(
),长度均是64。它们是通过3个不同的权值矩阵由嵌入向量
乘以三个不同的权值矩阵
得到,其中三个矩阵的尺寸也是相同的。均是
。(Transformer中使用的词嵌入的维度为
。)

slef: 自己和自己计算相似度函数,然后进一步进行关注对吧。

计算过程:

假如我们要翻译一个词组Thinking Machines,其中Thinking的输入的embedding vector用

表示,Machines的embedding vector用
表示,以
举例,整个过程可以分成7步:
  1. 输入单词转化成嵌入向量;
  2. 根据嵌入向量得到
    三个向量;
  3. 为每个向量计算自注意力得分,分数决定当我们在某个位置对单词进行编码时,要在输入句子的其他部分上投入多少注意力:
  4. 为了梯度的稳定,对计算的分数进行 Scale,即除以
    ,原因是如果点乘结果过大,使得经过 softmax 之后的梯度很小,不利于反向传播
  5. 对score施以softmax激活函数,归一化
  6. softmax乘Value值
    (每个单词的value),得到加权的每个输入向量的评分
  7. 相加之后得到最终的输出结果
v2-22f4c75f7f79b9204640e42ca705a0d4_b.jpg
q, k, v 的自我理解:q:当前词作为搜索词,此时对应query_vector k: 句子中的所有词(包括query,去和query匹配搜寻相关度,此时对应key_vector。故qk相乘可以决定在句子每个单词上投入多少注意力。v: 句子中每个词自身的价值value,将求得的注意力得分与v相乘得到最终每个单词的得分。

矩阵计算:

实际实现时采用的是基于矩阵的计算方式

v2-41b14c93fed52b579e775ba6d7483af7_b.jpg

三、实例演示

步骤(忽略了二中的第四步):

  1. 准备输入x
  2. 初始化 K, Q, V的权重矩阵
  3. x与K, Q, V相乘得到key, query, value的表示
  4. 计算x的注意力得分(k, v),即求得每个单词的权重weight
  5. softmax
  6. weight分别乘value中的每一行,得到的是对应的加权矩阵
  7. 矩阵按列相加得到分数矩阵Zx
  8. 将每个输入的分数按列排列得到最终的输出Z

手动计算过程:

参考:https://mp.weixin.qq.com/s/xLI0yY1hAlOZ1c01SexA1A

四、代码实现(pytorch)

具体步骤同三

(矩阵中点乘和乘法不同,具体见参考3)


参考:

动手推导Self-Attention​mp.weixin.qq.com

v2-69bac2b7865fadb993cc3afffde8153c_180x120.jpg

自定义:Transformer详解​zhuanlan.zhihu.com

[Python] numpy中运算符* @ mutiply dot的用法分析​blog.csdn.net

v2-2a5027b5bff83f50a189c6146b4f7548_ipico.jpg

pytorch实现attention_Self-Attention手动推导及实现相关推荐

  1. 手动推导计算AES中的s盒的输出

    手动推导计算AES中的s盒的输出 初衷 为了解决一道密码学课后作业: 在AES中,对于字节 "00" 和 "01" 计算S盒的输出. 百度查了很久,很多都是浅尝 ...

  2. 基于PyTorch实现Seq2Seq + Attention的英汉Neural Machine Translation

    NMT(Neural Machine Translation)基于神经网络的机器翻译模型效果越来越好,还记得大学时代Google翻译效果还是差强人意,近些年来使用NMT后已基本能满足非特殊需求了.目前 ...

  3. 哈夫曼编码原理与Python实现代码(附手动推导过程原稿真迹)

    哈夫曼编码依据字符出现概率来构造异字头(任何一个字符的编码都不是其他字符的前缀)的平均长度最短的码字,通过构造二叉树来实现,出现频次越多的字符编码越短,出现频次越少的字符编码越长.为了演示哈夫曼编码原 ...

  4. BP神经网络反向传播手动推导

    BP神经网络过程: 基本思想 BP算法是一个迭代算法,它的基本思想如下: 将训练集数据输入到神经网络的输入层,经过隐藏层,最后达到输出层并输出结果,这就是前向传播过程. 由于神经网络的输出结果与实际结 ...

  5. 图片化加手动推导深刻记忆冒泡排序全过程

    冒泡排序是把最(大/小)值数往后一直"浮动",直到序列全部浮动完成. 时间复杂度:最好情况是O(n),最坏情况和平均情况是O(n2) 空间复杂度:O(1) #!/usr/bin/e ...

  6. PyTorch 笔记(13)— autograd(0.4 之前和之后版本差异)、Tensor(张量)、Gradient(梯度)

    1. 背景简述 torch.autograd 是 PyTorch 中方便用户使用,专门开发的一套自动求导引擎,它能够根据输入和前向传播过程自动构建计算图,并执行反向传播. 计算图是现代深度学习框架 P ...

  7. PyTorch 的 Autograd详解

    ↑ 点击蓝字 关注视学算法 作者丨xiaopl@知乎 来源丨https://zhuanlan.zhihu.com/p/69294347 编辑丨极市平台 PyTorch 作为一个深度学习平台,在深度学习 ...

  8. 一文详解pytorch的“动态图”与“自动微分”技术

    前言 众所周知,Pytorch是一个非常流行且深受好评的深度学习训练框架.这与它的两大特性"动态图"."自动微分"有非常大的关系."动态图" ...

  9. Pytorch autograd.grad与autograd.backward详解

    Pytorch autograd.grad与autograd.backward详解 引言 平时在写 Pytorch 训练脚本时,都是下面这种无脑按步骤走: outputs = model(inputs ...

最新文章

  1. “1天一朵云”,这是如何做到的?
  2. linux 时间戳计数器,使用TSC(时间戳计数器)计算时间
  3. 比特币锚定币总锁仓量触及38亿美元 创历史新高
  4. 【TSP】基于matlab GUI遗传算法求解旅行商问题【含Matlab源码 899期】
  5. 一个基本开发框架的整合演化之路--9、整合文件服务器fastdfs
  6. 用matlab读pcap文件,libpcap读取本地pcap文件
  7. CMMB 全国各地市频率规划表
  8. linux定时关闭系统at,『学了就忘』Linux系统管理 — 8.系统定时任务(at命令)
  9. 李沐动手学深度学习V2-BERT预训练和代码实现
  10. 腾讯优图实验室贾佳亚:加入优图第一年 | 专访
  11. 什么是云原生?这回终于有人讲明白了
  12. iphone11右上角信号显示_原来,iOS 11移动信号图标里竟然隐藏着 iPhone 8的巨大秘密...
  13. 分子对接(docking):蛋白质-蛋白质分子对接
  14. SSM框架将数据库数据导出为Excel文件
  15. 网吧服务器用户锁怎么办,如何进行网吧安全模式的锁定与解除
  16. 江苏省高中计算机课程标准,江苏省普通高中课程设置
  17. POI打印-----文件下载
  18. 【医学图像处理】1 (医学)图像及图像处理流程
  19. 数据可视化,是如何扭曲我们对现实的感知?
  20. word文档怎么给数字加千分符_word实用技巧:如何为数字添加千分符的3种方法

热门文章

  1. Latex个人常用清单--不断更新
  2. Matlab之程序的暂停与中止
  3. async与defer
  4. win10+anaconda3+python3.6+opencv3.1.0
  5. SpringMVC框架使用注解执行定时任务(转)
  6. vs2005/vs2008 快捷键【转】
  7. .NET中三种获取当前路径的代码
  8. ie下的firebug
  9. echo, print, printf 和 sprintf 区别(PHP)
  10. 圣诞节PPT模板制作技巧分析