介绍

在TensorFlow的官方入门课程中,多次用到mnist数据集。mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx*-ubyte.gz的二进制文件。

可以直接从官网进行下载
http://yann.lecun.com/exdb/mnist/


如果我们想要知道大名鼎鼎的mnist手写体数字都长什么样子,就需要从mnist数据集中导出手写体数字图片。了解这些手写体的总体形状,也有助于加深我们对TensorFlow入门课程的理解。

训练数据集

当我们下载了数据集后,需要对数据集进行训练。并保存训练的模型

#!/usr/bin/python3.5
# -*- coding: utf-8 -*-
from tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tfmnist = input_data.read_data_sets('MNIST_data', one_hot=True)x = tf.placeholder(tf.float32, [None, 784])y_ = tf.placeholder(tf.float32, [None, 10])def weight_variable(shape):initial = tf.truncated_normal(shape, stddev=0.1)return tf.Variable(initial)def bias_variable(shape):initial = tf.constant(0.1, shape=shape)return tf.Variable(initial)def conv2d(x, W):return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')def max_pool_2x2(x):return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])x_image = tf.reshape(x, [-1, 28, 28, 1])h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))saver = tf.train.Saver()with tf.Session() as sess:sess.run(tf.global_variables_initializer())for i in range(20000):batch = mnist.train.next_batch(50)if i % 100 == 0:train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})print('step %d, training accuracy %g' % (i, train_accuracy))train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})saver.save(sess, 'WModel/model.ckpt')print('test accuracy %g' % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

对应的模型文件如图所示

用画图手写数字

通过电脑自带画图工具,手写一个数字,像素为28,如图所示

识别手写数字

把上面生成的图片保存为bmp或png
然后通过程序调用,在使用之前需要先加载前面保存的模型

#!/usr/bin/python3.5
# -*- coding: utf-8 -*-
from PIL import Image, ImageFilter
import tensorflow as tf
import matplotlib.pyplot as plt
import timedef imageprepare():"""This function returns the pixel values.The imput is a png file location."""file_name='result/4.bmp'#导入自己的图片地址#in terminal 'mogrify -format png *.jpg' convert jpg to pngim = Image.open(file_name)# plt.imshow(im)# plt.show()im = im.convert('L')im.save("sample.png")tv = list(im.getdata()) #get pixel values#normalize pixels to 0 and 1. 0 is pure white, 1 is pure black.tva = [ (255-x)*1.0/255.0 for x in tv] #print(tva)return tva"""This function returns the predicted integer.The imput is the pixel values from the imageprepare() function."""# Define the model (same as when creating the model file)result=imageprepare()x = tf.placeholder(tf.float32, [None, 784])y_ = tf.placeholder(tf.float32, [None, 10])def weight_variable(shape):initial = tf.truncated_normal(shape,stddev = 0.1)return tf.Variable(initial)def bias_variable(shape):initial = tf.constant(0.1,shape = shape)return tf.Variable(initial)def conv2d(x,W):return tf.nn.conv2d(x, W, strides = [1,1,1,1], padding = 'SAME')def max_pool_2x2(x):return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])x_image = tf.reshape(x,[-1,28,28,1])h_conv1 = tf.nn.relu(conv2d(x_image,W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))saver = tf.train.Saver()
with tf.Session() as sess:sess.run(tf.global_variables_initializer())saver.restore(sess, "./WModel/model.ckpt")#这里使用了之前保存的模型参数print ("Model restored.")prediction=tf.argmax(y_conv,1)predint=prediction.eval(feed_dict={x: [result],keep_prob: 1.0}, session=sess)print(h_conv2)print('识别结果:')print(predint[0])

识别结果如图所示:

基于MNIST数据集实现手写数字识别相关推荐

  1. matlab朴素贝叶斯手写数字识别_基于MNIST数据集实现手写数字识别

    介绍 在TensorFlow的官方入门课程中,多次用到mnist数据集.mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx*-ubyte ...

  2. 【机器学习】基于mnist数据集的手写数字识别

    文章目录 第1关:创建训练样本批量生成器 第2关:创建卷积神经网络

  3. 基于TensorFlow和mnist数据集的手写数字识别系统 ,可识别电话号码,识别准确率高,有对比实验,两组模型,可讲解代码

    基于TensorFlow和mnist数据集的手写数字识别系统 ,可识别电话号码,识别准确率高,有对比实验,两组模型,可讲解代码

  4. 【MLP实战】001:基于Minist数据集的手写数字识别

    本文又是一篇基于Minist数据集的手写数字识别. 首先,mnist数据集: 链接:https://pan.baidu.com/s/1z7R7_jnDKZm9F7M6n8hiIw 提取码:rn8z 首 ...

  5. DL之CNN:自定义SimpleConvNet【3层,im2col优化】利用mnist数据集实现手写数字识别多分类训练来评估模型

    DL之CNN:自定义SimpleConvNet[3层,im2col优化]利用mnist数据集实现手写数字识别多分类训练来评估模型 目录 输出结果 设计思路 核心代码 更多输出 输出结果 设计思路 核心 ...

  6. MNIST数据集实现手写数字识别(基于tensorflow)

    ------------先看看别人的博客--------------------- Tensorflow 实现 MNIST 手写数字识别         用这个的代码跑通了 使用Tensorflow和 ...

  7. 神经网络——实现MNIST数据集的手写数字识别

    由于官网下载手写数字的数据集较慢,因此提供便捷下载地址如下 手写数字的数据集MNIST下载:https://download.csdn.net/download/gaoyu1253401563/108 ...

  8. 北京大学曹健——Tensorflow笔记 05 MNIST数据集输出手写数字识别准确率

              # 前向传播:描述了网络结构 minist_forward.py # 反向传播:描述了模型参数的优化方法 mnist_backward.py # 测试输出准确率minist_tes ...

  9. MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试

    文章目录 MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试 一.题目要求 二.完整的目录结构说明 三.Mnist数据集及数据格式转换 四.BP神经网络相关知识 4.1 ...

  10. 基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明)

    基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明) 配置环境 1.前言 2.问题描述 3.解决方案 4.实现步骤 4.1数据集选择 4.2构建网络 4.3训练网络 4.4测试网络 4.5图 ...

最新文章

  1. 成功解决win10下dos中运行tensorboard --logdir=logs和调用events.out.tfevents一闪而过的问题
  2. 从源码编译Chrome(chromium)
  3. Python用MySQLdb, pymssql 模块通过sshtunnel连接远程数据库
  4. 【MySQL】【备份】mydumper安装与使用细节
  5. Spring源代码分析-Persist--JdbcTemplate
  6. ubuntu报警邮件服务简单搭建
  7. 如何在 iPhone、iPad 和 Mac 上压缩照片?
  8. 我们是如何解决偶发性的 502 错误的
  9. 同城滴滴啦啦啦啦啦啦啦啦
  10. vue获取上级路由地址
  11. 如何拥有一个免费云服务器
  12. Java LTS版本——Java 11新特性
  13. 2021全球与中国自动导引车市场现状及未来发展趋势
  14. csbte路点机器人_《cs1.6》awp地图
  15. win7声卡驱动不能安装和系统激活完美解决方案
  16. gif录屏与gif图片合成工具
  17. 依乌《一个土著的下午》新书分享会
  18. python的代码块使用什么控制类、函数以及其他逻辑判断_一篇文章教会你什么是Python模仿强类型...
  19. html控制台源码,可以在浏览器控制台中执行源码字符动画的js插件
  20. 资源分享:230个Proteus仿真原理图和经典案例

热门文章

  1. 应用程序正常初始化(0xc0000034)失败
  2. 文本文件与二进制文件区别 r 与 rb 方式 w 与 wb方式(windows)—————— 开开开山怪
  3. §1 打开百度地图的大门——注册百度地图开发者账户与创建应用
  4. python自动发送微信文件_Python脚本定期发送微信文件,定时
  5. IDEA下载源码报错 Cannot reconnect.
  6. 勇者斗恶龙 java实现
  7. YAML格式与Three dashes(hyphen) ---
  8. 【论文笔记】Deep Learning on Graphs: A Survey
  9. 怎样测试企业级SSD
  10. 7-1 大師と仙人との奇遇 (20 分)