文章目录

  • 前言
  • 一、传统RNN
  • 二、RNN带来的缺陷
    • 1、梯度爆炸和梯度弥散
    • 2、memory记忆不足
  • 三、LSTM理解
    • 1、LSTM原理
    • 2、LSTM公式
      • 前向传播
      • 反向传播
  • 四、LSTM实践(Python)
    • 训练minist数据集
    • IMDB电影评论数据集进行文本分类
  • 五、参考

前言

参照网上资料对LSTM的理解和总结,如文章内容有错误和不足之处,烦请读者联系作者修改。


一、传统RNN

循环神经网络(Recurrent Neural Network,RNN)是一种用于处理序列数据的神经网络,可以解决训练样本输入是连续且长短不一的序列的问题,比如基于时间序列的问题。

基础的神经网络只在层与层之间建立了权连接,RNN最大的不同之处就是在层与层之间的神经元之间也建立了权连接。

RNN神经网络的结构如下:


二、RNN带来的缺陷

1、梯度爆炸和梯度弥散

虽然在某些情况下,RNN神经网络的参数比卷积网络要少很多。但是,随着循环次数的叠加,很容易出现梯度爆炸或梯度弥散。而导致这个缺陷产生的主要原因是,传统RNN在计算梯度时,其公式中存在一个 ωhhω_{hh}ωhh 的k次方。

由于其梯度求解公式中有 ωhhω_{hh}ωhh 的k次方的存在,所以会出现下面的极限情况:
ωhh>1ω_{hh}>1ωhh>1ωhhkω_{hh}^kωhhk 接近于 ∞∞ ——出现梯度爆炸
ωhh<1ω_{hh}<1ωhh<1ωhhkω_{hh}^kωhhk 接近于 000 ——出现梯度弥散

2、memory记忆不足

虽然RNN使用了一个全局的memory去记录全局的语境信息,但实际上,memory只能记住很短的全局信息,随着迭代次数的增加,memory会逐渐遗忘前面的语境信息。


三、LSTM理解

长短期记忆(Long short-term memory, LSTM)是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度爆炸或梯度弥散问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

1、LSTM原理

LSTM 结构如下图,不同于单一神经网络层,这里是有四个,以一种非常特殊的方式进行交互。

LSTM 的关键就是细胞状态,水平线在图上方贯穿运行。
细胞状态类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变会很容易。

若只有上面的那条水平线是没办法实现添加或者删除信息的。LSTM网络在传统RNN网络中设置了三道闸门——输入门(Input Gate)、遗忘门(Forget Gate)、输出门(Output Gate)用于控制不同对象的输出量,达到选择性记忆的目的。

2、LSTM公式

前向传播

遗忘门:
ft=δ(Wf⋅[ht−1,xt]+bf)f_t=δ(W_f·[h_{t-1},x_t]+b_f)ft=δ(Wf[ht1,xt]+bf)

遗忘门决定上一时刻细胞状态 Ct−1C_{t-1}Ct1 中的多少信息(由 ftf_tft 控制,值域为 (0,1)(0,1)(0,1) )可以传递到当前时刻 CtC_{t}Ct 中,1 表示“完全保留”,0 表示“完全舍弃”。

输入门:
it=δ(Wi⋅[ht−1,xt]+bi)i_t=δ(W_i·[h_{t-1},x_t]+b_i)it=δ(Wi[ht1,xt]+bi)
Ct~=tanh(WC⋅[ht−1,xt]+bC)\tilde{C_t}=tanh(W_C·[h_{t-1},x_t]+b_C)Ct~=tanh(WC[ht1,xt]+bC)

输入门用来控制当前输入新生成的信息 CtC_{t}Ct 中有多少信息(由 iti_tit 控制,值域为 (0,1)(0,1)(0,1) )可以加入到细胞状态 CtC_{t}Ct 中。 tanhtanhtanh 层用来产生当前时刻新的信息, δδδ 层用来控制有多少新信息可以传递给细胞状态。

Ct=ft∗Ct−1+it∗Ct~C_t=f_t*C_{t-1}+i_t*\tilde{C_t}Ct=ftCt1+itCt~

基于遗忘门和输入门的输出,来更新细胞状态。更新后的细胞状态有两部分构成:一、来自上一时刻旧的细胞状态信息 Ct−1C_{t-1}Ct1 ;二、当前输入新生成的信息 CtC_tCt
把旧状态Ct−1C_{t-1}Ct1ftf_tft相乘,丢弃掉确定需要丢弃的信息,接着加上it∗C~ti_t * \tilde{C}_titC~t,这就是新的状态。

输出门:
ot=δ(Wo⋅[ht−1,xt]+bo)o_t=δ(W_o·[h_{t-1},x_t]+b_o)ot=δ(Wo[ht1,xt]+bo)
ht=ot∗tanh(Ct)h_t=o_t*tanh(C_t)ht=ottanh(Ct)

最后,基于更新的细胞状态,输出隐藏状态 hth_tht 。这里依然用 δδδ 层 (输出门,oto_tot ) 来控制有多少细胞状态信息(tanh(Ct)tanh(C_t)tanh(Ct),将细胞状态缩放至 (0,1)(0,1)(0,1) ) 可以作为隐藏状态的输出 hth_tht

反向传播

如有不懂之处可前往
https://blog.csdn.net/lrs1353281004/article/details/81188250
博客查看细节。


四、LSTM实践(Python)

训练minist数据集

import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import pandas as pd
import numpy as np
import time
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline# Mnist数据集加载
(x_train_all, y_train_all), (x_test, y_test) = keras.datasets.mnist.load_data()# Mnist数据集简单归一化
x_train_all, x_test = x_train_all / 255.0, x_test / 255.0x_train, x_test = x_train_all[:50000], x_train_all[50000:]
y_train, y_test = y_train_all[:50000], y_train_all[50000:]
print(x_train.shape)# 构建模型
inputs = layers.Input(shape=(x_train.shape[1], x_train.shape[2]), name='inputs')
print(inputs.shape)
lstm = layers.LSTM(units=128, return_sequences=False)(inputs)
print(lstm.shape)
outputs = layers.Dense(10, activation='softmax')(lstm)
print(outputs.shape)
lstm = keras.Model(inputs, outputs)# 查看模型
lstm.compile(optimizer=keras.optimizers.Adam(0.001),loss='sparse_categorical_crossentropy',metrics=['accuracy'])
lstm.summary()# 训练模型
history = lstm.fit(x_train, y_train, batch_size=32, epochs=50, validation_split=0.1)# 绘制准确率图像
data = {}
data['accuracy'] = history.history['accuracy']
data['val_accuracy'] = history.history['val_accuracy']
pd.DataFrame(data).plot(figsize=(8, 5))
plt.grid(True)
plt.axis([0, 30, 0, 1])
plt.show()# 绘制损失图像
data = {}
data['loss'] = history.history['loss']
data['val_loss'] = history.history['val_loss']
pd.DataFrame(data).plot(figsize=(8, 5))
plt.grid(True)
plt.axis([0, 30, 0, 1])
plt.show()



IMDB电影评论数据集进行文本分类

#加载数据
#参数 num_words=10000 保留了训练数据中最常出现的 10000 个单词。为了保持数据规模的可管理性,低频词将被丢弃。
imdb = keras.datasets.imdb(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)print("Training entries: {}, labels: {}".format(len(train_data), len(train_labels)))#首条评论
print(train_data[0])#将整数转换回单词# 一个映射单词到整数索引的词典
word_index = imdb.get_word_index()# 保留第一个索引
word_index = {k:(v+3) for k,v in word_index.items()}
word_index["<PAD>"] = 0
word_index["<START>"] = 1
word_index["<UNK>"] = 2  # unknown
word_index["<UNUSED>"] = 3reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])def decode_review(text):return ' '.join([reverse_word_index.get(i, '?') for i in text])
#显示首条评论的文本
decode_review(train_data[0])#将训练和测试数据处理成相同长度
train_data = keras.preprocessing.sequence.pad_sequences(train_data,value=word_index["<PAD>"],padding='post',maxlen=256)test_data = keras.preprocessing.sequence.pad_sequences(test_data,value=word_index["<PAD>"],padding='post',maxlen=256)
#构建模型# 输入形状是用于电影评论的词汇数目(10,000 词)
vocab_size = 10000model = keras.Sequential()
model.add(keras.layers.Embedding(vocab_size, 16))
model.add(keras.layers.GlobalAveragePooling1D())
model.add(keras.layers.Dense(16, activation='relu'))
model.add(keras.layers.Dense(1, activation='sigmoid'))model.summary()#编译模型
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])
#从原始训练数据中分离 10,000 个样本来创建一个验证集
x_val = train_data[:10000]
partial_x_train = train_data[10000:]y_val = train_labels[:10000]
partial_y_train = train_labels[10000:]#训练模型
history = model.fit(partial_x_train,partial_y_train,epochs=40,batch_size=512,validation_data=(x_val, y_val),verbose=1)
#模型评估
results = model.evaluate(test_data,  test_labels, verbose=2)print(results)def plot_learning_curves(history):pd.DataFrame(history.history).plot(figsize=(8,5))plt.show()plot_learning_curves(history)




五、参考

  • Christopher Olah 的博文:http://colah.github.io/posts/2015-08-Understanding-LSTMs/
  • https://zhuanlan.zhihu.com/p/32085405
  • https://blog.csdn.net/flyinglittlepig/article/details/72229041
  • https://blog.csdn.net/lrs1353281004/article/details/81188250

理解LSTM网络+Python实现相关推荐

  1. 从Tensorflow代码中理解LSTM网络

    目录 RNN LSTM 参考文档与引子 缩略词  RNN (Recurrent neural network) 循环神经网络  LSTM (Long short-term memory) 长短期记忆人 ...

  2. (译)理解 LSTM 网络 (Understanding LSTM Networks by colah)

    前言:其实之前就已经用过 LSTM 了,是在深度学习框架 keras 上直接用的,但是到现在对LSTM详细的网络结构还是不了解,心里牵挂着难受呀!今天看了 tensorflow 文档上面推荐的这篇博文 ...

  3. 从任务到可视化,如何理解LSTM网络中的神经元 By 机器之心2017年7月03日 14:29 对人类而言,转写是一件相对容易并且可解释的任务,所以它比较适合用来解释神经网络做了哪些事情,以及神经网

    从任务到可视化,如何理解LSTM网络中的神经元 By 机器之心2017年7月03日 14:29 对人类而言,转写是一件相对容易并且可解释的任务,所以它比较适合用来解释神经网络做了哪些事情,以及神经网络 ...

  4. TensorFlow2.0(十一)--理解LSTM网络

    理解LSTM网络 前言 1. 循环神经网络 2. 长期依赖问题 3. LSTM网络 4. LSTM背后的核心思想 5. 单步解析LSTM网络结构 5.1 遗忘门结构 5.2 输入门结构 5.3 输出门 ...

  5. 【译】深入理解LSTM网络

    递归神经网络 人类不会每时每刻都开始思考. 当你阅读这篇文章时,你会根据你对之前单词的理解来理解每个单词. 你不要扔掉所有东西,然后再从头开始思考. 你的想法有持久性. 传统的神经网络无法做到这一点, ...

  6. 理解LSTM网络(翻译)

    Translated on December 19, 2015 本文为博客<Understanding LSTM Networks>的翻译文章 原文链接: http://colah.git ...

  7. lstm网络python代码实现

    LSTM的宏观讲解推荐这篇博客,以动图的形式展示特别容易理解https://blog.csdn.net/dQCFKyQDXYm3F8rB0/article/details/82922386 LSTM的 ...

  8. 理解LSTM 网络Understanding LSTM Networks

    Recurrent Neural Networks Humans don't start their thinking from scratch every second. As you read t ...

  9. TensorFlow2.0(十二)--实现简单RNN与LSTM网络

    实现简单RNN与LSTM网络 前言 1. 导入相应的库 2. 加载与构建数据集 2.1 加载数据集 2.2 构建词表 2.3 处理数据 3. 构建简单的RNN模型 3.1 单向RNN模型 3.2 双向 ...

最新文章

  1. unity桌面设置vnc_win7系统通过VNCViewer访问Ubuntu桌面环境的操作方法
  2. css样式 数据展示,教程:使用CSS设置数据样式
  3. ROS2学习(十六).ROS概念 - 构建系统
  4. php中upload函数,PHP中文件的上传和下载常用函数
  5. 如何禁止特定用户使用sqlplus或PL/SQL Developer等工具登陆?
  6. 你的设备中缺少重要的安全和质量修复_2020华富管道非开挖修复工程施工欢迎前来咨询...
  7. 最强面试题整理第二弹:Python 进阶面试题(附答案)
  8. 布谷鸟优化算法 matlab,布谷鸟算法(Cuckoo Search,CS)MATLAB案例详细解析
  9. 【asp】aspUpload
  10. VMware ESXi定制版(OEM ISO)资源下载(包含5.1\5.5\6.0)
  11. 智能合约审计之整形溢出攻击
  12. 怎么做有内容的二维码?二维码在线制作教程
  13. 泰勒级数(Taylor Series)和利用python计算自然常数
  14. 【Python】python 字符串转数组
  15. `spyder总是闪退?spyder打不开?spyder又又又又又出错啦?
  16. 淘宝关键词API接口、1688、京东、拼多多平台商品信息获取采集
  17. 2022年全国最新中级消防设施操作员考试模拟题库及答案
  18. ES源码学习之--Get API的实现逻辑
  19. 清空UIWebView历史网页
  20. 前端面试 | JavaScript知识点 | 课程笔记

热门文章

  1. python+uiautomator2 UI自动化
  2. 闲鱼架构专家,详解亿级C2C电商平台,商品体系架构如何搭建?
  3. C++水电管理信息系统
  4. 如何创建对搜索引擎更加友好的内容
  5. Wamp环境安装redis扩展
  6. 工商局爬虫 商标网爬虫
  7. 三冲IPO,亨达海天能否敲开美股上市大门?
  8. 【华为云技术分享】最终,我决定将代码迁出x86架构!
  9. 支付宝:APP支付接口2.0(alipay.trade.app.pay)
  10. 升级CentOS 7.4内核版本的三种方案