目录

下载数据集及显示样本

数据集类

建立数据集类及显示部分样本

数据变换

后记


python提供了许多工具简化数据加载,使代码更具可读性。经常用到的包有scikit-image、pandas等,本文通过相关包进行数据加载和预处理相关简要介绍。

从此处(提取码:ilqy)下载数据集,数据存于"data/faces/"的目录中。这个数据集实际上是imagenet数据集标注为face的图片当中在dlib面部检测(dlib's pose estimation)表现良好的图片。下面以该数据集为例,对数据加载即预处理进行简要介绍。

下载数据集及显示样本

下面为下载数据集及显示其中某一样本的相关代码:

#!/usr/bin/env torch
# -*- coding:utf-8 -*-
# @Time  : 2021/2/3, 23:54
# @Author: Lee
# @File  : test.pyimport os, torch
import pandas as pd
from skimage import io, transform  # 用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasetsimport warnings
warnings.filterwarnings("ignore")plt.ion()  # interactive mode# # 读取数据集  将csv中的标注点数据读入(N,2)数组中,其中N是特征点的数量
landmarks_frame = pd.read_csv("data/faces/face_landmarks.csv")n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].values
landmarks = landmarks.astype('float').reshape(-1, 2)print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))# 展示一张图片和它对应的标注点
def show_landmarks(image, landmarks):plt.imshow(image)plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')plt.pause(0.001)plt.figure()
show_landmarks(io.imread(os.path.join("data/faces/", img_name)), landmarks)
plt.show()
plt.pause(0)

打印结果如下:

Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.][33. 76.][34. 86.][34. 97.]]

显示图形如下:

数据集类

torch.utils.data.Dataset是表示数据集的抽象类,因此自定义数据集应继承Dataset并覆盖以下方法*__len__实现len(dataset)返回数据集的尺寸。*__getitem__用来获取一些索引数据,例如dataset[i]中的(i)。

建立数据集类及显示部分样本

为面部数据集创建一个数据集类。在__init__中读取csv的文件内容,在__getitem__中读取图片。这么做是为了节省内存空间。只有在需要用到图片的时候才读取它而不是一开始就把图片全部放到内存中。

数据样本将按这样一个字典{'image':image, 'landmarks':landmarks}组织。该数据类将添加一个可选参数transform以方便对样本进行预处理,代码如下:

#!/usr/bin/env torch
# -*- coding:utf-8 -*-
# @Time  : 2021/2/4, 0:13
# @Author: Lee
# @File  : dataset_class.pyimport os, torch
import pandas as pd
from skimage import io, transform  # 用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasetsimport warnings
warnings.filterwarnings("ignore")plt.ion()  # interactive mode# 展示一张图片和它对应的标注点
def show_landmarks(image, landmarks):plt.imshow(image)plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')plt.pause(0.001)# 数据集类
# 数据样本按这样一个字典{'image': image, 'landmarks': landmarks}组织。
# 添加一个可选参数transform 以方便对样本进行预处理
class FaceLandmarksDataset(Dataset):"""人脸标记数据集"""def __init__(self, csv_file, root_dir, transform=None):"""csv_file(string):带注释的csv文件的路径。root_dir(string):包含所有图像的目录。transform(callable, optional):一个样本上的可用的可选变换"""self.landmarks_frame = pd.read_csv(csv_file)self.root_dir = root_dirself.transform = transformdef __len__(self):return len(self.landmarks_frame)def __getitem__(self, idx):img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])image = io.imread(img_name)landmarks = self.landmarks_frame.iloc[idx, 1:]landmarks = np.array([landmarks])landmarks = landmarks.astype('float').reshape(-1, 2)sample = {'image': image, 'landmarks': landmarks}if self.transform:sample = self.transform(sample)return sample# 获取图片并可视化部分图片
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/')
fig = plt.figure()
for i in range(len(face_dataset)):sample = face_dataset[i]print(i, sample['image'].shape, sample['landmarks'].shape)ax = plt.subplot(1, 4, i+1)ax.set_title('Sample #{}'.format(i))ax.axis('off')show_landmarks(**sample)if i == 3:plt.show()break
plt.pause(0)

运行结果如下:

0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

数据变换

通过上面的例子可知数据集中的图片并不是同样的尺寸。绝大多数神经网络都假定图片的尺寸相同。因此需要做预处理。这里以三个转换*Rescale(缩放图片),*RandomCrop(对图片进行随机剪裁),*ToTensor(把numpy格式的图片转换为torch格式图片,需要交换坐标轴)。

可以将它们写成可调用的类的形式而不是简单的函数,这样就不需要每次调用时传递一遍参数。只需要实现__call__方法,必要的时候实现__init__方法。

把这些整合起来以创建一个带组合转换的数据集。每次这个数据集被采样时,*即使地从文件中读取图片*对读取的图片应用转换*,由于其中一步是随机的(Randmpcrop),数据有所增强,现在可以用循环来对所有创建的数据执行同样的操作。

但是,对所有数据简单地使用for循环牺牲了很多功能,尤其是*批量处理数据(指定batch_size)*打乱数据(shuffle置True)*使用多线程(multiprocessingworker)并加载数据。

torch.utils.data.DataLoader是一个提供了上述所有这些功能的迭代器。下面使用的参数必须是清楚的。一个值得关注的参数是collate_fn,可以通过它来决定如何对数据进行批量处理,但是绝大多数情况下默认值就能运行良好。

代码如下:

#!/usr/bin/env torch
# -*- coding:utf-8 -*-
# @Time  : 2021/2/4, 0:28
# @Author: Lee
# @File  : data_preprogress.pyimport os, torch
import pandas as pd
from skimage import io, transform  # 用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasetsimport warnings
warnings.filterwarnings("ignore")plt.ion()  # interactive mode# 展示一张图片和它对应的标注点
def show_landmarks(image, landmarks):plt.imshow(image)plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')plt.pause(0.001)class FaceLandmarksDataset(Dataset):"""人脸标记数据集"""def __init__(self, csv_file, root_dir, transform=None):"""csv_file(string):带注释的csv文件的路径。root_dir(string):包含所有图像的目录。transform(callable, optional):一个样本上的可用的可选变换"""self.landmarks_frame = pd.read_csv(csv_file)self.root_dir = root_dirself.transform = transformdef __len__(self):return len(self.landmarks_frame)def __getitem__(self, idx):img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])image = io.imread(img_name)landmarks = self.landmarks_frame.iloc[idx, 1:]landmarks = np.array([landmarks])landmarks = landmarks.astype('float').reshape(-1, 2)sample = {'image': image, 'landmarks': landmarks}if self.transform:sample = self.transform(sample)return sampleclass Rescale(object):"""将样本中的图像重新缩放到给定大小Args:output_size(tuple或int):所需的输出大小。如果是元组,则输出为与output_size匹配。如果是int,则匹配较小的边缘到output_size保持横纵比相同"""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))self.output_size =output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]if isinstance(self.output_size, int):if h > w:new_h, new_w = self.output_size * h / w, self.output_sizeelse:new_h, new_w = self.output_size, self.output_size * w / helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)img = transform.resize(image, (new_h, new_w))# h and w are swapped for landmarks because for images,# x and y axes are axis 1 and 9 respectivelylandmarks = landmarks * [new_w / w, new_h / h]return {'image': img, 'landmarks': landmarks}class RandomCrop(object):"""随机裁剪样本中的图像Args:output_size(tuple或int):所需的输出大小,如果是int, 方形裁剪是"""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))if isinstance(output_size, int):self.output_size = (output_size, output_size)else:assert len(output_size) == 2self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]new_h, new_w = self.output_sizetop = np.random.randint(0, h-new_h)left = np.random.randint(0, w-new_w)image = image[top: top+new_h, left:left+new_h]landmarks = landmarks - [left, top]return {'image': image, 'landmarks': landmarks}class ToTensor(object):"""将样本中的ndarrays转换为Tensors"""def __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']"""交换颜色轴原因numpy包的图片时H*W*C 而torch包的图片是 C*H*W"""image = image.transpose((2, 0, 1))return {'image': torch.from_numpy(image),'landmarks': torch.from_numpy(landmarks)}# 辅助功能:显示批次
def show_landmark_batch(sample_batched):"""show image with landmarks for a batch of samples"""images_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks']batch_size = len(images_batch)im_size = images_batch.size(2)grid_border_size = 2grid = utils.make_grid(images_batch)plt.imshow(grid.numpy().transpose((1, 2, 0)))for i in range(batch_size):plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i +1) * grid_border_size,landmarks_batch[i, :, 1].numpy() + grid_border_size, s=10, marker='.', c='r')plt.title('Batch from dataloader')if __name__ == '__main__':# 获取图片并可视化部分图片# 数据变换及torchvision.transforms.Compose组合操作face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/')scale = Rescale(256)crop = RandomCrop(128)composed = transforms.Compose([Rescale(256),RandomCrop(224)])# 在样本上应用上述的变换fig = plt.figure()sample = face_dataset[65]print('数据变换及torchvision.transforms.Compose组合操作')for i, tsfrm in enumerate([scale, crop, composed]):transformed_sample = tsfrm(sample)ax = plt.subplot(1, 3, i + 1)plt.tight_layout()ax.axis('off')ax.set_title(type(tsfrm).__name__)show_landmarks(**transformed_sample)plt.show()plt.pause(0.5)# 迭代数据集transformed_dataset = FaceLandmarksDataset(csv_file="data/faces/face_landmarks.csv",root_dir="data/faces/",transform=transforms.Compose([Rescale(256), RandomCrop(224), ToTensor()]))for i in range(len(transformed_dataset)):sample = transformed_dataset[i]print(i, sample['image'].size(), sample['landmarks'].size())if i == 3:breakdataloader = DataLoader(transformed_dataset, batch_size=4,shuffle=True, num_workers=4)print('迭代数据集,batch_size=4')for i_batch, sample_batched in enumerate(dataloader):print(i_batch, sample_batched['image'].size(),sample_batched['landmarks'].size())if i_batch == 3:plt.figure()show_landmark_batch(sample_batched)plt.axis('off')plt.ioff()plt.show()breakplt.pause(0)

运行结果如下:

数据变换及torchvision.transforms.Compose组合操作
0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])
迭代数据集,batch_size=4
0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

后记

上面的例子用函数实现了数据的部分预处理操作,主要包括使用数据集类(datasets),转换(transform)和数据加载器(DataLoader)。torchvision包提供了畅通的数据集类datasets和转换transforms,可能并不需要我们自己构造这些类。torchvision中海油一个更常用的数据集类ImageFolder。它假定了数据集是以如下方式构造的:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
...
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

其中'ants','bees'等是分类标签。在PIL.Image中可以使用类似的转换transform,例如RandHorizontalFlip,Scale。利用这些可以按如下方式创建一个数据集加载器(dataloader),以hymenoptera_data(提取码:2rvf)数据集为例:

#!/usr/bin/env torch
# -*- coding:utf-8 -*-
# @Time  : 2021/2/4, 1:06
# @Author: Lee
# @File  : data_transform.pyimport os, torch
import pandas as pd
from skimage import io, transform  # 用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasetsimport warnings
warnings.filterwarnings("ignore")plt.ion()  # interactive modedata_transform = transforms.Compose([transforms.RandomSizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0,225])])
hymenpotera_dataset = datasets.ImageFolder(root='data/hymenoptera_data/train',transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenpotera_dataset,batch_size=4, shuffle=True,num_workers=4)print(dataset_loader)

打印结果如下:

<torch.utils.data.dataloader.DataLoader object at 0x000001C37596F608>

Debug可获得如下窗口:

Pytorch基础(三)数据集加载及预处理相关推荐

  1. PyTorch基础(四)-----数据加载和预处理

    前言 之前已经简单讲述了PyTorch的Tensor.Autograd.torch.nn和torch.optim包,通过这些我们已经可以简单的搭建一个网络模型,但这是不够的,我们还需要大量的数据,众所 ...

  2. PyTorch 系列 | 数据加载和预处理教程

    图片来源:Unsplash,作者:Damiano Baschiera 2019 年第 66 篇文章,总第 90 篇文章 本文大约 8000 字,建议收藏阅读 原题 | DATA LOADING AND ...

  3. pytorch dataset自定义_PyTorch 系列 | 数据加载和预处理教程

    原题 | DATA LOADING AND PROCESSING TUTORIAL 作者 | Sasank Chilamkurthy 原文 | https://pytorch.org/tutorial ...

  4. pytorch dataset自定义_PyTorch | 数据加载及预处理教程

    原题 | DATA LOADING AND PROCESSING TUTORIAL 作者 | Sasank Chilamkurthy 译者 | kbsc13("算法猿的成长"公众号 ...

  5. pytorch dataset dataloader_PyTorch(五)——数据的加载和预处理

    前言 PyTorch通过torch.utils.data对一般的常用数据进行封装,可以很容易地实现多线程数据预读和批量加载.torchvision已经预先实现了常用的图像数据集,包括CIFAR-10. ...

  6. pytorch中的数据加载(dataset基类,以及pytorch自带数据集)

    目录 pytorch中的数据加载 模型中使用数据加载器的目的 数据集类 Dataset基类介绍 数据加载案例 数据加载器类 pytorch自带的数据集 torchvision.datasets MIN ...

  7. 【自然语言处理入门系列】加载和预处理数据-以Cornell Movie-Dialogs Corpus数据集为例

    [自然语言处理入门系列]加载和预处理数据-以Cornell Movie-Dialogs Corpus数据集为例 Author: Yirong Chen from South China Univers ...

  8. 【 数据集加载 DatasetDataLoader 模块实现与源码详解 深度学习 Pytorch笔记 B站刘二大人 (7/10)】

    数据集加载 Dataset&DataLoader 模块实现与源码详解 深度学习 Pytorch笔记 B站刘二大人 (7/10) 模块介绍 在本节中没有关于数学原理的相关介绍,使用的数据集和类型 ...

  9. Pytorch中的数据加载

    Pytorch中的数据加载 1. 模型中使用数据加载器的目的 在前面的线性回归模型中,使用的数据很少,所以直接把全部数据放到模型中去使用. 但是在深度学习中,数据量通常是都非常多,非常大的,如此大量的 ...

最新文章

  1. linux 进入单用户模式修改root密码
  2. [恢]hdu 1860
  3. TensorFlow文件操作
  4. 开发里程碑计划_项目里程碑你真的会用了吗?(干货)
  5. SQL语句(DQL)
  6. 深入理解javascript函数参数
  7. (38)FPGA原语设计(BUFH)
  8. android9 前台服务通知_Android通知概览
  9. android开发:Android 中自定义属性(attr.xml,TypedArray)的使用
  10. linux目录蓝色,前言linux系统默认目录颜色是蓝色的,在黑背景下看不清楚,可以通过以下2种方法修改ls查看的颜色。方法:1、拷贝/etc/DIR_COLORS文件为...
  11. Lucene2.4.0一般查询结果过滤与排行
  12. 15个基本不定积分公式和分类基本积分表
  13. 串口服务器接无线网桥,AB7006-HMS串口服务器、Anybus-M主站、Anybus-S从站接口模块...
  14. MapGuide open source开发心得一:简介
  15. 通过S2B2C供应链电商平台网站解决方案,实现大宗商品万亿产业数字化转型
  16. CTF-代码审计(2)
  17. git push时 please tell me who you are 或 git fatal: empty ident name (for <>) not llowed
  18. 数据库中什么是内联接、左外联接、右外联接?
  19. 国产单机《我的武林江湖》v1.1.159
  20. 京东轮播图片的静态页面CSS3

热门文章

  1. ssm如何支持热部署_最新Spring Boot实战文档推荐:项目搭建+配置+SSM整合
  2. ES6入门之对象的扩展
  3. 黑客攻击公司化:网络犯罪也有商业模式也有CEO
  4. 判断Sbo的Matrix中是否存在相同数据行
  5. 教你如何配置IIS Rewrite模块写规则
  6. IT人母亲的美国之行(8)
  7. nginx linux 下载安装,Linux(CentOS)下载安装Nginx并配置
  8. 鸿蒙会取代emui,华为称自家手机运行鸿蒙系统正在推进 未来会取代安卓吗?
  9. linux dd 随机文件,Linux之dd工具
  10. python xampp mysql_让XAMPP支持Python及Django