every blog every motto: You can do more than you think.

0. 前言

在训练模型时,我们往往不一次将数据全部加载进内存中,而是将数据分批次加载到内存中。


  • 一种方法是用 while True 遍历数据,用yeid产生,具体可参考语义分割代码讲解部分
  • 另一种方法是本文即将讲解的tf.keras.utils.Sequence方法

1. 正文

1.1 基础用法

__ len __ 中返回的即1个epoch迭代的次数,即:
总样本数/ batch_size

__ getitem __ 根据len中的迭代次数,生成数据


注意: __ len __ ,__ getitem __ 必须要实现

"""
测试
__getitem__
"""
import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tfclass Date(tf.keras.utils.Sequence):def __init__(self):print('初始化相关参数')def __len__(self):"""此方法要实现,否则会报错正常程序中返回1个epoch迭代的次数:return:"""return 5def __getitem__(self, index):"""生成一个batch的数据"""print('index:', index)x_batch = ['x1', 'x2', 'x3', 'x4']y_batch = ['y1', 'y2', 'y3', 'y4']print('-'*20)return x_batch, y_batch# 实例化数据
date = Date()for batch_number, (x, y) in enumerate(date):print('正在进行第{} batch'.format(batch_number))print('x_batch:', x)print('y_batcxh:', y)

结果:

1.2 扩展(2020.11.12 15:37增补)

可以在类中实现on_epoch_end方法,保证在每个epoch后打乱原有数据的顺序

1.2.1 训练样例:

测试代码,如下:

import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as npprint('tensorflow version: ', tf.__version__)class ZerosFirstEpochOnesAfter(tf.keras.utils.Sequence):def __init__(self):self.shuffle = Truedef __len__(self):return 2def on_epoch_end(self):print('---------------on_epoch_end------------')# 打乱索引# if self.shuffle:#     print('==============================================================shuffle')#     np.random.shuffle(self.indices)def __getitem__(self, item):return np.zeros((16, 1)), np.zeros((16,))def main():model = tf.keras.Sequential()model.add(tf.keras.layers.Dense(1, input_dim=1, activation="softmax"))model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])model.fit(ZerosFirstEpochOnesAfter(), epochs=3, )if __name__ == '__main__':main()

tensorflow 2.0:

tensorflow 2.1:

tesorflow 2.3:

由以上三个版本的训练结果,我们可以发现,

  • 在2.0和2.1版本中,是没有进行on_epoch_end方法调用的,即没有实现on_epoch_end方法内注释部分的打乱顺序,这是tensorflow早期版本的一个bug,具体可参考文后第4个链接。
  • 在2.3版本中已得到改进

1.2.2 循环遍历:

1.2.2.1 原始版测试

循环遍历,如下所示:

import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as npprint('tensorflow version: ', tf.__version__)class Date(tf.keras.utils.Sequence):def __init__(self):print('初始化相关参数')self.lines = [1,2,3,4,5]self.shuffle = Truedef __len__(self):"""此方法要实现,否则会报错正常程序中返回1个epoch迭代的次数:return:"""return 2def on_epoch_end(self):print('=======================')if self.shuffle == True:print('------------一个epoch结束,打乱了顺序---')np.random.shuffle(self.lines)def __getitem__(self, index):"""生成一个batch的数据"""print('index:', index)x_batch = ['x1', 'x2', 'x3', 'x4']y_batch = ['y1', 'y2', 'y3', 'y4']print('-' * 20)return x_batch, y_batch# 实例化数据
date = Date()for epoch in range(2):for batch_number, (x, y) in enumerate(date):print('正在进行第{} batch'.format(batch_number))print('x_batch:', x)print('y_batcxh:', y)print('一个epoch结束=============================')

结果:

如上图所示,通过循环遍历这种方法仍然不能调用on_epoch_end,即无法打乱顺序

1.2.2.2 改进版

import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as npprint('tensorflow version: ', tf.__version__)class Date(tf.keras.utils.Sequence):def __init__(self):print('初始化相关参数')self.lines = [1,2,3,4,5]self.shuffle = Truedef __len__(self):"""此方法要实现,否则会报错正常程序中返回1个epoch迭代的次数:return:"""return 2def on_epoch_end(self):print('=======================')if self.shuffle == True:print('------------一个epoch结束,打乱了顺序---')np.random.shuffle(self.lines)def __getitem__(self, index):"""生成一个batch的数据"""print('index:', index)x_batch = ['x1', 'x2', 'x3', 'x4']y_batch = ['y1', 'y2', 'y3', 'y4']print('-' * 20)return x_batch, y_batch# 实例化数据
date = Date()for epoch in range(2):print(date.lines)for batch_number, (x, y) in enumerate(date):print('正在进行第{} batch'.format(batch_number))print('x_batch:', x)print('y_batcxh:', y)np.random.shuffle(date.lines)print('一个epoch结束=============================')

如下图所示,我们发现已经打乱了“样本”顺序,

参考文献

[1] https://blog.csdn.net/weixin_39190382/article/details/105808830
[2] https://blog.csdn.net/weixin_43198141/article/details/89926262
[3] https://blog.csdn.net/u011311291/article/details/80991330
[4] https://github.com/tensorflow/tensorflow/issues/35911
[5] https://colab.research.google.com/gist/bfs15/fd18263f788a071225c60cedaf126748/35911.ipynb

【tf.keras.utils.Sequence】构建自己的数据集生成器相关推荐

  1. tensorflow tf.keras.utils.plot_model 画深度学习神经网络拓扑图

    tensorflow tf.keras.utils.plot_model 画网络拓扑图 # pip install graphviz # pip install pydot # 下载 graphviz ...

  2. 垃圾分类、EfficientNet模型、数据增强(ImageDataGenerator)、混合训练Mixup、Random Erasing随机擦除、标签平滑正则化、tf.keras.Sequence

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 垃圾分类.EfficientNet模型.数据增强(ImageD ...

  3. tf.Keras.Model类总结

    文章目录 tf.keras.Model类 1. 创建一个tf.keras.Model类实例的方法 1.1 通过指定输入输出进行实例化 1.2 通过继承Model类进行实例化 2. tf.Keras.M ...

  4. 使用估算器、tf.keras 和 tf.data 进行多 GPU 训练

    文 / Zalando Research 研究科学家 Kashif Rasul 来源 | TensorFlow 公众号 与大多数 AI 研究部门一样,Zalando Research 也意识到了对创意 ...

  5. Tensorflow学习之tf.keras(一) tf.keras.layers.Model(另附compile,fit)

    模型将层分组为具有训练和推理特征的对象. 继承自:Layer, Module tf.keras.Model(*args, **kwargs ) 参数 inputs 模型的输入:keras.Input ...

  6. 日月光华深度学习(一、二)深度学习基础和tf.keras

    日月光华深度学习(一.二)深度学习基础和tf.keras [2.2]--tf.keras实现线性回归 [2.5]--多层感知器(神经网络)的代码实现 [2.6]--逻辑回归与交叉熵 [2.7]--逻辑 ...

  7. python怎么导入数据集keras_keras使用Sequence类调用大规模数据集进行训练的实现

    使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里 ...

  8. 从零开始,手把手教你使用Keras和TensorFlow构建自己的CNN模型

    最近学习CNN,搭建CNN模型时看网上鱼龙混杂的博客走了不少歪路,决定自己来总结一下. 注意本教程未必对所有版本有效,请根据需要的版本适当调整.文章中配置的环境是Python 3.8.12 ,Tens ...

  9. 深度学习-Tensorflow2.2-深度学习基础和tf.keras{1}-softmax多分类-06

    softmax分类 Fashion MNIST数据集 import tensorflow as tf import pandas as pd import numpy as np import mat ...

  10. TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络

    TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...

最新文章

  1. 异构计算架构师眼中的AI算法(object detection)
  2. 任务31:课时介绍 任务32:Cookie-based认证介绍 任务33:34课 :AccountController复制过来没有移除[Authorize]标签...
  3. CCF201412-2 Z字形扫描(模拟)
  4. mfc控件随框变化(EasySize的用法,仔细看绝对有用)
  5. 软件工程学习进度第八周暨暑期学习进度之第八周汇总
  6. HashMap分别按照key和value进行排序的快捷方法
  7. 微软发布企业安全进度报告 云应用安全服务即将面世
  8. python函数定义及调用-python函数基础(函数的定义和调用)
  9. 2018/5/7~2018/5/11 周记
  10. 路由器刷固件——斐讯路由器FIR300M刷OpenWrt固件教程
  11. html自我介绍5页模板,关于个人自我介绍模板6篇
  12. python不是5的倍数_查找所有低于1000的数字之和,这是Python中3或5的倍数
  13. 《手把手教你构建自己的 Linux 系统》学习笔记(9)
  14. 两台服务器联通如何配置文件,两个服务器之间数据库怎么连接
  15. 静雅学校有高中吗有计算机,涿州靖雅中学
  16. git 报错解决方法:Your branch is ahead of ‘origin/dev‘ by 65 commits.
  17. 深入浅出matplotlib(101):研究最有名的滤波函数:sinc函数
  18. 关于overflow适配IE的问题
  19. 静态测试 vs 动态测试
  20. ie6下z-index不起作用?

热门文章

  1. php+模版取余,PHP取余函数介绍MOD(x,y)与x%y_php技巧
  2. mybatis-plus配置日志
  3. html日期判断程序,javascript – HTML5日期验证
  4. dubbo的基于java的路由_1 | Dubbo:探讨标签路由的实现
  5. python 数据结构面试_【Python排序面试题】面试问题:所谓数据结构,… - 看准网...
  6. oracle删sequ_Oracle序列(Sequence)创建、使用、修改、删除
  7. 面向对象java试题_经典面向对象试题,用Java做,要详细点的!先谢过了
  8. 从零开始搭二维激光SLAM --- 前言
  9. Windows远程访问Linux (Ubuntu)服务器
  10. 浅谈混合精度训练imagenet