TFRecord简介,原理分析,代码实现?

在利用深度学习算法搭建完成网络之后,我们要对网络进行训练,要训练网络就要有训练数据,通常我们会直接对硬盘上存放数据进行操作,来fetch到网络中。这样直接从硬盘上读取 数据太慢了,为了加快数据读取,今天我们介绍一种比较好的数据格式 tfrecord,那么什么是tfrecord呢?

什么TFRecord格式的数据?

Tensorflow支持的一种数据格式,内部使用了“Protocol Buffer”二进制数据编码方案,方便我们模型训练,验证,测试数据集的输入。

为什么提出TFRecord格式的数据?

通常情况下,我们使用Tensorflow搭建好网络模型之后,要输入数据进行训练,验证,测试,其对应的文件夹经常为 train,val, test文件夹,这些文件夹内部往往会存着上百万的数据文件,这些文件散列存放在磁盘上,并且读取时候非常慢,繁琐,会有大量的I/O操作。同时,占用大量内存空间)。而TFRecord格式的文件存储形式会很合理的帮我们存储数据,其内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。

Tensorflow读取数据的机制大致可以分为三种:

  • 直接从硬盘上读取文件数[如下图,来自慕课学习]

上述方式:tensorflow进行模型训练的时候,可以直接从硬盘上去读取数据,并将读出的数据喂给网络,从而完成运算。若数据读取和运算是不同步的【串行的】,那么意味着在完成了运算之后,需要进行IO来对硬盘上的数据进行读取,并将数据放入内存中,此时接着完成后续的运算,由于这个过程中存在IO操作,造成大部分资源处于等待中,造成大量浪费,训练时间比较长。

  • 在内存中开辟读写队列,来读取数据[如下图,来自慕课学习]

上述方式:若先在内存对数据进行缓存,相当于很大的buffer,对于硬盘上的数据,采用读取数据的线程,将硬盘的数据不断的向内存中开辟的buffer中进行搬运,对于计算设备,采用另一个数据读取的线程,每次计算时候,直接从内存中的buffer中读取数据。以此达到同步进行的目的,彼此之间不会发生阻塞,提高了对资源的利用率,也一定程度上加快了 网络的训练。

  • 基于文件队列+内存队列结合的形式读取数据[如下图,来自慕课学习]

上述方式:对于硬盘上存放的数据,首先将硬盘上的数据文件名存放在文件名队列中,内存队列从文件名队列中进行数据的读取,计算设备之间从内存中读取运算所需数据。通过文件队列可以完成对epoch更好的管理,若训练要经过3个epoch,则在文件名队列中会形成A,B,C三个epoch,其中A,B,C分别包含了所有的文件列表,将A,B,C进行合并,最终形成文件名队列,之后内存队列从文件名队列中依次读取。

在Tensorflow中用来构建文件队列的函数主要有如下

  • tf.train.slice_input_producer
import tensorflow as tf# 数据文件的路径
image_files = ["a.jpg","b.jpg","c.jpg","d.jpg"]image_size = len(image_files)
# 数据文件对应的label
image_labels = [1,2,3,4]# 数据读取的轮数
epochs = 2#使用tf.train.slice_input_producer()来构建文件队列,可以很好的num_epoch管理
[image_files,image_labels] = tf.train.slice_input_producer([image_files,image_labels],num_epochs=epochs,shuffle=True)# 使用tf.Session() 完成后段数据读取
with tf.Session() as sess:sess.run(tf.local_variables_initializer())# 启动入队线程,由多个或单个线程,按照设定规则,把文件读入Filename Queue中tf.train.start_queue_runners(sess=sess)for i in range(epochs*image_size):print(sess.run([image_files,image_labels]))

上述代码执行结果 :

  • tf.train.string_input_producer
import tensorflow as tf# 数据文件的路径
image_files = ["cat.0.jpg","cat.1.jpg","cat.2.jpg"]image_size = len(image_files)# 数据读取的轮数
epochs = 2#使用tf.train.string_input_producer()来构建文件队列,可以很好的num_epoch管理
file_queue = tf.train.string_input_producer(image_files,num_epochs=epochs,shuffle=True)# 使用reader读取队列文件中数据
reader = tf.WholeFileReader()# 返回文件名和对应数据
key,value = reader.read(file_queue)# 对读出tfrecord序列化数据进行解码
value = tf.decode_raw(value,tf.uint8)# 使用tf.Session() 完成后段数据读取
with tf.Session() as sess:sess.run(tf.local_variables_initializer())# 启动入队线程,由多个或单个线程,按照设定规则,把文件读入Filename Queue中tf.train.start_queue_runners(sess=sess)for i in range(epochs*image_size):print(sess.run([key,value]))

上述代码执行结果 :

  • tf.data库

上述函数用来完成对文件队列的构造,通过sess.run来完成后端数据读取。而tf.data不使用sess.run,而是基于使用动态图的机制来完成数据的读取。对于slice_input_producer可以通过sess.run直接获取文件队列中的数据。而string_input_producer不可以,需要创建一个文件读写器来进行读取。利用读取器返回文件的名和里面的内容,之后通过sess.run()来进行文件内容的读取。具体操作后续补存。

如何生成TFRecord格式的数据?

 首先数据文件目录如下图:dog和cat文件夹中分别存储对应的狗和猫:

------data

----------train

---------------dog

---------------cat

----------validation

---------------dog

---------------cat

上图为我们此次处理数据目录data为根目录,其下有两个文件夹train和validation,在train和validation下分别有dog和cat两个文件夹,其中存放对应图片数据。具体TFRecord格式数据转换如下代码:

# -*- coding:utf-8 -*-
import os
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile# 定义函数转化变量类型
def _int64_feature(value):return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def _bytes_feature(value):return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))# 将数据转化为tf.train.Example格式
def _make_example(label, image):image_raw = image.tostring()example = tf.train.Example(features=tf.train.Features(feature={'image/class/label': _int64_feature(label),'image/encoded': _bytes_feature(image_raw}))return example# 读取图片
def read_images(sess,path,flag):# 获取path下所有目录,同时包括path目录sub_dirs = [x[0] for x in os.walk(path)]# 去除path目录is_root_dir = True设置当前label标记为:0current_label = 0print("开始处理训练数据")#开始生成TFRecord格式数据with tf.python_io.TFRecordWriter("./data/dogsVScats_%s_.tfrecord" % flag) as writer:# 读取所有的子目录for sub_dir in sub_dirs:if is_root_dir:is_root_dir = Falsecontinue# 定义图像类型extensions = ['jpg','png']# 存放图像数据file_list = []# 获取文件的名字dir_name = os.path.basename(sub_dir)for extension in extensions:# 文件匹配,类似正则表达式file_glob = os.path.join(path, dir_name, '*.' + extension)#将匹配数据加入列表file_list.extend(glob.glob(file_glob))if not file_list:continueprint("processing:", dir_name)i = 0# 处理图片数据for file_name in file_list:i += 1//读取图像,不过得到的结果是最原始的图像,是没有经过解码的图像,即为是一个字//符串,没法显示,为byte类型image_raw_data = gfile.FastGFile(file_name, 'rb').read()//对上面函数读出的图像进行解码,得到图像的像素值,这个像素值可以用于显示图像image = tf.image.decode_jpeg(image_raw_data)if image.dtype != tf.float32://将image转换为dtype,并根据需要缩放其值.image = tf.image.convert_image_dtype(image, dtype=tf.float32)//在进行tfrecord格式转之前要对所有图像的宽高变的一致image = tf.image.resize(image, [299, 299])image_value = sess.run(image)example = _make_example(current_label, image_value)writer.write(example.SerializeToString())print("正在处理{}中的第{}张图片".format(dir_name,i))current_label += 1print("TFRecord %s 文件已保存" % flag)# 执行产生tfrecord文件
with tf.Session() as sess:read_images(sess,"./data/train","train")read_images(sess,"./data/validation","validation")

将您的数据文件按照上面目录存放,之后会生成dogsVScats_train_.tfrecord和dogsVScats_validation_.tfrecord两个TFRecord格式的数据。

如何读取生成的TFRecord格式的数据进行训练?

# -*- coding:utf-8 -*-
import numpy as np
import tensorflow as tfBATCH_SIZE = 32#tfrecord文件路径
train_tf_path = "./data/dogsVScats_train_.tfrecord"# 使用文件队列来填充数据
filename_queue =
tf.train.string_input_producer([train_tf_path],shuffle=True,num_epochs=None,seed=666)# 创建TFRecoder读写器进行读写打包文件中数据serialized_example,该数据为序列化数据
reader = tf.TFRecordReader()# 返回文件名和序列数据
_,serialized_example = reader.read(filename_queue)feature = {'image/encoded':tf.FixedLenFeature([],tf.string),'image/class/label':tf.FixedLenFeature([],tf.int64)}#将Example原型解析为张量的dict。
features =  tf.parse_single_example(serialized_example,features=feature)image = tf.decode_raw(features['image/encoded'],tf.uint8)
single_image = tf.reshape(image,[229,229,3])single_label = features['image/class/label']# 按顺序文件读取队列中的数据
img_batch,img_labels = tf.train.shuffle_batch([single_image, single_label],batch_size=BATCH_SIZE,capacity=BATCH_SIZE*10,min_after_dequeue=BATCH_SIZE*5)with tf.Session() as sess:sess.run(tf.local_variables_initializer())tf.global_variables_initializer().run()# 启动多线程处理输入数据coord = tf.train.Coordinator()# 启动入队线程,由多个或单个线程,按照设定规则,把文件读入文件队列中。threads = tf.train.start_queue_runners(sess=sess,coord=coord)//读取10个batch_size的数据for i in range(10):batch_data,batch_labels = sess.run([img_batch,img_labels])print(batch_data.shape)print(batch_labels.shape)

上面是使用TFRecord格式数据生成,读取的方式。


另外补存图像解码代码实现,如下代码段

#-*-coding:utf-8-*-
import matplotlib.pyplot as plt
import tensorflow as tf#tf.gfile.FastGFile读出文件数据为序列化数据
image_raw_data_jpg = tf.gfile.FastGFile('1.jpg', 'r').read()
image_raw_data_png = tf.gfile.FastGFile('2.png', 'r').read()with tf.Session() as sess:# 对序列化图像数据进行解码img_data_jpg = tf.image.decode_jpeg(image_raw_data_jpg)  img_data_jpg = tf.image.convert_image_dtype(img_data_jpg, dtype=tf.uint8)  # 改变图像数据的类型img_data_png = tf.image.decode_png(image_raw_data_png)img_data_png = tf.image.convert_image_dtype(img_data_png, dtype=tf.uint8)plt.figure(1)  # 图像显示plt.imshow(img_data_jpg.eval())plt.figure(2)plt.imshow(img_data_png.eval())plt.show()

TFRecord简介,原理分析,代码实现?相关推荐

  1. 【OpenGL】二十一、OpenGL 矩阵压栈与出栈 ( 不同类型矩阵变换先后顺序 | 渲染前不设置单位阵 | 压栈出栈原理分析 | 代码示例 )

    文章目录 一.不同类型矩阵变换先后顺序 二.渲染前不设置单位阵 三.矩阵的压栈和出栈原理分析 四.矩阵的压栈和出栈代码示例 五.相关资源 一.不同类型矩阵变换先后顺序 对 OpenGL 中的 模型视图 ...

  2. SFR 原理分析 代码

    MTF:REF 在表示相机图像解析力时,通常采用MTF50或者MTF50P. MTF50是当MTF数值下降至最大值的50%时,对应的频率(Cycle Per Pixel),它是一个广泛应用的锐利度衡量 ...

  3. (2)咚咚客户端核心设计原理分析 - 代码流程篇 (套接字建立,发送消息流程,接收消息流程)

    京麦Tcp建立连接流程: ConnectTask.run() -> connection.connect()(这里面也注册了一个连接状态的一个回调类) -> tryToConnectOnT ...

  4. 辗转相除法求最大公约数原理分析(附代码实现)

    辗转相除法求最大公约数原理分析(附代码实现) 前言 解释 原理分析 代码 结语 前言 辗转相除法用起来很简单,但是其原理却自己想不明白.于是乎看了几篇有关辗转相除法原理的分析,在这里自己写下自己的理解 ...

  5. 使用EasyExcel导出模板并设置级联下拉及其原理分析

    一.概述 项目中有时会遇到需要导出一个Excel模板,然后在导出的Excel中填充数据,最终再调用接口批量把Excel中的数据导入到数据库当中的需求. 其中级联下拉选择,手机号校验,性别校验等都是比较 ...

  6. 后门BROOTKIT代码学习和原理分析

    周末闲来无事,想找点东西学习一下,随手翻到了之前看到的一篇关于brootkit的文章,知道它是用Bash写的一个后门程序.刚好最近在做Bash相关的工作,就想着学习一下这方面的知识,稍作整理之后就有了 ...

  7. Adaboost算法原理分析和实例+代码(简明易懂)

    Adaboost算法原理分析和实例+代码(简明易懂) [尊重原创,转载请注明出处] http://blog.csdn.net/guyuealian/article/details/70995333   ...

  8. Adaboost算法原理分析和实例+代码(转载)

    [尊重原创,转载请注明出处] http://blog.csdn.net/guyuealian/article/details/70995333     本人最初了解AdaBoost算法着实是花了几天时 ...

  9. python 相关性分析原理及代码详细介绍

    一.相关性分析简介 相关性分析(correlation analysis)是指对两个或多个具备相关关系的变量进行线性相关分析,从而衡量变量间的相关程度或密切程度.相关性程度即为相关性系数R,R的取值范 ...

最新文章

  1. JavaScript两个变量交换值(不使用临时变量)
  2. Tomcat意外宕机分析
  3. 前端学python有什么用-杭州前端需要学习python
  4. HDU - 1875 畅通工程再续
  5. 概率论与数理统计(一)
  6. FTP基础知识及排障方法
  7. MySQL优化之三:SQL语句优化
  8. 数据库-优化-Limit查询的优化
  9. Android开发之百度地图距离判断(判断定位位置是否在圆的范围内)
  10. 管理系统中计算机应用怎么自学,有自学的没,管理系统中计算机应用上机题。...
  11. mysql增错误_使用MySQL练习增删改查时因为版本问题出现连接错误
  12. MIT6.828 boot.S文件分析
  13. js操作动态表格内元素
  14. MFC中将view内容保存为bmp
  15. EPT DATA数据库能查到1998-2013年工业企业数据库
  16. ATV 开发 三 DRM技术简介
  17. 【项目】08年度科创项目“绘图助手工具箱”项目成果发布
  18. python+django+mysql图片分享平台毕业设计毕设开题报告
  19. mac 卸载java 9
  20. ai怎么做盒子效果图_AI制作包装纸盒贴图教程

热门文章

  1. Cisco 交换机修改密码
  2. matlab预处理光谱数据,一种近红外光谱数据预处理方法与流程
  3. 2016蓝桥杯假期任务之《 黄金连分数》
  4. html浏览器获取域账号密码,Firefox/chrome等FTP域内权限限制不严 可通过XSS遍历网站目录并获取源码...
  5. Freenom注册顶级域名
  6. 让web应用正常显示ios输入法表情
  7. 那个智商200以上、14岁上大学、17岁读博、24岁成教授的天才神童,如今怎样了?...
  8. LVS负载均衡+高可用
  9. TensorFlow LSTM 注意力机制图解
  10. 在VMware 10.0上安装NetBSD 1.0