背景
当不同类型数据的数量差别巨大的时候,比如猫有200张训练图片,而狗有2000张,很容易出现模型只能学到狗的特征,导致准确率无法提升的情况。

这时候,一种可行的方法就是对原始数据集进行采样,从而生成猫、狗图片数量接近的新数据集。这个新数据集中可能猫、狗图片都各有500张,其中猫的图片有一部分重复的,而狗的2000张图片中有一部分没有被采样到,但是这时候新数据集的数据分布是均衡的,就可以比较好的训练了。

操作方法
我们知道pytorch训练一般都是用的DataLoader加载数据的,我们可以通过给Dataloader传入一个sampler的采样器进行采样操作。

train_loader = DataLoader( train_dataset, batch_size=256, num_workers=2, sampler=sampler)

采样器sampler有多种,大家可以根据自己需要研究一下,这里我们使用一个按权重采样的WeightedRandomSampler。其作用是:我们可以人为的给每张图片定一个被抽取到的概率,一般每一类的所有图片的概率可以一样,然后就按每个图片的这个概率对整个数据集进行重新采样。

比如:猫只有200张图片,我们设置取到每张猫的图片的概率为1/200,而狗有2000张图片,我们设置取到每张狗的概率为1/2000。这样虽然狗的图片比较多,但我们取到猫和狗的概率是一样的,只是猫会有一些重复,而狗有一些不会取到,最终形成的新数据集就平衡了。

参考 https://blog.csdn.net/tyfwin/article/details/108435756


注意
注意上图的replacement参数,为True表示有放回的采样,也就是我们上边说的那种采样,有部分数据重复,有部分数据没有出现;为False表示不放回的采样,即采样后的数据集跟原来一样,只是内部数据的顺序有些变化,概率大的可能会在前边,这主要作用于有序的数据。num_samples定义采样的次数,也即采样后的数据集数目,一般设为跟原来一样。

代码

参考 https://www.cnblogs.com/king-lps/p/11004653.html

# 定义每个类别采样的权重,这个只做参考,可以根据自己需要随便定义
target = train_dataset.targets
class_sample_count = np.array([len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])
samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
# 读取原始数据为datasets对象
dataset_train = datasets.ImageFolder(traindir)                                                                                                         # 在DataLoader的时候传入采样器即可
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, sampler = sampler)

pytorch对数据集进行重新采样相关推荐

  1. Pytorch自定义数据集

    简述 Pytorch自定义数据集方法,应该是用pytorch做算法的最基本的东西. 往往网络上给的demo都是基于torch自带的MNIST的相关类.所以,为了解决使用其他的数据集,在查阅了torch ...

  2. PyTorch Upsample() 函数实现上采样

    PyTorch Upsample() 函数实现上采样 import torch import torch.nn as nninput = torch.arange(1, 5, dtype=torch. ...

  3. ML之FE:数据随机抽样之利用pandas的sample函数对超大样本的数据集进行随机采样,并另存为csv文件

    ML之FE:数据随机抽样之利用pandas的sample函数对超大样本的数据集进行随机采样,并另存为csv文件 目录 数据随机抽样之利用pandas的sample函数对超大样本的数据集进行随机采样,并 ...

  4. DataScience:对严重不均衡数据集进行多种采样策略(随机过抽样、SMOTE过采样、SMOTETomek综合采样、改变样本权重等)简介、经验总结之详细攻略

    DataScience:对严重不均衡数据集进行多种采样策略(随机过抽样.SMOTE过采样.SMOTETomek综合采样.改变样本权重等)简介.经验总结之详细攻略 目录

  5. 从零开始的图像语义分割:FCN快速复现教程(Pytorch+CityScapes数据集)

    从零开始的图像语义分割:FCN复现教程(Pytorch+CityScapes数据集) 前言 一.图像分割开山之作FCN 二.代码及数据集获取 1.源项目代码 2.CityScapes数据集 三.代码复 ...

  6. pytorch 读取数据集(LiTS-肝肿瘤分割挑战数据集)

    pytorch 读取数据集 我的数据集长这样: xx.png和xx_mask.png是对应的待分割图像和ground truth 读取数据集 数据集对象被抽象为Dataset类,实现自定义的数据集需要 ...

  7. pytorch自定义数据集DataLoder

    pytorch官方例程: DATA LOADING AND PROCESSING TUTORIAL torch.utils.data.Dataset 是dataset的抽象类,我们可以同过继承Data ...

  8. 数据集制作_轻松学Pytorch自定义数据集制作与使用

    点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 大家好,这是轻松学Pytorch系列的第六篇分享,本篇你将学会如何从头开始制作自己的数据集,并通过DataLo ...

  9. 【小白学习PyTorch教程】十七、 PyTorch 中 数据集torchvision和torchtext

    @Author:Runsen 对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext. 之前使用 torchDataLoader类直接加载图像并将其转换为张量. ...

最新文章

  1. 前端怎么通过后台来判断已读状态_类目图片支持商家后台设置 | 前端设计
  2. 表达式树 php,Linux_LINQ学习笔记:表达式树,构建查询表达式 本节中, 我们 - phpStudy...
  3. Java IO: FileReader和FileWriter
  4. Linux运维基础命令笔试题
  5. 创新!谷歌大改Transformer注意力
  6. Msql自学日志01---基本操作增,删,改,查,建
  7. “嘲羊群众”词条视频惹怒粉丝 百度知道向张艺兴道歉...
  8. 面试中 项目遇见的难点答案_5月6日周一晚八点CCtalk直播2019年江苏省考公务员面试冲刺类型题难点解析突破举一反三...
  9. 托管项目到github
  10. Salesforce和SAP HANA的元数据访问加速
  11. 智能合约语言Solidity教程系列2 - 地址类型介绍
  12. VMware网络配置基础
  13. 环境化学试题及答案大全
  14. 我的大脑越来越喜欢那些碎片化的、不用思考的文章了!
  15. iOS App上架遇到的错误(ERRORITMS-90096: )
  16. 电脑的学名为电子计算机
  17. nodemon:运行提示错误:无法加载文件 xxxx
  18. Qt开发北斗定位系统融合百度地图API及Qt程序打包发布
  19. Modern Data Stack 下 Data Integration 生态(下)
  20. 【ICCV19 超分辨】Deep SR-ITM: Joint Learning of Super-Resolution and Inverse Tone-Mapping for 4K UHD HDR

热门文章

  1. 验证控件jQuery Validation Engine调用外部函数验证
  2. 排查 CI Unable to load the requested file
  3. [Git]Git远程仓库
  4. 通过命令行编译器来编译运行程序
  5. 一个站点存在多个web.config时如何管理?
  6. 【UDP通过多线程改进,在一个窗口中同时接收又发送】
  7. 【Leetcode_easy】1078. Occurrences After Bigram
  8. html/js/css资源
  9. EDS之后的block
  10. Spring Data JPA教程,第一部分: Configuration(翻译)