时间序列预测——DA-RNN模型

作者:梅昊铭

1. 背景介绍

传统的用于时间序列预测的非线性自回归模型(NRAX)很难捕捉到一段较长的时间内的数据间的时间相关性并选择相应的驱动数据来进行预测。本文将介绍一种基于 Seq2Seq 模型(Encoder-Decoder 模型)并结合 Attention 机制的时间序列预测方法。作者提出了一种双阶段的注意力机制循环神经网络模型(DA-RNN),能够很好的解决上述两个问题。

模型的第一部分,我们引入输入注意力机制在每个时间步选择相应的输入特征。模型的第二部分,我们使用时间注意力机制在整个时间步长中选择相应的隐藏层状态。通过这种双阶段注意力机制,我们能够有效地解决一些时序预测方面的问题。我们将这两个注意力机制模型集成在基于 LSTM 的循环神经网络中,使用标准反向传播进行联合训练。

2. DA-RNN 模型

2.1 输入与输出

输入:给定 n 个驱动序列(输入特征),X=(x1,x2,...,xn)T=(x1,x2,...,xT)∈Rn×TX = (x^1,x^2,...,x^n)^T = (x_1,x_2,...,x_T) \in R^{n \times T}X=(x1,x2,...,xn)T=(x1​,x2​,...,xT​)∈Rn×T ,TTT 表示时间步长,nnn 表示输入特征的维度。

输出:y^T=F(y1,...,yT−1,x1,...,xT)\hat{y}_{T}= F(y_1,...,y_{T-1},x_1,...,x_T)y^​T​=F(y1​,...,yT−1​,x1​,...,xT​)。(y1,...,yT−1)(y_1,...,y_{T-1})(y1​,...,yT−1​)表示预测目标过去的值,其中 yt∈Ry_t\in Ryt​∈R;(x1,...,xT)(x_1,...,x_T)(x1​,...,xT​) 为时间 TTT 内 nnn 维的外源驱动输入序列,xt∈Rnx_t \in R^nxt​∈Rn;F(⋅)F(\cdot)F(⋅) 为模型需要学习的非线性映射函数。

2.2 模型结构

DA-RNN 模型是一种基于注意力机制的 Encoder-Decoder 模型。在编码器部分,我们引入了输入注意力机制来选择相应的驱动序列;在解码器部分,我们使用时间注意力机制来选择整个儿时间步长中相应的隐藏层状态。通过这个两种注意力机制,DA-RNN 模型能够选择最相关的输入特征,并且捕捉到较长时间内的时间序列之间的依赖关系,如图1所示。


图 1:DA-RNN 模型结构

2.3 编码器

编码器本质上是一个 RNN 模型,它能够将输入序列转换为一种特征表示,我们称之为隐藏层状态。对于时间序列预测问题,给定输入 X=(x1,x2,...,xT)∈Rn×T,xt∈RnX = (x_1,x_2,...,x_T) \in R^{n \times T},x_t \in R^nX=(x1​,x2​,...,xT​)∈Rn×T,xt​∈Rn,在时刻 ttt ,编码器将 xtx_txt​ 映射为 hth_tht​:ht=f1(ht−1,xt)h_t = f_1(h_{t-1},x_t)ht​=f1​(ht−1​,xt​),ht∈Rmh_t \in R^mht​∈Rm 表示编码器隐藏层在时刻 ttt 的状态,mmm 表示隐藏层的维度,KaTeX parse error: Expected group after '_' at position 2: f_̲ 为非线性激活函数,本文中我们使用 LSTM。

本文中,我们提出了一种输入注意力机制编码器。它能够适当地选择相应的驱动序列,这对时间序列预测是至关重要的。我们通过确定性注意力模型来构建一个输入注意力层。它需要将之前的隐藏层状态ht−1h_{t-1}ht−1​ 和** LSTM** 单元的** cell **状态 st−1s_{t-1}st−1​ 作为该层的输入得到:
etk=veTtanh(We[ht−1;st−1]+Uexk)e^k_t = v^T_etanh(W_e[h_{t-1};s_{t-1}]+U_ex^k)etk​=veT​tanh(We​[ht−1​;st−1​]+Ue​xk),其中ve∈RT,We∈RT×2m,Ue∈RT×Tv_e \in R^T,W_e \in R^{T \times 2m},U_e \in R^{T \times T}ve​∈RT,We​∈RT×2m,Ue​∈RT×T是需要学习的参数。
输入注意力层的输出 (et1,et2,...,etn)(e^1_t,e^2_t,...,e^n_t)(et1​,et2​,...,etn​) 输入到 softmax 层得到 αtk\alpha_t^kαtk​ 以确保所有的注意力权重的和为1,αtk\alpha_t^kαtk​ 表示在时刻 ttt 第 kkk 个输入特征的重要性。

得到注意权重后,我们可以自适应的提取驱动序列 x~t=(αt1xt1,αt2xt2,...,αtnxtn)\tilde x_t = (\alpha^1_tx^1_t,\alpha^2_tx^2_t,...,\alpha^n_tx^n_t)x~t​=(αt1​xt1​,αt2​xt2​,...,αtn​xtn​),此时我们更新隐藏层的状态为 ht=f1(ht−1,x~t)h_t = f_1(h_{t-1},\tilde x_t)ht​=f1​(ht−1​,x~t​)。

2.4 解码器

为了预测输出 y^T\hat y_Ty^​T​,我们使用另外一个 LSTM 网络层来解码编码器的信息,即 隐藏层状态 KaTeX parse error: Expected group after '_' at position 2: h_̲。当输入序列过长时,传统的Encoder-Decoder 模型效果会急速恶化。因此,在解码器部分,我们引入了时间注意力机制来选择相应的隐藏层状态。

与编码器中注意力层类似,解码器的注意力层也需要将之前的隐藏层状态dt−1d_{t-1}dt−1​ 和LSTM 单元的cell状态 st−1′s'_{t-1}st−1′​ 作为该层的输入得到该层的输出:
lti=vdTtanh(Wd[dt−1;st−1′]+Udhi)l^i_t = v^T_dtanh(W_d[d_{t-1};s'_{t-1}]+U_dh_i)lti​=vdT​tanh(Wd​[dt−1​;st−1′​]+Ud​hi​),其中vd∈Rm,Wd∈Rm×2p,Ue∈Rm×mv_d \in R^m,W_d \in R^{m \times 2p},U_e \in R^{m \times m}vd​∈Rm,Wd​∈Rm×2p,Ue​∈Rm×m是需要学习的参数。通过 softmax 层,我们可以得到第 iii 个编码器隐藏状态 hih_ihi​ 对于最终预测的重要性 βti\beta^i_tβti​。解码器将所有的编码器隐藏状态按照权重求和得到文本向量 ct=∑i=1Tβtihic_t = \sum_{i=1}^T \beta_t^ih_ict​=∑i=1T​βti​hi​,注意 ctc_tct​ 在不同的时间步是不同的。

在得到文本向量之后,我们将其和目标序列结合起来得到 y~t−1=w~T[yt−1;ct−1]+b~\tilde y_{t-1} = \tilde w^T[y_{t-1};c_{t-1}]+\tilde by~​t−1​=w~T[yt−1​;ct−1​]+b~。利用新计算得到的 y~t−1\tilde y_{t-1}y~​t−1​,我们来更新解码器隐藏状态 dt=f2(dt−1,y~t−1)d_t=f_2(d_{t-1},\tilde y_{t-1})dt​=f2​(dt−1​,y~​t−1​),我们使用 LSTM 来作为激活函数 f2f_2f2​。
通过 DA-RNN 模型,我们预测 y^T=F(y1,...,yT−1,x1,...,xT)=vyT(Wy[dT;cT]+bw)+bv\hat y_T = F(y_1,...,y_{T-1},x_1,...,x_T) = v_y^T(W_y[d_T;c_T]+b_w)+b_vy^​T​=F(y1​,...,yT−1​,x1​,...,xT​)=vyT​(Wy​[dT​;cT​]+bw​)+bv​。

2.5 训练过程

在该模型中,作者使用平均方差作为目标函数,利用 Adam 优化器,min-batch 为128来进行参数优化。
目标函数:
O(yT,y~T)=1N∑i=1N(y^Ti−yTi)2O(y_T,\tilde y_T)=\frac{1}{N}\sum_{i=1}^N(\hat y^i_T-y_T^i)^2O(yT​,y~​T​)=N1​i=1∑N​(y^​Ti​−yTi​)2

3. 实验

3.1 数据集

本文的作者采用了,两种不同的数据集来测试验证 DA-RNN 模型的效果。这里我们仅对 NASDAQ 100 Stock 数据集进行介绍。作者根据 NASDAQ 100 Stock 收集了 81 家主要公司的股票价格作为驱动时间序列,NASDAQ 100 的股票指数做目标序列。数据收集的频率为一分钟一次。该数据集包含了从2016年7月26日至2016年12月22日总共105天的数据。在本实验中,作者使用 35100 条数据作为训练集,2730条数据作为验证集,以及最后2730条数据作为测试集。

3.2 参数设置和评价指标

时间窗口的大小 T∈{3,5,10,15,25}T \in \{3,5,10,15,25\}T∈{3,5,10,15,25}。实验表明 :T=10 时,模型在验证集上的效果最好。编码器和解码器隐藏层的大小 m,p∈{16,32,64,128,256}m ,p\in\{16,32,64,128,256\}m,p∈{16,32,64,128,256}。当m=p=64,128m=p=64,128m=p=64,128 时,实验效果最好。

为评估模型的效果,我们考虑了三种不同的评价指标:RSME,MAE,MAPE。

3.3 模型预测

为展示 DA-RNN 模型的效果,作者将该模型和其他的模型在两个不同的数据集上的预测效果进行了对比,如表1所示。由表1可以看出,DA-RNN模型相对于其他模型,误差更小一些。DA-RNN模型在时间序列预测方面具有良好的表现。

表 1:SML 2010数据集和纳斯达克100股票数据集的时间序列预测结果

为了更好的视觉比较,我们将Encoder-Decoder 模型,Attention RNN 和 DA-RNN 模型的在纳斯达克100股票数据集上的预测结果在图2中展示出来。我们不难看出DA-RNN模型能更好地反映真实情况。

图 3:三种模型在纳斯达克100股票数据集上的预测结果

4. 总结

在本文中,我们介绍了一种基于注意力机制的双阶段循环神经网络模型。该模型由两部分组成:Encoder 和 Decoder。在编码器部分,我们引入了输入注意力机制来对输入特征进行特征提取,为相关性较高的特征变量赋予更高的权重;在解码器部分,我们通过时间注意力机制为不同时间 ttt 的隐藏状态赋予不同的权重,不断地更新文本向量,来找出时间相关性最大的隐藏层状态。Encoder 和 Decode 中的注意力层分别从空间和时间上来寻找特征表示和目标序列之间的相关性,为不同的特征变量赋予不同的权重,以此来更准确地预测目标序列。
项目源码地址:https://momodel.cn/workspace/5da8cc2ccfbef78329c117ed?type=app

5. 参考资料

  1. 论文:A Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction
  2. 注意力机制详解:https://blog.csdn.net/BVL10101111/article/details/78470716
  3. 项目源码:https://github.com/chensvm/A-Dual-Stage-Attention-Based-Recurrent-Neural-Network-for-Time-Series-Prediction
  4. 数据集:https://cseweb.ucsd.edu/~yaq007/NASDAQ100_stock_data.html

欢迎关注我们的微信公众号:MomodelAI

同时,欢迎使用 「Mo AI编程」 微信小程序

以及登录官网,了解更多信息:Mo 平台

Mo,发现意外,创造可能

【Mo 人工智能技术博客】时间序列预测——DA-RNN模型相关推荐

  1. 【Mo 人工智能技术博客】采用 Python 机器学习预测足球比赛结果

    采用 Python 机器学习预测足球比赛结果 足球是世界上最火爆的运动之一,世界杯期间也往往是球迷们最亢奋的时刻.比赛狂欢季除了炸出了熬夜看球的铁杆粉丝,也让足球竞猜也成了大家茶余饭后最热衷的话题.甚 ...

  2. 【Mo 人工智能技术博客】胶囊网络——Capsule Network

    胶囊网络--Capsule Network 作者:林泽龙 1. 背景介绍 CNN 在处理图像分类问题上表现非常出色,已经完成了很多不可思议的任务,并且在一些项目上超过了人类,对整个机器学习的领域产生了 ...

  3. 【Mo 人工智能技术博客】StarGAN——生成你的明星脸

    1 GAN 介绍 GAN,叫做生成对抗网络 (Generative Adversarial Network) .其基本原理是生成器网络 G(Generator) 和判别器网络 D(Discrimina ...

  4. 【Mo 人工智能技术博客】多标准中文分词 Multi-Criteria-CWS

    多标准中文分词 Multi-Criteria-CWS 作者:宋彤彤 自然语言处理(NLP)是人工智能中很重要且具有挑战性的方向,而自然语言处理的第一步就是分词,分词的效果直接决定和影响后续工作的效率. ...

  5. 【Mo 人工智能技术博客】基于耦合网络的推荐系统

    基于耦合网络的推荐系统 作者:陈东瑞 1.复杂网络基础知识 当我们拿起手机给家人.朋友或者同事拨打电话时,就不知不觉中参与到了社交网络形成的过程中:当我们登上高铁或者飞机时,就可以享受交通网络给我们带 ...

  6. 【Mo 人工智能技术博客】利用Logistic函数和LSTM分析疫情数据

    利用Logistic函数和LSTM分析疫情数据 作者:林泽龙 Mo 1. 背景 2019 新型冠状病毒 (SARS-CoV-2),曾用名 2019-nCoV,通用简称新冠病毒,是一种具有包膜的正链单股 ...

  7. 【Mo 人工智能技术博客】使用 Seq2Seq 实现中英文翻译

    1. 介绍 1.1 Deep NLP 自然语言处理(Natural Language Processing,NLP)是计算机科学.人工智能和语言学领域交叉的分支学科,主要让计算机处理或理解自然语言,如 ...

  8. 【Mo 人工智能技术博客】现在最流行的图神经网络库 pytorch geometric 上手教学

    简介 Graph Neural Networks 简称 GNN,称为图神经网络.近年来 GNN 在学术界受到的关注越来越多,与之相关的论文数量呈上升趋势,GNN 通过对信息的传递,转换和聚合实现特征的 ...

  9. 【Mo 人工智能技术博客】基于 Python 和 NLTK 的推特情感分析

    基于 Python 和 NLTK 的推特情感分析 作者:宋彤彤 1. 导读 NLTK 是 Python 的一个自然语言处理模块,其中实现了朴素贝叶斯分类算法.这次 Mo 来教大家如何通过 python ...

最新文章

  1. 怎么DIY一个粒子检测器
  2. 2020年球云计算市值或将达4490亿欧元
  3. Neutron 架构 - 每天5分钟玩转 OpenStack(67)
  4. MacOS 的软件包管理工具 MacPorts
  5. 关卡 动画 蓝图 运行_UE4无缝过场动画
  6. python线程池操作_python线程池和进程池
  7. Pessimistic and Optimistic locking
  8. 语言生日创意代码_BlenderOSL代码编程
  9. python多态_记录学习python第9天-继承/多态
  10. Asp.Net Core 发布IIS报错 HTTP Error 500.30 - ASP.NET Core app failed to start
  11. Unable to round-trip http request to upstream: EOF
  12. [XA]读书感想:个人对敏捷软件开发宣言的理解
  13. Atitit 高性能架构之道 attilax著 艾龙 著 1. 应用服务与数据隔离 2 2. 负载均衡你问题 2 2.1. 用户的请求由谁来转发到到具体的应用服务器 2 2.2. 有什么转发的算法
  14. 调用百度语音合成API,Qt实现语音合成,Qt语音合成
  15. 新加坡10月新人扎堆结婚,只要生娃政府就发3000新币
  16. (转)国企,私企与外企利弊通观--关键时刻给应届毕业生及时点拨5
  17. 这些雷达书籍,你需要收藏~(终极大汇总)
  18. 发送邮件服务器连接错误什么意思,SMTP 错误(-1) :连接服务器失败
  19. 进行域名解析时,递归和迭代查询方式是什么意思?
  20. Camtasia Studio2021喀秋莎激活下载如何录制屏幕教程

热门文章

  1. 【Machine Learning】19.多分类实践:手写数字分类
  2. 【官方标准】- 交通运输领域元数据标准规范
  3. DHCP Lease Time - 动态 IP 使用时限
  4. 重庆邮电学院计算机学院彭凯,感知重邮丨重庆邮电大学计算机学院稳固学科建设核心 提升人才培养质量...
  5. 重建分区表主键 - Recreate Primary Key on a partition table
  6. 七夕情人节礼物:爱情花园 v3.2 bug
  7. android 余额宝收益列表,Android 仿支付宝中的余额宝收益进度条
  8. 2021-2027全球与中国单光束紫外可见分光光度计市场现状及未来发展趋势
  9. Vue - 全局组件之间传值(中间件传值)
  10. 【JDBC】The new driver class is `com.mysql.cj.jdbc.Driver‘. The driver is automatically