pytorch实现attention_Self-Attention手动推导及实现
一、前言
问:transformer模型的众多派生BERT,RoBERTa,ALBERT,SpanBERT,DistilBERT,SesameBERT,SemBERT,SciBERT,BioBERT,MobileBERT,TinyBERT和CamemBERT有什么共同点?
答:Self-attention//Transformer架构
使用Transformer架构对NLP任务建模,避免使用递归神经网络,完全依赖Self-Attention机制绘制输入和输出之间的全局依赖关系。
本文要:
- 探究Self-Attention机制背后的数学原理
- 引导完成Self-Attention模块中涉及的数学计算
- 从头带领编写Self-Attention模块代码(pytorch)
二、自注意力机制(Self-Attention)
一个self-attention模块输入为 n,输出也为 n.那么在这个模块内部发生了什么?用门外汉的术语来说,self-attention机制允许输入彼此之间进行交互(“self”)并找出它们应该更多关注的区域(“Attention”)。输出是这些交互作用和注意力得分的总和。
在self-attention中,每个单词有3个不同的向量,它们分别是Query向量(
slef: 自己和自己计算相似度函数,然后进一步进行关注对吧。
计算过程:
假如我们要翻译一个词组Thinking Machines,其中Thinking的输入的embedding vector用
- 输入单词转化成嵌入向量;
- 根据嵌入向量得到
,,三个向量;
- 为每个向量计算自注意力得分,分数决定当我们在某个位置对单词进行编码时,要在输入句子的其他部分上投入多少注意力:
;
- 为了梯度的稳定,对计算的分数进行 Scale,即除以
,原因是如果点乘结果过大,使得经过 softmax 之后的梯度很小,不利于反向传播
- 对score施以softmax激活函数,归一化;
- softmax乘Value值
(每个单词的value),得到加权的每个输入向量的评分;
- 相加之后得到最终的输出结果
:。
矩阵计算:
实际实现时采用的是基于矩阵的计算方式
三、实例演示
步骤(忽略了二中的第四步):
- 准备输入x
- 初始化 K, Q, V的权重矩阵
- x与K, Q, V相乘得到key, query, value的表示
- 计算x的注意力得分(k, v),即求得每个单词的权重weight
- softmax
- weight分别乘value中的每一行,得到的是对应的加权矩阵
- 矩阵按列相加得到分数矩阵Zx
- 将每个输入的分数按列排列得到最终的输出Z
手动计算过程:
参考:https://mp.weixin.qq.com/s/xLI0yY1hAlOZ1c01SexA1A
四、代码实现(pytorch)
具体步骤同三
(矩阵中点乘和乘法不同,具体见参考3)
参考:
动手推导Self-Attentionmp.weixin.qq.com
自定义:Transformer详解zhuanlan.zhihu.com
[Python] numpy中运算符* @ mutiply dot的用法分析blog.csdn.net
pytorch实现attention_Self-Attention手动推导及实现相关推荐
- 手动推导计算AES中的s盒的输出
手动推导计算AES中的s盒的输出 初衷 为了解决一道密码学课后作业: 在AES中,对于字节 "00" 和 "01" 计算S盒的输出. 百度查了很久,很多都是浅尝 ...
- 基于PyTorch实现Seq2Seq + Attention的英汉Neural Machine Translation
NMT(Neural Machine Translation)基于神经网络的机器翻译模型效果越来越好,还记得大学时代Google翻译效果还是差强人意,近些年来使用NMT后已基本能满足非特殊需求了.目前 ...
- 哈夫曼编码原理与Python实现代码(附手动推导过程原稿真迹)
哈夫曼编码依据字符出现概率来构造异字头(任何一个字符的编码都不是其他字符的前缀)的平均长度最短的码字,通过构造二叉树来实现,出现频次越多的字符编码越短,出现频次越少的字符编码越长.为了演示哈夫曼编码原 ...
- BP神经网络反向传播手动推导
BP神经网络过程: 基本思想 BP算法是一个迭代算法,它的基本思想如下: 将训练集数据输入到神经网络的输入层,经过隐藏层,最后达到输出层并输出结果,这就是前向传播过程. 由于神经网络的输出结果与实际结 ...
- 图片化加手动推导深刻记忆冒泡排序全过程
冒泡排序是把最(大/小)值数往后一直"浮动",直到序列全部浮动完成. 时间复杂度:最好情况是O(n),最坏情况和平均情况是O(n2) 空间复杂度:O(1) #!/usr/bin/e ...
- PyTorch 笔记(13)— autograd(0.4 之前和之后版本差异)、Tensor(张量)、Gradient(梯度)
1. 背景简述 torch.autograd 是 PyTorch 中方便用户使用,专门开发的一套自动求导引擎,它能够根据输入和前向传播过程自动构建计算图,并执行反向传播. 计算图是现代深度学习框架 P ...
- PyTorch 的 Autograd详解
↑ 点击蓝字 关注视学算法 作者丨xiaopl@知乎 来源丨https://zhuanlan.zhihu.com/p/69294347 编辑丨极市平台 PyTorch 作为一个深度学习平台,在深度学习 ...
- 一文详解pytorch的“动态图”与“自动微分”技术
前言 众所周知,Pytorch是一个非常流行且深受好评的深度学习训练框架.这与它的两大特性"动态图"."自动微分"有非常大的关系."动态图" ...
- Pytorch autograd.grad与autograd.backward详解
Pytorch autograd.grad与autograd.backward详解 引言 平时在写 Pytorch 训练脚本时,都是下面这种无脑按步骤走: outputs = model(inputs ...
最新文章
- “1天一朵云”,这是如何做到的?
- linux 时间戳计数器,使用TSC(时间戳计数器)计算时间
- 比特币锚定币总锁仓量触及38亿美元 创历史新高
- 【TSP】基于matlab GUI遗传算法求解旅行商问题【含Matlab源码 899期】
- 一个基本开发框架的整合演化之路--9、整合文件服务器fastdfs
- 用matlab读pcap文件,libpcap读取本地pcap文件
- CMMB 全国各地市频率规划表
- linux定时关闭系统at,『学了就忘』Linux系统管理 — 8.系统定时任务(at命令)
- 李沐动手学深度学习V2-BERT预训练和代码实现
- 腾讯优图实验室贾佳亚:加入优图第一年 | 专访
- 什么是云原生?这回终于有人讲明白了
- iphone11右上角信号显示_原来,iOS 11移动信号图标里竟然隐藏着 iPhone 8的巨大秘密...
- 分子对接(docking):蛋白质-蛋白质分子对接
- SSM框架将数据库数据导出为Excel文件
- 网吧服务器用户锁怎么办,如何进行网吧安全模式的锁定与解除
- 江苏省高中计算机课程标准,江苏省普通高中课程设置
- POI打印-----文件下载
- 【医学图像处理】1 (医学)图像及图像处理流程
- 数据可视化,是如何扭曲我们对现实的感知?
- word文档怎么给数字加千分符_word实用技巧:如何为数字添加千分符的3种方法