使用pytorch训练图像分类模型需要加载数据集,关于train_set,train_loader的写法介绍如下。

首先参考MNIST数据集的train_set,train_loader的写法。

# 训练集
trainset = torchvision.datasets.MNIST(root='./datasets/ch08/pytorch',     # 选择数据的根目录train=True,download=True,    # 不从网络上download图片transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)
# 测试集
testset = torchvision.datasets.MNIST(root='./datasets/ch08/pytorch',     # 选择数据的根目录train=False,download=True,    # 不从网络上download图片transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)

MNIST直接使用的是torchvision.dataset.MNIST类加载数据集,train_loader的使用都是一样的。所以核心问题是如何书写这个Dataset类。

Dataset的核心是需要返回包含所有数据集的列表,以及每个数据集对应的标签。
我的数据集存放方式为:train_print目录下有10个文件夹,里面有10个类别的数据集。

比如打开‘01’文件夹如下所示。

我采用的思路是,将这10个文件夹所有图片的绝对路径存放到一个dataset-text.txt文件中。下一步遍历该dataset-text.txt文件。这个文件长这样。

很容易能够发现标签存放在倒数第二个位置啦。每张图片的标签就存放在文件绝对路径中,也就是文件夹‘01’的文件名。
因此,有了这个txt文件,我们遍历的时候,读取一行路径,将该路径存放到images中,把倒数第二个位置上的标签存放到labels中,核心工作就完成了。Dataset类的编写没有固定的,但核心都是不管你通过什么方法,把数据集和对应的标签存放到images,labels中。如果你仔细看到这里还不明白,就私信我。因为说实话这个问题我也困扰了很久,因为我也是入门。

class MyDataset(Dataset):def __init__(self, dataset_path, num_class, transforms=None):super(MyDataset,self).__init__()images = []labels = []txt_path = self.dataset2txt(dataset_path,num_class)with open(txt_path, 'r') as f:for line in f:if int(line.split('/')[-2]) > args.num_class:breakline = line.strip('\n')images.append(line)labels.append(int(line.split('/')[-2]))self.images = imagesself.labels = labelsself.transforms = transformsdef __getitem__(self, index):image = Image.open(self.images[index])label = self.labels[index]if self.transforms is not None:image = self.transforms(image)return image, labeldef __len__(self):return len(self.labels)def dataset2txt(self,dataset_path, class_num=None):'''transform dataset into a txt file which contain every Image:param In_path: path of dataset:param num_class: classes:return:path of txt file'''# 1.创建文件# 一下两行代码目的是与数据集同级目录下新建dataset-text.txt文件txt_path = os.path.abspath(os.path.dirname(dataset_path))txt_path = txt_path + '/dataset-text.txt'# 删除已经存在的文件,要保证每次操作的文件是一个空的txt文件if os.path.exists(txt_path):os.remove(txt_path)f = open(txt_path, 'w')f.close()# 2.写入文件# 打开数据集,将主目录下所有文件夹放入list中dirs = os.listdir(dataset_path)# 将文件夹按从小到大排序,文件夹的名字是按照数字命名的,01,02,03...dirs.sort()# 打开第二级每个文件夹,将并将每个文件的绝对路径写入到上面新建的txt文件for i, dir in enumerate(dirs):file = os.path.abspath(dataset_path) + '/' + dirs[i]DIRLIST = os.listdir(file)for j, d in enumerate(DIRLIST):content = file + '/' + d + '\n'# 每次执行前一定确保要写入的文件是空的with open(txt_path, 'a') as f:f.write(content)return txt_path

剩下的就是书写train_set,train_loader.

train_set = MyDataset(dataset_path=args.dataset, num_class=args.num_class, transforms=transform)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)

卷积神经网络-加载数据集相关推荐

  1. [转载] 卷积神经网络做mnist数据集识别

    参考链接: 卷积神经网络在mnist数据集上的应用 Python TensorFlow是一个非常强大的用来做大规模数值计算的库.其所擅长的任务之一就是实现以及训练深度神经网络. 在本教程中,我们将学到 ...

  2. 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)

    目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...

  3. python_torch_加载数据集_构建模型_构建训练循环_保存和调用训练好的模型

    以下代码均来自bilibili:[适用于初学者的Pytorch编程教学] 以下为完整代码,复制即可运行. import torch import time import json import tor ...

  4. R语言构建xgboost模型:使用xgb.DMatrix保存、加载数据集、使用getinfo函数抽取xgb.DMatrix结构中的数据

    R语言构建xgboost模型:使用xgb.DMatrix保存.加载数据集.使用getinfo函数抽取xgb.DMatrix结构中的数据 目录

  5. pytorch 入门学习加载数据集-8

    pytorch 入门学习加载数据集 import torch import numpy as np import torchvision import numpy as np from torch.u ...

  6. Pytorch深度学习(五):加载数据集以及mini-batch的使用

    Pytorch深度学习(五):加载数据集以及mini-batch的使用 参考B站课程:<PyTorch深度学习实践>完结合集 传送门:<PyTorch深度学习实践>完结合集 一 ...

  7. pytorch创建自己的Dataset加载数据集

    文章目录 创建一个类并继承torch.utils.data.dataset.Datase类 创建__getitem__方法 加载数据集 创建一个类并继承torch.utils.data.dataset ...

  8. python加载数据集卡住 dmesg报错Nvidia xid31

    在一次运维中发现客户加载数据集会卡住,物理机总共是4块显卡.使用k8s独占显卡进行任务训练,其中有三块显卡在跑任务训练加载数据集时卡住,同时查看dmesg报错 (xid 31). [Tue Apr 1 ...

  9. URLError: <urlopen error [Errno 11004] getaddrinfo failed>关于使用seabron加载数据集报错的解决方案

    在使用seaborn加载内置数据集时,出现以下错误: dataset = sns.load_dataset("iris") dataset.head() 解决方案: 一.原因需要连 ...

  10. FlexCell控件初始化以及加载数据集[原创]

    '================================写在之前的话 抱歉,一直没有时间,所以FlexCell作者给我的几种加载数据集方法的代码一直没有发出来. 同时再次感谢FlexCell ...

最新文章

  1. BasicLSTMCell中num_units参数解释
  2. 【 Notes 】WLLS Algorithm of TOA - Based Positioning (include the two - step WLS estimator)
  3. python实现表格_零基础小白怎么用Python做表格?
  4. linux find d,Linux find命令傻瓜入门
  5. 6.29 Vue 第二天 学习笔记
  6. leetcode-最大子序和(动态规划讲解)
  7. 30销售是让用户开心的购买和消费
  8. 二叉搜索树(HDU3791)
  9. 小米电视4a刷鸿蒙,小米电视4A 删除内置应用及其去广告攻略
  10. zabbix登陆拒绝报没有权限
  11. 如何在地址栏显示图标
  12. 西瓜视频 iOS 播放器技术重构
  13. html 倒计时小工具
  14. matlab解常微分方程——符号解法
  15. 简历学习课程:1-9课
  16. VSCode插件,TODO标记
  17. c语言编程仓鼠吃豆子,动态规划之仓鼠吃豆子 - osc_8quu62cg的个人空间 - OSCHINA - 中文开源技术交流社区...
  18. Opencv学习——LSD直线检测
  19. 12.利用API抓取数据
  20. 安卓系统无线投屏到win10

热门文章

  1. 域猫(域名分享平台)
  2. 【信号处理第十二章】转置卷积
  3. My thoughts after NOIP 2018(2)
  4. mysql中计算日期整数差
  5. tensorflow中slim模块api介绍
  6. 第一次冲刺-个人工作总结06
  7. python基础(1)——简介与安装
  8. 字符串部分函数的实现
  9. mongodb数据库命令操作(转)
  10. selenium窗口截图操作