2. LSTM

2.1. 长期依赖问题

标准 RNN 结构在理论上完全可以实现将最初的信息保留到即使很远的时刻,但是在实践中发现 RNN 会受到短时记忆的影响。如果一条序列足够长,那它们将很难将信息从较早的时刻传送到后面的时刻。 因此,如果正在尝试处理一段文本进行预测,RNN 可能从一开始就会遗漏重要信息。比如我们尝试预测 “I grew up in France … I speak fluent French” 这句话的最后一个词 ”French“ 。当前的信息(“I speak fluent”)表明接下来的单词是很可能是语言的名字。但是需要哪种语言,我们就要根据离当前位置很远的 “France” 来确定。这就说明相关信息和当前预测词的位置之间的间隔可能非常大,随着这间隔不断变大,RNN 就会失去学习连接如此远的信息的能力。 这就是我们上面提到的 RNN 最致命的缺点。

为了解决这个问题,提出了 LSTM 。

2.2. 网络结构

LSTM 属于 RNN 的扩展模型,二者的区别仅在于每个单元内部结构不同。LSTM 单元结构如下。

其中,黄色矩形表示一层神经网络,包含权重和激活函数,矩形中的符号表明激活函数的类型,σ\sigmaσ 对应 sigmoid 函数,tanh\rm tanhtanh 对应 tanh 函数;粉色(椭)圆表示逐元素操作,比如粉色(椭)圆中为乘号表明矩阵进行对应元素相乘(点乘)操作,tanh\rm tanhtanh 表明进行逐元素取 tanh 值。

下图展示了 LSTM 单元的完整前向传播过程。

从”遗忘门“、”输入门“和”输出门“,这三个”门“的角度来理解 LSTM 单元。

之所以称之为”门“,是考虑到生活中的”门“存在”开/闭“两种状态。LSTM 单元中的”门“也是存在”开/闭“两种状态,”开“表示全部(绝大部分)信息都可以经过”门“流出,”闭“表示全部(绝大部分)信息都不能经过”门“流出,而是被”门“过滤掉。由于 sigmoid 函数非常适合二分类,所以该函数在 LSTM 单元中起到”门“过滤的作用,用于控制信息是否流出(流出量)。

  • 遗忘门

    ”遗忘门“决定了前一个单元的状态 ct−1c_{t-1}ct−1​ 有多少信息保留到当前单元状态 ctc_tct​ 中。对应图中过程 [ht−1,xt]→ft[h_{t-1},x_t]\rightarrow f_t[ht−1​,xt​]→ft​ 。

  • 输入门

    ”输入门“决定了当前单元的输入 xtx_txt​ 有多少信息保存到单元状态 ctc_tct​ 。对应图中过程 [ht−1,xt]→it[h_{t-1},x_t]\rightarrow i_t[ht−1​,xt​]→it​ 。

  • 输出门

    ”输出门“用于控制当前单元的状态 ctc_tct​ 有多少信息输出到当前输出值 hth_tht​ 。对应图中过程 [ht−1,xt]→ot[h_{t-1},x_t]\rightarrow o_t[ht−1​,xt​]→ot​ 。​

模型单元的思想可以理解为,[ht−1,xt][h_{t-1},x_t][ht−1​,xt​] 经过遗忘门确定保留多少前一个单元的信息,ct−1c_{t-1}ct−1​ 和 σ(Wxfxt+Whfht−1+bf)\sigma(W_{xf}x_t+W_{hf}h_{t-1}+b_f)σ(Wxf​xt​+Whf​ht−1​+bf​) 按位点乘实现筛选出要保留的信息;σ(Wxixt+Whiht−1+bi)\sigma(W_{xi}x_t+W_{hi}h_{t-1}+b_i)σ(Wxi​xt​+Whi​ht−1​+bi​) 和 tanh(Wxgxt+Whght−1+bg){\rm tanh}(W_{xg}x_t+W_{hg}h_{t-1}+b_g)tanh(Wxg​xt​+Whg​ht−1​+bg​) 按位点乘实现从外部输入信息 xtx_txt​ 中筛选出需要保留的信息,过滤到无用信息;将保留的原始信息和保留的外部信息按位相加,得到当前单元包含的信息 ctc_tct​; tanh(ct){\rm tanh}(c_t)tanh(ct​) 用于将每个单元的信息统一到一定范围内,再与 σ(Wxoxt+Whoht−1+bo)\sigma(W_{xo}x_t+W_{ho}h_{t-1}+b_o)σ(Wxo​xt​+Who​ht−1​+bo​) 按位点乘筛选出当前单元的全部信息中可以用于评估单元优劣的信息 hth_tht​,对全部 hth_tht​ 进一步处理可以得到用于评估模型优劣的损失函数,同时也会直接传入到下一个单元,循环往复。

总结一下,整个流程是分为三个大部分,对应着三个”门“的操作。遗忘门部分筛选有用的内部信息,输入门筛选有用的外部信息,将两部分信息整合,输出门筛选用于评估单元优劣的信息。可以看到,每次的筛选操作都是通过 sigmoid 函数对 [ht−1,xt][h_{t-1},x_t][ht−1​,xt​] 的线性映射进行非线性激活完成的。

2.3. 前向传播与反向传播

  • 前向传播

    前面已经讲解了。

  • 反向传播

    还是以计算图的形式说明反向传播过程。存在如下公式:
    ft=σ(Wxfxt+Whfht−1+bf)it=σ(Wxixt+Whiht−1+bi)gt=tanh(Wxgxt+Whght−1+bg)ot=σ(Wxoxt+Whoht−1+bo)ct=ct−1⊙ft+gt⊙itht=tanh(ct)⊙otL=∑loss(ht,yt)\begin{align} f_t&=\sigma(W_{xf}x_t+W_{hf}h_{t-1}+b_f) \tag{2.1}\\ i_t&=\sigma(W_{xi}x_t+W_{hi}h_{t-1}+b_i) \tag{2.2}\\ g_t&={\rm tanh}(W_{xg}x_t+W_{hg}h_{t-1}+b_g) \tag{2.3}\\ o_t&=\sigma(W_{xo}x_t+W_{ho}h_{t-1}+b_o) \tag{2.4}\\ c_t&=c_{t-1}\odot f_t+g_t\odot i_t \tag{2.5}\\ h_t&={\rm tanh}(c_t)\odot o_t \tag{2.6} \\ L&=\sum loss(h_t,y_t) \tag{2.7} \\ \end{align} ft​it​gt​ot​ct​ht​L​=σ(Wxf​xt​+Whf​ht−1​+bf​)=σ(Wxi​xt​+Whi​ht−1​+bi​)=tanh(Wxg​xt​+Whg​ht−1​+bg​)=σ(Wxo​xt​+Who​ht−1​+bo​)=ct−1​⊙ft​+gt​⊙it​=tanh(ct​)⊙ot​=∑loss(ht​,yt​)​(2.1)(2.2)(2.3)(2.4)(2.5)(2.6)(2.7)​
    一个单元的计算图如下。灰色框圈出的是一个单元涉及的计算关系,其他单元都可以类似地画出。

    我们引入 Lt=loss(ht,yt)(t=1,2,…,T)L_t=loss(h_t,y_t)\space (t=1,2,\dots,T)Lt​=loss(ht​,yt​) (t=1,2,…,T) ,因此 LLL 可以表示为 L=∑t=1TLtL=\sum_{t=1}^TL_tL=∑t=1T​Lt​ 。反向传播过程如下。

    以计算 ∂L∂Whf\frac{\partial L}{\partial W_{hf}}∂Whf​∂L​ 为例推导公式,其他参数类似,推导的思路是根据反向传播过程按顺序推导每个结点代表的链式偏导。

    考虑最特别的 TTT 时刻,计算出损失函数(值)关于 TTT 时刻各个变量的偏导
    ∂L∂LT∂L∂hT=∂L∂LT∂LT∂hT∂L∂oT=∂L∂hT∂hT∂oT∂L∂cT=∂L∂hT∂hT∂cT∂L∂fT=∂L∂cT∂cT∂fT∂L∂iT=∂L∂cT∂cT∂iT∂L∂gT=∂L∂cT∂cT∂gT∂L∂Whf⟨T⟩=∂L∂fT∂fT∂Whf+∂L∂iT∂iT∂Whf+∂L∂gT∂gT∂Whf\begin{align} \frac{\partial L}{\partial L_T} &\notag \\\notag \\ \frac{\partial L}{\partial h_T} &= \frac{\partial L}{\partial L_T}\frac{\partial L_T}{\partial h_T} \notag\\\notag \\ \frac{\partial L}{\partial o_T} &= \frac{\partial L}{\partial h_T} \frac{\partial h_T}{\partial o_T} \notag \\\notag \\ \frac{\partial L}{\partial c_T} &= \frac{\partial L}{\partial h_T} \frac{\partial h_T}{\partial c_T} \notag \\\notag \\ \frac{\partial L}{\partial f_T}&=\frac{\partial L}{\partial c_T}\frac{\partial c_T}{\partial f_T} \notag\\\notag \\ \frac{\partial L}{\partial i_T}&=\frac{\partial L}{\partial c_T}\frac{\partial c_T}{\partial i_T} \notag\\\notag \\ \frac{\partial L}{\partial g_T}&=\frac{\partial L}{\partial c_T}\frac{\partial c_T}{\partial g_T} \notag\\\notag \\ \frac{\partial L}{\partial W_{hf}^{\left\langle T \right\rangle}} &= \frac{\partial L}{\partial f_T} \frac{\partial f_T}{\partial W_{hf}} + \frac{\partial L}{\partial i_T} \frac{\partial i_T}{\partial W_{hf}}+\frac{\partial L}{\partial g_T} \frac{\partial g_T}{\partial W_{hf}} \notag \end{align} ∂LT​∂L​∂hT​∂L​∂oT​∂L​∂cT​∂L​∂fT​∂L​∂iT​∂L​∂gT​∂L​∂Whf⟨T⟩​∂L​​=∂LT​∂L​∂hT​∂LT​​=∂hT​∂L​∂oT​∂hT​​=∂hT​∂L​∂cT​∂hT​​=∂cT​∂L​∂fT​∂cT​​=∂cT​∂L​∂iT​∂cT​​=∂cT​∂L​∂gT​∂cT​​=∂fT​∂L​∂Whf​∂fT​​+∂iT​∂L​∂Whf​∂iT​​+∂gT​∂L​∂Whf​∂gT​​​

    其中,∂L∂Whf⟨T⟩\frac{\partial L}{\partial W_{hf}^{\left\langle T \right\rangle}}∂Whf⟨T⟩​∂L​ 表示 TTT 时刻对损失函数(值)关于 WhfW_{hf}Whf​ 偏导的贡献,满足 ∂L∂Whf=∑t=1T∂L∂Whf⟨t⟩\frac{\partial L}{\partial W_{hf}} = \sum\limits_{t=1}^T \frac{\partial L}{\partial W_{hf}^{\left\langle t \right\rangle}}∂Whf​∂L​=t=1∑T​∂Whf⟨t⟩​∂L​ 。

    根据式 (2.1)∼(2.7)(2.1)\sim (2.7)(2.1)∼(2.7) 将上面各式计算出来。TTT 时刻各个变量的偏导总结如下。
    ∂L∂LT=1∂L∂hT=∂LT∂hT∂L∂oT=∂LT∂hTtanh(cT)∂L∂cT=∂LT∂hToTtanh′(⋅)∂L∂fT=∂LT∂hToTtanh′(⋅)ct−1∂L∂iT=∂LT∂hToTtanh′(⋅)gT∂L∂gT=∂LT∂hToTtanh′(⋅)iT∂L∂Whf⟨T⟩=∂L∂fT∂fT∂Whf=∂L∂hToTtanh′(⋅)cT−1σ′(⋅)hT−1\begin{align} \frac{\partial L}{\partial L_T} &=1\notag \\\notag \\ \frac{\partial L}{\partial h_T} &= \frac{\partial L_T}{\partial h_T} \notag\\\notag \\ \frac{\partial L}{\partial o_T} &= \frac{\partial L_T}{\partial h_T} {\rm tanh}(c_T) \notag \\\notag \\ \frac{\partial L}{\partial c_T} &= \frac{\partial L_T}{\partial h_T}o_T{\rm tanh'(·)} \notag \\\notag \\ \frac{\partial L}{\partial f_T}&=\frac{\partial L_T}{\partial h_T}o_T {\rm tanh'(·)}c_{t-1} \notag\\\notag \\ \frac{\partial L}{\partial i_T}&=\frac{\partial L_T}{\partial h_T}o_T {\rm tanh'(·)}g_T \notag\\\notag \\ \frac{\partial L}{\partial g_T}&=\frac{\partial L_T}{\partial h_T}o_T {\rm tanh'(·)}i_T \notag\\\notag \\ \frac{\partial L}{\partial W_{hf}^{\left\langle T \right\rangle}} &= \frac{\partial L}{\partial f_T} \frac{\partial f_T}{\partial W_{hf}} =\frac{\partial L}{\partial h_T}o_T {\rm tanh'(·)}c_{T-1}\sigma'(·) h_{T-1} \notag \end{align} ∂LT​∂L​∂hT​∂L​∂oT​∂L​∂cT​∂L​∂fT​∂L​∂iT​∂L​∂gT​∂L​∂Whf⟨T⟩​∂L​​=1=∂hT​∂LT​​=∂hT​∂LT​​tanh(cT​)=∂hT​∂LT​​oT​tanh′(⋅)=∂hT​∂LT​​oT​tanh′(⋅)ct−1​=∂hT​∂LT​​oT​tanh′(⋅)gT​=∂hT​∂LT​​oT​tanh′(⋅)iT​=∂fT​∂L​∂Whf​∂fT​​=∂hT​∂L​oT​tanh′(⋅)cT−1​σ′(⋅)hT−1​​

    当 t=1,2,…,T−1t=1,2,\dots,T-1t=1,2,…,T−1 时,计算出损失函数(值)关于 ttt 时刻刻个变量的偏导
    ∂L∂Lt∂L∂ht=∂L∂Lt∂Lt∂ht+∂L∂ot+1∂ot+1∂ht+∂L∂ft+1∂ft+1∂ht+∂L∂it+1∂it+1∂ht+∂L∂gt+1∂gt+1∂ht∂L∂ot=∂L∂ht∂ht∂ot∂L∂ct=∂L∂ht∂ht∂ct+∂L∂ct+1∂ct+1∂ct∂L∂ft=∂L∂ct∂ct∂ft∂L∂it=∂L∂ct∂ct∂it∂L∂gt=∂L∂ct∂ct∂gt∂L∂Whf⟨t⟩=∂L∂ft∂ft∂Whf+∂L∂it∂it∂Whf+∂L∂gt∂gt∂Whf\begin{align} \frac{\partial L}{\partial L_t} &\notag \\\notag \\ \frac{\partial L}{\partial h_t} &= \frac{\partial L}{\partial L_t}\frac{\partial L_t}{\partial h_t} + \frac{\partial L}{\partial o_{t+1}}\frac{\partial o_{t+1}}{\partial h_t} +\frac{\partial L}{\partial f_{t+1}} \frac{\partial f_{t+1}}{\partial h_{t}} + \frac{\partial L}{\partial i_{t+1}} \frac{\partial i_{t+1}}{\partial h_{t}}+\frac{\partial L}{\partial g_{t+1}} \frac{\partial g_{t+1}}{\partial h_{t}} \notag\\\notag \\ \frac{\partial L}{\partial o_t} &= \frac{\partial L}{\partial h_t} \frac{\partial h_t}{\partial o_t} \notag \\\notag \\ \frac{\partial L}{\partial c_t} &= \frac{\partial L}{\partial h_t} \frac{\partial h_t}{\partial c_t} + \frac{\partial L}{\partial c_{t+1}} \frac{\partial c_{t+1}}{\partial c_t} \notag \\\notag \\ \frac{\partial L}{\partial f_t}&=\frac{\partial L}{\partial c_t}\frac{\partial c_t}{\partial f_t} \notag\\\notag \\ \frac{\partial L}{\partial i_t}&=\frac{\partial L}{\partial c_t}\frac{\partial c_t}{\partial i_t} \notag\\\notag \\ \frac{\partial L}{\partial g_t}&=\frac{\partial L}{\partial c_t}\frac{\partial c_t}{\partial g_t} \notag\\\notag \\ \frac{\partial L}{\partial W_{hf}^{\left\langle t \right\rangle}} &= \frac{\partial L}{\partial f_t} \frac{\partial f_t}{\partial W_{hf}} + \frac{\partial L}{\partial i_t} \frac{\partial i_t}{\partial W_{hf}}+\frac{\partial L}{\partial g_t} \frac{\partial g_t}{\partial W_{hf}} \notag \end{align} ∂Lt​∂L​∂ht​∂L​∂ot​∂L​∂ct​∂L​∂ft​∂L​∂it​∂L​∂gt​∂L​∂Whf⟨t⟩​∂L​​=∂Lt​∂L​∂ht​∂Lt​​+∂ot+1​∂L​∂ht​∂ot+1​​+∂ft+1​∂L​∂ht​∂ft+1​​+∂it+1​∂L​∂ht​∂it+1​​+∂gt+1​∂L​∂ht​∂gt+1​​=∂ht​∂L​∂ot​∂ht​​=∂ht​∂L​∂ct​∂ht​​+∂ct+1​∂L​∂ct​∂ct+1​​=∂ct​∂L​∂ft​∂ct​​=∂ct​∂L​∂it​∂ct​​=∂ct​∂L​∂gt​∂ct​​=∂ft​∂L​∂Whf​∂ft​​+∂it​∂L​∂Whf​∂it​​+∂gt​∂L​∂Whf​∂gt​​​
    根据式 (2.1)∼(2.7)(2.1)\sim (2.7)(2.1)∼(2.7) 将上面各式计算出来。t(t=1,2,…,T−1)t\space (t=1,2,\dots,T-1)t (t=1,2,…,T−1) 时刻各个变量的偏导总结如下(部分等式由于展开过长而不代入展开)。
    ∂L∂Lt=1∂L∂ht=∂Lt∂ht+∂Lt+1∂ht+1tanh(ct+1)σ′(⋅)Who+∂Lt+1∂ht+1ot+1tanh′(⋅)ctσ′(⋅)Whf+∂Lt+1∂ht+1ot+1tanh′(⋅)gt+1σ′(⋅)Whi+∂Lt+1∂ht+1ot+1tanh′(⋅)it+1σ′(⋅)Whg∂L∂ot=∂L∂httanh(ct)∂L∂ct=∂L∂htottanh′(⋅)+∂L∂ct+1ft+1∂L∂ft=∂L∂ctct−1∂L∂it=∂L∂ctgt∂L∂gt=∂L∂ctit∂L∂Whf⟨t⟩=∂L∂ft∂ft∂Whf=∂L∂ftht−1=∂L∂ctct−1σ′(⋅)ht−1\begin{align} \frac{\partial L}{\partial L_t} &=1\notag \\\notag \\ \frac{\partial L}{\partial h_t} &= \frac{\partial L_t}{\partial h_t} + \frac{\partial L_{t+1}}{\partial h_{t+1}}{\rm tanh}(c_{t+1})\sigma'(·)W_{ho} +\frac{\partial L_{t+1}}{\partial h_{t+1}} o_{t+1}{\rm tanh'(·)}c_t\sigma'(·)W_{hf} + \frac{\partial L_{t+1}}{\partial h_{t+1}}o_{t+1}{\rm tanh'(·)}g_{t+1}\sigma'(·)W_{hi}+\frac{\partial L_{t+1}}{\partial h_{t+1}}o_{t+1}{\rm tanh'(·)}i_{t+1}\sigma'(·)W_{hg} \notag\\\notag \\ \frac{\partial L}{\partial o_t} &= \frac{\partial L}{\partial h_t} {\rm tanh} (c_t) \notag \\\notag \\ \frac{\partial L}{\partial c_t} &= \frac{\partial L}{\partial h_t} o_t{\rm tanh'(·)} + \frac{\partial L}{\partial c_{t+1}} f_{t+1} \tag{*} \\\notag \\ \frac{\partial L}{\partial f_t}&=\frac{\partial L}{\partial c_t}c_{t-1} \notag\\\notag \\ \frac{\partial L}{\partial i_t}&=\frac{\partial L}{\partial c_t}g_t \notag\\\notag \\ \frac{\partial L}{\partial g_t}&=\frac{\partial L}{\partial c_t}i_t \notag\\\notag \\ \frac{\partial L}{\partial W_{hf}^{\left\langle t \right\rangle}} &=\frac{\partial L}{\partial f_t} \frac{\partial f_t}{\partial W_{hf}}=\frac{\partial L}{\partial f_t} h_{t-1}=\frac{\partial L}{\partial c_t}c_{t-1}\sigma'(·) h_{t-1} \tag{**} \end{align} ∂Lt​∂L​∂ht​∂L​∂ot​∂L​∂ct​∂L​∂ft​∂L​∂it​∂L​∂gt​∂L​∂Whf⟨t⟩​∂L​​=1=∂ht​∂Lt​​+∂ht+1​∂Lt+1​​tanh(ct+1​)σ′(⋅)Who​+∂ht+1​∂Lt+1​​ot+1​tanh′(⋅)ct​σ′(⋅)Whf​+∂ht+1​∂Lt+1​​ot+1​tanh′(⋅)gt+1​σ′(⋅)Whi​+∂ht+1​∂Lt+1​​ot+1​tanh′(⋅)it+1​σ′(⋅)Whg​=∂ht​∂L​tanh(ct​)=∂ht​∂L​ot​tanh′(⋅)+∂ct+1​∂L​ft+1​=∂ct​∂L​ct−1​=∂ct​∂L​gt​=∂ct​∂L​it​=∂ft​∂L​∂Whf​∂ft​​=∂ft​∂L​ht−1​=∂ct​∂L​ct−1​σ′(⋅)ht−1​​(*)(**)​
    上面式 (∗)(*)(∗) 没有计算出 ∂L∂ct\frac{\partial L}{\partial c_t}∂ct​∂L​ 的通项公式,只是给出了递推公式,对其归纳后得
    ∂L∂ct=∑t=1T∂L∂hioitanh′(ci)(1+∏j=2ifj)\frac{\partial L}{\partial c_t}=\sum_{t=1}^T\frac{\partial L}{\partial h_i}o_i{\rm tanh'}(c_i)\left( 1+\prod_{j=2}^i f_j\right) ∂ct​∂L​=t=1∑T​∂hi​∂L​oi​tanh′(ci​)(1+j=2∏i​fj​)
    进而计算出式 (∗∗)(**)(∗∗)
    ∂L∂Whf⟨t⟩=ct−1σ′(Wxfxt+Whfht−1+bf)ht−1∑t=1T∂L∂hioitanh′(ci)(1+∏j=2ifj)\frac{\partial L}{\partial W_{hf}^{\left\langle t \right\rangle}}= c_{t-1}\sigma'(W_{xf}x_t+W_{hf}h_{t-1}+b_f)h_{t-1}\sum_{t=1}^T\frac{\partial L}{\partial h_i}o_i{\rm tanh'}(c_i)\left( 1+\prod_{j=2}^i f_j\right) ∂Whf⟨t⟩​∂L​=ct−1​σ′(Wxf​xt​+Whf​ht−1​+bf​)ht−1​t=1∑T​∂hi​∂L​oi​tanh′(ci​)(1+j=2∏i​fj​)
    最后将全部的梯度贡献值相加,得
    ∂L∂Whf=∂L∂hToTtanh′(cT)cT−1σ′(WxfxT+WhfhT−1+bf)hT−1+∑t=1T−1ct−1σ′(Wxfxt+Whfht−1+bf)ht−1∑t=1T∂L∂hioitanh′(ci)(1+∏j=2ifj)\frac{\partial L}{\partial W_{hf}} = \frac{\partial L}{\partial h_T}o_T {\rm tanh'}(c_T)c_{T-1}\sigma'(W_{xf}x_T+W_{hf}h_{T-1}+b_f) h_{T-1} + \sum_{t=1}^{T-1} c_{t-1}\sigma'(W_{xf}x_t+W_{hf}h_{t-1}+b_f)h_{t-1}\sum_{t=1}^T\frac{\partial L}{\partial h_i}o_i{\rm tanh'}(c_i)\left( 1+\prod_{j=2}^i f_j\right) ∂Whf​∂L​=∂hT​∂L​oT​tanh′(cT​)cT−1​σ′(Wxf​xT​+Whf​hT−1​+bf​)hT−1​+t=1∑T−1​ct−1​σ′(Wxf​xt​+Whf​ht−1​+bf​)ht−1​t=1∑T​∂hi​∂L​oi​tanh′(ci​)(1+j=2∏i​fj​)
    也可以不体现函数的参数,得到更简洁的形式
    ∂L∂Whf=∂L∂hToTtanh′(⋅)cT−1σ′(⋅)hT−1+∑t=1T−1ct−1σ′(⋅)ht−1∑t=1T∂L∂hioitanh′(⋅)(1+∏j=2ifj)\frac{\partial L}{\partial W_{hf}} = \frac{\partial L}{\partial h_T}o_T {\rm tanh'}(·)c_{T-1}\sigma'(·) h_{T-1} + \sum_{t=1}^{T-1} c_{t-1}\sigma'(·)h_{t-1}\sum_{t=1}^T\frac{\partial L}{\partial h_i}o_i{\rm tanh'}(·)\left( 1+\prod_{j=2}^i f_j\right) ∂Whf​∂L​=∂hT​∂L​oT​tanh′(⋅)cT−1​σ′(⋅)hT−1​+t=1∑T−1​ct−1​σ′(⋅)ht−1​t=1∑T​∂hi​∂L​oi​tanh′(⋅)(1+j=2∏i​fj​)

    由于无法将 TTT 时刻的梯度贡献值与其他时刻的梯度贡献值统一表示,因此,对应上式中加号左右的两部分。

    上面计算出了 ∂L∂Whf\frac{\partial L}{\partial W_{hf}}∂Whf​∂L​ ,类似地也可以计算出 LLL 对 WxfW_{xf}Wxf​、WhiW_{hi}Whi​、WxiW_{xi}Wxi​、WhgW_{hg}Whg​、WxgW_{xg}Wxg​、WhoW_{ho}Who​、WxoW_{xo}Wxo​、bfb_fbf​、bib_ibi​、bgb_gbg​、bob_obo​ 。

    以下在讨论引入 LtL_tLt​ 的原因,选读。

    不同于 RNN 反向传播公式的推导,RNN 并没有特意地引入 LtL_tLt​ ,而 LSTM 反向传播公式的推导中却需要引入。我们不妨先不引入该符号,当计算 ∂L∂ht(t=1,2,…,T−1)\frac{\partial L}{\partial h_t}\space (t=1,2,\dots,T-1)∂ht​∂L​ (t=1,2,…,T−1) 时,我们可以找到两条从 LLL 到 hth_tht​ 的路径,分别是 L→htL\rightarrow h_tL→ht​ 和 L→ht+1→ot+1→htL\rightarrow h_{t+1}\rightarrow o_{t+1}\rightarrow h_tL→ht+1​→ot+1​→ht​ ,因此 ∂L∂ht\frac{\partial L}{\partial h_t}∂ht​∂L​ 可以表示为 ∂L∂ht=∂L∂ht+∂L∂ht+1∂ht+1∂ot+1∂ot+1∂ht\frac{\partial L}{\partial h_t}=\frac{\partial L}{\partial h_t}+\frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial o_{t+1}}\frac{\partial o_{t+1}}{\partial h_t}∂ht​∂L​=∂ht​∂L​+∂ht+1​∂L​∂ot+1​∂ht+1​​∂ht​∂ot+1​​ ,观察等式两边会发现,这显然不合理。

    出现这种情况的原因很好理解。 ∂L∂ht\frac{\partial L}{\partial h_t}∂ht​∂L​ 只是一个符号,表示全部的从 LLL 到 hth_tht​ 的路径(直接到达或经过其他任意结点中转到达)对应的链式求导之和;∂ot+1∂ht\frac{\partial o_{t+1}}{\partial h_t}∂ht​∂ot+1​​ 也只是符号,表达全部的从 ot+1o_{t+1}ot+1​ 到 hth_tht​ 的路径对应的链式求导之和,不过由于只存在一条路径,这使得 ∂ot+1∂ht\frac{\partial o_{t+1}}{\partial h_t}∂ht​∂ot+1​​ 能够唯一地代表一条路径,所以我们也就不需要继续将 ∂ot+1∂ht\frac{\partial o_{t+1}}{\partial h_t}∂ht​∂ot+1​​ 化为偏导连乘的形式了;类似的道理,∂L∂ht+1∂ht+1∂ot+1\frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial o_{t+1}}∂ht+1​∂L​∂ot+1​∂ht+1​​ 可以由 ∂L∂ot+1\frac{\partial L}{\partial o_{t+1}}∂ot+1​∂L​ 代替,即 ∂L∂ot+1=∂L∂ht+1∂ht+1∂ot+1\frac{\partial L}{\partial o_{t+1}}=\frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial o_{t+1}}∂ot+1​∂L​=∂ht+1​∂L​∂ot+1​∂ht+1​​,这正是因为从 LLL 到 ot+1o_{t+1}ot+1​ 的路径唯一。综上,只有路径唯一时才能用符号 ∂∂\frac{\partial}{\partial}∂∂​ 表示完整的链式偏导。

    重新考虑不引入符号 LtL_tLt​ 出现的问题,等式左侧的符号 ∂L∂ht\frac{\partial L}{\partial h_t}∂ht​∂L​ 对应了多条从 LLL 到 hth_tht​ 的路径,等式右侧需要详细地将每条路径对应的链式偏导表达出来。如果想要唯一地表达路径 L→htL\rightarrow h_tL→ht​ (直接到达)则必须要引入另一个中间结点 LtL_tLt​ ,从而构成新的路径 L→Lt→htL\rightarrow L_t\rightarrow h_tL→Lt​→ht​,对应的链式偏导为 ∂L∂Lt∂Lt∂ht\frac{\partial L}{\partial L_t}\frac{\partial L_t}{\partial h_t}∂Lt​∂L​∂ht​∂Lt​​。

    形象地理解一下,我从家走到学校告诉同学:“放学的时候小心从我家到学校路上的狗”,同学傻了“那么多道,我怎么知道是哪条有狗啊!”,我细说“从我家先到布达拉宫,再到天安门,再到曹县,最后到学校的那条路上有狗;还有,从我家直通学校的路上也有,你可要小心啊!”,同学一听既害怕又感激,于是决定坐飞机回家。

    从这个例子中可以看出 ⌈\lceil⌈ “家 →⋯→\rightarrow \dots\rightarrow→⋯→ 学校”有狗 ⌋⇔⌈\rfloor\Leftrightarrow \lceil⌋⇔⌈ “家 →\rightarrow→ 布达拉宫 →\rightarrow→ 天安门 →\rightarrow→ 曹县 →\rightarrow→ 学校”有狗,并且“家 →\rightarrow→ 学校”有狗 ⌋\rfloor⌋ ,对应于等式的左侧和等式的右侧。

    缓解所谓的“梯度消失”

    令 ki=∂L∂hik_i=\frac{\partial L}{\partial h_i}ki​=∂hi​∂L​ ,将处理后的式 (∗∗)(**)(∗∗) 展开,得
    ∂L∂Whf⟨t⟩=ct−1σ′(⋅)ht−1[(k1o1)+(k2o2f2)+(k3o3f3f2)+⋯+(kToTfT…f3f2)]\frac{\partial L}{\partial W_{hf}^{\left\langle t \right\rangle}}= c_{t-1}\sigma'(·)h_{t-1} \left[ (k_1o_1)+(k_2o_2f_2) + (k_3o_3f_3f_2) + \dots + (k_To_Tf_T\dots f_3f_2) \right] ∂Whf⟨t⟩​∂L​=ct−1​σ′(⋅)ht−1​[(k1​o1​)+(k2​o2​f2​)+(k3​o3​f3​f2​)+⋯+(kT​oT​fT​…f3​f2​)]
    其中,fif_ifi​ 为 sigmoid 函数,通过监督训练,这些函数的取值将起到“门”的作用,即非 000 即 111 。上式中显然不存在激活函数导数连乘的形式,这降低了梯度消失发生的可能,另外还通过多个 sigmoid 函数连乘实现对远距离的信息进行筛选,弥补了 RNN 无法解决长期依赖的问题。

2.4. 训练过程

根据上面的动态传播过程图我们知道,每个 LSTM 单元的四个神经网络(结构图中的黄色部件)的输入都是向量 ht−1h_{t-1}ht−1​ 和 xtx_txt​ 经过拼接(concatenate)后的向量,输出到下一个单元的向量为 hth_tht​,当然,这里无需考虑 ct−1c_{t-1}ct−1​,因为 ct−1c_{t-1}ct−1​ 不经过神经网络,也就不存在维度变化。假设 ht−1h_{t-1}ht−1​ 是 hidden_size\rm hidden\_sizehidden_size 维向量,xtx_txt​ 是 x_size\rm x\_sizex_size 维向量,每个神经网络的输出均为 hidden_size\rm hidden\_sizehidden_size 维向量,相当于将 hidden_size+x_size\rm hidden\_size+x\_sizehidden_size+x_size 维向量映射到 hidden_size\rm hidden\_sizehidden_size 维向量,所以每个神经网络对应的参数可以表示为 (hidden_size+x_size,hidden_size)(\rm hidden\_size+x\_size,\rm hidden\_size)(hidden_size+x_size,hidden_size) 的矩阵。四个神经网络,将 hidden_size+x_size\rm hidden\_size+x\_sizehidden_size+x_size 维向量映射到 4×hidden_size\rm 4\times hidden\_size4×hidden_size 维向量,一个 LSTM 单元完整的参数矩阵为 (hidden_size+x_size,4×hidden_size)(\rm hidden\_size+x\_size,4\times hidden\_size)(hidden_size+x_size,4×hidden_size)。由于 LSTM 每个单元共享参数矩阵,所以整个 LSTM 的参数矩阵即为 (hidden_size+x_size,4×hidden_size)(\rm hidden\_size+x\_size,4\times hidden\_size)(hidden_size+x_size,4×hidden_size)。注意,将 444 个神经网络对应的参数矩阵合并只是为了进行矩阵乘法时更简便,所以计算完之后还是要拆开,再进行不同的运算。

举个简单的例子,训练 batch_size=64\rm batch\_size=64batch_size=64 的一组语句,每个语句 202020 个词,每个词向量 200200200 维,隐藏层向量 hth_tht​ 128128128 维,ctc_tct​ 与 hth_tht​ 同维。LSTM 的输入张量为 (64,20,200)(64, 20, 200)(64,20,200),LSTM 的参数矩阵为 (128+200,4×128)(128+200,4\times 128)(128+200,4×128)。对于某一个 LSTM 单元来说,输入为 (64,200)(64, 200)(64,200) 的矩阵,和 hth_tht​ 拼接得到 (64,200+128)(64, 200+128)(64,200+128),输入矩阵与参数矩阵相乘得到 (64,4×128)(64,4\times 128)(64,4×128),即每个神经网络的输出为 (64,128)(64, 128)(64,128)。神经网络的输出会进行一些不影响矩阵维度的位操作,所以该单元输出的 ctc_tct​ 和 hth_tht​ 仍然为 (64,128)(64,128)(64,128) 的矩阵。每个单元都重复进行相同的操作,202020 次操作(时间步)后,最终全部单元的输出为 (20,64,128)(20,64, 128)(20,64,128) 的矩阵。

如此我们得到了 LSTM 的输出矩阵为 (time_step,batch_size,hidden_size)\rm(time\_step, batch\_size, hidden\_size)(time_step,batch_size,hidden_size)。根据下游任务的不同,会定义不同的损失函数,比如分类任务,那么我们仅会选择最后一个时刻的这批样本的交叉熵作为损失函数;当然,对于其他的一些任务,也可以选择对全部时刻的交叉熵进行加和或求均值作为最终的损失函数。

这里我们仅讲解将 LSTM 最后一个单元(时刻)输出结果的交叉熵作为损失函数,其他情况类似。假设全部单词数为 vocabulary_size\rm vocabulary\_sizevocabulary_size,我们需要先定义一个可训练的矩阵,大小为 (hidden_size,vocabulary_size)\rm (hidden\_size, vocabulary\_size)(hidden_size,vocabulary_size),作用是将 LSTM 最后一个单元的输出为 (batch_size,hidden_size)\rm (batch\_size, hidden\_size)(batch_size,hidden_size) 的矩阵映射到大小为 (batch_size,vocabulary_size)\rm (batch\_size, vocabulary\_size)(batch_size,vocabulary_size) 的矩阵上。这样,矩阵的每一行代表一个样本(单词),按行 softmax 后每行均为概率分布。每个样本根据对应的独热“标签”计算对应的交叉熵,再将 batch_size\rm batch\_sizebatch_size 个交叉熵加和或者求均值作为目标函数。采用梯度下降等方法对模型参数进行更新。

注意区别 softmax 和交叉熵。softmax 只是一种将一般向量化为同维概率分布的手法,而交叉熵则是一种将两组概率分布变为标量的计算。

LSTM 作为语言模型,任务是根据输入的若干个单词预测下一个单词。因此,每个 LSTM 的“标签”是该条输入语句当前单词的下一个单词对应的独热编码。对于单词处于语句末尾的情况,一般会在句末引入特殊的语句结束符号;还有一些其他的与具体实现有关的特殊情况,在这里不详细展开。

REF

[1] Understanding LSTM Networks - colah’s blog

[2] LSTM神经网络详解 - CSDN博客

[3] 详解LSTM - 知乎 - 仅参考图片

[4] 《神经网络的梯度推导与代码验证》之LSTM的前向传播和反向梯度推导 - 博客园

[5] 4.RNN梯度消失回顾(公式推导)- bilibili

[6] LSTM如何来避免梯度弥散和梯度爆炸? - 知乎 - 用户Quokka的回答

[7] LSTM如何解决RNN带来的梯度消失问题 - CSDN博客

[8] LSTM训练过程与参数解读 - CSDN

[9] 使用LSTM实现语言模型 - 知乎

[10] 关于LSTM的输入和训练过程的理解 - 博客园

[11] tf.nn.dynamic_rnn的输出outputs和state含义 - CSDN

[12] tf.nn.softmax_cross_entropy_with_logits函数 - CSDN

[13] LSTM每一个时间步都有一个损失函数吗? - 知乎

【机器学习】LSTM 讲解相关推荐

  1. 使用Keras进行深度学习:(六)LSTM和双向LSTM讲解及实践

    ###欢迎大家关注我们的网站和系列教程:http://www.tensorflownews.com/,学习更多的机器学习.深度学习的知识! 介绍 长短期记忆(Long Short Term Memor ...

  2. 最通俗易懂的LSTM讲解,一个例子理解通透!!

    刚开始做程序开发时,提交代码前需要让大佬review,大佬们看完,总会在评论区打出一串"LGTM".作为小白的我,天真的以为大佬觉得我的代码质量很好,在开玩笑的夸我说"老 ...

  3. LSTM和双向LSTM讲解及实践

    目录: RNN的长期依赖问题 LSTM原理讲解 双向LSTM原理讲解 keras实现LSTM和双向LSTM RNN 的长期依赖问题 在上篇文章中介绍的循环神经网络RNN在训练的过程中会有长期依赖的问题 ...

  4. 最全机器学习种类讲解:监督、无监督、在线和批量学习都讲明白了

    导读:现有的机器学习系统种类繁多,根据以下内容将它们进行分类有助于我们理解: 是否在人类监督下训练(监督式学习.无监督式学习.半监督式学习和强化学习) 是否可以动态地进行增量学习(在线学习和批量学习) ...

  5. 很好的关于机器学习入门讲解(深入浅出)

    在本篇文章中,我将对机器学习做个概要的介绍.本文的目的是能让即便完全不了解机器学习的人也能了解机器学习,并且上手相关的实践.这篇文档也算是EasyPR开发的番外篇,从这里开始,必须对机器学习了解才能进 ...

  6. 磐创AI - 专注机器学习技术分享

    微信公众号推荐 磐创AI 编辑 安可 今天给大家推荐一个微信公众号「磐创AI」,是一个从三大深度学习框架Tensorflow.Keras与PyTorch的角度剖析AI行业最新动态,机器学习干货文章,深 ...

  7. matlab图像信息熵交叉熵,【机器学习】信息量,信息熵,交叉熵,KL散度和互信息(信息增益)...

    首先先强烈推荐一篇外文博客Visual Information Theory这个博客的博主colah是个著名的计算机知识科普达人,以前很是著名的那篇LSTM讲解的文章也是他写的.这篇文章详细讲解了信息 ...

  8. 人工智能之机器学习算法体系汇总

    https://www.toutiao.com/i6638371599303049731/ 2018-12-24 09:52:12 此处梳理出面向人工智能的机器学习方法体系,主要体现机器学习方法和逻辑 ...

  9. 【机器学习入门笔记0:OpenCV+TensorFlow学习内容(目录)】20190122

    2019-01-22 by 崔斐然 计划在三月初完成如下内容 第一章 课程导学 1-1 计算机视觉导学(解释机器视觉可以做什么,既然已经看到这篇文章,估计心里已经有谱了) 第二章 计算机视觉入门 2- ...

最新文章

  1. Python控制流:顺序结构、分支结构、循环结构+for、if ... else、while、if... elif...elif..else、if
  2. CSS: HTML 和 Body 的区别
  3. HDU 6129 Just do it 组合数学 杨辉三角 递推
  4. Android P(1)---Android P 上手:十年来的最大升级,你想知道的都在这里
  5. Extjs5.0 学习之路【资源篇】
  6. 2020数学建模C题摘要
  7. qq拼音 搜狗拼音的对比
  8. webkit 学习笔记 - 1. 浏览器与浏览器内核
  9. SAP license key developer access key
  10. 编程序,输出1/3-3/5+5/7-7/9…+19/21的结果
  11. centos下查看服务器型号命令
  12. 每日一数——梅森数(Mersenne number)
  13. 【JAVA 求差集】
  14. android闹钟 推迟功能,Android闹钟事件被延迟
  15. Prompt IDE来了…是划时代AI神器还是鸡肋?
  16. 刷机-3.71M33升级到3.71M33-4的方法
  17. linux 查看u盘文件,详细介绍Linux系统下检测U盘是不是已经连接的方法
  18. 安全认证框架Shiro (一)- ini配置文件
  19. GNSS高边坡在线监测预警系统解决方案
  20. matlab GUI学习笔记1 如何生成一个GUI文件,及如何修改保存后的.fig文件

热门文章

  1. 【华为ICT备赛】生成树基础
  2. Android 设置延时操作
  3. 打造一款语音菜谱软件,人人都能是大厨!
  4. 神经网络梯度爆炸、梯度消失原因及解决方案
  5. 通达信欧奈尔RPS指标公式详解
  6. LaTex常见数学符号与示例
  7. Typora自定义主题样式
  8. UndefinedTable: relation “ ****_id_seq“ does not exist解决方法
  9. 基于ETest的车身控制系统硬件在环测试技术研究
  10. configure 配置参数说明