文章目录

  • 1.前言
  • 2.用Keras搭建RNN循环神经网络
    • 2.1.导入必要模块
    • 2.2.超参数设置
    • 2.3.数据预处理
    • 2.4.搭建模型
    • 2.5.激活模型
    • 2.6.训练+测试

1.前言

这次我们用循环神经网络(RNN, Recurrent Neural Networks)进行分类(classification),采用MNIST数据集,主要用到SimpleRNN层。

2.用Keras搭建RNN循环神经网络

2.1.导入必要模块

import numpy as np
from keras.datasets import mnist    #手写体数据集模块
from keras.utils import np_utils
from keras.models import Sequential   #构建网络必需模块
from keras.layers import SimpleRNN, Activation, Dense    #RNN、激活函数、全连接层模块
from keras.optimizers import Adam   #优化器模块
np.random.seed(42)   #随机数种子

2.2.超参数设置

MNIST里面的图像分辨率是28×28,为了使用RNN,我们将图像理解为序列化数据。 每一行作为一个输入单元,所以输入数据大小INPUT_SIZE = 28; 先是第1行输入,再是第2行,第3行,第4行,…,第28行输入, 这就是一张图片也就是一个序列,所以步长TIME_STEPS = 28。

TIME_STEPS = 28     #可理解为每张图片的行数
INPUT_SIZE = 28     #可理解为每张图片的列数
BATCH_SIZE = 50     #批量大小
BATCH_INDEX = 0
OUTPUT_SIZE = 10    #输出维度大小
CELL_SIZE = 50      #经过RNN后的输出大小
LR = 0.001

2.3.数据预处理

训练数据要进行归一化处理,因为原始数据是8bit灰度图像所以需要除以255。

(X_train, y_train),(X_test, y_test) = mnist.load_data()     #拆分训练集与测试集X_train = X_train.reshape(-1,28,28)/255    #满足输入RNN为(-1,28,28)
X_test = X_test.reshape(-1,28,28)/255
y_train = np_utils.to_categorical(y_train,num_classes=10)    #将类别向量转换为二进制(只有0和1)的矩阵类型表示
y_test = np_utils.to_categorical(y_test,num_classes=10)

2.4.搭建模型

首先添加RNN层,输入为训练数据,输出数据大小由CELL_SIZE定义。

然后添加输出层,激励函数选择softmax

model = Sequential()
model.add(SimpleRNN(batch_input_shape = (None,TIME_STEPS,INPUT_SIZE),output_dim = CELL_SIZE,unroll = True
))model.add(Dense(OUTPUT_SIZE))
model.add(Activation('softmax'))

2.5.激活模型

设置优化方法,loss函数和metrics方法之后就可以开始训练了。 每次训练的时候并不是取所有的数据,只是取BATCH_SIZE个序列,或者称为BATCH_SIZE张图片,这样可以大大降低运算时间,提高训练效率。

adam = Adam(LR)
model.compile(optimizer=adam,loss = 'categorical_crossentropy',metrics=['accuracy'])

2.6.训练+测试

for step in range(10001):X_batch = X_train[BATCH_INDEX:BATCH_INDEX+BATCH_SIZE,:,:]y_batch = y_train[BATCH_INDEX:BATCH_INDEX+BATCH_SIZE,:]cost = model.train_on_batch(X_batch,y_batch)BATCH_INDEX += BATCH_SIZEBATCH_INDEX = 0 if BATCH_INDEX >= X_train.shape[0] else BATCH_INDEXif step % 500 == 0:cost, accuracy = model.evaluate(X_test, y_test, batch_size=y_test.shape[0],verbose=False)print('test cost:',cost,'test accuracy:',accuracy)

Keras——用Keras搭建RNN分类循环神经网络相关推荐

  1. Keras——用Keras搭建RNN回归循环神经网络

    文章目录 1.前言 2.用Keras搭建RNN回归循环神经网络 2.1.导入必要模块 2.2.超参数设置 2.3.构造数据 2.4.搭建模型 2.5.激活模型 2.6.训练+测试 1.前言 这次我们用 ...

  2. DL之RNN:循环神经网络RNN的简介、应用、经典案例之详细攻略

    DL之RNN:循环神经网络RNN的简介.应用.经典案例之详细攻略 目录 循环神经网络RNN的简介 1.RNN的分类 1.RNN的常见算法分类 2.RNN的三种分类

  3. 自然语言菜鸟学习笔记(七):RNN(循环神经网络)及变体(LSTM、GRU)理解与实现(TensorFlow)

    目录 前言 RNN(循环神经网络) 为什么要用循环神经网络(RNN)? 循环神经网络(RNN)可以处理什么类型的任务? 多对一问题 一对多问题 多对多问题 循环神经网络结构 单层网络情况 正向传播 反 ...

  4. 【直观理解】一文搞懂RNN(循环神经网络)基础篇

    推荐阅读时间8min~15min 主要内容简介:神经网络基础.为什么需要RNN.RNN的具体结构.以及RNN应用和一些结论 1神经网络基础 神经网络可以当做是能够拟合任意函数的黑盒子,只要训练数据足够 ...

  5. (pytorch-深度学习)使用pytorch框架nn.RNN实现循环神经网络

    使用pytorch框架nn.RNN实现循环神经网络 首先,读取周杰伦专辑歌词数据集. import time import math import numpy as np import torch f ...

  6. 循环取矩阵的某行_一文搞懂RNN(循环神经网络)基础篇

    神经网络基础 神经网络可以当做是能够拟合任意函数的黑盒子,只要训练数据足够,给定特定的x,就能得到希望的y,结构图如下: 将神经网络模型训练好之后,在输入层给定一个x,通过网络之后就能够在输出层得到特 ...

  7. 理论——RNN(循环神经网络)与LSTM(长短期记忆神经网络)

    这里写目录标题 RNN 背景 结构 应用 梯度消失.爆炸 LSTM 长期依赖问题 LSTM网络 结构 RNN 背景 人类的思考具有连续性,我们常联系过去的经验来理解现在.比如阅读时我们常提及的&quo ...

  8. Keras【Deep Learning With Python】RNN Classifier 循环神经网络

    文章目录 1 前言 2 RNN-循环神经网络 2.1 序列数据 2.2 处理序列数据的神经网络 2.3 应用 3 代码实现 4 代码讲解 5 输出 1 前言 本文分为RNN简单讲解,与Keras快速搭 ...

  9. 马里兰大学calce电池循环测试数据集_Keras-建立RNN(循环神经网络)

    循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的递 ...

最新文章

  1. linux负载均衡(什么是负载均衡)
  2. django 获取环境变量_Django 安装和配置环境变量
  3. 切换器黑屏_景阳华泰科技高清无缝矩阵切换器高端视频会议运用
  4. php代码expl,php – 参数号无效:参数未定义Explination
  5. 译文 | 与TensorFlow的第一次接触 第六章:并发
  6. python对于字典d d.get(x、y)_给定字典 d ,哪个选项对 d.get(x, y) 的描述是正确的?_学小易找答案...
  7. r语言 append_Python爬取近十年TIOBE编程语言热度数据并可视化可视化
  8. Windows登录密码轻松破解
  9. MySQL 常用分库分表方案,都在这里了!
  10. 不确定度用计算机怎么算,算A类不确定度用计算器该怎样按
  11. 硕士研究生培养方案及课程大纲
  12. 杭电计算机考研复试经验
  13. Mac中Homebrew下载指定版本软件的方法
  14. 数据挖掘 --如何有效地进行数据挖掘和分析
  15. ISP - bayer 是什么?
  16. 在线教育学习平台网校系统v2020 html5响应式在线教育培训类企业使用+安装说明
  17. 《名字竞技场 V3.0》 组队功能开放!
  18. 联发科有没有高端处理器_联发科处理器哪些好 2019联发科处理器排名
  19. 计算机电路计数器pl什么意思,计数器的原理为什么1下来是2.而且频率是一样的.它是怎么进位的.它的电路原理是什么...
  20. 3.2 QuickBI可视化分析工具

热门文章

  1. win7 64位系统配置服务器,Tomcat服务器win764位配置方法
  2. java zip ant 密码_java对 zip文件的压缩和解压(ant解决中文乱码)
  3. 【NetApp】exportfs命令的使用
  4. 寻找不到iframe元素
  5. swf php文本,SWFFont - PHP 5 中文文档
  6. Windows 7,无法访问internet,DNS无响应
  7. 手机自带计算机的功能,手机上的这3个小功能,比电脑方便好用,你知道吗?...
  8. ubuntu关闭自动更新、打开 ubuntu 的 apport 崩溃检测报告功能
  9. bilstmcrf词性标注_深度学习--biLSTM_CRF 命名实体识别
  10. 管理历程篇---学会四心