mnist 转换为record 使用tf data 转换 读取 训练
原文链接: mnist 转换为record 使用tf data 转换 读取 训练
上一篇: js 数组 堆栈 和 buckets 效率 对比
下一篇: tf data 切换数据集 使用并行提高效率
参考,这篇有用text和record两种方式实现读取
https://www.jianshu.com/p/eec32f6c5503
下载官网
http://yann.lecun.com/exdb/mnist/
下载四个文件
record文件下载
链接:https://pan.baidu.com/s/1q07w7JRZUwH2s43ua5dRtg
提取码:at2h
转换
创建record文件
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_dataMNIST_DIR = 'd:/data/mnist/raw'# 载入数据集
mnist = input_data.read_data_sets(MNIST_DIR)print('训练集数据总数', mnist.train.num_examples) # 55000
print('验证集数据总数', mnist.validation.num_examples) # 5000
print('测试集数据总数', mnist.test.num_examples) # 55000
# (55000, 784) (55000,)
print('训练数据集形式', mnist.train.images.shape, mnist.train.labels.shape)
image = mnist.train.images[0]# 0-1之间的数,表示灰度值
print(image.max(), image.min())def save_record(path, data):with tf.python_io.TFRecordWriter(path) as writer:for image, label in zip(data.images, data.labels):image = image.reshape((28 * 28))image = image.tobytes() # 将图片转化为二进制格式example = tf.train.Example(features=tf.train.Features(feature={"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))})) # example对象对label和image数据进行封装writer.write(example.SerializeToString()) # 序列化为字符串if __name__ == '__main__':RECORD_PATH = 'd:/data/mnist/record/mnist.record'train_path = 'd:/data/mnist/record/mnist_train.record'test_path = 'd:/data/mnist/record/mnist_test.record'save_record(train_path, mnist.train)save_record(test_path, mnist.validation)
读取
使用tf.data读取record文件,并显示图像和标签
import tensorflow as tf
import matplotlib.pyplot as pltdef parser(record):keys_to_features = {"image": tf.FixedLenFeature((), tf.string),"label": tf.FixedLenFeature((), tf.int64),}parsed = tf.parse_single_example(record, keys_to_features)images = tf.decode_raw(parsed["image"], tf.float32)images = tf.reshape(images, [28, 28, 1])labels = tf.cast(parsed['label'], tf.int64)print("IMAGES", images.shape) # IMAGES (28, 28, 1)print("LABELS", labels.shape) # LABELS ()return images, labelsfilenames = ['d:/data/mnist/record/mnist.record']
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(parser).shuffle(32)
dataset = dataset.repeat(-1).batch(4)iterator = dataset.make_initializable_iterator()
batch = iterator.get_next()with tf.Session() as sess:sess.run(tf.global_variables_initializer())sess.run(iterator.initializer)image_batch, label_batch = sess.run(batch)print(image_batch.shape, label_batch.shape) # (4, 28, 28, 1) (4,)for image, label in zip(image_batch, label_batch):plt.imshow(image.reshape((28, 28)))plt.show()print(label)
训练
训练可以直接使用placehold的形式,不需要使用sess先获取数值,直接在一次run中完成数据读取和训练
import tensorflow as tf
import tensorflow.contrib.slim as slimdef parser(record):keys_to_features = {"image": tf.FixedLenFeature((), tf.string),"label": tf.FixedLenFeature((), tf.int64),}parsed = tf.parse_single_example(record, keys_to_features)images = tf.decode_raw(parsed["image"], tf.float32)images = tf.reshape(images, [28, 28, 1])labels = tf.cast(parsed['label'], tf.int64)# labels = tf.reshape(labels, (None,))print("IMAGES", images.shape) # IMAGES (28, 28, 1)print("LABELS", labels.shape) # LABELS ()return images, labelsfilenames = ['d:/data/mnist/record/mnist.record']
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(parser).shuffle(1024)
dataset = dataset.repeat(-1).batch(32)iterator = dataset.make_initializable_iterator()
image, label = iterator.get_next()
print('image, label ', image.shape, label.shape)
# in_x = tf.placeholder(tf.float32, (None, 28, 28, 1))
# in_y = tf.placeholder(tf.int64, (None,))net = slim.conv2d(image, 32, 3, 2)
net = slim.conv2d(net, 32, 3, 2)
net = slim.conv2d(net, 32, 3, 2)
net = slim.flatten(net)
net = slim.fully_connected(net, 10)
print('net ', net.shape)
predict = tf.argmax(net, axis=1)
print('predict ', predict.shape)
accuracy = tf.reduce_mean(tf.cast(tf.equal(predict, label), tf.float32))
print('accuracy ', accuracy.shape)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=net, labels=label))
print('loss ', loss.shape)
train = tf.train.AdamOptimizer().minimize(loss)train_epoch = 100000
show_epoch = 500
with tf.Session() as sess:sess.run(tf.global_variables_initializer())sess.run(iterator.initializer)for i in range(1, 1 + train_epoch):sess.run(train)if not i % show_epoch:loss_val, accuracy_val = sess.run([loss, accuracy])print(i, loss_val, accuracy_val)
使用handle 切换数据集
import tensorflow as tf
import tensorflow.contrib.slim as slimdef parser(record):keys_to_features = {"image": tf.FixedLenFeature((), tf.string),"label": tf.FixedLenFeature((), tf.int64),}parsed = tf.parse_single_example(record, keys_to_features)images = tf.decode_raw(parsed["image"], tf.float32)images = tf.reshape(images, [28, 28, 1])labels = tf.cast(parsed['label'], tf.int64)# labels = tf.reshape(labels, (None,))print("IMAGES", images.shape) # IMAGES (28, 28, 1)print("LABELS", labels.shape) # LABELS ()return images, labelstrain_paths = ['d:/data/mnist/record/mnist_train.record']
test_paths = ['d:/data/mnist/record/mnist_test.record']
train_data = tf.data.TFRecordDataset(train_paths)
train_data = train_data.map(parser).shuffle(1024)
train_data = train_data.repeat(-1).batch(32)test_data = tf.data.TFRecordDataset(test_paths)
test_data = test_data.map(parser).shuffle(1024)
test_data = test_data.repeat(-1).batch(32)handle = tf.placeholder(tf.string, [])
# iterator = tf.data.Iterator.from_structure(train_data.output_types,
# train_data.output_shapes)iterator = tf.data.Iterator.from_string_handle(handle,train_data.output_types,train_data.output_shapes
)train_iterator = train_data.make_initializable_iterator()
test_iterator = test_data.make_initializable_iterator()# test_data = tf.data.TFRecordDataset(test_paths)image, label = iterator.get_next()
print('image, label ', image.shape, label.shape)
# in_x = tf.placeholder(tf.float32, (None, 28, 28, 1))
# in_y = tf.placeholder(tf.int64, (None,))net = slim.conv2d(image, 32, 3, 2)
net = slim.conv2d(net, 32, 3, 2)
net = slim.conv2d(net, 32, 3, 2)
net = slim.flatten(net)
net = slim.fully_connected(net, 10)
print('net ', net.shape)
predict = tf.argmax(net, axis=1)
print('predict ', predict.shape)
accuracy = tf.reduce_mean(tf.cast(tf.equal(predict, label), tf.float32))
print('accuracy ', accuracy.shape)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=net, labels=label))
print('loss ', loss.shape)
train = tf.train.AdamOptimizer().minimize(loss)train_epoch = 100000
show_epoch = 500
with tf.Session() as sess:sess.run(tf.global_variables_initializer())sess.run([train_iterator.initializer, test_iterator.initializer])train_handle = sess.run(train_iterator.string_handle())test_handle = sess.run(test_iterator.string_handle())for i in range(1, 1 + train_epoch):sess.run(train, {handle: train_handle})if not i % show_epoch:loss_val, accuracy_val = sess.run([loss, accuracy], {handle: test_handle})print(i, loss_val, accuracy_val)
mnist 转换为record 使用tf data 转换 读取 训练相关推荐
- tf data 切换数据集 使用并行提高效率
原文链接: tf data 切换数据集 使用并行提高效率 上一篇: mnist 转换为record 使用tf data 转换 读取 训练 下一篇: tf 风格迁移 固定内容 固定风格 vgg19 输入 ...
- Tensorflow数据预处理之tf.data.TFRecordDataset---TFRecords详解\TFRecords图像预处理
目录 1.概述 2.预处理数据 2.1.常量定义 2.2.导入库 2.3.从train.txt文件中读取图片-标签对 2.4.预处理图片并保存 2.5.调用main函数 3.读取预处理后的数据 3.1 ...
- TF2.0使用tf.data处理数据建模Demo
目录 背景 数据集 特征处理 模型构建及评估 背景: 很多TF模型的例子都是使用dataframe进行数据处理及读取的,在部署及大任务处理时可能会遇到需要特征额外处理及内存不足等问题,所以想直接使用t ...
- tf.data 加载 pandas dataframes
tf.data 加载 pandas dataframes code # -*- coding: utf-8 -*- """ Created on 2020/11/20 1 ...
- TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和 tf.data.Dataset机制
TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和tf.data.Dataset机制 之前写了一篇博客,关于<Tensorflow生成自己的 ...
- Tensorflow读取数据-tf.data.TFRecordDataset
tensorflow TFRecords文件的生成和读取方法 文章目录 tensorflow TFRecords文件的生成和读取方法 1. TFRecords说明 2.关键API 2.1 tf.io. ...
- tensorflow的数据读取 tf.data.DataSet、tf.data.Iterator
tensorflow的工程有使用python的多进程读取数据,然后给feed给神经网络进行训练. 也有tensorflow中的 tf.data.DataSet的使用.并且由于是tensorflow框架 ...
- tf.contrib.data.Dataset 读取数据的原理--buffer
在用tf.contrib.data.Dataset读取数据集数据时,会遇到一个概念/参数: buffer buffer的含义是缓存 buffer_size的含义是:放入缓存的样本的个数 prefech ...
- TensorFlow学习笔记02:使用tf.data读取和保存数据文件
TensorFlow学习笔记02:使用tf.data读取和保存数据文件 使用`tf.data`读取和写入数据文件 读取和写入csv文件 写入csv文件 读取csv文件 读取和保存TFRecord文件 ...
最新文章
- .net应用程序如何批上XP的外衣?
- 游戏运行时报0xc000007b错的解决办法
- Python: 绝对导入 Absolute Imports
- 中国癌症大数据出来了!每年126万例癌症死亡本可避免
- 关于framework4.5的相关介绍
- windows打开cmd的几种方式
- 考试周来临。。蓦然回首
- 陈安之超级成功法则(1)
- 教你如何关闭Win7视频预览节约资源
- 利用Python批量将csv文件转化成xml文件
- AE不能导入mov等格式文件
- 中国电信3G业务抢先发 3G终端国产占到六席
- 《Norwegain Wood》—— The Beatles
- FTK1000与FTK2000机型差异对比
- Unity3D ugui获取ui控件屏幕坐标
- 面试官:你期望薪资多少?你真的会答吗?你的回答是否是面试官想要的呢?
- eclipse 下载
- AdobePhotoshopCS快捷键
- 榆熙教育:影响拼多多DSR的三点因素介绍
- 如何实现软件自动重启