前面学习了关于使用MNIST数据集中的数据进行训练和测试。现在要用自己的手写数字进行识别

使用自己的手写数字进行识别

主要部分如下

def application():testNum = int(input("input the number of test pictures:"))for i in range(testNum):testPic = input("the path of test picture:")testPicArr = pre_pic(testPic)preValue = restore_model(testPicArr)print("prediction num is",preValue)

整体代码如下

1、先是导包

import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import mnist_backward
import mnist_forward

2、模型加载

def restore_model(testPicArr):with tf.Graph().as_default() as tg:x = tf.placeholder(tf.float32,[None,mnist_forward.INPUT_NODE])y = mnist_forward.forward(x,None)preValue = tf.argmax(y,1)variable_averages = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)variables_to_restore = variable_averages.variables_to_restore()saver = tf.train.Saver(variables_to_restore)with tf.Session() as sess:ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)if (ckpt and ckpt.model_checkpoint_path):saver.restore(sess,ckpt.model_checkpoint_path)preValue = sess.run(preValue, feed_dict={x:testPicArr})return preValueelse:print("No file found")return -1

3、将自己的手写数字图片转换为要求的格式

def pre_pic(picName):#打开传入的原始图片img = Image.open(picName)#为符合要求,把图片resize成28*28,用消除锯齿的方法relm = img.resize((28,28),Image.ANTIALIAS)#变成灰度图并转换为矩阵形式im_arr = np.array(relm.convert("L"))threshold = 110#给图片反色,因为要求输入黑底白字,输入的是白底黑字,并进行二值化处理for i in range(28):for j in range(28):im_arr[i][j]= 255-im_arr[i][j]if(im_arr[i][j]<threshold):im_arr[i][j] = 0else:im_arr[i][j] = 255nm_arr = im_arr.reshape([1,784])nm_arr = nm_arr.astype(np.float32)img_ready = np.multiply(nm_arr,1.0/255.0)return img_ready

4、运行部分

def application():testNum = int(input("input the number of test pictures:"))for i in range(testNum):testPic = input("path of picture:")im = plt.imread(testPic)plt.imshow(im)plt.show()testPicArr = pre_pic(testPic)preValue=restore_model(testPicArr)print("the prediction number is:",preValue)def main():application()if __name__ =='__main__':main()

数据集的制作

tfrecords文件

tfrecords:是一种二进制文件,可先将图片和标签制作成该格式的文件。 使用 tfrecords 进行数据读取,会提高内存利用率。

tf.train.Example: 用来存储训练数据。训练数据的特征用键值对的形式表示。

SerializeToString( ):把数据序列化成字符串存储。

生成tfrecords文件

#新建一个writer
writer = tf.python_io.TFRecordWriter(tfRecordName)
for in range(): #循环遍历每张图和标签#在Features中特征会以字典的形式给出example = tf.train.Example(features = tf.trainFeatures(feature={ #img_raw放入二进制图片'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value = [img_raw])),#labels放入该图片对应的标签'label':tf.train.Feature(int64_list=train.Int64List(value=labels))}))#把每张图片封装在examples中writer.write(example.SerializeToString())#把example进行序列化
writer.close()

解析tfrecords文件

#先建立文件队列名
filename_queue = tf.train.string_input_producer([tfRecord_path])
reader = tf.TFRecordReader()
#读出的每一个样本保存在serialized_exapmle中
_,serialized_exapmle = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,features={'img_raw':tf.FixedLenFeature([],tf.string)'label':tf.FixedLenFeature([10],tf.int64)})
#恢复img_row到img
img = tf.decode_raw(features['img_raw'],tf.uint8)
img.set_shape([784])
#把每个元素变为浮点数
img = tf.cast(img,tf.float32)*(1./255)
label = tf.cast(features['label'],tf.float32)

mnist.generateds.py

import tensorflow as tf
import numpy as np
from PIL import Image
import osimage_train_path='./mnist_data_jpg/mnist_train_jpg_60000/'
label_train_path = './mnist_data_jpg/mnist_train_jpg_60000.txt'
tfRecord_train='./data/mnist_train.tfrecords'
image_test_path='./mnist_data_jpg/mnist_test_jpg_10000/'
label_test_path='./mnist_data_jpg/mnist_test_jpg_10000.txt'
tfRecord_test='./data/mnist_test.tfrecords'
data_path='./data'
resize_height=28
resize_width = 28def write_tfRecord(tfRecordName,image_path,label_path):writer = tf.python_io.TFRecordWriter(tfRecordName)num_pic = 0f = open(label_path,'r')contents = f.readlines()f.close()for content in contents:value = content.split()img_path = img_path + value[0]img = Image.open(img_path)img_raw = img.tobytes()labels = [0]*10labels[int(value[1])] = 1example = tf.train.Example(features=tf.train.Features({'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))}))writer.write(example.SerializeToString())num_pic +=1print('the number of picture:',num_pic)writer.close()print("write tfrecord successful")def generate_tfRecord():isExists = os.path.exists(data_path)if not isExists:os.makedirs(data_path)print('dir was created successfully')else:print( 'dir already exists')write_tfRecord(tfRecord_train, image_train_path,label_train_path)write_tfRecord(tfRecord_test,image_test_path,label_test_path)def read_tfRecord(tfRecord_path):filename_queue = tf.train.string_input_producer([tfRecord_path])reader = tf.TFReacordReader()_, serialized_example = reader.read(filename_queue)features =tf.parse_single_example(serialized_example,features={'labels':tf.FixedLenFeature([10],tf.int64),'img_raw':tf.FixedLenFeature([],tf.string)})img = tf.decode_raw(features['img_raw'],tf.uint8)img.set_shape([784])img = tf.cast(img,tf.float32)*(1./255)label = tf.cast(features['label'],tf.float32)return img, label
def get_tfrecord(num,isTrain=True):if isTrain:tfRecord_path = tfRecord_trainelse:tfRecord_path = tfRecord_testimg,label = read_tfRecord(tfRecord_path)img_batch,label_batch = tf.train.shuffle_batch([img,label],batch_size = num,num_thread=1000,min_after_dequeue = 700)return img_batch,label_batch
def main():generate_tfRecord()if __name__ == '__main__':main()

在反向传播和测试程序中,修改图片和标签的批获取接口,可以使用多线程提高图片和标签的提取效率

#开启线程协调器
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
#图片和标签的批获取
#关闭
coord.request_stop()
coord.join(threads)

Tensorflow 自己的手写数字实践与数据集制作相关推荐

  1. tensorflow saver_机器学习入门(6):Tensorflow项目Mnist手写数字识别-分析详解

    本文主要内容:Ubuntu下基于Tensorflow的Mnist手写数字识别的实现 训练数据和测试数据资料:http://yann.lecun.com/exdb/mnist/ 前面环境都搭建好了,直接 ...

  2. mnist手写数字识别python_Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】...

    本文实例讲述了Python tensorflow实现mnist手写数字识别.分享给大家供大家参考,具体如下: 非卷积实现 import tensorflow as tf from tensorflow ...

  3. tensorflow网页版手写数字识别-使用flask进行网络部署

    tensorflow网页版手写数字识别-使用flask进行网络部署 tensorflow如何将训练好的模型部署在网页中呢,在python中可以很方便的使用django或者flask框架来进行搭建.这里 ...

  4. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 下边两张 ...

  5. TF之DNN:利用DNN【784→500→10】对MNIST手写数字图片识别数据集(TF自带函数下载)预测(98%)+案例理解DNN过程

    TF之DNN:利用DNN[784→500→10]对MNIST手写数字图片识别数据集(TF自带函数下载)预测(98%)+案例理解DNN过程 目录 输出结果 案例理解DNN过程思路 代码设计 输出结果 案 ...

  6. TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线

    TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线 目录 输出结果 设计代码 输出结果 设计代码 ...

  7. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 1.10 ...

  8. DL之DNN:利用DNN【784→50→100→10】算法对MNIST手写数字图片识别数据集进行预测、模型优化

    DL之DNN:利用DNN[784→50→100→10]算法对MNIST手写数字图片识别数据集进行预测.模型优化 导读 目的是建立三层神经网络,进一步理解DNN内部的运作机制 目录 输出结果 设计思路 ...

  9. Dataset之Handwritten Digits:Handwritten Digits(手写数字图片识别)数据集简介、安装、使用方法之详细攻略

    Dataset之Handwritten Digits:Handwritten Digits(手写数字图片识别)数据集简介.安装.使用方法之详细攻略 目录 Handwritten Digits数据集的简 ...

最新文章

  1. 请求拦截_实战SpringCloud通用请求字段拦截处理
  2. 清华大学人工智能研究院成立智能信息获取研究中心
  3. 异常--自定义异常类
  4. jetson nano 系统镜像制作_2.Jetson Nano烧写系统镜像
  5. 申请购买计算机的报告,关于申请购买电脑的请示(最新整理)
  6. 用post方式获取html,httpclient中怎么使用post方法获取html的源码
  7. 高阶函数-参数与返回值
  8. nginx编译包含perl模块
  9. STL之pair及其非成员函数make_pair()
  10. 总结:服务网格(Service Mesh)
  11. 微信好友只有昵称没有微信号_只知道昵称怎么查他的微信号
  12. 《用户体验要素——以用户为中心的产品设计》读书笔记
  13. MacBook Pro维修过程
  14. HTML 入门基础教程
  15. 联想拯救者y7000怎么配置Java环境_联想拯救者Y7000性能配置如何 用起来怎么样...
  16. python爬猫眼电影影评,Python系列爬虫之爬取并简单分析猫眼电影影评
  17. 液晶显示器c语言编程,51驱动1602液晶显示器c程序
  18. python 踩坑之解决django.core.exceptions.ImproperlyConfigured: Error loading MySQLdb module.Did you insta
  19. Diagnosing symbol problems
  20. 【自然语言实战】·第二章(1.1)——获取词语首字字母

热门文章

  1. 云栖深圳峰会开幕 阿里云高调宣布全面进军物联网领域
  2. 过桥问题(Java递归)
  3. 腾讯云副总裁黄俊洪:驭“云原生”之力,驱动产业互联网持续发展
  4. 计算机网络性能指标:速率、带宽、吞吐量、时延、时延带宽积、RTT、利用率
  5. 网络层之IP协议,它带来了哪些功能,真的能顺着网线找到?
  6. 手机免费wifi上网,且看【三招】
  7. Unity性能优化技巧
  8. 深度学习之:对均方误差 mse 的理解
  9. 基于微信小程序的预约挂号系统
  10. 相位谱的matlab程序,基于相位谱视觉注意机制matlab代码