实现手写体识别的原理和代码实现

1.原理简单介绍
本次实现的手写体识别,是通过卷积神经网络来实现的。
卷积神经网络的主要结构可以分成

卷积层
池化层
全连接层
卷积层: 使用卷积层,可以保持图像数据的形状不变,输入图像数据时,卷积层会以三维数据的形式接收输入数据,并且同样以三维数据输出至下一层。卷积层的作用类似图像处理之中的滤波处理
池化层: 池化层有三个特点:

  1. 池化层和卷积层不同,没用要学习的参数,只是从目标区域中取最大值(平均值)
  2. 池化层中数据数据和输出数据的通道数不会发生变化,计算是按通道独立进行的
  3. 池化层对微小的位置变化具有鲁棒性

全连接层: 是网络最后在通过多次的卷积,池化操作之后,对提取出来的特征通过全连接层来实现分类输出。

下图是一张卷积神经网络的结构原理图:

首先输入的图像是3232 的图像,通过第一层卷积层,提到了2828的图像,将这张图像经过第一层池化层得到1414的图像,我们把卷积,池化层的组合叫做卷积组,通过两个卷积组,得到55的图像特征,放到后面的全连接层中,进行分类,得到其不同的类别。

代码:
import os
os.environ[“CUDA_VISIBLE_DEVICES”] = “-1”
import tensorflow as tf

#import numpy as np # 习惯加上这句,但这边没有用到
from tensorflow.examples.tutorials.mnist import input_data
#import matplotlib.pyplot as plt
mnist = input_data.read_data_sets(‘MNIST_data/’, one_hot=True)

sess = tf.InteractiveSession()

#1、权重初始化,偏置初始化
#为了创建这个模型,我们需要创建大量的权重和偏置项
#为了不在建立模型的时候反复操作,定义两个函数用于初始化
def weight_variable(shape):
initial = tf.truncated_normal(shape,stddev=0.1)#正太分布的标准差设为0.1
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1,shape=shape)
return tf.Variable(initial)

#2、卷积层和池化层也是接下来要重复使用的,因此也为它们定义创建函数
#tf.nn.conv2d是Tensorflow中的二维卷积函数,参数x是输入,w是卷积的参数
#strides代表卷积模块移动的步长,都是1代表会不遗漏地划过图片的每一个点,padding代表边界的处理方式
#padding = ‘SAME’,表示padding后卷积的图与原图尺寸一致,激活函数relu()
#tf.nn.max_pool是Tensorflow中的最大池化函数,这里使用2 * 2 的最大池化,即将2 * 2 的像素降为1 * 1的像素
#最大池化会保留原像素块中灰度值最高的那一个像素,即保留最显著的特征,因为希望整体缩小图片尺寸
#ksize:池化窗口的大小,取一个四维向量,一般是[1,height,width,1]
#因为我们不想再batch和channel上做池化,一般也是[1,stride,stride,1]
def conv2d(x, w):
return tf.nn.conv2d(x, w, strides=[1,1,1,1],padding=‘SAME’) # 保证输出和输入是同样大小
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1],padding=‘SAME’)

#3、参数
#这里的x,y_并不是特定的值,它们只是一个占位符,可以在TensorFlow运行某一计算时根据该占位符输入具体的值
#输入图片x是一个2维的浮点数张量,这里分配给它的shape为[None, 784],784是一张展平的MNIST图片的维度
#None 表示其值的大小不定,在这里作为第1个维度值,用以指代batch的大小,means x 的数量不定
#输出类别y_也是一个2维张量,其中每一行为一个10维的one_hot向量,用于代表某一MNIST图片的类别
x = tf.placeholder(tf.float32, [None,784], name=“x-input”)
y_ = tf.placeholder(tf.float32,[None,10]) # 10列

#4、第一层卷积,它由一个卷积接一个max pooling完成
#张量形状[5,5,1,32]代表卷积核尺寸为5 * 5,1个颜色通道,32个通道数目
w_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_variable([32]) # 每个输出通道都有一个对应的偏置量
#我们把x变成一个4d 向量其第2、第3维对应图片的宽、高,最后一维代表图片的颜色通道数(灰度图的通道数为1,如果是RGB彩色图,则为3)
x_image = tf.reshape(x,[-1,28,28,1])
#因为只有一个颜色通道,故最终尺寸为[-1,28,28,1],前面的-1代表样本数量不固定,最后的1代表颜色通道数量
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) # 使用conv2d函数进行卷积操作,非线性处理
h_pool1 = max_pool_2x2(h_conv1) # 对卷积的输出结果进行池化操作

#5、第二个和第一个一样,是为了构建一个更深的网络,把几个类似的堆叠起来
#第二层中,每个5 * 5 的卷积核会得到64个特征
w_conv2 = weight_variable([5,5,32,64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)# 输入的是第一层池化的结果
h_pool2 = max_pool_2x2(h_conv2)

#6、密集连接层
#图片尺寸减小到7 * 7,加入一个有1024个神经元的全连接层,
#把池化层输出的张量reshape(此函数可以重新调整矩阵的行、列、维数)成一些向量,加上偏置,然后对其使用Relu激活函数
w_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1,7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

#7、使用dropout,防止过度拟合
#dropout是在神经网络里面使用的方法,以此来防止过拟合
#用一个placeholder来代表一个神经元的输出
#tf.nn.dropout操作除了可以屏蔽神经元的输出外,
#还会自动处理神经元输出值的scale,所以用dropout的时候可以不用考虑scale
keep_prob = tf.placeholder(tf.float32, name=“keep_prob”)# placeholder是占位符
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

#8、输出层,最后添加一个softmax层
w_fc2 = weight_variable([1024,10])
b_fc2 = bias_variable([10])
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, w_fc2) + b_fc2, name=“y-pred”)

#9、训练和评估模型
#损失函数是目标类别和预测类别之间的交叉熵
#参数keep_prob控制dropout比例,然后每100次迭代输出一次日志
cross_entropy = tf.reduce_sum(-tf.reduce_sum(y_ * tf.log(y_conv),reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#预测结果与真实值的一致性,这里产生的是一个bool型的向量
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
#将bool型转换成float型,然后求平均值,即正确的比例
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#初始化所有变量,在2017年3月2号以后,用 tf.global_variables_initializer()替代tf.initialize_all_variables()
sess.run(tf.initialize_all_variables())

#保存最后一个模型
saver = tf.train.Saver(max_to_keep=1)

for i in range(10):
batch = mnist.train.next_batch(16)
if i % 10 == 0:
train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1],keep_prob: 1.0})
print(“Step %d ,training accuracy %g” % (i, train_accuracy))
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
print("test accuracy %f " % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

保存模型于文件夹
saver.save(sess,“save/model”)
import tensorflow as tf
import numpy as np
import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk
from tkinter import filedialog
import time

def creat_windows():
win = tk.Tk() # 创建窗口
sw = win.winfo_screenwidth()
sh = win.winfo_screenheight()
ww, wh = 400, 450
x, y = (sw-ww)/2, (sh-wh)/2
win.geometry("%dx%d+%d+%d"%(ww, wh, x, y-40)) # 居中放置窗口
win.title(‘手写体识别’) # 窗口命名
bg1_open = Image.open(“timg.jpg”).resize((300, 300))
bg1 = ImageTk.PhotoImage(bg1_open)
canvas = tk.Label(win, image=bg1)
canvas.pack()
var = tk.StringVar() # 创建变量文字
var.set(’’)
tk.Label(win, textvariable=var, bg=’#C1FFC1’, font=(‘宋体’, 21), width=20, height=2).pack()
tk.Button(win, text=‘选择图片’, width=20, height=2, bg=’#FF8C00’, command=lambda:main(var, canvas), font=(‘圆体’, 10)).pack()
win.mainloop()
def main(var, canvas):
file_path = filedialog.askopenfilename()
bg1_open = Image.open(file_path).resize((28, 28))
pic = np.array(bg1_open).reshape(784,)
bg1_resize = bg1_open.resize((300, 300))
bg1 = ImageTk.PhotoImage(bg1_resize)
canvas.configure(image=bg1)
canvas.image = bg1
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
saver = tf.train.import_meta_graph(‘save/model.meta’) # 载入模型结构
saver.restore(sess, ‘save/model’) # 载入模型参数
graph = tf.get_default_graph() # 加载计算图
x = graph.get_tensor_by_name(“x-input:0”) # 从模型中读取占位符变量
keep_prob = graph.get_tensor_by_name(“keep_prob:0”)
y_conv = graph.get_tensor_by_name(“y-pred:0”) # 关键的一句 从模型中读取占位符变量
prediction = tf.argmax(y_conv, 1)
predint = prediction.eval(feed_dict={x: [pic], keep_prob: 1.0}, session=sess) # feed_dict输入数据给placeholder占位符
answer = str(predint[0])
var.set(“预测的结果是:” + answer)

if name_ == “main_”:
creat_windows()

分析Mnist数据集中一些歧义数据:



通过训练出来的网络来识别,发现其中1的正确率还是挺高的,是因为1这样的数据,数据的特征比较单一,所以在识别时候,虽然有形状上的差异,但是因为特征容易识别,所以就能够很好的识别。在0的识别上,出现了两个奇怪的现象,3和4这两个识别,一般看来3是更接近6。然而实验上刚好相反,原因是4中的连接处太过粗重,让机器识别错误。

计算机视觉--Tensorflow对Mnist手写体数据集做手写体识别相关推荐

  1. 计算机视觉(十)——Tensorflow对Mnist手写体数据集做手写体识别

    博文主要内容 分析Mnist手写体数据集 实现手写体识别的原理和代码实现 分析Mnist数据集中一些歧义数据 实验中遇到的一些问题 分析Mnist手写体数据集 MNIST 数据集来自美国国家标准与技术 ...

  2. TensorFlow的MNIST手写数字分类问题

    一.简介MNIST TensorFlow编程学习的入门一般都是基于MNIST手写数字数据集和Cifar(包括cifar-10和cifar-100)数据集,因为它们都比较小,一般的设备即可进行训练和测试 ...

  3. 利用TCN网络实现MNIST手写体数据集的识别

    利用TCN网络实现MNIST手写体数据集的识别 TCN识别MNIST的GitHub网址 https://github.com/locuslab/TCN 论文来源 https://arxiv.org/p ...

  4. CNN卷积神经网络—LeNet原理以及tensorflow实现mnist手写体训练

    CNN卷积神经网络-LeNet原理以及tensorflow实现minst手写体训练 1. LeNet原理 2.tensorflow实现Mnist手写体识别 1.安装tensorflow 2.代码实现手 ...

  5. TensorFlow笔记(3)——利用TensorFlow和MNIST数据集训练一个最简单的手写数字识别模型...

    前言 当我们开始学习编程的时候,第一件事往往是学习打印"Hello World".就好比编程入门有Hello World,机器学习入门有MNIST. MNIST是一个入门级的计算机 ...

  6. 【机器学习】SVM支持向量机在手写体数据集上进行二分类、采⽤ hinge loss 和 cross-entropy loss 的线性分类模型分析和对比、网格搜索

    2022Fall 机器学习 1. 实验要求 考虑两种不同的核函数:i) 线性核函数; ii) ⾼斯核函数 可以直接调⽤现成 SVM 软件包来实现 ⼿动实现采⽤ hinge loss 和 cross-e ...

  7. TensorFlow读取MNIST数据集错误的问题

    TensorFlow读取mnist数据集错误的问题 运行程序出现"URLError"错误的问题 可能是服务器或路径的原因,可以自行下载数据集后,将数据集放到代码所在的文件夹下,并将 ...

  8. 机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

    机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

  9. CASIA手写体数据集HWDB1.0 gnt和dgrl格式解析

    目录 引言 Gnt格式解析代码 dgrl格式解析代码 相关资料 引言 最近用到了CASIA这个手写体数据集,但是HWDB1.0~1.2系列其存储格式为gnt 虽说官网也给了读取方式,但是仍然具有一定门 ...

最新文章

  1. jq的ajax和模块引擎
  2. 广州.NET俱乐部活动通知(11月17日)
  3. 【Kaggle-MNIST之路】两层的神经网络Pytorch(改进版)(二)
  4. ACM入门之【ST表/RMQ】
  5. chrome设置微信ua_Chrome谷歌浏览器模拟微信内置浏览器的方法(电脑上)
  6. linux下载b站的视频+ffmpeg抽取出mp3
  7. 图【数据结构F笔记】
  8. 版本控制:集中式(SVN) vs 分布式(GIT)
  9. 5场直播丨PostgreSQL、openGauss、Oracle、GoldenDB、EsgynDB
  10. 数据库三范式,轻松理解
  11. Linux 中Vim 命令大全
  12. (原创)使用TimeStamp控制并发问题[示例]-.cs脚本
  13. 19.04.13--指针笔记
  14. Java语言程序设计(基础篇)第十版 第一章复习题答案
  15. 项目立项,项目经理需要做什么
  16. android 音效下载地址,V4A+Dolby Atmos安卓全局音效
  17. 外地人排北京新能源指标需要什么条件?需要摇号吗?
  18. 华为机试 素数伴侣 匹配匈牙利算法
  19. 史上最全Java开发手册!!!阿里出版
  20. 清华大学计算机博后,科学网—我在清华做博士后的收获 - 喻海良的博文

热门文章

  1. 计算机组装课堂作业,《电脑组装、使用与维护》公选课大作业
  2. 下列哪个不属于CRF模型对于HMM和MEMM模型的优势( )
  3. 时间的定义,时间接口包括 1PPS+ToD、DCLS、IRIG-B、NTP、PTP、串行口 ASCII 字符串
  4. 实现高效率、精准化的管理方式​-​兼职APP开发
  5. TL-WR841N V8 新人openwrt入门刷机经验 含固件及资料
  6. C++仿照string类,写一个my_string类并实现相关功能
  7. 【MySQL数据库】——多表联查
  8. BEA WebLogic Platform 8.1 Single Sign-On Enablement:概述
  9. 短视频去水印小程序源码+支持图集/自带多平台解析API
  10. Blender进行DEM数据3D制图(一)