目录

  • 1 写tfrecor方式
    • 1.1 变长特征转tfrecord
    • 1.2 定长特征转tfrecord
  • 2 读tfrecord
    • 2.1 变长方式读tfrecord
    • 2.2 定长方式读tfrecord
  • 3 从hdfs中读取批量tfrecord文件

在训练模型的时候,一般会将数据预处理转换成tfrecord格式,负责I/O操作的CPU和进行数值运行计算的GPU相互之间可以并行工作,保证GPU高的利用率。以下是对特征是定长和变长读写tfrecord方式。

1 写tfrecor方式

一般会将数据按照模型训练所需要的方式对输入x和label标签进行tfrecord格式转换。主要有定长和变长两种方式,根据实际应用和需求决定。若输入的每个example的input 是变长的,比如每个example的输入特征索引个数不是相同的,则可以按照变长的方式转换,否则按照定长的方式转换。

1.1 变长特征转tfrecord

import collections
writer = tf.python_io.TFRecordWriter('data.tfrecord')
def toTF(data):''' data是一个dict,假设其中key有input_x和input_y,对应的value是索引list'''features = collections.OrderedDict()input_x = tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_x"])))features["input_x"] = tf.train.FeatureList(feature=input_x)input_y = tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_y"])))features["input_y"] = tf.train.FeatureList(feature=input_y)sequence_example = tf.train.SequenceExample(feature_lists=tf.train.FeatureLists(feature_list=features))writer.write(sequence_example.SerializeToString())

以下方式实现与上面方式等价:

def toTF_v2(data)sequence_example = tf.train.SequenceExample()input_x = sequence_example.feature_lists.feature_list["input_x"]input_y = sequence_example.feature_lists.feature_list["input_y"]for x in data["input_x"]:input_x.feature.add().int64_list.value.append(x)for y in data["input_y"]:input_y.feature.add().int64_list.value.append(y)writer.write(sequence_example.SerializeToString())

1.2 定长特征转tfrecord

def toTF_fixed(data):features = collections.OrderedDict()features["input_x"]= tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_x")))features["input_y"]= tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_y")))example = tf.train.Example(features=tf.train.Features(feature=features))write.write(example.SerializeToString())

2 读tfrecord

和写trrecord一样,也分定长和变长方式,如果写tfrecord是定长方式,则读tfrecord也需要定长方式。读写方式需要保持一致。

2.1 变长方式读tfrecord

需要定义特征的格式,如果是变长则定义tf.FixedLenSequenceFeature类型特征

import tensorflow as tf
features = {'input_x': tf.FixedLenSequenceFeature([], tf.int64)'input_y': tf.FixedLenSequenceFeature([], tf.int64)}

2.2 定长方式读tfrecord

定长方式用tf.FixedLenFeature类型

seq_length = 10
features = {'input_x': tf.FixedLenFeature([seq_length], tf.int64).'input_y': tf.FixedLenFeature([seq_length], tf.int64}

3 从hdfs中读取批量tfrecord文件

当训练数据量级很大时,一般转tfrecord试用分布式方式处理数据,提高效率。训练模型的时候,可以从远程,例如hdfs上读取批量文件。以下是从hdfs上批量读取tfrecord文件。

def input_fn_builder(file_path, num_cpu_threads, seq_length, num_class, batch_size):'''其中file_path是hdfs上文件的路径,比如data目录下的所有tfrecord文件读的是定长的feature'''features = {'input_x': tf.FixedLenFeature([seq_length], tf.int64),'input_y': tf.FixedLenFeature([seq_length], tf.int64),}def _decode_record(record):# 一个样本解析example = tf.io.parse_single_example(record, features)multi_label_enc = tf.one_hot(indices=example["input_y"], depth=num_class)example["input_y"] = tf.reduce_sum(multi_label_enc, axis=0)return exampledef _decode_batch_record(batch_record):# 一个batch样本解析batch_example = tf.io.parse_example(serialized=batch_record, features=features)multi_label_enc = tf.one_hot(indices=batch_example["input_y"], depth=num_class)batch_example["input_y"] = tf.reduce_sum(multi_label_enc, axis=1)return batch_exampledef input_fn(params):# d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))d = tf.data.Dataset.list_files(file_path)d = d.repeat()d = d.shuffle(buffer_size=100)d = d.appley(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset,sloppy=True,cycle_length=num_cpu_threads))d = d.apply(tf.contrib.data.map_and_batch(lambda record: _decode_record(record),batch_size = batch_size,num_parallel_batches=num_cpu_threads,drop_remainder=True))return ddef input_fn_v2(params):d = tf.data.Dataset.list_files(file_path)d = d.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=num_cpu_threads, block_length=128).\batch(batch_size).map(_decode_batch_record, num_parallel_calls=tf.data.experimental.AUTOTRUE).prefetch(tf.data.experimental.AUTOTUNE).repeat()return dreturn input_fn#return input_fn_v2

上面提供了两个解析函数,input_fn和input_fn_v2两种方式都可行,配合estimator方式训练,可以使得CPU读取数据与GPU训练数据之间可以并行处理,减少等待时间,提高GPU的利用率,加快训练速度。解析tfrecord文件时,有下面四种方式,根据自己具体的数据格式进行选择:

  • 解析单个样本,定长特征:tf.io.parse_single_example()
  • 解析单个样本,变长特征:tf.io.parse_single_sequence_example()
  • 解析批量样本,定长特征:tf.io.parse_example()
  • 解析批量样本,定长特征:tf.io.parse_sequence_example()

读写tfrecord文件相关推荐

  1. Tensorflow—TFRecord文件生成与读取

    Tensorflow-TFRecord文件生成与读取 微信公众号:幼儿园的学霸 个人的学习笔记,关于OpenCV,关于机器学习, -.问题或建议,请公众号留言; 目录 文章目录 Tensorflow- ...

  2. 利用pandas读写HDF5文件

    一.简介 HDF5(Hierarchical Data Formal)是用于存储大规模数值数据的较为理想的存储格式,文件后缀名为h5,存储读取速度非常快,且可在文件内部按照明确的层次存储数据,同一个H ...

  3. OpenCV读写视频文件解析(二)

    OpenCV读写视频文件解析(二) VideoCapture::set 设置视频捕获中的属性. C++: bool VideoCapture::set(int propId, double value ...

  4. OpenCV读写视频文件解析

    OpenCV读写视频文件解析 一.视频读写类 视频处理的是运动图像,而不是静止图像.视频资源可以是一个专用摄像机.网络摄像头.视频文件或图像文件序列. 在OpenCV 中,VideoCapture 类 ...

  5. dom4j读写xml文件

    dom4j读写xml文件 首先我们给出一段示例程序: import java.io.File; import java.io.FileWriter; import java.util.Iterator ...

  6. 使用WinPcap和libpcap类库读写pcap文件(002)PCAP文件格式

    本文基本翻译自https://wiki.wireshark.org/Development/LibpcapFileFormat,主要分析pcap文件的格式. 其中一些字段可能和现在的WinPcap类库 ...

  7. 使用WinPcap和libpcap类库读写pcap文件(001)开发环境配置

    最近的项目要求写一个读写pcap文件的小程序,用来修改pcap中的部分信息,实现pcap的定制. 所以必须学会使用wireshark并能有利用WinPcap库和libpcap库进行开发. 虽然本文记录 ...

  8. java如何读写json文件

    java如何读写json文件 在实际项目开发中,有时会遇到一些全局的配置缓存,最好的做法是配置redis数据库作为数据缓存,而当未有配置redis服务器时,读取静态资源文件(如xml.json等)也是 ...

  9. python读写压缩文件使用gzip和bz2

    python读写压缩文件使用gzip和bz2 #读取压缩文件 # gzip compression import gzip with gzip.open('somefile.gz', 'rt') as ...

最新文章

  1. ECMAScript6——Set数据结构
  2. CVPR 2017 全部及部分论文解读集锦
  3. C++输入/输出文件
  4. 矩阵论基础知识4——强大的矩阵奇异值分解(SVD)及其应用
  5. 部署 Node.js 应用以完成服务器端渲染 Server Side Rendering 的性能调优
  6. 开课吧python全栈靠谱么-杭州Web全栈
  7. 如何快速学从零开始学习3d建模?
  8. apache ii评分怎么评_apache ii评分多少分为危重患者
  9. Windows系统压缩卷时可压缩空间远小于实际剩余空间解决方法
  10. 数据扒一扒《隐秘的角落》到底怎么火的?
  11. c语言 求矩阵各行元素之和
  12. POJ1753 Flip Game
  13. 安装CentOS7时选择install后直接黑屏的解决办法
  14. 微信小程序 仿朋友圈
  15. Ristretto—SqueezeNet示例详解
  16. http://www.dewen.net.cn/q/15807/java byte 疑问
  17. OpenCV mat类实现水平投影和垂直投影
  18. java毕业设计 基于vue的小区停车场停车位短租管理系统ssm源码介绍
  19. 全连接层介绍以及简单实现
  20. 渗透测试常用工具总结——DAST、SAST、IAST

热门文章

  1. 计算机组成原理之高速缓冲存储器(Cache)
  2. WWDC 2017后果:最重要的公告
  3. idea tomcat启动不来
  4. Building Secure Environments for Microservices
  5. 上海亚商投顾:沪指冲高回落 纺织服装股午后集体走强
  6. python画绿叶_python画一片绿叶给你
  7. 数字未来:世界正走向新的“破茧时刻”
  8. 网络层IP协议:IP网段划分(A类 B类 C类 D类 E类)
  9. 力软敏捷开发框架,快速搭建企业级应用系统
  10. 一个刚刚踏入IT界的新人的自我介绍