(一) 实验目的

使用简单的全连接层神经网络对MNIST手写数字图片进行分类。通过本次实验,可以掌握如下知识点:

  1. 学习 TensorFlow2 神经网络模型构建方式;
  2. 学习 tf.keras.layers.Flatten()tf.keras.layers.Dense()tf.keras.layers.Dropout() 三种神经网络层;
  3. 学习 relusoftmax 两种激活函数;(另写)
  4. 学习 adam 优化算法。 (另写)

(二) 实验过程

1. 导入 TensorFlow 模块

import tensorflow as tfprint(tf.__version__)

2.3.0

2. 读取 MNIST 数据集

(1) Tensorflow2 版本已经集成了包括MNIST在内的几种常见数据集,可以通过 tf.keras.datasets 模块进行下载和读取。

(train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data()

(2) 查看数据的维度信息。训练集数据是由60000张手写数字图片组成,每张图片为 28x28 矩阵。

print("训练集的图片数据维度:", train_x.shape)
print("训练集的标签数据维度:", train_y.shape)print("测试集的图片数据维度:", test_x.shape)
print("测试集的标签数据维度:", test_y.shape)

训练集的图片数据维度: (60000, 28, 28)
训练集的标签数据维度: (60000,)
 
测试集的图片数据维度: (10000, 28, 28)
测试集的标签数据维度: (10000,)

(3) 查看前5张图片

for i in range(5):plt.subplot(1,5,i+1)plt.xticks([])plt.yticks([])plt.imshow(train_x[i], cmap=plt.cm.binary)plt.xlabel(train_y[i])
plt.show()

(4) 对图片数据进行归一化处理,取值范围从原 [0-255] 转换为 [0-1]。

归一化的目的是为了减少各维度数据因取值范围差异而带来干扰。例如,有两个维度的特征数据A和B。其中,A的取值范围是(0,10),而B的取值范围是(0,10000)。这时就会发现,在B面前,A的取值范围变化是可以忽略的,从而导致A的信息在噪声中被淹没。

train_x, test_x = train_x / 255.0, test_x / 255.0

3. 构建神经网络模型

(1) 实例化一个网络模型。 tf.keras.Sequential() 方法可以让我们将神经网络层进行线性组合形成神经网络结构。

model = tf.keras.Sequential()

(2) 添加 tf.keras.layers.Flatten() 作为第1层神经网络。

Flatten层用来将输入“压平”,即把多维的输入一维化。还需要注意的一点是:作为第一层,需要设置input_shape 参数(参数值为元组形式),用以说明喂入模型的训练集数据的形状。

model.add(tf.keras.layers.Flatten(input_shape = (28,28)))

(3) 添加 tf.keras.layers.Dense 作为第2层神经网络。

Dense层,即全连接神经网络层,第一个参数为units,这里设为128。理解为这一层的输出神经元为128个。
 
知乎上有一篇介绍全连接层原理的文章。

model.add(tf.keras.layers.Dense(128, activation = "relu"))

(4) 添加tf.keras.layers.Dropout() 作为第3层神经网络。

Dropout层的工作机制就是每步训练时,按照一定的概率随机使神经网络的神经元失效,这样可以极大降低连接的复杂度。同时,由于每次训练都是由不同的神经元协同工作的,这样也可以很好地避免数据带来的过拟合,提高了神经网络的泛化性。
 
在使用Dropout时,需要配置的参数如下:

  • rate:配置神经元失效的概率
  • noise_shape:配置Dropout的神经元
  • seed:生成随机数
model.add(tf.keras.layers.Dropout(0.2))

(5) 添加 tf.keras.layers.Dense 作为第4层神经网络。相比于第2层的全连接神经网络,这里输出的神经元为10,正好对应label标签的数目(即数字0-9)。此外,激活函数也是选择softmax

model.add(tf.keras.layers.Dense(10, activation = "softmax"))

(6) 构建好神经网络后,需要调用 model.compile() 方法对模型进行编译。

optimizer 参数用来配置模型的优化器,可以通过名称tf.keras.optimizers API调用定义好的优化器。

loss 参数用来配置模型的损失函数,可以通过名称tf.losses API调用已经定义好的loss函数。

metrics 参数用来配置模型评价的方法,如 accuracy、mse 等。

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

(7) 查看神经网络模型

通过 summary() 方法,查看构建的神经网络模型信息。如下图所示,我们的模型一共包含4层,依次为Flatten、 Dense、 Dropout和Dense。 整个神经网络模型共有101770个参数。

model.summary()

4. 训练神经网络模型

在神经网络模型编译后,可以使用准备好的训练数据对模型进行训练。Sequential().fit() 方法提供了神经网络模型的训练功能。其中主要的配置参数如下(epoch和batch_size的概念):

  • x:配置训练的输入数据,可以是array或者tensor类型。
  • y:配置训练的标注数据,可以是array或者tensor类型。
  • batch_size:配置批大小,默认值是32。
  • epochs:配置训练的epochs的数量。
  • verbose:配置训练过程信息输出的级别,共有三个级别,分别是0、1、2。0表示不输出任何训练过程信息;1表示以进度条的方式输出训练过程信息;2表示每个epoch输出一条训练过程信息。
  • validation_split:配置验证数据集占训练数据集的比例,取值范围为0-1.
  • validation_data:配置验证数据集。如果已经配置validation_split参数,则可以不配置该参数。如果同时配置validation_split和validation_data,那么validation_split参数的配置将会失效。
  • shuffle:配置是否随机打乱训练数据集。当配置steps_per_epoch参数为None时,本参数失效。
  • initial_epoch:配置进行fine-tune时,新的训练周期是从指定的epoch开始继续训练的。
  • steps_per_epoch:配置每个epoch训练的步数。
model.fit(train_x, train_y, epochs=5)

5. 模型评估

利用测试数据集和evaluate() 方法对已经训练好的模型进行评估。

model.evaluate(test_x, test_y, verbose=2)

输出结果如下所示:

313/313 - 0s - loss: 0.0724 - accuracy: 0.9769
[0.07237865030765533, 0.9768999814987183]
 
损失函数值为0.0724; 正确率为0.9769

6. 保存训练得到的模型

通过save()或者save_weights()方法保存并导出训练得到的模型,在使用这两个方法是需要分别配置一下参数。

  • save() 方法的参数配置
  1. filepath:配置模型文件保存的路径。
  2. overwrite:配置是否覆盖重名的HDF5文件。
  3. include_optimizer:配置是否保存优化器的参数。
  • save_weights()方法的参数配置
  1. filepath:配置模型文件保存的路径。
  2. overwrite:配置是否覆盖重名的模型文件。
  3. save_format:配置保存文件的格式
model.save("mnist_dense")

可以看到在当前文件目录下,出现了一个名为 mnist_dense 的文件夹,里面的东西就是这次训练好的神经网络模型。

7. 神经模型的加载

通过 tf.keras.models.load_model() 方法加载一个已经训练好的模型。需要配置的参数如下:

  1. filepath:加载模型文件的路径
  2. custom_objects:配置神经网络模型自定义的对象。如果自定义了神经网络层级,则需要进行配置,否则在加载时会出现无法找到自定义对象的错误。
  3. compile:配置加载模型之后是否需要进行重新编译。
mnist_load = tf.keras.models.load_model("./mnist_dense")

查看重新加载的模型信息:

mnist_load.summary()

8. 神经网络的预测

(1) 作为例子,对测试数据集中的第一个数据进行预测。首先,提取该数据并转换为模型相对应的格式,即将二维数据压缩成一维数据。

pre = test_x[1].reshape(1,-1)

(2) 利用模型的predict()方法对数据进行预测。

res = mnist_load.predict(pre)
print(res)

预测结果如下,对应label标签(即数字0-9)的概率值,其中概率最大的就是预测的最有可能的结果。

tf.Tensor(
[[2.0393419e-07 1.3130091e-05 9.9991751e-01 6.6462977e-05 2.1617651e-15
8.9421718e-07 6.9612319e-08 2.9972663e-13 1.5803761e-06 1.0543532e-13]], shape=(1, 10), dtype=float32)

(3) 获得预测结果标签为2,这与真实数据是一致的。

print("预测的数字为:", np.argmax(res))
print("预测正确的概率为:", res.max())print("实际数字为:", test_y[1])

预测的数字为: 2
预测正确的概率为: 0.9999175
 
实际数字为: 2

使用 全连接神经网络 训练MNIST数据分类模型相关推荐

  1. 深度学习3—用三层全连接神经网络训练MNIST手写数字字符集

    上一篇文章:深度学习2-任意结点数的三层全连接神经网络 距离上篇文章过去了快四个月了,真是时光飞逝,之前因为要考博所以耽误了更新,谁知道考完博后之前落下的接近半个学期的工作是如此之多,以至于弄到现在才 ...

  2. python神经网络案例——FC全连接神经网络实现mnist手写体识别

    全栈工程师开发手册 (作者:栾鹏) python教程全解 FC全连接神经网络的理论教程参考 http://blog.csdn.net/luanpeng825485697/article/details ...

  3. [转载] python bp神经网络 mnist_Python利用全连接神经网络求解MNIST问题详解

    参考链接: Python中的单个神经元神经网络 本文实例讲述了Python利用全连接神经网络求解MNIST问题.分享给大家供大家参考,具体如下: 1.单隐藏层神经网络 人类的神经元在树突接受刺激信息后 ...

  4. 全连接神经网络实现MNIST手写数字识别

    在对全连接神经网络的基本知识(全连接神经网络详解)学习之后,通过MNIST手写数字识别这个小项目来学习如何实现全连接神经网络. MNIST数据集 对于深度学习的任何项目来说,数据集是其中最为关键的部分 ...

  5. 基于PyTorch框架的多层全连接神经网络实现MNIST手写数字分类

    多层全连接神经网络实现MNIST手写数字分类 1 简单的三层全连接神经网络 2 添加激活函数 3 添加批标准化 4 训练网络 5 结论 参考资料 先用PyTorch实现最简单的三层全连接神经网络,然后 ...

  6. PyTorch基础入门五:PyTorch搭建多层全连接神经网络实现MNIST手写数字识别分类

    )全连接神经网络(FC) 全连接神经网络是一种最基本的神经网络结构,英文为Full Connection,所以一般简称FC. FC的准则很简单:神经网络中除输入层之外的每个节点都和上一层的所有节点有连 ...

  7. 深度学习2---任意结点数的三层全连接神经网络

    上一篇文章:深度学习1-最简单的全连接神经网络 我们完成了一个三层(输入+隐含+输出)且每层都具有两个节点的全连接神经网络的原理分析和代码编写.本篇文章将进一步探讨如何把每层固定的两个节点变成任意个节 ...

  8. 图像识别python cnn_MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)...

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 全连接神经网络是深度学习的基础,理解它就可以掌握深度学习的核心概念:前向传播.反向误差传递.权重.学习 ...

  9. 深度学习初级阶段——全连接神经网络(MLP、FNN)

    在前面的数学原理篇中,已经提到了各种深度学习的数学知识基本框架,那么从这篇文章开始,我将和大家一起走进深度学习的大门(部分图片和描述取自其他文章). 目录 一.首先我们需要知道什么是深度学习? 二.了 ...

最新文章

  1. 收藏:JavaScript
  2. Sword STL之map效率问题
  3. 使用ABAP批量下载Markdown源文件里的图片到本地
  4. Comcast以纯文本泄露客户Wi-Fi登录信息,立即更改密码
  5. H.264/AVC技术进展及其务实发展策略思考
  6. 2017将转行进行到底
  7. 计算机信息技术基础学的是什么内容,计算机信息技术基础练习题及答案(许骏)...
  8. 动态规划(6)——NYOJ469擅长排列的小明II*
  9. 2017.8.14 分手是祝愿 失败总结
  10. GLSL Optimizer
  11. linux fsck命令,Linux中fsck命令起什么作用呢?
  12. 电子数字计算机和电子模拟计算机区别,电子数字计算机和电子模拟计算机的区别在哪里?...
  13. Windows最强ssh客户端推荐 —— Bitvise SSH Client(一)
  14. 麒麟案例 | 传统企业偶遇“麒麟计划” 相见恨晚 ,却恰逢其时!
  15. 有没免费云桌面,免费桌面虚拟化软件? 确实有的
  16. PHP slideup,vue+原生JavaScript实现slideDown与slideUp[简单思路]
  17. 美赛数模论文之公式写作
  18. MRCTF 2021 8bit adventure
  19. 看过这样一个纪录片吗《父亲》,令人深思
  20. 圣诞节购物软件测试,圣诞节心理测试-鑫海软件

热门文章

  1. 达芬奇编解码引擎Codec Engine(CE)【转】ceapp.cfg
  2. python基础 输入圆的半径,求圆的周长
  3. wifidog php,WifiDog-ng是新一代的WifiDog
  4. 【论文写作】之LaTeX中插入Visio图文件
  5. java mediator模式_设计模式之中介者模式(mediator模式)
  6. 一个摄影爱好者眼中的PRESSon
  7. xmapp启动Tomcat时报Jdk、Jre未安装错误的解决方法
  8. 鼠标控制物体移动旋转缩放
  9. Android _ MVVM 设计模式的一种实现方式,最新BAT大厂面试者整理的Android面试题目
  10. 【彩彩只能变身队】第二次会议