读写tfrecord文件
目录
- 1 写tfrecor方式
- 1.1 变长特征转tfrecord
- 1.2 定长特征转tfrecord
- 2 读tfrecord
- 2.1 变长方式读tfrecord
- 2.2 定长方式读tfrecord
- 3 从hdfs中读取批量tfrecord文件
在训练模型的时候,一般会将数据预处理转换成tfrecord格式,负责I/O操作的CPU和进行数值运行计算的GPU相互之间可以并行工作,保证GPU高的利用率。以下是对特征是定长和变长读写tfrecord方式。
1 写tfrecor方式
一般会将数据按照模型训练所需要的方式对输入x和label标签进行tfrecord格式转换。主要有定长和变长两种方式,根据实际应用和需求决定。若输入的每个example的input 是变长的,比如每个example的输入特征索引个数不是相同的,则可以按照变长的方式转换,否则按照定长的方式转换。
1.1 变长特征转tfrecord
import collections
writer = tf.python_io.TFRecordWriter('data.tfrecord')
def toTF(data):''' data是一个dict,假设其中key有input_x和input_y,对应的value是索引list'''features = collections.OrderedDict()input_x = tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_x"])))features["input_x"] = tf.train.FeatureList(feature=input_x)input_y = tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_y"])))features["input_y"] = tf.train.FeatureList(feature=input_y)sequence_example = tf.train.SequenceExample(feature_lists=tf.train.FeatureLists(feature_list=features))writer.write(sequence_example.SerializeToString())
以下方式实现与上面方式等价:
def toTF_v2(data)sequence_example = tf.train.SequenceExample()input_x = sequence_example.feature_lists.feature_list["input_x"]input_y = sequence_example.feature_lists.feature_list["input_y"]for x in data["input_x"]:input_x.feature.add().int64_list.value.append(x)for y in data["input_y"]:input_y.feature.add().int64_list.value.append(y)writer.write(sequence_example.SerializeToString())
1.2 定长特征转tfrecord
def toTF_fixed(data):features = collections.OrderedDict()features["input_x"]= tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_x")))features["input_y"]= tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_y")))example = tf.train.Example(features=tf.train.Features(feature=features))write.write(example.SerializeToString())
2 读tfrecord
和写trrecord一样,也分定长和变长方式,如果写tfrecord是定长方式,则读tfrecord也需要定长方式。读写方式需要保持一致。
2.1 变长方式读tfrecord
需要定义特征的格式,如果是变长则定义tf.FixedLenSequenceFeature类型特征
import tensorflow as tf
features = {'input_x': tf.FixedLenSequenceFeature([], tf.int64)'input_y': tf.FixedLenSequenceFeature([], tf.int64)}
2.2 定长方式读tfrecord
定长方式用tf.FixedLenFeature类型
seq_length = 10
features = {'input_x': tf.FixedLenFeature([seq_length], tf.int64).'input_y': tf.FixedLenFeature([seq_length], tf.int64}
3 从hdfs中读取批量tfrecord文件
当训练数据量级很大时,一般转tfrecord试用分布式方式处理数据,提高效率。训练模型的时候,可以从远程,例如hdfs上读取批量文件。以下是从hdfs上批量读取tfrecord文件。
def input_fn_builder(file_path, num_cpu_threads, seq_length, num_class, batch_size):'''其中file_path是hdfs上文件的路径,比如data目录下的所有tfrecord文件读的是定长的feature'''features = {'input_x': tf.FixedLenFeature([seq_length], tf.int64),'input_y': tf.FixedLenFeature([seq_length], tf.int64),}def _decode_record(record):# 一个样本解析example = tf.io.parse_single_example(record, features)multi_label_enc = tf.one_hot(indices=example["input_y"], depth=num_class)example["input_y"] = tf.reduce_sum(multi_label_enc, axis=0)return exampledef _decode_batch_record(batch_record):# 一个batch样本解析batch_example = tf.io.parse_example(serialized=batch_record, features=features)multi_label_enc = tf.one_hot(indices=batch_example["input_y"], depth=num_class)batch_example["input_y"] = tf.reduce_sum(multi_label_enc, axis=1)return batch_exampledef input_fn(params):# d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))d = tf.data.Dataset.list_files(file_path)d = d.repeat()d = d.shuffle(buffer_size=100)d = d.appley(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset,sloppy=True,cycle_length=num_cpu_threads))d = d.apply(tf.contrib.data.map_and_batch(lambda record: _decode_record(record),batch_size = batch_size,num_parallel_batches=num_cpu_threads,drop_remainder=True))return ddef input_fn_v2(params):d = tf.data.Dataset.list_files(file_path)d = d.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=num_cpu_threads, block_length=128).\batch(batch_size).map(_decode_batch_record, num_parallel_calls=tf.data.experimental.AUTOTRUE).prefetch(tf.data.experimental.AUTOTUNE).repeat()return dreturn input_fn#return input_fn_v2
上面提供了两个解析函数,input_fn和input_fn_v2两种方式都可行,配合estimator方式训练,可以使得CPU读取数据与GPU训练数据之间可以并行处理,减少等待时间,提高GPU的利用率,加快训练速度。解析tfrecord文件时,有下面四种方式,根据自己具体的数据格式进行选择:
- 解析单个样本,定长特征:tf.io.parse_single_example()
- 解析单个样本,变长特征:tf.io.parse_single_sequence_example()
- 解析批量样本,定长特征:tf.io.parse_example()
- 解析批量样本,定长特征:tf.io.parse_sequence_example()
读写tfrecord文件相关推荐
- Tensorflow—TFRecord文件生成与读取
Tensorflow-TFRecord文件生成与读取 微信公众号:幼儿园的学霸 个人的学习笔记,关于OpenCV,关于机器学习, -.问题或建议,请公众号留言; 目录 文章目录 Tensorflow- ...
- 利用pandas读写HDF5文件
一.简介 HDF5(Hierarchical Data Formal)是用于存储大规模数值数据的较为理想的存储格式,文件后缀名为h5,存储读取速度非常快,且可在文件内部按照明确的层次存储数据,同一个H ...
- OpenCV读写视频文件解析(二)
OpenCV读写视频文件解析(二) VideoCapture::set 设置视频捕获中的属性. C++: bool VideoCapture::set(int propId, double value ...
- OpenCV读写视频文件解析
OpenCV读写视频文件解析 一.视频读写类 视频处理的是运动图像,而不是静止图像.视频资源可以是一个专用摄像机.网络摄像头.视频文件或图像文件序列. 在OpenCV 中,VideoCapture 类 ...
- dom4j读写xml文件
dom4j读写xml文件 首先我们给出一段示例程序: import java.io.File; import java.io.FileWriter; import java.util.Iterator ...
- 使用WinPcap和libpcap类库读写pcap文件(002)PCAP文件格式
本文基本翻译自https://wiki.wireshark.org/Development/LibpcapFileFormat,主要分析pcap文件的格式. 其中一些字段可能和现在的WinPcap类库 ...
- 使用WinPcap和libpcap类库读写pcap文件(001)开发环境配置
最近的项目要求写一个读写pcap文件的小程序,用来修改pcap中的部分信息,实现pcap的定制. 所以必须学会使用wireshark并能有利用WinPcap库和libpcap库进行开发. 虽然本文记录 ...
- java如何读写json文件
java如何读写json文件 在实际项目开发中,有时会遇到一些全局的配置缓存,最好的做法是配置redis数据库作为数据缓存,而当未有配置redis服务器时,读取静态资源文件(如xml.json等)也是 ...
- python读写压缩文件使用gzip和bz2
python读写压缩文件使用gzip和bz2 #读取压缩文件 # gzip compression import gzip with gzip.open('somefile.gz', 'rt') as ...
最新文章
- ECMAScript6——Set数据结构
- CVPR 2017 全部及部分论文解读集锦
- C++输入/输出文件
- 矩阵论基础知识4——强大的矩阵奇异值分解(SVD)及其应用
- 部署 Node.js 应用以完成服务器端渲染 Server Side Rendering 的性能调优
- 开课吧python全栈靠谱么-杭州Web全栈
- 如何快速学从零开始学习3d建模?
- apache ii评分怎么评_apache ii评分多少分为危重患者
- Windows系统压缩卷时可压缩空间远小于实际剩余空间解决方法
- 数据扒一扒《隐秘的角落》到底怎么火的?
- c语言 求矩阵各行元素之和
- POJ1753 Flip Game
- 安装CentOS7时选择install后直接黑屏的解决办法
- 微信小程序 仿朋友圈
- Ristretto—SqueezeNet示例详解
- http://www.dewen.net.cn/q/15807/java byte 疑问
- OpenCV mat类实现水平投影和垂直投影
- java毕业设计 基于vue的小区停车场停车位短租管理系统ssm源码介绍
- 全连接层介绍以及简单实现
- 渗透测试常用工具总结——DAST、SAST、IAST