使用 全连接神经网络 训练MNIST数据分类模型
(一) 实验目的
使用简单的全连接层神经网络对MNIST手写数字图片进行分类。通过本次实验,可以掌握如下知识点:
- 学习 TensorFlow2 神经网络模型构建方式;
- 学习
tf.keras.layers.Flatten()
、tf.keras.layers.Dense()
、tf.keras.layers.Dropout()
三种神经网络层; - 学习
relu
和softmax
两种激活函数;(另写) - 学习
adam
优化算法。 (另写)
(二) 实验过程
1. 导入 TensorFlow 模块
import tensorflow as tfprint(tf.__version__)
2.3.0
2. 读取 MNIST 数据集
(1) Tensorflow2 版本已经集成了包括MNIST在内的几种常见数据集,可以通过 tf.keras.datasets
模块进行下载和读取。
(train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data()
(2) 查看数据的维度信息。训练集数据是由60000张手写数字图片组成,每张图片为 28x28 矩阵。
print("训练集的图片数据维度:", train_x.shape)
print("训练集的标签数据维度:", train_y.shape)print("测试集的图片数据维度:", test_x.shape)
print("测试集的标签数据维度:", test_y.shape)
训练集的图片数据维度: (60000, 28, 28)
训练集的标签数据维度: (60000,)
测试集的图片数据维度: (10000, 28, 28)
测试集的标签数据维度: (10000,)
(3) 查看前5张图片
for i in range(5):plt.subplot(1,5,i+1)plt.xticks([])plt.yticks([])plt.imshow(train_x[i], cmap=plt.cm.binary)plt.xlabel(train_y[i])
plt.show()
(4) 对图片数据进行归一化处理,取值范围从原 [0-255] 转换为 [0-1]。
归一化的目的是为了减少各维度数据因取值范围差异而带来干扰。例如,有两个维度的特征数据A和B。其中,A的取值范围是(0,10),而B的取值范围是(0,10000)。这时就会发现,在B面前,A的取值范围变化是可以忽略的,从而导致A的信息在噪声中被淹没。
train_x, test_x = train_x / 255.0, test_x / 255.0
3. 构建神经网络模型
(1) 实例化一个网络模型。 tf.keras.Sequential()
方法可以让我们将神经网络层进行线性组合形成神经网络结构。
model = tf.keras.Sequential()
(2) 添加 tf.keras.layers.Flatten()
作为第1层神经网络。
Flatten层用来将输入“压平”,即把多维的输入一维化。还需要注意的一点是:作为第一层,需要设置
input_shape
参数(参数值为元组形式),用以说明喂入模型的训练集数据的形状。
model.add(tf.keras.layers.Flatten(input_shape = (28,28)))
(3) 添加 tf.keras.layers.Dense
作为第2层神经网络。
Dense层,即全连接神经网络层,第一个参数为units,这里设为128。理解为这一层的输出神经元为128个。
知乎上有一篇介绍全连接层原理的文章。
model.add(tf.keras.layers.Dense(128, activation = "relu"))
(4) 添加tf.keras.layers.Dropout()
作为第3层神经网络。
Dropout层的工作机制就是每步训练时,按照一定的概率随机使神经网络的神经元失效,这样可以极大降低连接的复杂度。同时,由于每次训练都是由不同的神经元协同工作的,这样也可以很好地避免数据带来的过拟合,提高了神经网络的泛化性。
在使用Dropout时,需要配置的参数如下:
- rate:配置神经元失效的概率
- noise_shape:配置Dropout的神经元
- seed:生成随机数
model.add(tf.keras.layers.Dropout(0.2))
(5) 添加 tf.keras.layers.Dense
作为第4层神经网络。相比于第2层的全连接神经网络,这里输出的神经元为10,正好对应label标签的数目(即数字0-9)。此外,激活函数也是选择softmax
。
model.add(tf.keras.layers.Dense(10, activation = "softmax"))
(6) 构建好神经网络后,需要调用 model.compile()
方法对模型进行编译。
optimizer 参数用来配置模型的优化器,可以通过名称
tf.keras.optimizers
API调用定义好的优化器。
loss 参数用来配置模型的损失函数,可以通过名称
tf.losses
API调用已经定义好的loss函数。
metrics 参数用来配置模型评价的方法,如 accuracy、mse 等。
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
(7) 查看神经网络模型
通过 summary()
方法,查看构建的神经网络模型信息。如下图所示,我们的模型一共包含4层,依次为Flatten、 Dense、 Dropout和Dense。 整个神经网络模型共有101770个参数。
model.summary()
4. 训练神经网络模型
在神经网络模型编译后,可以使用准备好的训练数据对模型进行训练。Sequential().fit()
方法提供了神经网络模型的训练功能。其中主要的配置参数如下(epoch和batch_size的概念):
- x:配置训练的输入数据,可以是array或者tensor类型。
- y:配置训练的标注数据,可以是array或者tensor类型。
- batch_size:配置批大小,默认值是32。
- epochs:配置训练的epochs的数量。
- verbose:配置训练过程信息输出的级别,共有三个级别,分别是0、1、2。0表示不输出任何训练过程信息;1表示以进度条的方式输出训练过程信息;2表示每个epoch输出一条训练过程信息。
- validation_split:配置验证数据集占训练数据集的比例,取值范围为0-1.
- validation_data:配置验证数据集。如果已经配置validation_split参数,则可以不配置该参数。如果同时配置validation_split和validation_data,那么validation_split参数的配置将会失效。
- shuffle:配置是否随机打乱训练数据集。当配置steps_per_epoch参数为None时,本参数失效。
- initial_epoch:配置进行fine-tune时,新的训练周期是从指定的epoch开始继续训练的。
- steps_per_epoch:配置每个epoch训练的步数。
model.fit(train_x, train_y, epochs=5)
5. 模型评估
利用测试数据集和evaluate()
方法对已经训练好的模型进行评估。
model.evaluate(test_x, test_y, verbose=2)
输出结果如下所示:
313/313 - 0s - loss: 0.0724 - accuracy: 0.9769
[0.07237865030765533, 0.9768999814987183]
损失函数值为0.0724; 正确率为0.9769
6. 保存训练得到的模型
通过save()
或者save_weights()
方法保存并导出训练得到的模型,在使用这两个方法是需要分别配置一下参数。
- save() 方法的参数配置
- filepath:配置模型文件保存的路径。
- overwrite:配置是否覆盖重名的HDF5文件。
- include_optimizer:配置是否保存优化器的参数。
- save_weights()方法的参数配置
- filepath:配置模型文件保存的路径。
- overwrite:配置是否覆盖重名的模型文件。
- save_format:配置保存文件的格式
model.save("mnist_dense")
可以看到在当前文件目录下,出现了一个名为 mnist_dense 的文件夹,里面的东西就是这次训练好的神经网络模型。
7. 神经模型的加载
通过 tf.keras.models.load_model()
方法加载一个已经训练好的模型。需要配置的参数如下:
- filepath:加载模型文件的路径
- custom_objects:配置神经网络模型自定义的对象。如果自定义了神经网络层级,则需要进行配置,否则在加载时会出现无法找到自定义对象的错误。
- compile:配置加载模型之后是否需要进行重新编译。
mnist_load = tf.keras.models.load_model("./mnist_dense")
查看重新加载的模型信息:
mnist_load.summary()
8. 神经网络的预测
(1) 作为例子,对测试数据集中的第一个数据进行预测。首先,提取该数据并转换为模型相对应的格式,即将二维数据压缩成一维数据。
pre = test_x[1].reshape(1,-1)
(2) 利用模型的predict()
方法对数据进行预测。
res = mnist_load.predict(pre)
print(res)
预测结果如下,对应label标签(即数字0-9)的概率值,其中概率最大的就是预测的最有可能的结果。
tf.Tensor(
[[2.0393419e-07 1.3130091e-05 9.9991751e-01 6.6462977e-05 2.1617651e-15
8.9421718e-07 6.9612319e-08 2.9972663e-13 1.5803761e-06 1.0543532e-13]], shape=(1, 10), dtype=float32)
(3) 获得预测结果标签为2,这与真实数据是一致的。
print("预测的数字为:", np.argmax(res))
print("预测正确的概率为:", res.max())print("实际数字为:", test_y[1])
预测的数字为: 2
预测正确的概率为: 0.9999175
实际数字为: 2
使用 全连接神经网络 训练MNIST数据分类模型相关推荐
- 深度学习3—用三层全连接神经网络训练MNIST手写数字字符集
上一篇文章:深度学习2-任意结点数的三层全连接神经网络 距离上篇文章过去了快四个月了,真是时光飞逝,之前因为要考博所以耽误了更新,谁知道考完博后之前落下的接近半个学期的工作是如此之多,以至于弄到现在才 ...
- python神经网络案例——FC全连接神经网络实现mnist手写体识别
全栈工程师开发手册 (作者:栾鹏) python教程全解 FC全连接神经网络的理论教程参考 http://blog.csdn.net/luanpeng825485697/article/details ...
- [转载] python bp神经网络 mnist_Python利用全连接神经网络求解MNIST问题详解
参考链接: Python中的单个神经元神经网络 本文实例讲述了Python利用全连接神经网络求解MNIST问题.分享给大家供大家参考,具体如下: 1.单隐藏层神经网络 人类的神经元在树突接受刺激信息后 ...
- 全连接神经网络实现MNIST手写数字识别
在对全连接神经网络的基本知识(全连接神经网络详解)学习之后,通过MNIST手写数字识别这个小项目来学习如何实现全连接神经网络. MNIST数据集 对于深度学习的任何项目来说,数据集是其中最为关键的部分 ...
- 基于PyTorch框架的多层全连接神经网络实现MNIST手写数字分类
多层全连接神经网络实现MNIST手写数字分类 1 简单的三层全连接神经网络 2 添加激活函数 3 添加批标准化 4 训练网络 5 结论 参考资料 先用PyTorch实现最简单的三层全连接神经网络,然后 ...
- PyTorch基础入门五:PyTorch搭建多层全连接神经网络实现MNIST手写数字识别分类
)全连接神经网络(FC) 全连接神经网络是一种最基本的神经网络结构,英文为Full Connection,所以一般简称FC. FC的准则很简单:神经网络中除输入层之外的每个节点都和上一层的所有节点有连 ...
- 深度学习2---任意结点数的三层全连接神经网络
上一篇文章:深度学习1-最简单的全连接神经网络 我们完成了一个三层(输入+隐含+输出)且每层都具有两个节点的全连接神经网络的原理分析和代码编写.本篇文章将进一步探讨如何把每层固定的两个节点变成任意个节 ...
- 图像识别python cnn_MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)...
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 全连接神经网络是深度学习的基础,理解它就可以掌握深度学习的核心概念:前向传播.反向误差传递.权重.学习 ...
- 深度学习初级阶段——全连接神经网络(MLP、FNN)
在前面的数学原理篇中,已经提到了各种深度学习的数学知识基本框架,那么从这篇文章开始,我将和大家一起走进深度学习的大门(部分图片和描述取自其他文章). 目录 一.首先我们需要知道什么是深度学习? 二.了 ...
最新文章
- 收藏:JavaScript
- Sword STL之map效率问题
- 使用ABAP批量下载Markdown源文件里的图片到本地
- Comcast以纯文本泄露客户Wi-Fi登录信息,立即更改密码
- H.264/AVC技术进展及其务实发展策略思考
- 2017将转行进行到底
- 计算机信息技术基础学的是什么内容,计算机信息技术基础练习题及答案(许骏)...
- 动态规划(6)——NYOJ469擅长排列的小明II*
- 2017.8.14 分手是祝愿 失败总结
- GLSL Optimizer
- linux fsck命令,Linux中fsck命令起什么作用呢?
- 电子数字计算机和电子模拟计算机区别,电子数字计算机和电子模拟计算机的区别在哪里?...
- Windows最强ssh客户端推荐 —— Bitvise SSH Client(一)
- 麒麟案例 | 传统企业偶遇“麒麟计划” 相见恨晚 ,却恰逢其时!
- 有没免费云桌面,免费桌面虚拟化软件? 确实有的
- PHP slideup,vue+原生JavaScript实现slideDown与slideUp[简单思路]
- 美赛数模论文之公式写作
- MRCTF 2021 8bit adventure
- 看过这样一个纪录片吗《父亲》,令人深思
- 圣诞节购物软件测试,圣诞节心理测试-鑫海软件
热门文章
- 达芬奇编解码引擎Codec Engine(CE)【转】ceapp.cfg
- python基础 输入圆的半径,求圆的周长
- wifidog php,WifiDog-ng是新一代的WifiDog
- 【论文写作】之LaTeX中插入Visio图文件
- java mediator模式_设计模式之中介者模式(mediator模式)
- 一个摄影爱好者眼中的PRESSon
- xmapp启动Tomcat时报Jdk、Jre未安装错误的解决方法
- 鼠标控制物体移动旋转缩放
- Android _ MVVM 设计模式的一种实现方式,最新BAT大厂面试者整理的Android面试题目
- 【彩彩只能变身队】第二次会议