摘要:

在前面的文章里面,RNN训练与BP算法,我们提到了RNN的训练算法。但是回头看的时候在时间的维度上没有做处理,所以整个推导可能存在一点问题。

那么,在这篇文章里面,我们将介绍bptt(Back Propagation Through Time)算法如在训练RNN。

关于bptt

这里首先解释一下所谓的bptt,bptt的思路其实很简单,就是把整个RNN按时间的维度展开成一个“多层的神经网络”。具体来说比如下图:

既然RNN已经按时间的维度展开成一个看起来像多层的神经网络,这个时候用普通的bp算法就可以同样的计算,只不过这里比较复杂的是权重共享。比如上图中每一根线就是一个权重,而我们可以看到在RNN由于权重是共享的,所以三条红线的权重是一样的,这在运用链式法则的时候稍微比较复杂。

正文:

首先,和以往一样,我们先做一些定义。
hti=f(netthi)h_i^t=f(net_{hi}^t)

netthi=∑m(vimxtm)+∑s(uisht−1s)net_{hi}^t=\sum_m{(v_{im}x_m^t)}+\sum_s{(u_{is}h_s^{t-1})}

nettyk=∑mwkmhtmnet_{yk}^t=\sum_m{w_{km}h_m^t}
最后一层经过softmax的转化
otk=enettyk∑k′enettyk′o_k^t=\frac{e^{net_{yk}^t}}{\sum_{k'}{e^{net_{y{k'}}^t}}}
在这里我们使用交叉熵作为Loss Function
Et=−∑kztklnotkE_t=-\sum_k{z_k^tlno_k^t}

我们的任务同样也是求∂E∂wkm\left.\frac{\partial E}{\partial w_{km}}\right.、∂E∂vim\left.\frac{\partial E}{\partial v_{im}}\right.、∂E∂uim\left.\frac{\partial E}{\partial u_{im}}\right.。
注意,这里的EE没有时间的下标。因为在RNN里,这些梯度分别为各个时刻的梯度之和。
即:
∂E∂wkm=∑stept=0∂Et∂wkm\left.\frac{\partial E}{\partial w_{km}}\right.=\sum_{t=0}^{step}\left.\frac{\partial E_t}{\partial w_{km}}\right.
∂E∂vim=∑stept=0∂Et∂vim\left.\frac{\partial E}{\partial v_{im}}\right.=\sum_{t=0}^{step}\left.\frac{\partial E_t}{\partial v_{im}}\right.
∂E∂uim=∑stept=0∂Et∂uim\left.\frac{\partial E}{\partial u_{im}}\right.=\sum_{t=0}^{step}\left.\frac{\partial E_t}{\partial u_{im}}\right.。

所以下面我们推导的是∂Et∂wkm\left.\frac{\partial E_t}{\partial w_{km}}\right.、∂Et∂vim\left.\frac{\partial E_t}{\partial v_{im}}\right.、∂Et∂uim\left.\frac{\partial E_t}{\partial u_{im}}\right.。

我们先推导∂Et∂wkm\left.\frac{\partial E_t}{\partial w_{km}}\right.。
∂Et∂wkm=∑k′∂Et∂otk′∂otk′∂nettyk∂nettyk∂wkm=(otk−ztk)∗htm\left.\frac{\partial E_t}{\partial w_{km}}\right.=\sum_{k'}{\left.\frac{\partial E_t}{\partial o_{k'}^t}\right.\left.\frac{\partial o_{k'}^t}{\partial net_{yk}^t}\right.\left.\frac{\partial net_{yk}^t}{\partial w_{km}}\right.}=(o_k^t-z_k^t)*h_m^t。(这一部分的推导在前面的文章已经讨论过了)。
在这里,记误差信号:
δ(output,t)k=∂Et∂nettyk=∑k′∂Et∂otk′∂otk′∂nettyk=(otk−ztk)\delta_k^{(output,t)}=\left.\frac{\partial E_t}{\partial net_{yk}^t}\right.=\sum_{k'}{\left.\frac{\partial E_t}{\partial o_{k'}^t}\right.\left.\frac{\partial o_{k'}^t}{\partial net_{yk}^t}\right.}=(o_k^t-z_k^t)(后面会用到)

对于∂Et∂vim\left.\frac{\partial E_t}{\partial v_{im}}\right.、∂Et∂uim\left.\frac{\partial E_t}{\partial u_{im}}\right.其实是差不多的,所以这里详细介绍其中一个。这两个导数也是RNN里面最复杂的。

推导:∂Et∂vim\left.\frac{\partial E_t}{\partial v_{im}}\right.

∂Et∂vim=∑tt′=0∂Et∂nett′hi∂nett′hi∂vim\left.\frac{\partial E_t}{\partial v_{im}}\right.=\sum_{t'=0}^{t}{\left.\frac{\partial E_{t}}{\partial net_{hi}^{t'}}\right.\left.\frac{\partial net_{hi}^{t'}}{\partial v_{im}}\right.}
对于这个式子第一次看可能有点懵逼,这里稍微解释一下:
从式:hti=f(∑m(vimxtm)+∑s(uisht−1s))h_i^t=f(\sum_m{(v_{im}x_m^t)}+\sum_s{(u_{is}h_s^{t-1})})中我们可以看到,vimv_{im}影响的是所有时刻的netthi,t=0,1,2,....stepnet_{hi}^{t},t=0,1,2,....step。所以当EtE_t对vimv_{im}求偏导的时候,由于链式法则需要考虑到所有时刻的netthinet_{hi}^{t}。

下面分成两部分来求∂Et∂nett′hi\left.\frac{\partial E_{t}}{\partial net_{hi}^{t'}}\right.,∂nett′hi∂vim.\left.\frac{\partial net_{hi}^{t'}}{\partial v_{im}}\right..。
第一部分:∂Et∂nett′hi\left.\frac{\partial E_{t}}{\partial net_{hi}^{t'}}\right.。
这里我们记δ(t′,t)i=∂Et∂nett′hi\delta_i^{(t',t)}=\left.\frac{\partial E_{t}}{\partial net_{hi}^{t'}}\right.(误差信号,和前面文章一样)。


(由于带着符号去求这两个导数会让人看起来非常懵逼,所以下面指定具体的值,后面抽象给出通式)
假设共3个时刻,即t=0,1,2。
对于t=2t=2,t′=2{t'}=2时:
(E2E_{2}表示第2个时刻(也是最后一个时刻)的误差)
(net2hinet_{hi}^{2}表示第2个时刻隐藏层第i个神经元的净输入)
具体来说:∂E2∂net2hi=∂E2∂h2i∂h2i∂net2hi\left.\frac{\partial E_{2}}{\partial net_{hi}^{2}}\right.=\left.\frac{\partial E_{2}}{\partial h_i^2}\right.\left.\frac{\partial h_i^2}{\partial net_{hi}^{2}}\right.

对于∂E2∂h2i=∑k′∂E2∂net2yk′∂net2yk′∂h2i\left.\frac{\partial E_{2}}{\partial h_i^2}\right.=\sum_{k'}{\left.\frac{\partial E_{2}}{\partial net_{yk'}^2}\right.\left.\frac{\partial net_{yk'}^2}{\partial h_i^2}\right.}
由于δ(output,t)k=∂Et∂nettyk\delta_k^{(output,t)}=\left.\frac{\partial E_t}{\partial net_{yk}^t}\right.
所以,我们有:
∂E2∂h2i=∑k′∂E2∂net2yk′∂net2yk′∂h2i=∑k′δ(output,2)k′∂net2yk′∂h2i=∑k′δ(output,2)k′wk′i\left.\frac{\partial E_{2}}{\partial h_{i}^{2}}\right.=\sum_{k'}{\left.\frac{\partial E_{2}}{\partial net_{yk'}^2}\right.\left.\frac{\partial net_{yk'}^2}{\partial h_i^2}\right.}=\sum_{k'}{\delta_{k’}^{(output,2)}\left.\frac{\partial net_{yk'}^2}{\partial h_i^2}\right.}=\sum_{k'}{\delta_{k’}^{(output,2)}w_{k'i}}
综上:
δ(2,2)i=∂E2∂net2hi=∂E2∂h2i∂h2i∂net2hi=(∑k′δ(output,2)k′wk′i)∗f′(net2hi)\delta_i^{(2,2)}=\left.\frac{\partial E_{2}}{\partial net_{hi}^{2}}\right.=\left.\frac{\partial E_{2}}{\partial h_i^2}\right.\left.\frac{\partial h_i^2}{\partial net_{hi}^{2}}\right.=(\sum_{k'}{\delta_{k’}^{(output,2)}w_{k'i}})*f{'}(net_{hi}^2)

对于t=1t=1,t′=2{t'}=2时:
(E2E_{2}表示第2个时刻的误差)
(net1hinet_{hi}^1表示第1个时刻隐藏层第i个神经元的净输入)
具体来说:∂E2∂net1hi=∂E2∂h1i∂h1i∂net1hi\left.\frac{\partial E_{2}}{\partial net_{hi}^{1}}\right.=\left.\frac{\partial E_{2}}{\partial h_i^1}\right.\left.\frac{\partial h_i^1}{\partial net_{hi}^{1}}\right.
那么∂E2∂h1i=∑k′∂E2∂net1yk′∂net1yk′∂h1i+∑j∂E2∂net2hj∂net2hj∂h1i\left.\frac{\partial E_{2}}{\partial h_{i}^{1}}\right.=\sum_{k'}{\left.\frac{\partial E_{2}}{\partial net_{yk'}^1}\right.\left.\frac{\partial net_{yk'}^1}{\partial h_i^1}\right.}+\sum_{j}{\left.\frac{\partial E_{2}}{\partial net_{hj}^2}\right.\left.\frac{\partial net_{hj}^2}{\partial h_{i}^{1}}\right.}。请对比这个式子和上面t=2t=2,t′=2{t'}=2时的区别,区别在于多了一项∑j∂E2∂net2hj∂net2hj∂h1i\sum_{j}{\left.\frac{\partial E_{2}}{\partial net_{hj}^2}\right.\left.\frac{\partial net_{hj}^2}{\partial h_{i}^{1}}\right.}。这个原因我们已经在RNN与bp算法中讨论过,这里简单的说就是由于t=1t=1时刻有t=2t=2时刻反向传播回来的误差,所以要考虑上这一项,但是对于t=2t=2已经是最后一个时刻了,没有反向传播回来的误差。

对于第一项∑k′∂E2∂net1yk′∂net1yk′∂h1i\sum_{k'}{\left.\frac{\partial E_{2}}{\partial net_{yk'}^1}\right.\left.\frac{\partial net_{yk'}^1}{\partial h_i^1}\right.}其实是0。下面简单分析下原因:
上式进一步可以化为:∑k′(∑k″∂E2∂o1k″∂o1k″∂net1yk′)∂net1yk′∂h1i\sum_{k'}(\sum_{k''}{\left.\frac{\partial E_{2}}{\partial o_{k''}^1}\right.\left.\frac{\partial o_{k''}^1}{\partial net_{yk'}^1}\right.})\left.\frac{\partial net_{yk'}^1}{\partial h_i^1}\right.而E2E_2与第1个时刻输出o1k″o_{k''}^{1}无关。所以为0。

对于第二项∑j∂E2∂net2hj∂net2hj∂h1i\sum_{j}{\left.\frac{\partial E_{2}}{\partial net_{hj}^2}\right.\left.\frac{\partial net_{hj}^2}{\partial h_{i}^{1}}\right.},我们带入δ(t′,t)i=∂Et∂nett′hi\delta_i^{(t',t)}=\left.\frac{\partial E_{t}}{\partial net_{hi}^{t'}}\right.有:
∑j∂E2∂net2hj∂net2hj∂h1i=∑jδ(2,2)j∂net2hj∂h1i\sum_{j}{\left.\frac{\partial E_{2}}{\partial net_{hj}^2}\right.\left.\frac{\partial net_{hj}^2}{\partial h_{i}^{1}}\right.}=\sum_{j}{\delta_j^{(2,2)}\left.\frac{\partial net_{hj}^2}{\partial h_{i}^{1}}\right.}。
同时明显有∂net2hj∂h1i=uji\left.\frac{\partial net_{hj}^2}{\partial h_{i}^{1}}\right.=u_{ji}
即:∂E2∂h1i=∑jδ(2,2)juji\left.\frac{\partial E_{2}}{\partial h_{i}^{1}}\right.=\sum_{j}{\delta_j^{(2,2)}u_{ji}}

综上:
δ(1,2)i=∂E2∂net1hi=∂E2∂h1i∂h1i∂net1hi=(∑jδ(2,2)j∂net2hj∂h1i)∗f′(net1hi)=(∑jδ(2,2)juji)∗f′(net1hi)\delta_i^{(1,2)}=\left.\frac{\partial E_{2}}{\partial net_{hi}^{1}}\right.=\left.\frac{\partial E_{2}}{\partial h_i^1}\right.\left.\frac{\partial h_i^1}{\partial net_{hi}^{1}}\right.=(\sum_{j}{\delta_j^{(2,2)}\left.\frac{\partial net_{hj}^2}{\partial h_{i}^{1}}\right.})*f{'}(net_{hi}^1)=(\sum_{j}{\delta_j^{(2,2)}u_{ji}})*f{'}(net_{hi}^1)

对于t=0t=0,t′=2{t'}=2时:
(E2E_{2}表示第2个时刻的误差)
(net0hinet_{hi}^0表示第0个时刻隐藏层第i个神经元的净输入)。
和上面的思路一样,我们容易得到:
δ(0,2)i=∂E2∂net0hi=(∑jδ(1,2)juji)∗f′(net0hi)\delta_i^{(0,2)}=\left.\frac{\partial E_{2}}{\partial net_{hi}^0}\right.=(\sum_{j}{\delta_j^{(1,2)}u_{ji})*f{'}(net_{hi}^0)}。

至此,我们求完了∂Et∂nett′hi\left.\frac{\partial E_{t}}{\partial net_{hi}^{t'}}\right.。下面我们来总结一下其通式:

∂Et∂nett′hi=δ(t′,t)i={(∑k′δ(output,t)k′wk′i)∗f′(nett′hi),(∑jδ(t′+1,t)juji)∗f′(nett′hi),t=t′t≠t′

\left.\frac{\partial E_{t}}{\partial net_{hi}^{t'}}\right.=\delta_i^{(t',t)}=\begin{cases} (\sum_{k'}{\delta_{k’}^{(output,t)}w_{k'i}})*f{'}(net_{hi}^{t'}), & t=t'\\ (\sum_{j}{\delta_j^{(t'+1,t)}u_{ji})*f{'}(net_{hi}^{t'})}, & t\neq t' \end{cases}

另外,对于δ(output,t)k\delta_k^{(output,t)}有以下表达式:
δ(output,t)k=∂Et∂nettyk=∑k′∂Et∂otk′∂otk′∂nettyk=(otk−ztk)\delta_k^{(output,t)}=\left.\frac{\partial E_t}{\partial net_{yk}^t}\right.=\sum_{k'}{\left.\frac{\partial E_t}{\partial o_{k'}^t}\right.\left.\frac{\partial o_{k'}^t}{\partial net_{yk}^t}\right.}=(o_k^t-z_k^t)


最后只要求出∂nett′hi∂vim\left.\frac{\partial net_{hi}^{t'}}{\partial v_{im}}\right.,其值具体为∂nett′hi∂vim=xtm\left.\frac{\partial net_{hi}^{t'}}{\partial v_{im}}\right.=x_m^t

最后,对于∂Et∂uim\left.\frac{\partial E_t}{\partial u_{im}}\right.其实和上面的差不多,主要是后面的部分不一样,具体来说:
∂Et∂uim=∑tt′=0∂Et∂nett′hi∂nett′hi∂uim\left.\frac{\partial E_t}{\partial u_{im}}\right.=\sum_{t'=0}^{t}{\left.\frac{\partial E_{t}}{\partial net_{hi}^{t'}}\right.\left.\frac{\partial net_{hi}^{t'}}{\partial u_{im}}\right.},可以看到就只有等式右边的第二项不一样,关键部分是一样的。∂nett′hi∂uim=ht′−1m\left.\frac{\partial net_{hi}^{t'}}{\partial u_{im}}\right.=h_m^{t'-1}

细节-1

上面提到,当只有3个时刻时,t=0,1,2。
对于误差E2E_2(最后一个时刻的误差),没有再下一个时刻反向传回的误差。
那么对于E1E_1(第1个时刻的误差)存在下一个时刻反向传回的误差,但是在∂E1∂h1i\left.\frac{\partial E_1}{\partial h_i^1}\right.中的第二项∑j∂E1∂net2hj∂net2hj∂h1i\sum_{j}{\left.\frac{\partial E_{1}}{\partial net_{hj}^{2}}\right.\left.\frac{\partial net_{hj}^{2}}{\partial h_{i}^{1}}\right.}仍然为0。是因为∂E1∂net2hj=0\left.\frac{\partial E_{1}}{\partial net_{hj}^{2}}\right.=0,因为E1E_1的误差和下一个时刻隐藏层的输出没有任何关系。

总结

看起来bptt和我们之前讨论的bp本质上是一样的,只是在一些细节的处理上由于权重共享的原因有所不同,但是基本上还是一样的。

下面这篇文章是有一个简单的rnn代码,大家可以参考一下
参考文章1
代码的bptt中每一步的迭代公式其实就是上面的公式。希望对大家有帮助~

RNN-bptt简单推导相关推荐

  1. RNN BPTT算法推导

    目录 BPTT算法推导 注1:激活函数tanh(x)求导 注2 softmax求导 BPTT算法推导 对于一个普通的RNN来说,其前向传播过程为: 先介绍一下等下计算过程中会用到的偏导数:   关于t ...

  2. RNN BPTT算法详细推导

    BPTT算法推导 BPTT全称:back-propagation through time.这里以RNN为基础,进行BPTT的推导. BPTT的推导比BP算法更难,同时所涉及的数学知识更多,主要用到了 ...

  3. DL之RNN:基于TF利用RNN实现简单的序列数据类型(DIY序列数据集)的二分类(线性序列随机序列)

    DL之RNN:基于TF利用RNN实现简单的序列数据类型(DIY序列数据集)的二分类(线性序列&随机序列) 目录 序列数据类型&输出结果 设计思路 序列数据类型&输出结果 1.t ...

  4. 多普勒效应及多普勒频移的简单推导

    多普勒效应及多普勒频移的简单推导 fd≡fR−fT(1)f_d\equiv f_R-f_T \tag{1} fd​≡fR​−fT​(1)   式中,fdf_dfd​表示多普勒频移,fRf_RfR​表示 ...

  5. 深度学习中反向传播算法简单推导笔记

    反向传播算法简单推导笔记 1.全连接神经网络 该结构的前向传播可以写成: z(1)=W(1)x+b(1)z^{(1)} = W^{(1)}x+b^{(1)}z(1)=W(1)x+b(1) a(1)=σ ...

  6. 【应试技巧】格林公式记忆方法及简单推导

    视频讲解:格林公式记忆方法及简单推导 大家在学格林公式的时候会发现其实书本上给的形式并不容易记忆. 大家可能会产生下述的问题 忘记了逆时针和顺时针哪个是正方向? 忘记了P,Q该对谁求偏导? 忘记了求偏 ...

  7. RNN中BPTT的推导和可能的问题

    最近开始啃LSTM,发现BPTT这块还是不是很清晰,结合RNN,把这块整理整理 RNN 前馈神经网络(feedforward neural networks)如下图所示(这块内容可见我的博客神经网络B ...

  8. 贝叶斯公式的理解及简单推导

    1. 贝叶斯概念理解 如下为贝叶斯公式,其中P(A)为事件A的先验概率,P(A|B)为事件A的后验概率,且后验概率的计算融合了先验概率的值. P ( A ∣ B ) = P ( B ∣ A ) ∗ P ...

  9. 傅里叶级数的理解与简单推导

    本文重点在于理解傅里叶级数的思想,对于公式的推导仅作泛泛而谈,在正式推导公式之前我想先简单的谈一谈对函数进行傅里叶级数展开有何意义,也即为什么我们要将一个函数进行傅氏变换? 分解与合成是一种朴素的思想 ...

最新文章

  1. 2022-2028年中国滑雪产业投资分析及前景预测报告(全卷)
  2. 解决set /p yn= 接受键盘输入导致ECHO 处于关闭状态的问题
  3. 正则表达式引擎执行原理——从未如此清晰!
  4. JS -- Unexpected trailing comma
  5. 二叉树的后序遍历—leetcode145
  6. 23种设计模式之《单例模式》
  7. SpringCloud Greenwich(一)注册中心之nacos、Zuul和 gateway网关配置
  8. 快期末考试了好烦躁啊来写点东西
  9. 如何将ListT转换相应的Html(xsl动态转换)(一)
  10. oracle19c 安装权限_Oracle 数据库安装系列一:19C 软件安装和补丁升级
  11. GnuTLS传输层安全性库
  12. Http通信(HttpClient)
  13. 二总线芯片RF601
  14. Invenio 数字图书馆框架
  15. python match函数返回值_Python中re.match函数起什么作用呢?
  16. IJCAI 2021 投稿安排出来了!新审稿机制体验一下?
  17. 【C语言】解决 “address of stack memory associated with local variable ‘num‘ returned”
  18. 菜鸟的mongoDB学习---(二)MongoDB 数据库,对象,集合
  19. Top10 ProxyClient 支持指定进程的代理客户端软件
  20. 【21SR】Designing a Practical Degradation Model for Deep BlindImage Super-Resolution

热门文章

  1. 有了这篇文章, Python 中的编码不再是噩梦
  2. MySQL安装与卸载详细教程
  3. IDC发布视频云市场报告:公有云厂商占主导,腾讯云持续领跑
  4. 树莓派变身影音播放器
  5. 微信小程序 onReachBottom 上拉触底加载 没触发
  6. 【学生作业】数字图像处理之MATLAB大作业:自制图像处理小工具
  7. OEM简介及按钮乱码问题
  8. C++根据点的类构造线三角形类,并测试(江苏大学平时作业)
  9. OpenGL编译着色器
  10. 请求方法+super+枚举+包装类+正则表达式+学习资料