前文:循环神经网络——初学RNN https://blog.csdn.net/weixin_38522681/article/details/109129490

循环神经网络——RNN的训练算法:BPTT

  • 基本步骤
  • 前向计算
  • 误差项的计算
  • 权重梯度的计算
  • RNN的梯度爆炸和消失问题

基本步骤

BPTT算法是针对循环层的训练算法,它的基本原理和BP算法是一样的,也包含同样的三个步骤:

1.前向计算每个神经元的输出值;
2.反向计算每个神经元的误差项δj\delta_jδj​值,它是误差函数E对神经元j的加权输入netjnet_jnetj​的偏导数;
3.计算每个权重的梯度。
最后再用随机梯度下降算法更新权重。
循环层如下图所示:

前向计算

st=f(Uxt+Wst−1)s_t =f(Ux_t+Ws_{t-1}) st​=f(Uxt​+Wst−1​)
上面的sts_tst​、xtx_txt​、st−1s_{t-1}st−1​都是向量,用黑体字母表示;而U、V是矩阵,用大写字母表示。向量的下标表示时刻,例如,sts_tst​表示在t时刻向量s的值。
我们假设输入向量x的维度是m,输出向量s的维度是n,则矩阵U的维度是n×mn\times mn×m,矩阵W的维度是n×nn\times nn×n。下面是上式展开成矩阵的样子,看起来更直观一些:

在这里我们用手写体字母表示向量的一个元素,它的下标表示它是这个向量的第几个元素,它的上标表示第几个时刻。例如,sjts^t_jsjt​表示向量s的第j个元素在t时刻的值。ujiu_{ji}uji​表示输入层第i个神经元到循环层第j个神经元的权重。wjiw_{ji}wji​表示循环层第t-1时刻的第i个神经元到循环层第t个时刻的第j个神经元的权重。

误差项的计算

BTPP算法将第l层t时刻的误差项值δjl\delta^l_jδjl​沿两个方向传播,一个方向是其传递到上一层网络,得到δjl−1\delta^{l-1}_jδjl−1​,这部分只和权重矩阵U有关;另一个是方向是将其沿时间线传递到初始时刻t1t_1t1​,得到δj1\delta^1_jδj1​,这部分只和权重矩阵W有关。
我们用向量nettnet_tnett​表示神经元在t时刻的加权输入,因为:


我们用a表示列向量,用aTa^TaT表示行向量。两项结果为Jacobian矩阵:


其中,diag[a]表示根据向量a创建一个对角矩阵,即

两项结合,得

上式描述了将δ\deltaδ沿时间往前传递一个时刻的规律,有了这个规律,我们就可以求得任意时刻k的误差项δk\delta_kδk​

式3就是将误差项沿时间反向传播的算法。
循环层将误差项反向传递到上一层网络,与普通的全连接层是完全一样的
循环层的加权输入netlnet^lnetl与上一层netl−1net^{l-1}netl−1的加权输入关系如下:

所以,

式4就是将误差项传递到上一层算法。

权重梯度的计算

现在,我们终于来到了BPTT算法的最后一步:计算每个权重的梯度。

首先,我们计算误差函数E对权重矩阵W的梯度∂E∂W\frac{\partial E}{\partial W}∂W∂E​。

只要知道了任意一个时刻的误差项δt\delta_tδt​,以及上一个时刻循环层的输出值st−1s_{t-1}st−1​,就可以按照下面的公式求出权重矩阵在t时刻的梯度∇WtE\nabla _{W_t}E∇Wt​​E:

式5推导如下

因为对W求导与UxtUx_tUxt​无关,我们不再考虑。现在,我们考虑对权重项wjiw_{ji}wji​求导。通过观察上式我们可以看到wjiw_{ji}wji​只与netjtnet^t_jnetjt​有关,所以:

我们已经求得了权重矩阵W在t时刻的梯度∇WtE\nabla _{W_t}E∇Wt​​E,最终的梯度∇WE\nabla _WE∇W​E是各个时刻的梯度之和,详解见参考

式6就是计算循环层权重矩阵W的梯度的公式。

RNN的梯度爆炸和消失问题

实践中前面介绍的几种RNNs并不能很好的处理较长的序列。一个主要的原因是,RNN在训练中很容易发生梯度爆炸和梯度消失,这导致训练时梯度不能在较长序列中一直传递下去,从而使RNN无法捕捉到长距离的影响。

根据式3可得

上式的β\betaβ定义为矩阵的模的上界。因为上式是一个指数函数,如果t-k很大的话(也就是向前看很远的时候),会导致对应的误差项的值增长或缩小的非常快,这样就会导致相应的梯度爆炸梯度消失问题(取决于β\betaβ大于1还是小于1)。

通常来说,梯度爆炸更容易处理一些。因为梯度爆炸的时候,我们的程序会收到NaN错误。我们也可以设置一个梯度阈值,当梯度超过这个阈值的时候可以直接截取。

梯度消失更难检测,而且也更难处理一些。总的来说,我们有三种方法应对梯度消失问题:

  1. 合理的初始化权重值。初始化权重,使每个神经元尽可能不要取极大或极小值,以躲开梯度消失的区域。
  2. 使用relu代替sigmoid和tanh作为激活函数。
  3. 使用其他结构的RNNs,比如长短时记忆网络(LTSM)和Gated Recurrent Unit(GRU),这是最流行的做法。

参考: https://zybuluo.com/hanbingtao/note/541458

循环神经网络——RNN的训练算法:BPTT相关推荐

  1. keras 多层lstm_机器学习100天-Day2403 循环神经网络RNN(训练多层RNN)

    说明:本文依据<Sklearn 与 TensorFlow 机器学习实用指南>完成,所有版权和解释权均归作者和翻译成员所有,我只是搬运和做注解. 进入第二部分深度学习 第十四章循环神经网络 ...

  2. 深度学习 --- 循环神经网络RNN详解(BPTT)

    今天开始深度学习的最后一个重量级的神经网络即RNN,这个网络在自然语言处理中用处很大,因此需要掌握它,同时本人打算在深度学习总结完成以后就开始自然语言处理的总结,至于强化学习呢,目前不打算总结了,因为 ...

  3. 循环神经网络(RNN, Recurrent Neural Networks)介绍

    循环神经网络(RNN, Recurrent Neural Networks)介绍   循环神经网络(Recurrent Neural Networks,RNNs)已经在众多自然语言处理(Natural ...

  4. 循环神经网络(RNN)相关知识

    文章目录 RNN概述 前向传播公式 通过时间反向传播(BPTT) RNN确定序列长度方式 其他RNN结构 基于RNN的应用 1,序列数据的分析 2,序列数据的转换 3,序列数据的生成 RNN的不足 1 ...

  5. 第六章_循环神经网络(RNN)

    文章目录 第六章 循环神经网络(RNN) CNN和RNN的对比 http://www.elecfans.com/d/775895.html 6.1 为什么需要RNN? 6.1 RNN种类? RNN t ...

  6. 循环神经网络(RNN)知识入门

    循环神经网络(RNN)知识入门 原创:方云 一. RNN的发展历史 1986年,Elman等人提出了用于处理序列数据的循环神经网络(Recurrent Neural Networks).如同卷积神经网 ...

  7. 「NLP」 深度学习NLP开篇-循环神经网络(RNN)

    https://www.toutiao.com/a6714260714988503564/ 从这篇文章开始,有三AI-NLP专栏就要进入深度学习了.本文会介绍自然语言处理早期标志性的特征提取工具-循环 ...

  8. 【NLP】 深度学习NLP开篇-循环神经网络(RNN)

    从这篇文章开始,有三AI-NLP专栏就要进入深度学习了.本文会介绍自然语言处理早期标志性的特征提取工具-循环神经网络(RNN).首先,会介绍RNN提出的由来:然后,详细介绍RNN的模型结构,前向传播和 ...

  9. 花书+吴恩达深度学习(十五)序列模型之循环神经网络 RNN

    目录 0. 前言 1. RNN 计算图 2. RNN 前向传播 3. RNN 反向传播 4. 导师驱动过程(teacher forcing) 5. 不同序列长度的 RNN 如果这篇文章对你有一点小小的 ...

最新文章

  1. cron计划任务使用
  2. Hibernate的数据查找,添加!
  3. Crontab使用详解
  4. 枚举集合的EnumSet
  5. 接口入口在什么地方_弱电工程施工图审查要点?有哪些地方需要审核?审核要求是什么?...
  6. python面向对象三大特性_Python面向对象之多态原理与用法案例分析
  7. php无重复字符的最长子串,无重复字符的最长字串问题
  8. python start
  9. 小程序 长按转发_小程序转发分享
  10. json mysql乱码问题_读写json中文ASCII乱码问题的解决方法
  11. matlab 实用快捷键
  12. python实现sql盲注
  13. python去重算法_python去重算法
  14. Exp4 恶意代码分析 20154328 常城
  15. Python组合数据类型
  16. Ubuntu 3D桌面完全教程
  17. DO447协调滚动更新--委派任务和事实
  18. python 集合的基本操作
  19. 需求调研中要注意的三点
  20. 《Microduino实战》——2.7 总结

热门文章

  1. MySQL安装与卸载详细教程
  2. IDEA 中某个项目 pom 文件灰色且有删除线
  3. 【Linux命令】nohup执行脚本文件
  4. 多线程核心知识:原子性
  5. 求职陷阱:Lazarus组织以日本瑞穗銀行等招聘信息为诱饵的攻击活动分析
  6. 信息化 + 个性化再造学习生命力
  7. sqlserver 分组合并列_夺冠!中国队国际奥数大赛再称雄,满分选手已保送清华姚班,“中国二队”并列第一...
  8. java什么是递归_JAVA的递归是什么意思?
  9. java基础(n个人围成一圈)
  10. 为什么很多人赚不到钱?赚了钱又存不了钱呢