通常做法是使用Tensorflow的Dataset来读取我们的tfRecord,但是老的版本也有通过TFRecordReader进行解析,这里我们先介绍使用Dataset方式读取

  • 加载TFRecord文件
  • 通过parse_fn方法对每条样本机型解析
  • 重复N epochs
  • batch
def parse_fn(example_proto):features = {"state": tf.FixedLenFeature((), tf.string),"action": tf.FixedLenFeature((), tf.int64),"reward": tf.FixedLenFeature((), tf.int64)}parsed_features = tf.parse_single_example(example_proto, features)return tf.decode_raw(parsed_features['state'], tf.float32), parsed_features['action'], parsed_features['reward']with tf.Session() as sess:dataset = tf.data.TFRecordDataset(output_file)  # 加载TFRecord文件dataset = dataset.map(parse_fn)  # 解析data到Tensordataset = dataset.repeat(1)  # 重复N epochsdataset = dataset.batch(3)  # batch sizeiterator = dataset.make_one_shot_iterator()next_data = iterator.get_next()while True:try:state, action, reward = sess.run(next_data)print(state)print(action)print(reward)except tf.errors.OutOfRangeError:break

遍历结果:

解析tfrecord的2种方式

for example in tf.io.tf_record_iterator(output_file):print("first method")print(tf.train.Example.FromString(example))# 或者用下面的方法print("second method")from google.protobuf.json_format import MessageToJsonjsonMessage = MessageToJson(tf.train.Example.FromString(example))print(jsonMessage)

解析结果:

first method
features {feature {key: "action"value {int64_list {value: 1}}}feature {key: "reward"value {int64_list {value: 90}}}feature {key: "state"value {bytes_list {value: "\037\205\277B\341zD@\217\302\247A\270\036\303B\205\353\237Aff\230A33\031B\315\314\300A\nW\202B\244p/B"}}}
}second method
{"features": {"feature": {"action": {"int64List": {"value": ["1"]}},"state": {"bytesList": {"value": ["H4W/QuF6RECPwqdBuB7DQoXrn0FmZphBMzMZQs3MwEEKV4JCpHAvQg=="]}},"reward": {"int64List": {"value": ["90"]}}}}
}

完整代码:

"""
本程序演示了如何保存numpy array为TFRecords文件,并将其读取出来。
"""
import randomimport numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()def save_tfrecords(state_data, action_data, reward_data, dest_file):"""保存numpy array到TFRecord文件中。这里输入了三个不同的numpy array来做演示,它们含有不同类型的元素。Args:state_data: 要保存到TFRecord文件的第1个numpy array,每一个 state_data[i] 是一个 numpy.ndarray(数组里的每个元素又是一个浮点数),因此不能用 Int64List 或 FloatList 来存储,只能用 BytesList。action_data: 要保存到TFRecord文件的第2个numpy array,每一个 action_data[i] 是一个整数,使用 Int64List 来存储。reward_data: 要保存到TFRecord文件的第3个numpy array,每一个 reward_data[i] 是一个整数,使用 Int64List 来存储。dest_file: 输出文件的路径。Returns:不返回任何值"""with tf.io.TFRecordWriter(dest_file) as writer:for i in range(len(state_data)):features = tf.train.Features(feature={"state": tf.train.Feature(bytes_list=tf.train.BytesList(value=[state_data[i].astype(np.float32).tobytes()])),# "state": tf.train.Feature(#     float_list=tf.train.FloatList(value=state_data[i].astype(np.float))),# "action": tf.train.Feature(#     int64_list=tf.train.Int64List(value=[action_data[i]])),# "reward": tf.train.Feature(#     int64_list=tf.train.Int64List(value=[reward_data[i]]))"action": tf.train.Feature(int64_list=tf.train.Int64List(value=action_data[i].astype(np.int))),"reward": tf.train.Feature(int64_list=tf.train.Int64List(value=reward_data[i].astype(np.int)))})tf_example = tf.train.Example(features=features)serialized = tf_example.SerializeToString()writer.write(serialized)def parse_fn(example_proto):features = {"state": tf.FixedLenFeature((), tf.string),"action": tf.FixedLenFeature((), tf.int64),"reward": tf.FixedLenFeature((), tf.int64)}parsed_features = tf.parse_single_example(example_proto, features)return tf.decode_raw(parsed_features['state'], tf.float32), parsed_features['action'], parsed_features['reward']if __name__ == '__main__':buffer_s, buffer_a, buffer_r = [], [], []# 随机生成一些数据for i in range(30):state = [round(random.random() * 100, 2) for _ in range(0, 10)]  # 一个数组,里面有10个数,每个都是一个浮点数action = random.randrange(0, 2)  # 一个数,值为 0 或 1reward = random.randrange(0, 100)  # 一个数,值域 [0, 100)# 把生成的数分别添加到3个list中buffer_s.append(state)buffer_a.append(action)buffer_r.append(reward)# 查看生成的数据print(buffer_s)print(buffer_a)print(buffer_r)# 在水平方向把各个list堆叠起来,堆叠的结果:得到3个矩阵s_stacked = np.vstack(buffer_s)a_stacked = np.vstack(buffer_a)r_stacked = np.vstack(buffer_r)print(s_stacked.shape)  # (3, 10)print(a_stacked.shape)  # (3, 1)print(r_stacked.shape)  # (3, 1)print(s_stacked)print(a_stacked)print(r_stacked)print("data generate sucess!")# 写入TFRecord文件output_file = './data.tfrecord'  # 输出文件的路径save_tfrecords(s_stacked, a_stacked, r_stacked, output_file)# 读取TFRecord文件并打印出其内容for example in tf.io.tf_record_iterator(output_file):print("first method")print(tf.train.Example.FromString(example))# 或者用下面的方法print("second method")from google.protobuf.json_format import MessageToJsonjsonMessage = MessageToJson(tf.train.Example.FromString(example))print(jsonMessage)# 读取TFRecord文件并还原成numpy array,再打印出来with tf.Session() as sess:dataset = tf.data.TFRecordDataset(output_file)  # 加载TFRecord文件dataset = dataset.map(parse_fn)  # 解析data到Tensordataset = dataset.repeat(1)  # 重复N epochsdataset = dataset.batch(3)  # batch sizeiterator = dataset.make_one_shot_iterator()next_data = iterator.get_next()while True:try:print("get next")state, action, reward = sess.run(next_data)print(state)print(action)print(reward)except tf.errors.OutOfRangeError:break

TFRecordReader方式

  • tf.train.string.input_producer 读取序列化后的的TFRecord记录,生成一个QueueRunner,它包含一个FIFOQueue队列
  • 通过tf.TFRecordReader() 依据定义的模式,进行反序列化parse,可以附带一些转换操作
  • batch,通过tf.train.shuffle_batch生成了RandomShuffleQueue
  • 通过 tf.train.Coordinator() tf.train.start_queue_runners 载入数据训练

附录:
使用tensorflow中的Dataset来读取制作好的tfrecords文件

tfrecord读取过程简介相关推荐

  1. 理解tfrecord读取数据——错误OutOfRangeError (see above for traceback)的解决

    转载自:TFRecord读取数据 前言 关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. ...

  2. python读docx文件_python-docx文件定位读取过程(尝试替换)

    python-docx文件定位读取过程(尝试替换) 以上是开头,安装完后需要导入转载的代码读取所有docx文件中的内容发现没有读取到表格数据: from docx import Document de ...

  3. tensorflow分类任务——TFRecord读取自己制作的数据集

    一.TensorFlow的数据读取机制 注意:这个地址是TensorFlow的数据读取机制,如果了解请跳过. 原博客地址:https://zhuanlan.zhihu.com/p/27238630 建 ...

  4. UA MATH565C 随机微分方程II Wiener过程简介

    UA MATH565C 随机微分方程II Wiener过程简介 Wiener过程的简单性质 Wiener过程的定义 在上一讲我们定义了WtW_tWt​: dWt=ηtdt⇔Wt=∫0tηsdsdW_t ...

  5. HDFS写入和读取过程

    HDFS写入和读取过程 一.HDFS HDFS全称是Hadoop Distributed System.HDFS是为以流的方式存取大文件而设计的.适用于几百MB,GB以及TB,并写一次读多次的场合.而 ...

  6. 【Hadoop】HDFS文件写入与文件读取过程

    HDFS文件写入与文件读取过程 1. 文件读取过程 2. 文件写入过程 1. 文件读取过程 详细过程: 客户端通过调用FileSystem对象的open()来读取希望打开的文件. Client向Nam ...

  7. 最详细的php使用com读取word文件,并且解决读取过程中乱码问题,doc/docx都适用,适用于thinkphp,laravel应该也可以

    一,首先要确认php版本,最好是高于5.6 二,将以下两行代码放入php.ini中并且重启 //这个是开启扩展 extension=php_com_dotnet.dll //这个是COM扩展里自带的, ...

  8. BetaFlight飞控启动运行过程简介

    BetaFlight飞控启动&运行过程简介 1. 源由 2. 启动过程 2.1 main(主程序) 2.2 init (初始化) 2.3 run 3. 任务调度 3.1 任务定义 3.2 sc ...

  9. RISC-V嵌入式开发准备篇1:编译过程简介

    原文出处:https://mp.weixin.qq.com/s/-syKN0DibKGGPCllaeNqMg 随着国内第一本RISC-V中文书籍<手把手教你设计CPU--RISC-V处理器篇&g ...

最新文章

  1. REUSE_ALV_GRID_DISPLAY事件子过程和cl_gui_grid类的事件对应关系
  2. 老男孩36期运维脱产班---- 决心书
  3. 16张图带你吃透高性能 Redis 集群
  4. Windows 7 正在走 XP 系统的老路
  5. ffmpeg 源代码简单分析 : avcodec_decode_video2()
  6. apktool反编译生成java_apktool反编译工具下载|apktool反编译工具 v3.0.1 最新版-520下载站...
  7. 敏捷开发案例:用白板解决项目管理和团队沟通
  8. [lammps教程]OVITO绘制原子应力云图
  9. 绘制自己的人际关系图_如何系统的绘制自己的人际关系网络图?
  10. C++——判身份证号码真伪
  11. UEFI开发,记录第一场胜利——调用一个自己编写的protocol
  12. for /f 用法详解
  13. 二叉树的镜像(递归非递归)
  14. Codeforces 1610C Keshi Is Throwing a Party
  15. 路由器与交换机的工作原理(转)
  16. “移”步到位:一站式移动应用研发体系
  17. 新版本MySQL的安装教程,非免安装版本。超详细!!!
  18. HTML event
  19. 线性代数---之正交向量
  20. 智慧公安警务系统开发,智慧公安行业解决方案

热门文章

  1. ideaeclipse快捷键
  2. 将3Dmax的模型导入到unity中(带材质)
  3. ThingsBoard CE添加Excel导出功能-优化篇
  4. 常用GIT代码托管平台
  5. 活动 | 注册即半价,治愈你第二杯半价时的孤单
  6. C++中sstream类
  7. 利用 ChatGPT 解决某些网站文章不允许复制粘贴的限制
  8. 基于RHEL8/CentOS8的网络IP配置详解
  9. 毁灭还是生存?业务连续性管理考验公司基业常青-系统体系风险防范
  10. 【NodeJS】如何安装淘宝cnpm