注意力机制:

父母在学校门口接送孩子的时候,可以在人群中一眼的发现自己的孩子,这就是一种注意力机制。
为什么父母可以在那么多的孩子中,找到自己的孩子?
比如现在有100个孩子,要被找的孩子发型是平头,个子中等,不戴眼镜,穿着红色上衣,牛仔裤
通过对这些特征,就可以对这100个孩子进行筛选,最后剩下的孩子数量就很少了,就是这些特征的存在,使得父母的注意力会主要放在有这些特征的孩子身上,这就是注意力机制。

注意力机制
Query 被找孩子的特征
Key 100个孩子,通过特征进行筛选,得到这100个孩子的可能性
Value 100个孩子中,找到自己孩子的可能性

attention = softmax(Q、K之间进行计算) * V
Q、K之间的计算方式不同,这就导致了不同的注意力机制。

最后一种就是Transformer中的一种注意力的计算机制。

实际应用中的理解

一般在自然语言处理应用里会把Attention模型看作是输出Target句子中某个单词和输入Source句子每个单词的对齐模型。
目标句子的每个单词 与输入句子中的每个单词 计算权重,计算注意力权重
类似于机器翻译中的短语对齐步骤

可以看到里面的 Q K V
QK之间的计算就是计算QK之间的相关性,或者说特征的相似性
这样就可以得到每个key对应的value的权重系数,然后与V相乘

Lx=||Source||代表Source的长度

计算过程

1.计算QK之前的相似度

2.softmax 归一化
3.对value进行加权求和**

代码实现

第一步:根据注意力计算规则,对Q,K,V进行相应的计算.
第二步:根据第一步采用的计算方法,如果是拼接方法,则需要将Q与第二步的计算结果再进行拼接,如果是转置点积,一般是自注意力,Q与V相同,则不需要进行与Q的拼接.
第三步:最后为了使整个attention机制按照指定尺寸输出,使用线性层作用在第二步的结果上做一个线性变换,得到最终对Q的注意力表示

第一步就是使用第一种计算的方式,获取注意力机制的权重,就是上边所说的孩子的特征占100个孩子权重
第三部就是为了获得指定尺寸的输出

import torch
from torch import nn
import torch.nn.functional as F# input = torch.randn(10, 3, 4)
# mat2 = torch.randn(10, 4, 5)
# res = torch.bmm(input, mat2)
# x = res.size()
# print(x)class Attention(nn.Module):def __init__(self,query_size, key_size, value_size1, value_size2, output_size):super(Attention, self).__init__()self.query_size = query_sizeself.key_size = key_sizeself.value_size1 = value_size1self.value_size2 = value_size2self.output_size = output_size# 第一种方式# 先拼接 然后进行线性变换 然后softmax# 最后乘V# 初始化注意力机制第一步# 两个size相加,是直接把矩阵拼接# 拼接后进行线性变换使用 (64,32)self.attn = nn.Linear(self.query_size + self.key_size, self.value_size1)# 最后乘V后 输出使用# 初始化注意力机制第三步# 线性变换 (96,64)self.attn_combine = nn.Linear(self.query_size + self.value_size2, self.output_size)#     Q K 都是三维数据 维度是相同的 (1,1,32)def forward(self, Q, K, V):# 采用上述第一种计算规则# 先进性QK的拼接以及线性变换,再经过softmax处理获得结果# 这里QKV都是三维张量# (1,32) (1,32) 在维度1 上的cat的维度为(1,64)# (1,64) * (64,32)# 结果为(1,32) 这是第一个线性变换的结果 然后在32这个维度上进行 softmax 最后的维度还是不变的attn_weights = F.softmax(self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)# 然后将结果 与 V相乘 (1,1,32) @ (1,32,64)= (1,1,64)attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)# 第二步,将Q与第一步的结果再进行拼接# (1,32) (1,64) 在第1个维度进行拼接 结果为(1,96)output = torch.cat((Q[0], attn_applied[0]), 1)# (1,96) * (96,64)=(1,64)# 经过unsqueeze (1,1,64)# 第三步,得到输出output = self.attn_combine(output).unsqueeze(0)return output, attn_weights# (1, 1, 64) (1,32)query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
output_size = 64
attn = Attention(query_size, key_size, value_size1, value_size2, output_size)
Q = torch.randn(1, 1, query_size)
print("---")
K = torch.randn(1, 1, key_size)
print(torch.cat((Q[0], K[0]), 1).shape)
V = torch.randn(1, value_size1, value_size2)
out = attn(Q, K, V)

pytorch 注意力机制相关推荐

  1. 64注意力机制 10章

    自主性提示:  我想....看想看的东西 非自主提示:随意一看     .看环境中现眼的东西      环境的东西就是键和值  我想要是查询 #@save def show_heatmaps(matr ...

  2. 各种注意力机制PyTorch实现

    给出了整个系列的PyTorch的代码实现,以及使用方法. 各种注意力机制 Pytorch implementation of "Beyond Self-attention: External ...

  3. 机器翻译注意力机制及其PyTorch实现

    前面阐述注意力理论知识,后面简单描述PyTorch利用注意力实现机器翻译 Effective Approaches to Attention-based Neural Machine Translat ...

  4. 神经网络中的注意力机制总结及PyTorch实战

    技术交流 QQ 群:1027579432,欢迎你的加入! 欢迎关注我的微信公众号:CurryCoder的程序人生 0.概述 当神经网络来处理大量的输入信息时,也可以借助人脑的注意力机制,只选择一些关键 ...

  5. 【深度学习】基于Pytorch多层感知机的高级API实现和注意力机制(三)

    [深度学习]基于Pytorch多层感知机的高级API实现和注意力机制(三) 文章目录 [深度学习]基于Pytorch多层感知机的高级API实现和注意力机制(三) 1 权重衰减 1.1 范数 1.2 L ...

  6. 【深度学习】基于Pytorch多层感知机的高级API实现和注意力机制(二)

    [深度学习]基于Pytorch多层感知机的高级API实现和注意力机制(二) 文章目录1 代码实现 2 训练误差和泛化误差 3 模型复杂性 4 多项式回归4.1 生成数据集4.2 对模型进行训练和测试4 ...

  7. 【深度学习】基于Pytorch多层感知机的高级API实现和注意力机制(一)

    [深度学习]基于Pytorch多层感知机的高级API实现和注意力机制(一) 文章目录 1 概述 2 从线性到非线性-激活函数2.1 ReLU函数2.2 sigmoid函数2.3 tanh函数 3 注意 ...

  8. keras cnn注意力机制_TensorFlow、PyTorch、Keras:NLP框架哪家强

    全文共3412字,预计学习时长7分钟 在对TensorFlow.PyTorch和Keras做功能对比之前,先来了解一些它们各自的非竞争性柔性特点吧. 非竞争性特点 下文介绍了TensorFlow.Py ...

  9. Attention 扫盲:注意力机制及其 PyTorch 应用实现

    点击上方"MLNLP",选择"星标"公众号 重磅干货,第一时间送达 来自 | 知乎 作者 | Lucas 地址 | https://zhuanlan.zhihu ...

  10. 【Pytorch神经网络理论篇】 20 神经网络中的注意力机制

    注意力机制可以使神经网络忽略不重要的特征向量,而重点计算有用的特征向量.在抛去无用特征对拟合结果于扰的同时,又提升了运算速度. 1 注意力机制 所谓Attention机制,便是聚焦于局部信息的机制,比 ...

最新文章

  1. Hugging Face官方NLP课程来了!Transformers库维护者之一授课,完全免费
  2. python | gtts 将文字转化为语音内容
  3. 【杂谈】野生在左 科班在右——数据结构学习誓师贴
  4. boost::intrusive::treap_set用法的测试程序
  5. IXDC 2018 | 打动人心的互联网保险设计
  6. 错误之Only one usage of each socket address (protocol/network address/port)解决办法
  7. 循环序列模型 —— 1.3循环神经网络
  8. C雨涵课后习题(18)
  9. (clion 安装插件联网络失败,pycharm pip联网失败)当电脑选择拨号上网时,解决系统代理被篡改/pip提示“目标计算机积极拒绝,无法连接”的方法! [ 此方法绝对解决系统代理被篡改问题 ]
  10. 十一月份英语学习总结—积累
  11. 红linux系统,红帽子linux系统
  12. windows下解压tar.gz文件
  13. emui 4.1 基于android 6.0,【荣耀V8评测】基于Android 6.0的EMUI 4.1_荣耀 V8_手机评测-中关村在线...
  14. mui下拉刷新 ,无法滑动
  15. 数据库同步有哪些方式?【怎么保障目标和源数据一致性】
  16. 42岁大龄程序员的看法
  17. 域适应(domain adaptation)
  18. httpclient设置代理
  19. java8新特性-stream学习
  20. dnn降噪_EdiCall通话降噪黑科技-漫步者蓝牙耳机技术有多强?

热门文章

  1. 深圳名校最新出炉 学校学区房房价飙升-查查吧深圳学区房地图
  2. 新手练字又快又好的方法
  3. python可以用于工业机器人编程与操作_非常实用的工业机器人编程语言有哪些?这些编程好用吗?...
  4. 计算机桌面组成部分教案,三年级第6课 《认识桌面》优秀教案
  5. sql server windows nt 64bit 占内存解决方法
  6. linux编译 __stdcall,Linux下的stdcall 约定格式
  7. Windows和Mac下的_stdcall
  8. 电容的耐压值选择---陶瓷电容、钽电容、电解电容
  9. LabVIEW控制Arduino实现红外测距(进阶篇—6)
  10. ultraISO方式制作win10安装U盘