之前比较常用的 sequence to sequence 学习方法大多数都利用了 RNN,但是 RNN 需要逐个处理序列数据,速度相对较慢。FaceBook 在 2017 年提出了一种使用 CNN 实现的 Seq2Seq 学习算法,该算法完全使用卷积模型,并在翻译任务中取得比以往更好的效果。

1.前言

RNN 的链式结构使其适合用于序列学习任务,RNN 通过一个隐藏向量保存序列信息。但是 RNN 需要按照顺序逐个处理,不容易并行化,因此通常速度比较慢。

FaceBook 在论文《Convolutional Sequence to Sequence Learning》中提出了一种用卷积实现的 Seq2Seq 模型,并在翻译任务上超过了之前的模型。

与 RNN 不同,CNN 的卷积核通常只能感受一个固定窗口的上下文信息,例如大小为 3 的卷积核,只能生成 3 个时刻的单词对应的特征。给定句子 "常回家看看",大小为 3 的卷积核可以生成 "常回家","回家看","家看看" 对应的特征。但是 CNN 可以通过堆叠多层的卷积核扩大感受野,使模型学习更长的依赖关系。如下图所示,使用两层卷积,则最上层的蓝色节点可以处理长度为 5 的句子。

多层卷积得到更大感受野

CNN 不依赖于前一时刻的输出,能够在整个序列上并行化处理,因此具有更快的速度。

CNN 可以生成句子层级的表示,相隔近的单词之间的关系会在低层级的卷积核处理,而相隔远的单词会在高层级的卷积核处理。这种层级结构使 CNN 可以利用更短的路径获取序列长期的信息。如果句子长度为 n,卷积核大小为 k,则 CNN 可以使用 O(n/k) 个卷积操作得到整个句子信息,而 RNN 需要 O(n) 个操作。

CNN 对于输入的每一个 token 都会经过相同数量的卷积操作,而 RNN 中第一个 token 会经过 n 次 RNN 操作,而最后一个 token 只会经过一次 RNN 操作。对输入采用固定数量的操作也会让模型更容易训练。

2.CNN Seq2Seq 结构

上图是 CNN Seq2Seq 模型,Encoder 和 Decoder 都是由 CNN 组成。上方的部分为 Encoder,下方左侧为 Decoder,下方右侧是 Decoder 进行 Attention 之后预测下一时刻输出。下文用 h 表示 Decoder输出,z 表示 Encoder 输出。

2.1 Position Embedding

RNN 是按照顺序输入的,因此一般不需要加上位置编码。但是 CNN 不能感知每个 token 的顺序,需要加上位置编码。

假设输入序列 x 包括 m 个 token,x = (x1, x2, ..., xm)。序列中每一个 token 的词向量是 w = (w1, w2, ..., wm),每一个位置的编码是 p = (p1, p2, ..., pm),则最终序列的表征 e 如下:

Encoder 和 Decoder 均是采用这样的编码。

2.2 Convolutional Structure

CNN Seq2Seq 模型的 Encoder 和 Decoder 都包含 L 个 Convolutional block (即 L 层卷积)。每一层都是采用多个一维卷积核,模型中的卷积核采用了 Gated Linear Units (GLU) 作为线性变换,即在原本卷积的输出上加上一个门结构控制其输出。接下来介绍 GLU 操作。

  • 每一个卷积核接收的输入X 的维度是 (k×d),k 是卷积核的尺寸,d是输入的 token 表示向量的维度。
  • 每一层包含 2d 个卷积核,则每一层卷积核的参数包括权重矩阵 W (2d×kd) 和偏置 b (2d)。
  • 卷积核处理完输入X (k×d) 后得到的输出是 Y (2d),即 Y 是维度等于 2d 的向量,可以分为两个维度是 d 的向量 A,B。在 A,B 加上门机制就是 GLU,如下公式。
  • σ 表示 sigmoid 函数,将 B 每一位转为 [0, 1] 之间的数,经过 GLU 之后的输出 v([A,B]) 维度与原来的 X 一致,均为 d 维。
  • 在输入序列中加上合适尺寸的 PAD,卷积后的序列长度可以保持和输入长度一致 (注意不是维度 d),这样可以加上残差结构,使模型可以更深。
  • Decoder 利用 i 时刻最后一层卷积层的输出预测 i+1 时刻的输出,如下

2.3 Multi-step Attention

CNN Seq2Seq 模型的 Decoder 采用了 Multi-step Attention,对于 Decoder 的每一层输出都计算 Attention。类似 Transformer,计算 attention 需要 query、key 和 value 向量。

对于第 l 层第 i 时刻的输出 hl(i),我们可以计算出其对应的 query 向量。

而 key 向量是 Encoder 最后一层卷积层 u 的所有时刻输出 zu(1), ..., zu(m),使用下面的公式计算 attention 值。

得到 attention 值之后将 value 向量加权结合在一起,value 向量是 Encoder 的输出 zu 和 Encoder 输入的 token 向量表示 e,如下所示。

将 attention 之后的向量 c 和 Decoder 的输出 h 相加就是最终的输出向量。

3.归一化和初始化

为了保证网络的方差不会大幅度的变化,论文中采用了一些归一化和初始化方法。

归一化

  • 对于残差结构的输入和输出均乘上 sqrt(0.5),即根号 0.5,使总和的方差减少一半。
  • 对于 attention 对 m 个向量加权得到的向量 c,需要乘上 m/sqrt(m)。其中乘上 m 是为了使向量放大回原来的尺寸,除以 sqrt(m) 主要为了抵消方差的变化。
  • 模型的 Decoder 采用了多重 attention 的机制,因此作者根据 attention 的个数对 Encoder 的梯度进行缩放,避免 Encoder 接收过多的梯度。

初始化

  • 所有的 Embedding 均采用均值为 0,标准差为 0.1 的正态分布初始化。
  • 对于不直接传入 GLU 的层,使用 N(0, sqrt(1/n)) 进行初始化,n 表示连接到该神经元的连接数量,这可以保证输入的正态分布方差得到保留。
  • 对于 GLU 激活之后的层,如果 GLU 的输入的分布均值为 0 且方差很小,输出的方差约为输入方差的 1/4。因此,需要初始化 GLU 输入的方差为后续层的 4 倍,即 N(0, sqrt(4/n))。
  • bias 统一初始化为 0。
  • 如果设置了概率为 p 的 dropout,会导致方差缩放 1/p,因此上述的初始化要分别改为 N(0, sqrt(p/n)) 和 N(0, sqrt(4p/n))。

4.参考文献

Convolutional Sequence to Sequence Learning

seq2seq模型_用 CNN 实现 Seq2Seq 模型相关推荐

  1. seq2seq模型_对话生成:seq2seq模型原理及优化

    更多干货内容敬请关注「平安寿险PAI」(公众号ID:PAL-AI),文末有本期分享内容资料获取方式. 人机对话作为人机交互系统的核心功能之一,发挥着十分重要的作用.目前,生成式的人机对话存在内容把控性 ...

  2. lr模型和dnn模型_建立ML或DNN模型的技巧

    lr模型和dnn模型 机器学习 (Machine Learning) Everyone can fit data into any model machine learning or deep lea ...

  3. 网页怎么预先加载模型_使用预先训练的模型进行转移学习

    网页怎么预先加载模型 深度学习 (Deep Learning) 什么是转学? (What is Transfer Learning?) Transfer learning is a research ...

  4. 全面理解java内存模型_深入理解Java内存模型(八)——总结

    处理器内存模型 顺序一致性内存模型是一个理论参考模型,JVM和处理器内存模型在设计时通常会把顺序一致性内存模型作为参照.JVM和处理器内存模型在设计时会对顺序一致性模型做一些放松,因为如果完全按照顺序 ...

  5. 怎么做 空间杜宾模型_面板数据空间杜宾模型

    4.3.1 模型及估计 (1) 无固定效应模型 当SAR和SEM模型在一定的显著性水平下同时成立时,我们需要进一步考虑面板数据空间杜宾模型,即解释变量的空间滞后项影响被解释变量时,就应该考虑建立空间杜 ...

  6. python knn模型_使用Python训练KNN模型并进行分类

    K临近分类算法是数据挖掘中较为简单的一种分类方法,通过计算不同数据点间的距离对数据进行分类,并对新的数据进行分类预测.我们在之前的文章<K邻近(KNN)分类和预测算法的原理及实现>和< ...

  7. fluent p1模型_干货 | ANSYS Fluent燃烧模型简介

    燃烧是一种相当复杂的化学反应,通常还伴随着流体流动.离散相颗粒扩散.传热.污染物产生等多种物理情况.为尽可能详细仿真多种化学反应,ANSYS Fluent提供了多种化学反应模型如EDC,EDM,PDF ...

  8. utxo模型_墨客UTXO和account模型 |技术教程

    来源:雪球App,作者: 一笑奈何君,(https://xueqiu.com/9803210374/134555099) 在当前区块链世界中,主要有两种记录保存方式,UTXO 模式(Unspent T ...

  9. seq2seq模型_推断速度达seq2seq模型的100倍,谷歌开源文本生成新方法LaserTagger

    使用 seq2seq 模型解决文本生成任务伴随着一些重大缺陷,谷歌研究人员提出新型文本生成方法 LaserTagger,旨在解决这些缺陷,提高文本生成的速度和效率. 选自arXiv,作者:Eric M ...

最新文章

  1. Kafka设计解析(三) : Kafka High Availability (下)
  2. 细数那些年我用过的前端开发工具
  3. jQuery LigerUI 使用教程入门篇
  4. python 多窗口编辑
  5. 《Troubleshooting Windows 7 Inside Out》文摘-1
  6. LeetCode—数据库简单题(三)
  7. cl.g4r.win index.php,win7 wamp环境配置Oracle数据库连接
  8. QNX系统上用Berkeley Packet Filter直接进行原始数据的收发
  9. 两向量点乘坐标运算_向量点乘(内积)和叉乘(外积、向量积)概念及几何意义解读...
  10. MyBatisX插件没有出现蓝色鸟
  11. javascript判断文本语言类型
  12. “小米汽车”商标被纺织品公司抢注
  13. 企业公众号运营见效难,如何突围?
  14. thinkphp6 websocket-room的加入房间+离开房间+房间消息发送
  15. 伸缩自如的ElasticSearch——通过bboss操作和访问elasticsearch模式
  16. 【附源码】计算机毕业设计SSM宁夏旅游信息管理系统
  17. 2023年,哪些Web3赛道的表现最值得期待?(文末有奖)
  18. scrollTop兼容性问题
  19. 江西彩礼到底有多高?我问了问身边的朋友们……
  20. EDKII实现bmp图片加载并显示的应用程序

热门文章

  1. 类里面没有参缺省构造函数 的带来的问题
  2. CCNP实验4-2:配置多区域和NBMA OSPF
  3. thymeleaf常用语法
  4. Ubuntu 通过apt安装VSCode
  5. JavaScript多继承(转载)
  6. netty源码解解析(4.0)-5 线程模型-EventExecutorGroup框架
  7. 链表基础操作及其逆置
  8. C++内存管理变革(3):另类内存管理
  9. BCB 连接数据库和查询数据
  10. mysql在centos下用命令批量导入报错_Variable ‘character_set_client‘ can‘t be set to the value of ‘---linux工作笔记042