1. 前言

作为一个对三种深度学习框架(Tensorflow,Keras,Pytorch)刚刚完成入门学习的菜鸟,在实战的过程中,遇到了一些菜鸟常见问题,即图片数据加载与预处理。在刚刚接触深度学习的时候,并不懂数据生成器的作用,后面随着数据量的增大,终于感受到了数据生成器的魅力了。我相信有部分新手和我一样,在训练的过程中,喜欢将所有图片数据读取到内存后,然后进行图片的预处理。那么会出现一个很大的问题 - Out of memory! 你可能会觉得是硬件的问题,实际上还是我们太菜了,所以这篇文章我想简单的说一下三种深度学习框架如何搭建数据生成器。

2. Pytorch下的图片数据生成器

首先还是介绍下当下最火热的Pytorch深度学习框架,本来坚持Keras+Tensorflow的我,还是逃不过“真香”警告呀。

不得不说,Pytorch真好用,看Pytorch代码总有种赏心悦目的感觉,Pytorch中大量使用类定义,使得代码看起来更加清晰明了,不论从模型定义到数据生成,都是一气呵成。话不多说,来看一下Pytorch关于数据迭代器的生成,下面的例子都是以最常见的猫狗数据来进行定义的。这个数据集的结构如下图所示。

2.1 图片数据集结构

Dogs_cats_data文件下包含三个子文件夹,分别放置训练集,验证集和测试集

我们打开training_set文件夹,有两个子文件夹dogs和cats,类似地,test和validation里面也有这两个子文件夹

Cats文件夹下和dogs文件夹下的文件如下图所示。命名格式为“cat.xxx.jpg”和“dog.xxx.jpg”。

cats子文件夹的内容

dogs子文件夹的内容

2.2 Pytorch图片数据生成器解读

介绍完了数据集的大致情况,我们就来写一个pytorch关于猫狗数据生成器。这里主要分别两步:

  1. 写一个torch.utils.data.Dataset的子类,继承父类Dataset的一些属性和方法。一般来说,修改该类的三个专有方法initgetitemlen)就可以完成一个新子类的生成。

  2. 使用torch.utils.data.DataLoader结合上面定义的Dataset子类来创建一个数据生成器

为了不占用过多内存,我们需要将图片的所有地址(并不是所有数字化图片)加载到内存中,需要多少图片数据的时候就从内存中解析多少图片地址,这样有效且合理地使用内存,也不会耽误时间。

下面为数据生成器代码,定义了一个Dataset的子类Dogcat,实例化类的时候,需要输入

  1. 图片的根目录root

  2. 是否对图片数据进行转换(标准化、resize、张量化

  3. Train还是val

from torchvision import transforms as T
import os
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
import os
import numpy as np
import pandas as pd
import glob# 搭建一个猫狗数据集生成器
class DogCat(Dataset):def __init__(self, root, transforms=None, train=True, val=False):"""get images and execute transforms."""self.val = valimgs=glob.glob(os.path.join(root,'*/*/*.jpg'))imgs = sorted(imgs, key=lambda x: x.split('.')[-2]) #对图片进行排序,按照图片的索引self.imgs = imgsif transforms is None:# normalizenormalize = T.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])# trainset and valset have different data transform# trainset need data augmentation but valset don't.# valsetif self.val:# T是组合的意思,将下面的操作组合起来self.transforms = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),normalize])# trainsetelse:self.transforms = T.Compose([T.Resize(256),T.RandomResizedCrop(224), #随机大小裁剪T.RandomHorizontalFlip(),T.ToTensor(),normalize])def __getitem__(self, index):"""return data and label"""img_path = self.imgs[index]label = 1 if 'dog' in img_path.split('/')[-1] else 0 #根据图片名称给与label,这里是二分类问题data = Image.open(img_path)data = self.transforms(data)return data, labeldef __len__(self):"""return images size."""return len(self.imgs)if __name__ == "__main__":train_dataset = DogCat('E:\Image_Dataset\dogs_cats_data', train=True)data_loader = DataLoader(dataset=train_dataset, batch_size=10, shuffle=True)for data,label in data_loader:print('label',label)break

这里,getitem是类专有方法,给予了实例化类的索引方法,即可以通过索引访问实例化的Dataset子类(DogCat类)的元素。比如对上述代码中,我可以使用train_dataset[0] (这里的train_dataset是实例化的DogCat类) 来获取第一张图片的信息。在__getitem__方法中,根据图片的索引,对相应位置的图片地址进行解析,读取图片,对其进行transforms,将其转换为图片数据并返回。这样一来,就完成了图片数据的处理。

再将实例化的Dataset作为Dataloader输入的时候,根据设置的batch_size,不断地使用__getitem__方法获取对应batch_size数量图片数据后,作为Dataloader返回值。

实践过程中,使用for循环不断地读取每批处理完后的图片数据,这样可以有效地缓解内存不足问题

3 Keras下的图片数据生成器

当初在学《Python深度学习》的时候,在猫狗分类的时候,使用了一个数据迭代器,代码是这样的,

from keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

上面这个是实例化一个图片数据生成器,这里包含了和pytorch类似的transforms,但是注意到我们这里没有数据地址索引,所以,我们需要使用实例化类的flow_from_directory方法,将图片的地址train_dir和target_size等作为输入,这样就完成了一个数据生成器的生成了,如下

train_generator = train_datagen.flow_from_directory(train_dir,target_size=(150, 150),batch_size=20,class_mode='binary')
validation_generator = test_datagen.flow_from_directory(validation_dir,target_size=(150, 150),batch_size=20,class_mode='binary')

在使用的时候,使用循环就可以不断地生成数据了,每次循环生成batch_size个数据了。下面代码,每次生成20个数据。

for data_batch, labels_batch in train_generator:print('data batch shape:', data_batch.shape)print('labels batch shape:', labels_batch.shape)break

训练过程中,使用fit_generator就可以了

history = model.fit_generator(train_generator,steps_per_epoch=100,epochs=30,validation_data=validation_generator,validation_steps=50)

但是按照上述方法生成生成器,最大的弊端就是不灵活,我只要输入一个地址端口train_dir,他就生成了一个数据生成器,这其实是一个高度封装的函数,那么输入的地址结构信息一定要固定,想用自己的数据集做点事,就很不灵活了。

参考以下这篇文章,我们也可以自定义一个keras的数据生成器。这里的定义就有点像Pytorch中数据生成器的定义,但是这里是继承Sequence生成一个新的子类不像Pytorch是继承Dataset的。

Keras Sequence方法用于拟合一个数据序列,每一个Sequence必须提供__getitem__和__len__方法,这跟Torch的Dataset模块类似。Sequence是进行多进程处理的更安全的方法,这种结构保证网络在每个时期每个样本只训练一次,这与生成器不同。

这里我们仍然使用上述猫狗数据,来新建一个Keras数据生成器。代码如下:

from skimage.io import imread
from skimage.transform import resize
import numpy as np
from keras.utils import Sequence
import cv2
import glob
import osclass Dogs_Cats_DataGenerator(Sequence):"""基于Sequence的自定义Keras数据生成器"""def __init__(self, filepath, batch_size=8, imgshape=(256, 472),n_channels=3, n_classes=13, shuffle=True):""" 初始化方法:param filepath :数据文件地址:param batch_size: batch size:param imgshape: 图像大小:param n_channels: 图像通道:param n_classes: 标签类别:param shuffle: 每一个epoch后是否打乱数据"""self.filepath=filepathself.pathlist=glob.glob(os.path.join(self.filepath,'*/*.jpg'))self.batch_size = batch_sizeself.imgshape = imgshapeself.n_channels = n_channelsself.n_classes = n_classesself.shuffle = shuffle# 每个epoch之后更新索引self.on_epoch_end()# 文件地址listdef __getitem__(self, index):"""生成每一批次训练数据:param index: 批次索引:return: 训练图像和标签"""# 生成批次索引indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]# 索引列表batch_pathlist = [self.pathlist[k] for k in indexes]# 生成数据X = self._generate_X(batch_pathlist)y = self._generate_y(batch_pathlist)return X, ydef __len__(self):"""每个epoch下的批次数量,也就是每个epoch的iteration"""return int(np.floor(len(self.pathlist) / self.batch_size))def _generate_X(self, batch_pathlist):"""生成每一批次的图像:param list_IDs_temp: 批次数据索引列表:return: 一个批次的图像"""# 初始化X = np.empty((self.batch_size, *self.imgshape, self.n_channels))# 生成数据for i, path in enumerate(batch_pathlist):# 存储一个批次X[i,] = self._load_image(path)return Xdef _generate_y(self, batch_pathlist):"""生成每一批次的标签:param list_IDs_temp: 批次数据索引列表:return: 一个批次的标签"""y = np.empty((self.batch_size, ), dtype=int)# Generate datafor i, path in enumerate(batch_pathlist):# Store sampley[i,]= 1 if 'dog' in path.split('/')[-1] else 0return ydef on_epoch_end(self):"""每个epoch之后更新索引"""self.indexes = np.arange(len(self.pathlist))if self.shuffle == True:np.random.shuffle(self.indexes)def _load_image(self, image_path):"""cv2读取图像"""# img = cv2.imread(image_path)img = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)w, h, _ = img.shapeif w>h:img = np.rot90(img)img = cv2.resize(img, self.imgshape)return imgif __name__=='__main__':# Parametersparams = {'batch_size': 10,'n_classes': 2,'n_channels': 3,'shuffle': True,'imgshape':(224,224)}train_filepath=r'E:\Image_Dataset\dogs_cats_data\training_set'val_filepath=r'E:\Image_Dataset\dogs_cats_data\validation'# Generatorstraining_generator = Dogs_Cats_DataGenerator(train_filepath, **params)validation_generator =  Dogs_Cats_DataGenerator(val_filepath, **params)for x,y in training_generator:print(y)break

到这里,我们就完成了对keras生成器的搭建了,其实和Pytorch的生成器还是很像,只不过继承的父类不一样罢了。

4 Tensorflow下的图片数据生成器

Tensorflow在我心中是老大哥,也是让我真正感受到计算图美妙的深度学习框架,虽然不如某torch简单易用(这确实是事实),如果不是大家都用pytorch,导致一些开源代码都是基于pytorch的,只用tensorflow和keras就能完全满足我的深度学习进程了。好了,话不多说,下面就用我的“前任”深度学习框架来创造一个图片数据生成器。

在Tensorflow中,使用tf.data.Dataset直接就可以实例化一个图片数据生成器。我将创建步骤分为以下几步:

  1. 根据读取数据的类型实例化一个Dataset

  2. 使用map函数对实例化的Dataset进行数据的预处理

  3. 创建一个数据迭代器,根据自身需求,创建不同类型的迭代器

  4. get_next()函数依次获取数据

讲了大致的思路后,接着我们开始定义实例化Dataset,当然还是以猫狗数据集为例,我们的思路仍然是将猫狗的数据的地址作为数据集的输入需要多少个图片数据的时候,就解析多少个图片地址。代码如下

    data_dir = config.dataset_dirdata_root = pathlib.Path(data_dir)all_image_path = list(data_root.glob('*/*'))all_image_path = [str(path) for path in all_image_path]label_names = sorted(item.name for item in data_root.glob('*/'))# # dict: name->index ,如{'cats': 0, 'dogs': 1}label_to_index = dict((index, name) for name, index in enumerate(label_names))print('train dic index',label_to_index)# # get all images' labelsall_image_label = [label_to_index[pathlib.Path(p).parent.name] for p in all_image_path]# # load dataset and preprocess images(the preprocess function in map  can preprocess all previous pictures)image_dataset = tf.data.Dataset.from_tensor_slices(all_image_path).map(load_and_preprocess_image_for_train)# print(image_dataset)label_dataset = tf.data.Dataset.from_tensor_slices(all_image_label)dataset = tf.data.Dataset.zip((image_dataset, label_dataset))

上面分别对image和label各实例化一个数据集,再组合成一个数据集。我们注意到,这里面使用了map函数,用来解析batch_size个图片地址,并加以预处理

定义如下:

def load_and_preprocess_image_for_train(img_path):# read picturesimg_raw = tf.io.read_file(img_path)# decode picturesimg_tensor = tf.image.decode_jpeg(img_raw, channels=channels)# resizeimg_tensor = tf.image.resize(img_tensor, [image_height, image_width])#tf.cast() function is a type conversion function that converts the data format of x into dtypeimg_tensor = tf.cast(img_tensor, tf.float32)# normalizationimg_tensor = img_tensor / 255.0# flip left or rightimg_tensor = tf.image.random_flip_left_right(img_tensor)# change color randomlyimg_tensor=change_color(img_tensor,NUM=np.random.randint(4))return img_tensor

定义完了相关的数据集后,还有一步,就是设定:

  1. batch_size,

  2. shuffle

  3. epoch

只需一步就到位了

train_dataset =dataset.shuffle(buffer_size=10000).batch(batch_size=config.BATCH_SIZE).repeat(config.Epoches)

到此就完成了Tensorflow数据集的定义和预处理,接着我们需要定义迭代器,如下:

 iteration_train = train_dataset.make_initializable_iterator()train_image_batch, train_label_batch = iteration_train.get_next()

然后使用sess.run()来获取每组生成数据,用来参与训练就好了

 for i in range(config.training_step):images,labels=sess.run([train_image_batch,train_label_batch])_,train_loss,acc=sess.run([train_op,losses,accuracy],feed_dict={image:images,label:labels,is_training:True})train_loss_list.append(train_loss)train_acc_list.append(acc)

到这里,使用Tensorflow来创建数据迭代器就完成了,没有多少花哨的地方,简单又直接。

5.总结

到这里,共计耗费我一天的时间,终于将基于三个深度学习框架图片数据迭代器(又可以叫数据生成器)总结写完了,目的是为了我以后回头看的时候更加方便,也为了可以让广大网友们学习、借鉴,甚至批评指正

建议阅读:

高考失利之后,属于我的大学本科四年

【资源分享】对于时间序列,你所能做的一切.

【时空序列预测第一篇】什么是时空序列问题?这类问题主要应用了哪些模型?主要应用在哪些领域?

【AI蜗牛车出品】手把手AI项目、时空序列、时间序列、白话机器学习、pytorch修炼

公众号:AI蜗牛车

保持谦逊、保持自律、保持进步

个人微信

备注:昵称+学校/公司+方向

如果没有备注不拉群!

拉你进AI蜗牛车交流群

基于Pytorch、Keras、Tensorflow的图片数据生成器搭建相关推荐

  1. 【项目实战课】人人免费可学!基于Pytorch的图像分类简单任务数据增强实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的图像分类简单任务数据增强实战>.所谓项目实战课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的 ...

  2. 【项目实战课】基于Pytorch的Pix2Pix黑白图片上色实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的黑白图像上色实战>.所谓项目实战课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的实战讲解. ...

  3. PyTorch 使用torchvision进行图片数据增广

    使用torchvision来进行图片的数据增广 数据增强就是增强一个已有数据集,使得有更多的多样性.对于图片数据来说,就是改变图片的颜色和形状等等.比如常见的: 左右翻转,对于大多数数据集都可以使用: ...

  4. 【项目实战课】基于Pytorch的UGATIT人脸动漫风格化实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的UGATIT人脸动漫风格化实战>.所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的实 ...

  5. 【项目实战课】基于Pytorch的SRGAN图像超分辨实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的SRGAN图像超分辨实战>.所谓项目实战课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的实战 ...

  6. 【项目实战课】基于Pytorch的InceptionNet花卉图像分类实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的InceptionNet花卉图像分类实战>.所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行 ...

  7. 【项目实战课】基于Pytorch的Semantic_Human_Matting(人像软分割)实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的Semantic_Human_Matting(人像软分割)实战>.所谓项目实战课,就是以简单的原理回顾+详细的项目实战的模式, ...

  8. 【项目实战课】基于Pytorch的DANet自然图像降噪实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的DANet自然图像降噪实战>.所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的实战讲 ...

  9. 【项目实战课】基于Pytorch的EnlightenGAN自然图像增强实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的EnlightenGAN自然图像增强实战>.所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行 ...

最新文章

  1. Cookie注入是怎样产生的
  2. P3321 [SDOI2015]序列统计
  3. python是什么编程教程-python教程看完了,还是不会编程?
  4. echarts toolbox 自定义工具
  5. MySql入门使用:登录及简单创建查询表
  6. 2021-04-11面试
  7. java网络编程实例_关于java网络编程的实例代码
  8. 拦截器、过滤器、@Aspect 区别
  9. 《大数据》2020年第3期目次摘要
  10. Activating Browser Modes with Doctype
  11. java 图文生成图片_java生成图片
  12. VS2015启动遇到的一些问题和解决方法
  13. Away3d学习笔记(1)
  14. 用户故事与敏捷方法—优秀用户故事准则
  15. U956(MTK6589系列)移植乐蛙教程
  16. 迄今为止最完整的DDD实践
  17. 信息熵与两种编码基础
  18. java商城管理系统_带数据库_带界面化可用来毕业设计
  19. 【Swift 60秒】51 - Closures as parameters
  20. css3探测光圈_CSS3光圈散开提示效果

热门文章

  1. Qualcomm ADK6 EARBUD APPLICATION
  2. 前端笔试面试题--1
  3. Excel如何从记录信息里批量提取出QQ号码
  4. win7打开桌面计算机很慢,Win7电脑反应慢如何解决?Win7电脑反应慢的解决方法
  5. 大厂前端面试分享:如何让6000万数据包和300万数据包在仅50M内存环境中求交集
  6. 四舍五入(多种方法)
  7. ROS 机器人双路视频手机监控
  8. 阿里资深架构师推荐:企业数字化转型私房菜
  9. tableau prep builder etl工具使用注意事项
  10. 首页被改为so8.zj.cn修复工具