原文链接: mnist 转换为record 使用tf data 转换 读取 训练

上一篇: js 数组 堆栈 和 buckets 效率 对比

下一篇: tf data 切换数据集 使用并行提高效率

参考,这篇有用text和record两种方式实现读取

https://www.jianshu.com/p/eec32f6c5503

下载官网

http://yann.lecun.com/exdb/mnist/

下载四个文件

record文件下载

链接:https://pan.baidu.com/s/1q07w7JRZUwH2s43ua5dRtg
提取码:at2h

转换

创建record文件

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_dataMNIST_DIR = 'd:/data/mnist/raw'# 载入数据集
mnist = input_data.read_data_sets(MNIST_DIR)print('训练集数据总数', mnist.train.num_examples)  # 55000
print('验证集数据总数', mnist.validation.num_examples)  # 5000
print('测试集数据总数', mnist.test.num_examples)  # 55000
# (55000, 784) (55000,)
print('训练数据集形式', mnist.train.images.shape, mnist.train.labels.shape)
image = mnist.train.images[0]# 0-1之间的数,表示灰度值
print(image.max(), image.min())def save_record(path, data):with tf.python_io.TFRecordWriter(path) as writer:for image, label in zip(data.images, data.labels):image = image.reshape((28 * 28))image = image.tobytes()  # 将图片转化为二进制格式example = tf.train.Example(features=tf.train.Features(feature={"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))}))  # example对象对label和image数据进行封装writer.write(example.SerializeToString())  # 序列化为字符串if __name__ == '__main__':RECORD_PATH = 'd:/data/mnist/record/mnist.record'train_path = 'd:/data/mnist/record/mnist_train.record'test_path = 'd:/data/mnist/record/mnist_test.record'save_record(train_path, mnist.train)save_record(test_path, mnist.validation)

读取

使用tf.data读取record文件,并显示图像和标签

import tensorflow as tf
import matplotlib.pyplot as pltdef parser(record):keys_to_features = {"image": tf.FixedLenFeature((), tf.string),"label": tf.FixedLenFeature((), tf.int64),}parsed = tf.parse_single_example(record, keys_to_features)images = tf.decode_raw(parsed["image"], tf.float32)images = tf.reshape(images, [28, 28, 1])labels = tf.cast(parsed['label'], tf.int64)print("IMAGES", images.shape) # IMAGES (28, 28, 1)print("LABELS", labels.shape) # LABELS ()return images, labelsfilenames = ['d:/data/mnist/record/mnist.record']
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(parser).shuffle(32)
dataset = dataset.repeat(-1).batch(4)iterator = dataset.make_initializable_iterator()
batch = iterator.get_next()with tf.Session() as sess:sess.run(tf.global_variables_initializer())sess.run(iterator.initializer)image_batch, label_batch = sess.run(batch)print(image_batch.shape, label_batch.shape) # (4, 28, 28, 1) (4,)for image, label in zip(image_batch, label_batch):plt.imshow(image.reshape((28, 28)))plt.show()print(label)

训练

训练可以直接使用placehold的形式,不需要使用sess先获取数值,直接在一次run中完成数据读取和训练

import tensorflow as tf
import tensorflow.contrib.slim as slimdef parser(record):keys_to_features = {"image": tf.FixedLenFeature((), tf.string),"label": tf.FixedLenFeature((), tf.int64),}parsed = tf.parse_single_example(record, keys_to_features)images = tf.decode_raw(parsed["image"], tf.float32)images = tf.reshape(images, [28, 28, 1])labels = tf.cast(parsed['label'], tf.int64)# labels = tf.reshape(labels, (None,))print("IMAGES", images.shape)  # IMAGES (28, 28, 1)print("LABELS", labels.shape)  # LABELS ()return images, labelsfilenames = ['d:/data/mnist/record/mnist.record']
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(parser).shuffle(1024)
dataset = dataset.repeat(-1).batch(32)iterator = dataset.make_initializable_iterator()
image, label = iterator.get_next()
print('image, label ', image.shape, label.shape)
# in_x = tf.placeholder(tf.float32, (None, 28, 28, 1))
# in_y = tf.placeholder(tf.int64, (None,))net = slim.conv2d(image, 32, 3, 2)
net = slim.conv2d(net, 32, 3, 2)
net = slim.conv2d(net, 32, 3, 2)
net = slim.flatten(net)
net = slim.fully_connected(net, 10)
print('net ', net.shape)
predict = tf.argmax(net, axis=1)
print('predict ', predict.shape)
accuracy = tf.reduce_mean(tf.cast(tf.equal(predict, label), tf.float32))
print('accuracy ', accuracy.shape)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=net, labels=label))
print('loss ', loss.shape)
train = tf.train.AdamOptimizer().minimize(loss)train_epoch = 100000
show_epoch = 500
with tf.Session() as sess:sess.run(tf.global_variables_initializer())sess.run(iterator.initializer)for i in range(1, 1 + train_epoch):sess.run(train)if not i % show_epoch:loss_val, accuracy_val = sess.run([loss, accuracy])print(i, loss_val, accuracy_val)

使用handle 切换数据集

import tensorflow as tf
import tensorflow.contrib.slim as slimdef parser(record):keys_to_features = {"image": tf.FixedLenFeature((), tf.string),"label": tf.FixedLenFeature((), tf.int64),}parsed = tf.parse_single_example(record, keys_to_features)images = tf.decode_raw(parsed["image"], tf.float32)images = tf.reshape(images, [28, 28, 1])labels = tf.cast(parsed['label'], tf.int64)# labels = tf.reshape(labels, (None,))print("IMAGES", images.shape)  # IMAGES (28, 28, 1)print("LABELS", labels.shape)  # LABELS ()return images, labelstrain_paths = ['d:/data/mnist/record/mnist_train.record']
test_paths = ['d:/data/mnist/record/mnist_test.record']
train_data = tf.data.TFRecordDataset(train_paths)
train_data = train_data.map(parser).shuffle(1024)
train_data = train_data.repeat(-1).batch(32)test_data = tf.data.TFRecordDataset(test_paths)
test_data = test_data.map(parser).shuffle(1024)
test_data = test_data.repeat(-1).batch(32)handle = tf.placeholder(tf.string, [])
# iterator = tf.data.Iterator.from_structure(train_data.output_types,
#                                            train_data.output_shapes)iterator = tf.data.Iterator.from_string_handle(handle,train_data.output_types,train_data.output_shapes
)train_iterator = train_data.make_initializable_iterator()
test_iterator = test_data.make_initializable_iterator()# test_data = tf.data.TFRecordDataset(test_paths)image, label = iterator.get_next()
print('image, label ', image.shape, label.shape)
# in_x = tf.placeholder(tf.float32, (None, 28, 28, 1))
# in_y = tf.placeholder(tf.int64, (None,))net = slim.conv2d(image, 32, 3, 2)
net = slim.conv2d(net, 32, 3, 2)
net = slim.conv2d(net, 32, 3, 2)
net = slim.flatten(net)
net = slim.fully_connected(net, 10)
print('net ', net.shape)
predict = tf.argmax(net, axis=1)
print('predict ', predict.shape)
accuracy = tf.reduce_mean(tf.cast(tf.equal(predict, label), tf.float32))
print('accuracy ', accuracy.shape)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=net, labels=label))
print('loss ', loss.shape)
train = tf.train.AdamOptimizer().minimize(loss)train_epoch = 100000
show_epoch = 500
with tf.Session() as sess:sess.run(tf.global_variables_initializer())sess.run([train_iterator.initializer, test_iterator.initializer])train_handle = sess.run(train_iterator.string_handle())test_handle = sess.run(test_iterator.string_handle())for i in range(1, 1 + train_epoch):sess.run(train, {handle: train_handle})if not i % show_epoch:loss_val, accuracy_val = sess.run([loss, accuracy], {handle: test_handle})print(i, loss_val, accuracy_val)

mnist 转换为record 使用tf data 转换 读取 训练相关推荐

  1. tf data 切换数据集 使用并行提高效率

    原文链接: tf data 切换数据集 使用并行提高效率 上一篇: mnist 转换为record 使用tf data 转换 读取 训练 下一篇: tf 风格迁移 固定内容 固定风格 vgg19 输入 ...

  2. Tensorflow数据预处理之tf.data.TFRecordDataset---TFRecords详解\TFRecords图像预处理

    目录 1.概述 2.预处理数据 2.1.常量定义 2.2.导入库 2.3.从train.txt文件中读取图片-标签对 2.4.预处理图片并保存 2.5.调用main函数 3.读取预处理后的数据 3.1 ...

  3. TF2.0使用tf.data处理数据建模Demo

    目录 背景 数据集 特征处理 模型构建及评估 背景: 很多TF模型的例子都是使用dataframe进行数据处理及读取的,在部署及大任务处理时可能会遇到需要特征额外处理及内存不足等问题,所以想直接使用t ...

  4. tf.data 加载 pandas dataframes

    tf.data 加载 pandas dataframes code # -*- coding: utf-8 -*- """ Created on 2020/11/20 1 ...

  5. TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和 tf.data.Dataset机制

    TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和tf.data.Dataset机制 之前写了一篇博客,关于<Tensorflow生成自己的 ...

  6. Tensorflow读取数据-tf.data.TFRecordDataset

    tensorflow TFRecords文件的生成和读取方法 文章目录 tensorflow TFRecords文件的生成和读取方法 1. TFRecords说明 2.关键API 2.1 tf.io. ...

  7. tensorflow的数据读取 tf.data.DataSet、tf.data.Iterator

    tensorflow的工程有使用python的多进程读取数据,然后给feed给神经网络进行训练. 也有tensorflow中的 tf.data.DataSet的使用.并且由于是tensorflow框架 ...

  8. tf.contrib.data.Dataset 读取数据的原理--buffer

    在用tf.contrib.data.Dataset读取数据集数据时,会遇到一个概念/参数: buffer buffer的含义是缓存 buffer_size的含义是:放入缓存的样本的个数 prefech ...

  9. TensorFlow学习笔记02:使用tf.data读取和保存数据文件

    TensorFlow学习笔记02:使用tf.data读取和保存数据文件 使用`tf.data`读取和写入数据文件 读取和写入csv文件 写入csv文件 读取csv文件 读取和保存TFRecord文件 ...

最新文章

  1. .net应用程序如何批上XP的外衣?
  2. 游戏运行时报0xc000007b错的解决办法
  3. Python: 绝对导入 Absolute Imports
  4. 中国癌症大数据出来了!每年126万例癌症死亡本可避免
  5. 关于framework4.5的相关介绍
  6. windows打开cmd的几种方式
  7. 考试周来临。。蓦然回首
  8. 陈安之超级成功法则(1)
  9. 教你如何关闭Win7视频预览节约资源
  10. 利用Python批量将csv文件转化成xml文件
  11. AE不能导入mov等格式文件
  12. 中国电信3G业务抢先发 3G终端国产占到六席
  13. 《Norwegain Wood》—— The Beatles
  14. FTK1000与FTK2000机型差异对比
  15. Unity3D ugui获取ui控件屏幕坐标
  16. 面试官:你期望薪资多少?你真的会答吗?你的回答是否是面试官想要的呢?
  17. eclipse 下载
  18. AdobePhotoshopCS快捷键
  19. 榆熙教育:影响拼多多DSR的三点因素介绍
  20. 如何实现软件自动重启

热门文章

  1. [附源码]Python计算机毕业设计Django医院门诊管理信息系统
  2. c语言中的比较大小问题
  3. 基于C++实现的股票大数据的统计分析与可视化
  4. java获取内容为空_Java使用POI读取Word文档时如果文档内容为空时出现异常
  5. 属性选择器、结构伪类选择器、伪元素选择器
  6. XMind中让分支显示在同一侧
  7. Datawhale零基础入门NLP赛事 - Task5 基于深度学习的文本分类2
  8. 钉钉自定义机器人无法指定正向代理问题解决
  9. HTML基础学习笔记合集
  10. Linux(CentOS)学习笔记