本篇博客将会提高神经网络对MNIST数据集预测的准确率。

1、标准化样本数据

对于MNIST数据集来讲,它们都是1字节的像素,不需要将它们的取值缩放到一个相似的范围之内(特征缩放)。

为了使得样本数据的取值接近于零,我们需要进行标准化处理,对标准化可以理解为:“重新调整输入,使其平均值为0,标准差为1”。标准差衡量的是一个变量“分布”的情况。

2、超参数的调优

本次实例中,超参数是历元epochs,隐藏节点的数量n_hidden_nodes,学习率lr。

选择历元数量的常见方法:从一个比较高的数字开始,找到在准确率趋于平稳的历元数量。

隐藏节点越多,训练速度就越慢,但在处理一些粗糙的数据时,会使得网络模型更加灵活。

学习率越小,训练速度就越慢,但这些小的步长可以帮助网络模型接近最小的损失,可以对一些值进行尝试。对于隐藏节点和学习率,这里编写了一个python程序,锁定一个范围,进行调试。在此,将不进行展示。

3、代码

mnist_standardized.py

import numpy as np
import struct
import gzip# 加载图像
def load_images(filename):# 打开并解压文件with gzip.open(filename, 'rb') as f:# 定义变量存储文件里的标题信息,struct.unpack()函数是根据模式字符串从二进制文件中读取数据_ignored, n_images, columns, rows = struct.unpack('>IIII', f.read(16))# 往Numpy的字节数组中读取所有的像素all_pixels = np.frombuffer(f.read(), dtype=np.uint8)# 将像素重塑为一个矩阵,其每一行都是一个图像,并返回return all_pixels.reshape(n_images, columns * rows)# 计算训练样本数据集的平均值和标准差
# 可以使用这个函数来标准化训练集、验证集和测试集
def standardize(training_set, test_set):average = np.average(training_set)standard_deviation = np.std(training_set)training_set_standardize = (training_set - average) / standard_deviation  # 标准化样本数据,将每个输入变量值减去这些变量的平均值,再除以这些变量的标准差test_set_standardize = (test_set - average) / standard_deviationreturn training_set_standardize, test_set_standardizeX_train_raw = load_images("../data/mnist/train-images-idx3-ubyte.gz")
X_test_raw = load_images("../data/mnist/t10k-images-idx3-ubyte.gz")
X_train, X_test_all = standardize(X_train_raw, X_test_raw)
X_validation, X_test = np.split(X_test_all, 2)# 加载标签
def load_labels(filename):with gzip.open(filename, 'rb') as f:# 跳过标题字节f.read(8)# 将所有的标签放入一个列表all_labels = f.read()# 将标签列表重塑为一列的矩阵return np.frombuffer(all_labels, dtype=np.uint8).reshape(-1, 1)def one_hot_encode(Y):n_labels = Y.shape[0]n_classes = 10encoded_Y = np.zeros((n_labels, n_classes))for i in range(n_labels):label = Y[i]encoded_Y[i][label] = 1return encoded_YY_train_unencoded = load_labels("../data/mnist/train-labels-idx1-ubyte.gz")
Y_train = one_hot_encode(Y_train_unencoded)
Y_test_all = load_labels("../data/mnist/t10k-labels-idx1-ubyte.gz")
Y_validation, Y_test = np.split(Y_test_all, 2)

neural_network.py

import numpy as npdef sigmoid(z):return 1 / (1 + np.exp(-z))def softmax(logits):exponentials = np.exp(logits)return exponentials / np.sum(exponentials, axis=1).reshape(-1, 1)# 从S型函数输出计算S型函数的梯度,帮助计算w1与w2
def sigmoid_gradient(sigmoid):return np.multiply(sigmoid, (1 - sigmoid))def loss(Y, y_hat):return -np.sum(Y * np.log(y_hat)) / Y.shape[0]def prepend_bias(X):return np.insert(X, 0, 1, axis=1)# 将训练样本数据集分成若干个批量
def prepare_batches(X_train, Y_train, batch_size):x_batches = []y_batches = []n_examples = X_train.shape[0]for batch in range(0, n_examples, batch_size):batch_end = batch + batch_sizex_batches.append(X_train[batch:batch_end])y_batches.append(Y_train[batch:batch_end])return x_batches, y_batchesdef forward(X, w1, w2):h = sigmoid(np.matmul(prepend_bias(X), w1))y_hat = softmax(np.matmul(prepend_bias(h), w2))return (y_hat, h)# 反向传播算法
def back(X, Y, y_hat, w2, h):w2_gradient = np.matmul(prepend_bias(h).T, (y_hat - Y)) / X.shape[0]w1_gradient = np.matmul(prepend_bias(X).T, np.matmul(y_hat - Y, w2[1:].T)* sigmoid_gradient(h)) / X.shape[0]return w1_gradient, w2_gradientdef classify(X, w1, w2):y_hat, _ = forward(X, w1, w2)labels = np.argmax(y_hat, axis=1)return labels.reshape(-1, 1)# 初始化权重,采用w=正负根号下r分之一,r是权重矩阵的行数
def initialize_weights(n_input_variables, n_hidden_nodes, n_classes):w1_rows = n_input_variables + 1w1 = np.random.randn(w1_rows, n_hidden_nodes) * np.sqrt(1 / w1_rows)  # 从标准正态分布中抽取一个随机数矩阵w2_rows = n_hidden_nodes + 1w2 = np.random.randn(w2_rows, n_classes) * np.sqrt(1 / w2_rows)return w1, w2def report(epoch, batch, X_train, Y_train, X_test, Y_test, w1, w2):y_hat, _ = forward(X_train, w1, w2)training_loss = loss(Y_train, y_hat)classifications = classify(X_test, w1, w2)accuracy = np.average(classifications == Y_test) * 100.0print("%5d-%d, Loss: %.8f, Accuracy: %.2f%%" %(epoch, batch, training_loss, accuracy))def train(X_train, Y_train, X_test, Y_test, n_hidden_nodes, epochs, batch_size, lr):n_input_variables = X_train.shape[1]n_classes = Y_train.shape[1]w1, w2 = initialize_weights(n_input_variables, n_hidden_nodes, n_classes)x_batches, y_batches = prepare_batches(X_train, Y_train, batch_size)# epoch是历元的意思,遍历训练集中的所有小批量样本数据for epoch in range(epochs):# 对单个小批量样本数据进行梯度下降的一步迭代计算for batch in range(len(x_batches)):y_hat, h = forward(x_batches[batch], w1, w2)w1_gradient, w2_gradient = back(x_batches[batch], y_batches[batch], y_hat, w2, h)w1 = w1 - (w1_gradient * lr)w2 = w2 - (w2_gradient * lr)report(epoch, batch, X_train, Y_train, X_test, Y_test, w1, w2)return w1, w2

开始测试

import mnist_standardized as mns
import neural_network as nnnn.train(mns.X_train, mns.Y_train, mns.X_test, mns.Y_test, n_hidden_nodes=100, epochs=10, batch_size=256, lr=1)

训练的结果:

修改一下 neural_network.py 和测试中的代码,提高准确率:

import mnist_standardized as mns
import neural_network as nnnn.train(mns.X_train, mns.Y_train, mns.X_test, mns.Y_test, n_hidden_nodes=1200, epochs=100, batch_size=600, lr=0.8)

参考文献:

Programming Machine Learning: Form Coding to Deep Learning.[M],Paolo Perrotta,2021.6.

MNIST数据集,图像识别(五)相关推荐

  1. 图像识别:利用KNN实现手写数字识别(mnist数据集)

    图像识别:利用KNN实现手写数字识别(mnist数据集) 步骤: 1.数据的加载(trainSize和testSize不要设置的太大) 2.k值的设定(不宜过大) 3.KNN的核心:距离的计算 4.k ...

  2. [Python图像识别] 五十.Keras构建AlexNet和CNN实现自定义数据集分类详解

    该系列文章是讲解Python OpenCV图像处理知识,前期主要讲解图像入门.OpenCV基础用法,中期讲解图像处理的各种算法,包括图像锐化算子.图像增强技术.图像分割等,后期结合深度学习研究图像识别 ...

  3. autoencoder自编码器原理以及在mnist数据集上的实现

    Autoencoder是常见的一种非监督学习的神经网络.它实际由一组相对应的神经网络组成(可以是普通的全连接层,或者是卷积层,亦或者是LSTMRNN等等,取决于项目目的),其目的是将输入数据降维成一个 ...

  4. 一文读懂经典卷积网络模型——LeNet-5模型(附代码详解、MNIST数据集)

    欢迎关注微信公众号[计算机视觉联盟] 获取更多前沿AI.CV资讯 LeNet-5模型是Yann LeCun教授与1998年在论文Gradient-based learning applied to d ...

  5. MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试

    文章目录 MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试 一.题目要求 二.完整的目录结构说明 三.Mnist数据集及数据格式转换 四.BP神经网络相关知识 4.1 ...

  6. 机器学习之sklearn使用下载MNIST数据集进行分类识别

    机器学习之sklearn使用下载MNIST数据集进行分类识别 一.MNIST数据集 1.MNIST数据集简介 2.获取MNIST数据集 二.训练一个二分类器 1.随机梯度下降(SGD)分类器 2.分类 ...

  7. 基于jupyter notebook的python编程-----MNIST数据集的的定义及相关处理学习

    基于jupyter notebook的python编程-----MNIST数据集的相关处理 一.MNIST定义 1.什么是MNIST数据集 2.python如何导入MNIST数据集并操作 3.接下来, ...

  8. mnist数据集在FATE上应用

    mnist数据集在FATE上应用 ** 一.下载mnist数据集 ** 我用阿里云盘分享了「MNIST」,复制这段内容打开「阿里云盘」App 即可获取 链接:https://www.aliyundri ...

  9. 全面理解主成分分析(PCA)和MNIST数据集的Python降维实现

    注:本博文为原创博文,如需转载请注明原创链接!!!   这篇博文主要讲述主成分分析的原理并用该方法来实现MNIST数据集的降维. 一.引言   主成分分析是一种降维和主成分解释的方法.举一个比较容易理 ...

  10. Python神经网络识别手写数字-MNIST数据集

    Python神经网络识别手写数字-MNIST数据集 一.手写数字集-MNIST 二.数据预处理 输入数据处理 输出数据处理 三.神经网络的结构选择 四.训练网络 测试网络 测试正确率的函数 五.完整的 ...

最新文章

  1. 浅析中科红旗的生与死
  2. 软件测试组织与管理思维导图
  3. Hadoop学习之以伪分布模式部署Hadoop及常见问题
  4. android 数组赋值字符串_c语言中的字符数组与字符串
  5. 滴滴Uber合并?光大是不行的
  6. 为Ubuntu Linux安装Docker CE Edge
  7. 获取页面可见区域,屏幕区域的尺寸
  8. 微信小程序中带参数返回上一页的方法总结(三种)
  9. 斐波那契数列(Fibonacci)递归和非递归实现
  10. Android平台的通话计时源码
  11. GPS研究---GPS 数据格式
  12. 根据word标题结构转换为excel的方法
  13. 思科交换机设置端口 trunk 模式报错
  14. Linux设置小红点键盘,debian linux上安装thinkpad小红点驱动/Installing Debian On Thinkpad – Trackpoint...
  15. 超链接一般有两种表现形式_超级链接有哪些常见的表现形式?
  16. 域控可以改计算机用户名,如何修改ActiveDirectory域控制器计算机名称
  17. 间谍用GAN生成“红发美女”!潜入美国政坛,全网广钓政客
  18. 【转】SCI论文写法攻略
  19. mysql数据库服务器的超级用户名是,MYSQL数据库的用户帐号管理基础知识 (2)
  20. springboot集成邮箱配置ssl或tls协议

热门文章

  1. 使用Java生成一个爱心图案
  2. 信息学奥赛一本通-2068:【例2.6】鸡兔同笼
  3. cesium添加环形扩散波纹
  4. cento6.8 升级glibc库 glibc2.14教程
  5. 如何制定能源项目管理计划?
  6. Java鼠标事件绘制简单图形-直线、矩形、椭圆
  7. 常见的网络欺诈风险类型有哪些?
  8. 多张图片合成gif怎么操作?教你一键在线gif动画制作
  9. JavaScript中的空数组和空对象布尔值是true还是false?
  10. 支持向量机 的三种境界