转自:http://blog.csdn.net/lujiandong1/article/details/53376802

这篇文章写数据读取,包含了线程以及batch的概念


1、准备数据,构造三个文件,A.csv,B.csv,C.csv

$ echo -e "Alpha1,A1\nAlpha2,A2\nAlpha3,A3" > A.csv
$ echo -e "Bee1,B1\nBee2,B2\nBee3,B3" > B.csv
$ echo -e "Sea1,C1\nSea2,C2\nSea3,C3" > C.csv 

2、从数据里生成样本和标签

2.1、单个Reader,每次生成一个样本

#-*- coding:utf-8 -*-
import tensorflow as tf
# 生成一个先入先出队列和一个QueueRunner,生成文件名队列
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
# 定义Reader
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# 定义Decoder
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
#example_batch, label_batch = tf.train.shuffle_batch([example,label], batch_size=1, capacity=200, min_after_dequeue=100, num_threads=2)
# 运行Graph
with tf.Session() as sess:  coord = tf.train.Coordinator()  #创建一个协调器,管理线程  threads = tf.train.start_queue_runners(coord=coord)  #启动QueueRunner, 此时文件名队列已经进队。  for i in range(10):  print example.eval(),label.eval()  coord.request_stop()  coord.join(threads) 

结果:这里生成的样本和label之间对应不上,乱序了。生成结果如下:
Alpha1 A2
Alpha3 B1
Bee2 B3
Sea1 C2
Sea3 A1
Alpha2 A3
Bee1 B2
Bee3 C1
Sea2 C3
Alpha1 A2


2.2、用tf.train.shuffle_batch,生成的结果就能够对应上

#-*- coding:utf-8 -*-
import tensorflow as tf
# 生成一个先入先出队列和一个QueueRunner,生成文件名队列
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
# 定义Reader
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# 定义Decoder
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
example_batch, label_batch = tf.train.shuffle_batch([example,label], batch_size=1, capacity=200, min_after_dequeue=100, num_threads=2)
# 运行Graph
with tf.Session() as sess:  coord = tf.train.Coordinator()  #创建一个协调器,管理线程  threads = tf.train.start_queue_runners(coord=coord)  #启动QueueRunner, 此时文件名队列已经进队。  for i in range(10):  e_val,l_val = sess.run([example_batch, label_batch])  print e_val,l_val  coord.request_stop()  coord.join(threads)  

运行结果


2.3、单个Reader,每次生成一个batch,主要也是通过tf.train.shuffle_batch来实现

#-*- coding:utf-8 -*-
import tensorflow as tf
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
# 使用tf.train.batch()会多加了一个样本队列和一个QueueRunner。
#Decoder解后数据会进入这个队列,再批量出队。
# 虽然这里只有一个Reader,但可以设置多线程,相应增加线程数会提高读取速度,但并不是线程越多越好。
example_batch, label_batch = tf.train.batch(  [example, label], batch_size=5)
with tf.Session() as sess:  coord = tf.train.Coordinator()  threads = tf.train.start_queue_runners(coord=coord)  for i in range(10):  e_val,l_val = sess.run([example_batch,label_batch])  print e_val,l_val  coord.request_stop()  coord.join(threads)

运行结果


2.4、下面这种写法,提取batch_size个样本,特征和label之间也是不同步的

#-*- coding:utf-8 -*-
import tensorflow as tf
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
# 使用tf.train.batch()会多加了一个样本队列和一个QueueRunner。
#Decoder解后数据会进入这个队列,再批量出队。
# 虽然这里只有一个Reader,但可以设置多线程,相应增加线程数会提高读取速度,但并不是线程越多越好。
example_batch, label_batch = tf.train.batch(  [example, label], batch_size=5)
with tf.Session() as sess:  coord = tf.train.Coordinator()  threads = tf.train.start_queue_runners(coord=coord)  for i in range(10):  print example_batch.eval(), label_batch.eval()  coord.request_stop()  coord.join(threads)  

运行结果


2.5、多个reader,生成batch。通过调用batch_join函数

#-*- coding:utf-8 -*-
import tensorflow as tf
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [['null'], ['null']]
#定义了多种解码器,每个解码器跟一个reader相连
example_list = [tf.decode_csv(value, record_defaults=record_defaults)  for _ in range(2)]  # Reader设置为2
# 使用tf.train.batch_join(),可以使用多个reader,并行读取数据。每个Reader使用一个线程。
example_batch, label_batch = tf.train.batch_join(  example_list, batch_size=5)
with tf.Session() as sess:  coord = tf.train.Coordinator()  threads = tf.train.start_queue_runners(coord=coord)  for i in range(10):  e_val,l_val = sess.run([example_batch,label_batch])  print e_val,l_val  coord.request_stop()  coord.join(threads)  

运行结果:


3、总结
tf.train.batch与tf.train.shuffle_batch函数使用单个Reader读取文件,但是可以使用多个线程(num_thread>1),这些线程同时读取同一个文件里的不同example。这种方法的优点:

  1. 如果读的线程比文件数多,这种方法可以避免两个线程同时读取同一个文件里的同一个example
  2. tf.train.batch_join方法所使用的同时读硬盘里的N个不同文件会花费多余的disk seek 时间

tf.train.batch_join与tf.train.shuffle_batch_join可设置多Reader读取,每个Reader使用一个线程、且每个reader读取各自的文件,直到这次epoch里的文件全部读取完。

至于两种方法的效率,单Reader时,2个线程就达到了速度的极限。多Reader时,2个Reader就达到了极限。所以并不是线程越多越快,甚至更多的线程反而会使效率下降。

TensorFlow batch相关推荐

  1. tensorflow --batch内负采样

    class NegativeCosineLayer():""" 自定义batch内负采样并做cosine相似度的层 """"&qu ...

  2. 如何用一套引擎搞定机器学习全流程?

    作者:陈戊超(仲卓) 深度学习技术在当代社会发挥的作用越来越大.目前深度学习被广泛应用于个性化推荐.商品搜索.人脸识别.机器翻译.自动驾驶等多个领域,此外还在向社会各个领域迅速渗透. 背景 当前,深度 ...

  3. 光流估计(三) PWC-Net 模型介绍

    一.PWC-Net 概述 PWC-Net 的网络模型在 CVPR,2018 由 NVIDIA 提出,发表文章为 <PWC-Net: CNNs for Optical Flow Using Pyr ...

  4. Kubernetes 和 Kubeflow 学习笔记

    Kubernetes Kubernetes是一个完备的分布式系统支撑平台,具有完备的集群管理能力,多扩多层次的安全防护和准入机制.多租户应用支撑能力.透明的服务注册和发现机制.內建智能负载均衡器.强大 ...

  5. Multi-task Learning

    Deep Learning 回顾之多任务学习 https://www.52ml.net/20775.html?utm_source=tuicool&utm_medium=referral 深度 ...

  6. Tensorflow BatchNormalization详解:4_使用tf.nn.batch_normalization函数实现Batch Normalization操作...

    使用tf.nn.batch_normalization函数实现Batch Normalization操作 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 吴恩达deeplearnin ...

  7. 3.1 Tensorflow: 批标准化(Batch Normalization)

    ##BN 简介 背景 批标准化(Batch Normalization )简称BN算法,是为了克服神经网络层数加深导致难以训练而诞生的一个算法.根据ICS理论,当训练集的样本数据和目标样本集分布不一致 ...

  8. tensorflow 的 Batch Normalization 实现(tf.nn.moments、tf.nn.batch_normalization)

    tensorflow 在实现 Batch Normalization(各个网络层输出的归一化)时,主要用到以下两个 api: tf.nn.moments(x, axes, name=None, kee ...

  9. tensorflow dataset.shuffle dataset.batch dataset.repeat 理解 注意点

    batch很好理解,就是batch size.注意在一个epoch中最后一个batch大小可能小于等于batch size  dataset.repeat就是俗称epoch,但在tf中与dataset ...

最新文章

  1. NodeJs教程(介绍总结!)终于在网上找到一个靠谱点的了T_T
  2. (chap4 IP协议) CIDR协议
  3. 密码篇——非对称加密
  4. 使用web3.js进行开发
  5. 在移动互联网上赚钱,行不行
  6. Python导入全局、局部模块以及如何让避免循环导入
  7. 分布式系统面试 - 常见问题
  8. RedHat中敲sh-copy-id命令报错:-bash: ssh-copy-id: command not found
  9. C语言小知识---特殊的字符串
  10. oracle test 卡死,oracle11g plsql调试存储过程卡死的处理技巧
  11. vc6.0快捷键小结收藏
  12. 也许你需要在 Antergos 与 Arch Linux 中查看印度语和梵文?
  13. 基础篇——树莓派远程连接工具VNC不显示视频或摄像头画面解决方式
  14. 用BWA进行序列比对
  15. 用ul、li做横向导航
  16. 技术社区、相关论坛推荐汇总(持续更新)
  17. Linux编译时如何减小so库文件的大小
  18. 大数据分析培训课程有哪些?初级阶段主要学习什么?
  19. 干货|Python爬虫如何设置代理IP
  20. HTML中input标签和button标签区别

热门文章

  1. 2020ICPR-化妆演示攻击
  2. ThreeJS FBXLoader 加载3D文件,材质消失,已解决
  3. 计算机操作系统知识框架要点复习,不包含习题!如有错误可以留言。
  4. java截取视频空间的中间段
  5. camunda如何插入以及获取流程审批
  6. 《复杂网络》复杂网络的结构及特点
  7. C语言fscanf/fprintf函数(格式化读写文件)的用法(%[]和%n说明符)
  8. 关于socket-error-10054的一点认知
  9. 怎么在线快速将多张CAD图纸转换成低版本DXF格式?
  10. java outer的使用