pytorch构建自定义dataset及一些经验
目录
- 1.前言
- 2. 构建 datasets
- 3.Dataloader的使用
- 4. 后记
1.前言
此篇文章是作者在搭建数据集碰壁,改错过程中,不断疯狂地CSDN、博客园、torch官网总结出来的,希望能带给读者一些帮助。
先理清一下逻辑脉络:我们要 构建DataLoader 将处理过(如shuffle,batch)的数据提供给要训练或验证的模型, 而Dataloader的数据需要从数据集中获取,所以最先要构建好数据集。
2. 构建 datasets
在实现自定义数据集前,首先要了解一下torch.utils.data.Dataset
这个类。 官方文档
下面是我结合源码和官方文档给出的理解。
一言以蔽之,想用索引的方式来让DataLoader获取数据的话,你的自定义dataset需要继承torch.utils.data.Dataset
。另一种Iterable-style的 datasets可以上官网自己查看。
默认情况下,DataLoader构造一个生成整数索引的索引采样器(sampler)来获取Dataset里面的数据。
而继承Dataset这个类时,它要求你必须重写__getitem__(self, index)
、 __len__(self)
两个方法,前者通过提供索引返回数据,也就是提供 DataLoader获取数据的方式;后者返回数据集的长度,DataLoader依据 len 确定自身索引采样器的长度。
所以需要我们动手做的是搭建好索引到数据文件之间的映射,然后继承Dataset类,重写两个函数。最后让DataLoader能够通过索引采样器来取数据。
下面给出框架:
import torch.utils.data as data
class Mydataset(data.Dataset):def __init__(self):passdef __getitem__(self, index):pass return ##返回你要提供给Dataloader的一个样本(数据+标签)def __len__(self):return ## 返回数据集的长度
tips: 可以定义一个函数:
遍历原始数据文件,将文件的路径存到一个列表中。如果一个文件只含一个样本,那么列表长度就是数据集长度了。然后将用列表的索引作为DataLoader读取Dataset使用的索引。 在getitem时,用索引获取列表中的文件路径,提取并返回数据,如果是小的数据集的话可能还需要进行在线数据增强来减少过拟合的可能性。可以看下实例:
自定义Dataset例子
3.Dataloader的使用
Dataloader类:提供了加载数据集的多种方式
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False)
啰嗦一下Dataloader获取数据的过程(详解可见后记中的2):
① Dataloader根据你提供的dataset,生成一个长度为数据集大小的采样器(sampler)。
② 若未自定义sampler,会根据你是否打乱(shuffle), 选择顺序采样器(sequential sampler)或随机采样器(random sampler)。
③ sampler首先根据Dataset的大小n形成一个可迭代的序号列表[0~n-1]。
④ batch_sampler 根据DataLoader的batch_size将sampler提供的序列划分成多个batch大小的可迭代序列组,drop_last参数决定是否保留最后一组。
下面就讲一下常见的几个参数。
最重要的是dataset,它是Dataloader获取数据的来源。Dataloader支持两种数据集类型,一种是上面讲述的 Map-style datasets(索引到样本的映射式的数据集),另一种为iterable-style datasets(迭代式的数据集)。
batch_size: 设置批训练的大小。
shuffle: 是否打乱数据集。 训练集正常是要打乱的,为的是尽可能避免过拟合。简单来说,就是不让训练数据的分布规律成为一种特征被牛X的神经网络学习,从而使模型在泛化能力上大打折扣,也即过拟合。
num_workers: 设置num_workers个线程,batch_sampler将指定batch分配给指定worker,worker将它负责的batch加载进RAM。若Dataloader需要的指定batch若不存在,就让每个worker继续加载batch到内存,直至找到指定batch。 好处是可以提前将一些batch载入内存,加快寻batch速度,但是会导致内存开销过大。
pin_memory:设置是否使用锁页内存。True的话,在锁页内存的Tensors会直接映射一份到GPU显存(锁页内存)上,省掉了数据从CPU到RAM再到GPU的时间数据传输时间。但是吃内存啊QAQ。
drop_last: 设定为 True时, 如果数据集大小不能被batch_size整除的话, 将丢掉最后一个不完整的batch。
tips:
- num_workers和pin_memory的设置要根据自己机子或服务器的实际情况,设置好的话可以提高GPU利用率。num_workers默认情况下是0,你不设置的话,Dataloader就从RAM找batch,找不到再去慢悠悠地加载指定batch,从而导致GPU利用率过低。
上图是跑项目时设置了num_workers为32的情况。第一次调Dataloader时我没设置num_workers,虽然GPU内存占用率(Memory usage)很高, 但是GPU利用率(GPU-Util)是0,那时候纳闷了好久,这也是此文诞生的原因吧。
4. 后记
若有错误,请指正。
最后,十分感谢下方大佬们文章给予的帮助!orz
- 董小姐~:pytorch技巧 五: 自定义数据集 torch.utils.data.DataLoader 及Dataset的使用
- 一步徐龙的浪:Sampler, DataLoader和数据batch的形成
- 小塞: 【Q&A】pytorch中的worker如何工作的
- 吨吨不打野:GPU显存占满利用率GPU-util为0
pytorch构建自定义dataset及一些经验相关推荐
- 在pytorch中自定义dataset读取数据2021-1-8学习笔记
在pytorch中自定义dataset读取数据 utils import os import json import pickle import randomimport matplotlib.pyp ...
- 深度学习-Pytorch:项目标准流程【构建、保存、加载神经网络模型;数据集构建器Dataset、数据加载器DataLoader(线性回归案例、手写数字识别案例)】
1.拿到文本,分词,清晰数据(去掉停用词语): 2.建立word2index.index2word表 3.准备好预训练好的word embedding 4.做好DataSet / Dataloader ...
- Pytorch自定义Dataset和DataLoader去除不存在和空的数据
Pytorch自定义Dataset和DataLoader去除不存在和空的数据 [源码GitHub地址]:https://github.com/PanJinquan/pytorch-learning-t ...
- 使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作
使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作 总共分为四步 构造一个my_dataset类,继承自torch.utils.data.Dataset 重写__getite ...
- rcnn代码实现_轻松学Pytorch实现自定义对象检测器
点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 大家好,今天来继续更新轻松学Pytorch专栏,这个是系列文章我会一直坚持写下去的,希望大家转发.点赞.留言支 ...
- PyTorch基础-自定义数据集和数据加载器(2)
处理数据样本的代码可能会变得混乱且难以维护: 理想情况下,我们想要数据集代码与模型训练代码解耦,以获得更好的可读性和模块化.PyTorch 域库提供了许多预加载的数据(例如 FashionMNIST) ...
- 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下
文章目录 1 测试鉴别器 2 建立生成器 3 测试生成器 4 训练生成器 5 使用生成器 6 内存查看 上一节,我们已经建立好了模型所必需的鉴别器类与Dataset类. 使用PyTorch构建GAN生 ...
- Pytorch之DataLoader Dataset、datasets、models、transforms的认识和学习
文章目录 利用PyTorch框架来开发深度学习算法时几个基础的模块 Dataset & DataLoader 基础概念 自定义数据集 1 读取自定义数据集 1 自定义数据集 2 自定义数据集3 ...
- ASP.NET性能优化之构建自定义文件缓存
ASP.NET的输出缓存(即静态HTML)在.NET4.0前一直是基于内存的.这意味着如果我们的站点含有大量的缓存,则很容易消耗掉本机内存.现在,借助于.NET4.0中的OutputCacheProv ...
最新文章
- 【视觉SLAM14讲】ch3课后题答案
- vb.net与matlab的混合编程
- [15] 星星(Star)图形的生成算法
- C语言处理字符串及内存操作
- 为什么建议大家使用 Linux 开发
- 到底能不能做一辈子的程序员——大龄程序员将何去何从
- Windows Phone中Wallet钱包的使用
- php dimage加上域名,PHP全功能无变形图片裁剪操作类与用法示例
- Ubuntu查看系统任务管理器(cpu+内存资源占用)+查看虚拟机分配核心数
- 最详细的linux下的磁盘分区及格式化
- python操作界面_Python使用PyQt5的Designer工具创建UI界面
- LeetCode 797. 所有可能的路径(DFS)
- php容器原理,容器与依赖注入的原理
- 页面发送请求到后台报错“Empty or invalid anti forgery header token.”问题解决
- Linux 如何打开pyo文件,Python的文件类型
- C语言程序设计教程(第三版)课后习题8.2
- vue 获取当前本机ip_Vue项目启动时自动获取本机IP地址
- 生活里不能只有苦涩,不堪,适当露出一条缝隙,让光透进来
- Windows7 VS2015 下编译 PythonQt3.2
- 建武28a对讲机最大距离_健伍TH-26A,TG-28A,TH-28A和TK208对讲机检修实例说明