Tensorflow 自己的手写数字实践与数据集制作
前面学习了关于使用MNIST数据集中的数据进行训练和测试。现在要用自己的手写数字进行识别
使用自己的手写数字进行识别
主要部分如下
def application():testNum = int(input("input the number of test pictures:"))for i in range(testNum):testPic = input("the path of test picture:")testPicArr = pre_pic(testPic)preValue = restore_model(testPicArr)print("prediction num is",preValue)
整体代码如下
1、先是导包
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import mnist_backward
import mnist_forward
2、模型加载
def restore_model(testPicArr):with tf.Graph().as_default() as tg:x = tf.placeholder(tf.float32,[None,mnist_forward.INPUT_NODE])y = mnist_forward.forward(x,None)preValue = tf.argmax(y,1)variable_averages = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)variables_to_restore = variable_averages.variables_to_restore()saver = tf.train.Saver(variables_to_restore)with tf.Session() as sess:ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)if (ckpt and ckpt.model_checkpoint_path):saver.restore(sess,ckpt.model_checkpoint_path)preValue = sess.run(preValue, feed_dict={x:testPicArr})return preValueelse:print("No file found")return -1
3、将自己的手写数字图片转换为要求的格式
def pre_pic(picName):#打开传入的原始图片img = Image.open(picName)#为符合要求,把图片resize成28*28,用消除锯齿的方法relm = img.resize((28,28),Image.ANTIALIAS)#变成灰度图并转换为矩阵形式im_arr = np.array(relm.convert("L"))threshold = 110#给图片反色,因为要求输入黑底白字,输入的是白底黑字,并进行二值化处理for i in range(28):for j in range(28):im_arr[i][j]= 255-im_arr[i][j]if(im_arr[i][j]<threshold):im_arr[i][j] = 0else:im_arr[i][j] = 255nm_arr = im_arr.reshape([1,784])nm_arr = nm_arr.astype(np.float32)img_ready = np.multiply(nm_arr,1.0/255.0)return img_ready
4、运行部分
def application():testNum = int(input("input the number of test pictures:"))for i in range(testNum):testPic = input("path of picture:")im = plt.imread(testPic)plt.imshow(im)plt.show()testPicArr = pre_pic(testPic)preValue=restore_model(testPicArr)print("the prediction number is:",preValue)def main():application()if __name__ =='__main__':main()
数据集的制作
tfrecords文件
tfrecords:是一种二进制文件,可先将图片和标签制作成该格式的文件。 使用 tfrecords 进行数据读取,会提高内存利用率。
tf.train.Example: 用来存储训练数据。训练数据的特征用键值对的形式表示。
SerializeToString( ):把数据序列化成字符串存储。
生成tfrecords文件
#新建一个writer
writer = tf.python_io.TFRecordWriter(tfRecordName)
for in range(): #循环遍历每张图和标签#在Features中特征会以字典的形式给出example = tf.train.Example(features = tf.trainFeatures(feature={ #img_raw放入二进制图片'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value = [img_raw])),#labels放入该图片对应的标签'label':tf.train.Feature(int64_list=train.Int64List(value=labels))}))#把每张图片封装在examples中writer.write(example.SerializeToString())#把example进行序列化
writer.close()
解析tfrecords文件
#先建立文件队列名
filename_queue = tf.train.string_input_producer([tfRecord_path])
reader = tf.TFRecordReader()
#读出的每一个样本保存在serialized_exapmle中
_,serialized_exapmle = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,features={'img_raw':tf.FixedLenFeature([],tf.string)'label':tf.FixedLenFeature([10],tf.int64)})
#恢复img_row到img
img = tf.decode_raw(features['img_raw'],tf.uint8)
img.set_shape([784])
#把每个元素变为浮点数
img = tf.cast(img,tf.float32)*(1./255)
label = tf.cast(features['label'],tf.float32)
mnist.generateds.py
import tensorflow as tf
import numpy as np
from PIL import Image
import osimage_train_path='./mnist_data_jpg/mnist_train_jpg_60000/'
label_train_path = './mnist_data_jpg/mnist_train_jpg_60000.txt'
tfRecord_train='./data/mnist_train.tfrecords'
image_test_path='./mnist_data_jpg/mnist_test_jpg_10000/'
label_test_path='./mnist_data_jpg/mnist_test_jpg_10000.txt'
tfRecord_test='./data/mnist_test.tfrecords'
data_path='./data'
resize_height=28
resize_width = 28def write_tfRecord(tfRecordName,image_path,label_path):writer = tf.python_io.TFRecordWriter(tfRecordName)num_pic = 0f = open(label_path,'r')contents = f.readlines()f.close()for content in contents:value = content.split()img_path = img_path + value[0]img = Image.open(img_path)img_raw = img.tobytes()labels = [0]*10labels[int(value[1])] = 1example = tf.train.Example(features=tf.train.Features({'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))}))writer.write(example.SerializeToString())num_pic +=1print('the number of picture:',num_pic)writer.close()print("write tfrecord successful")def generate_tfRecord():isExists = os.path.exists(data_path)if not isExists:os.makedirs(data_path)print('dir was created successfully')else:print( 'dir already exists')write_tfRecord(tfRecord_train, image_train_path,label_train_path)write_tfRecord(tfRecord_test,image_test_path,label_test_path)def read_tfRecord(tfRecord_path):filename_queue = tf.train.string_input_producer([tfRecord_path])reader = tf.TFReacordReader()_, serialized_example = reader.read(filename_queue)features =tf.parse_single_example(serialized_example,features={'labels':tf.FixedLenFeature([10],tf.int64),'img_raw':tf.FixedLenFeature([],tf.string)})img = tf.decode_raw(features['img_raw'],tf.uint8)img.set_shape([784])img = tf.cast(img,tf.float32)*(1./255)label = tf.cast(features['label'],tf.float32)return img, label
def get_tfrecord(num,isTrain=True):if isTrain:tfRecord_path = tfRecord_trainelse:tfRecord_path = tfRecord_testimg,label = read_tfRecord(tfRecord_path)img_batch,label_batch = tf.train.shuffle_batch([img,label],batch_size = num,num_thread=1000,min_after_dequeue = 700)return img_batch,label_batch
def main():generate_tfRecord()if __name__ == '__main__':main()
在反向传播和测试程序中,修改图片和标签的批获取接口,可以使用多线程提高图片和标签的提取效率
#开启线程协调器
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
#图片和标签的批获取
#关闭
coord.request_stop()
coord.join(threads)
Tensorflow 自己的手写数字实践与数据集制作相关推荐
- tensorflow saver_机器学习入门(6):Tensorflow项目Mnist手写数字识别-分析详解
本文主要内容:Ubuntu下基于Tensorflow的Mnist手写数字识别的实现 训练数据和测试数据资料:http://yann.lecun.com/exdb/mnist/ 前面环境都搭建好了,直接 ...
- mnist手写数字识别python_Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】...
本文实例讲述了Python tensorflow实现mnist手写数字识别.分享给大家供大家参考,具体如下: 非卷积实现 import tensorflow as tf from tensorflow ...
- tensorflow网页版手写数字识别-使用flask进行网络部署
tensorflow网页版手写数字识别-使用flask进行网络部署 tensorflow如何将训练好的模型部署在网页中呢,在python中可以很方便的使用django或者flask框架来进行搭建.这里 ...
- DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测
DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 下边两张 ...
- TF之DNN:利用DNN【784→500→10】对MNIST手写数字图片识别数据集(TF自带函数下载)预测(98%)+案例理解DNN过程
TF之DNN:利用DNN[784→500→10]对MNIST手写数字图片识别数据集(TF自带函数下载)预测(98%)+案例理解DNN过程 目录 输出结果 案例理解DNN过程思路 代码设计 输出结果 案 ...
- TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线
TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线 目录 输出结果 设计代码 输出结果 设计代码 ...
- DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测
DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 1.10 ...
- DL之DNN:利用DNN【784→50→100→10】算法对MNIST手写数字图片识别数据集进行预测、模型优化
DL之DNN:利用DNN[784→50→100→10]算法对MNIST手写数字图片识别数据集进行预测.模型优化 导读 目的是建立三层神经网络,进一步理解DNN内部的运作机制 目录 输出结果 设计思路 ...
- Dataset之Handwritten Digits:Handwritten Digits(手写数字图片识别)数据集简介、安装、使用方法之详细攻略
Dataset之Handwritten Digits:Handwritten Digits(手写数字图片识别)数据集简介.安装.使用方法之详细攻略 目录 Handwritten Digits数据集的简 ...
最新文章
- 请求拦截_实战SpringCloud通用请求字段拦截处理
- 清华大学人工智能研究院成立智能信息获取研究中心
- 异常--自定义异常类
- jetson nano 系统镜像制作_2.Jetson Nano烧写系统镜像
- 申请购买计算机的报告,关于申请购买电脑的请示(最新整理)
- 用post方式获取html,httpclient中怎么使用post方法获取html的源码
- 高阶函数-参数与返回值
- nginx编译包含perl模块
- STL之pair及其非成员函数make_pair()
- 总结:服务网格(Service Mesh)
- 微信好友只有昵称没有微信号_只知道昵称怎么查他的微信号
- 《用户体验要素——以用户为中心的产品设计》读书笔记
- MacBook Pro维修过程
- HTML 入门基础教程
- 联想拯救者y7000怎么配置Java环境_联想拯救者Y7000性能配置如何 用起来怎么样...
- python爬猫眼电影影评,Python系列爬虫之爬取并简单分析猫眼电影影评
- 液晶显示器c语言编程,51驱动1602液晶显示器c程序
- python 踩坑之解决django.core.exceptions.ImproperlyConfigured: Error loading MySQLdb module.Did you insta
- Diagnosing symbol problems
- 【自然语言实战】·第二章(1.1)——获取词语首字字母