1 seq2seq模型简介

seq2seq 模型是一种基于【 Encoder-Decoder】(编码器-解码器)框架的神经网络模型,广泛应用于自然语言翻译、人机对话等领域。目前,【seq2seq+attention】(注意力机制)已被学者拓展到各个领域。seq2seq于2014年被提出,注意力机制于2015年被提出,两者于2017年进入疯狂融合和拓展阶段。

1.1 seq2seq原理

通常,编码器和解码器可以是一层或多层 RNN、LSTM、GRU 等神经网络。为方便讲述原理,本文以 RNN 为例。seq2seq模型的输入和输出长度可以不一样。如图,Encoder 通过编码输入序列获得语义编码 C,Decoder 通过解码 C 获得输出序列。

seq2seq网络结构图

 Encoder

Decoder

说明:xi、hi、C、h'i 都是列向量

1.2 seq2seq+attention原理

普通的 seq2seq 模型中,Decoder 每步的输入都是相同的语义编码 C,没有针对性的学习,导致解码效果不佳。添加注意力机制后,使得每步输入的语义编码不一样,捕获的信息更有针对性,解码效果更佳。

seq2seq+attention网络结构图

Encoder

Decoder

(1)标准 attention

其中 ,v、W、U 都是待学习参数,v 为列向量,W、U 为矩阵

(2)attention 扩展

扩展的 attention 机制有3种方法,如下。其中,v、W 都是待学习参数,v 为列向量,W为矩阵。相较于标准的 attention,待学习的参数明显减少了些。

说明:xi、hi、Ci、h'i、wi 、ei 都是列向量,h 是矩阵

2 安装seq2seq

  • 下载【https://github.com/farizrahman4u/recurrentshop】,解压,通过cmd进入文件,输入 python setup.py install

  • 下载【https://github.com/farizrahman4u/seq2seq】,解压,通过cmd进入文件,输入 python setup.py install

  • 重启编译器

若下载比较慢,可以先通过【码云】导入,再在码云上下载,如下:

本文以MNIST手写数字分类为例,讲解 seq2seq 模型和 AtttionSeq2seq 模型的实现。关于MNIST数据集的说明,见使用TensorFlow实现MNIST数据集分类。

笔者工作空间如下:

代码资源见-->seq2seq模型和基于注意力机制的seq2seq模型

3 SimpleSeq2Seq

SimpleSeq2Seq(input_length, input_dim, hidden_dim, output_length, output_dim, depth=1)
  • input_length:输入序列长度
  • input_dim:输入序列维度
  • output_length:输出序列长度
  • output_dim:输出序列维度
  • depth:Encoder 和 Decoder 的深度,取值可以为整数或元组。如 depth=3,表示 Encoder 和 Decoder 都有 3 层;depth=(3, 4) 表示 Encoder 有3层和 Decoder 有4层

SimpleSeq2Seq.py

from tensorflow.examples.tutorials.mnist import input_data
from seq2seq.models import SimpleSeq2Seq
from keras.models import Sequential
from keras.layers import Dense,Flatten#载入数据
def read_data(path):mnist=input_data.read_data_sets(path,one_hot=True)train_x,train_y=mnist.train.images.reshape(-1,28,28),mnist.train.labels,valid_x,valid_y=mnist.validation.images.reshape(-1,28,28),mnist.validation.labels,test_x,test_y=mnist.test.images.reshape(-1,28,28),mnist.test.labelsreturn train_x,train_y,valid_x,valid_y,test_x,test_y#SimpleSeq2Seq模型
def seq2Seq(train_x,train_y,valid_x,valid_y,test_x,test_y):#创建模型model=Sequential()seq=SimpleSeq2Seq(input_dim=28,hidden_dim=32,output_length=10,output_dim=10)model.add(seq)model.add(Flatten())  #扁平化model.add(Dense(10,activation='softmax'))#查看网络结构model.summary()#编译模型model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])#训练模型model.fit(train_x,train_y,batch_size=500,nb_epoch=25,verbose=2,validation_data=(valid_x,valid_y))    #评估模型pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)print('test_loss:',pre[0],'- test_acc:',pre[1])train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
seq2Seq(train_x,train_y,valid_x,valid_y,test_x,test_y)

网络各层输出尺寸:

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
model_14 (Model)             (None, 10, 10)            10368
_________________________________________________________________
flatten_1 (Flatten)          (None, 100)               0
_________________________________________________________________
dense_23 (Dense)             (None, 10)                1010
=================================================================
Total params: 11,378
Trainable params: 11,378
Non-trainable params: 0

网络训练结果:

Epoch 23/25- 17s - loss: 0.1521 - acc: 0.9563 - val_loss: 0.1400 - val_acc: 0.9598
Epoch 24/25- 17s - loss: 0.1545 - acc: 0.9553 - val_loss: 0.1541 - val_acc: 0.9536
Epoch 25/25- 17s - loss: 0.1414 - acc: 0.9594 - val_loss: 0.1357 - val_acc: 0.9624
test_loss: 0.14208583533763885 - test_acc: 0.9567999958992004

4 AttentionSeq2Seq

AttentionSeq2Seq(input_length, input_dim, hidden_dim, output_length, output_dim, depth=1)
  • input_length:输入序列长度
  • input_dim:输入序列维度
  • output_length:输出序列长度
  • output_dim:输出序列维度
  • depth:Encoder 和 Decoder 的深度,取值可以为整数或元组。如 depth=3,表示 Encoder 和 Decoder 都有 3 层;depth=(3, 4) 表示 Encoder 有3层和 Decoder 有4层

AttentionSeq2Seq.py

from tensorflow.examples.tutorials.mnist import input_data
from seq2seq.models import AttentionSeq2Seq
from keras.models import Sequential
from keras.layers import Dense,Flatten#载入数据
def read_data(path):mnist=input_data.read_data_sets(path,one_hot=True)train_x,train_y=mnist.train.images.reshape(-1,28,28),mnist.train.labels,valid_x,valid_y=mnist.validation.images.reshape(-1,28,28),mnist.validation.labels,test_x,test_y=mnist.test.images.reshape(-1,28,28),mnist.test.labelsreturn train_x,train_y,valid_x,valid_y,test_x,test_y#AttentionSeq2Seq模型
def seq2Seq(train_x,train_y,valid_x,valid_y,test_x,test_y):#创建模型model=Sequential()seq=AttentionSeq2Seq(input_length=28,input_dim=28,hidden_dim=32,output_length=10,output_dim=10)model.add(seq)model.add(Flatten())  #扁平化model.add(Dense(10,activation='softmax'))#查看网络结构model.summary()#编译模型model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])#训练模型model.fit(train_x,train_y,batch_size=500,nb_epoch=25,verbose=2,validation_data=(valid_x,valid_y))    #评估模型pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)print('test_loss:',pre[0],'- test_acc:',pre[1])train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
seq2Seq(train_x,train_y,valid_x,valid_y,test_x,test_y)

网络各层输出尺寸:

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
model_102 (Model)            (None, 10, 10)            24459
_________________________________________________________________
flatten_6 (Flatten)          (None, 100)               0
_________________________________________________________________
dense_176 (Dense)            (None, 10)                1010
=================================================================
Total params: 25,469
Trainable params: 25,469
Non-trainable params: 0

网络训练结果:

Epoch 23/25- 36s - loss: 0.0533 - acc: 0.9835 - val_loss: 0.0719 - val_acc: 0.9794
Epoch 24/25- 37s - loss: 0.0511 - acc: 0.9843 - val_loss: 0.0689 - val_acc: 0.9800
Epoch 25/25- 37s - loss: 0.0473 - acc: 0.9860 - val_loss: 0.0700 - val_acc: 0.9802
test_loss: 0.06055343023035675 - test_acc: 0.9825000047683716

SimpleSeq2Seq 模型和 AttentionSeq2Seq 模型的预测精度分别为 0.9568、0.9825,说明添加注意力机制后,预测精度有了明显的提示。

seq2seq模型案例分析相关推荐

  1. R语言Logistic回归模型案例:分析吸烟、饮酒与食管癌的关系

    R语言Logistic回归模型案例:分析吸烟.饮酒与食管癌的关系 目录 R语言Logistic回归模型案例分析吸烟.饮酒与食管癌的关系 #样例数据

  2. 单方程误差修正模型案例分析

    单方程误差修正模型案例分析 数据的生成 set.seed(12345) u<-rnorm(500) x<-cumsum(u) y<-x+u E-G协整估计及检验 model.lm&l ...

  3. SWAT模型案例分析

    SWAT模型的产生 SWAT模型的最直接前身是SWRRB模型.而SWRRB模型则起始于20世纪70年代美国农业部农业研究中心开发的CREAMS(Chemicals, Runoff, and Erosi ...

  4. 卡诺模型案例分析_3个维度看竞品分析!

    谁都想站在巨人的肩膀上,问题是怎么上去? ABC分享会线下24期回顾 时间:10月24日 下午13:00-17:30 地点:上海嘉定U-CUBE创意空间 参与人数:18人 主题:怎样做竞品分析 这次活 ...

  5. 卡诺模型案例分析_AMOS案例分析 | 结构方程模型(二)

    使用结构方程的方法进行模型的验证 1. 替换缺失值 在读取数据文件之前,对数据的完整性问题做适当处理.利用SPSS进行数据缺失值的处理.操作过程:转换→替换缺失值 输出结果:形成新的数据集.从下表中可 ...

  6. 卡诺模型案例分析_质量管理神器,Kano模型您可知道?

    在六西格玛中,倾听.分析.理解顾客的需求 (voice of customer)是非常重要的,而Kano模型就是这个环节中关键的工具之一. 而笔者觉得更为重要的是,Kano模型不仅仅是质量领域的重要工 ...

  7. 3sigma模型案例分析彻底搞懂置信度与置信区间

    学习机器学习算法时,经常会碰到数理统计中置信区间.置信度,虽然学习过相关课程,但是每次遇到它总是懵懵懂懂,似懂非懂.为了对这两个概念有深入的了解,这里做了相关的介绍.为了不老是纠缠于数理统计理论,或者 ...

  8. python做马尔科夫模型预测法_python 日常笔记 hmmlearn 隐性马尔科夫模型案例分析...

    问题: 什么是马尔科夫模型?用来干什么? 大家可以参考这篇简书 python 实现 关于HMM有两个主要问题: 已知上述三个参数,和当前观测序列,求解隐藏状态的变化 所有参数未知,只有数据,如何获得三 ...

  9. GARCH模型案例分析

    read data library(quantmod) # 加载包 getSymbols('^HSI', from='1989-12-01',to='2013-11-30') # 从Yahoo网站下载 ...

最新文章

  1. 小A与任务 (贪心 优先队列)
  2. 成功检测远距离目标,将点云与RGB图像结合,谷歌Waymo提出新算法:4D-Net
  3. NodeJS”热部署“代码,实现动态调试
  4. Web 标准实战的评论
  5. Linux 笔记(持续补充)
  6. iOS 各种系统文件目录 临时,缓存,document,lib,归档,序列化
  7. Short-Session的推荐如何做?
  8. 计算机在线采集数据注意,全站仪数据采集和传输中的常见问题解决方案
  9. 39.数组中数值和下标相等的元素
  10. POJ 3660 Cow Contest(传递闭包floyed算法)
  11. HashMap与ConcurrentHashMap的测试报告
  12. 如何往linux上面上传东西
  13. 计算机组成原理第四章中,计算机组成原理第四章..ppt
  14. 等保测评--网络安全等级保护定级指南
  15. Xshell远程连接配置 Ubuntu 18.04.6 + Anaconda + CUDA + Cudnn + Pytorch(GPU+CPU)
  16. Jenkins linux 操作系统一键部署多节点
  17. 第三阶段应用层——1.12 数码相册—interval_page设置时间间隔界面的显存管理、页面规划、输入控制
  18. 优雅发送HTTP请求
  19. Unity游戏开发程序员学习线路图及技能提升指南
  20. div用css显示隐藏的效果

热门文章

  1. 目标检测DOTA数据集预处理相关函数
  2. jvm 崩溃日志设置_JVM崩溃的原因及解决过程
  3. Python分析今年的月饼之王花落谁家?
  4. debug断点调试中,查看request中的parameter值
  5. 三毛6-- 沙漠观浴记
  6. Alfred4配置信息复制到其他电脑
  7. 解释型语言-shell
  8. 【SVN】Clean failed to process the following paths:
  9. 在终端显示bash:/home/this/catkin_ws/setup.bash:没有那个文件或目录 的解决方法
  10. (22)STM32——RTC时钟笔记(基于正点原子探索者)