卷积神经网络-加载数据集
使用pytorch训练图像分类模型需要加载数据集,关于train_set,train_loader的写法介绍如下。
首先参考MNIST数据集的train_set,train_loader的写法。
# 训练集
trainset = torchvision.datasets.MNIST(root='./datasets/ch08/pytorch', # 选择数据的根目录train=True,download=True, # 不从网络上download图片transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)
# 测试集
testset = torchvision.datasets.MNIST(root='./datasets/ch08/pytorch', # 选择数据的根目录train=False,download=True, # 不从网络上download图片transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)
MNIST直接使用的是torchvision.dataset.MNIST类加载数据集,train_loader的使用都是一样的。所以核心问题是如何书写这个Dataset类。
Dataset的核心是需要返回包含所有数据集的列表,以及每个数据集对应的标签。
我的数据集存放方式为:train_print目录下有10个文件夹,里面有10个类别的数据集。
比如打开‘01’文件夹如下所示。
我采用的思路是,将这10个文件夹所有图片的绝对路径存放到一个dataset-text.txt文件中。下一步遍历该dataset-text.txt文件。这个文件长这样。
很容易能够发现标签存放在倒数第二个位置啦。每张图片的标签就存放在文件绝对路径中,也就是文件夹‘01’的文件名。
因此,有了这个txt文件,我们遍历的时候,读取一行路径,将该路径存放到images中,把倒数第二个位置上的标签存放到labels中
,核心工作就完成了。Dataset类的编写没有固定的,但核心都是不管你通过什么方法,把数据集和对应的标签存放到images,labels中。如果你仔细看到这里还不明白,就私信我。因为说实话这个问题我也困扰了很久,因为我也是入门。
class MyDataset(Dataset):def __init__(self, dataset_path, num_class, transforms=None):super(MyDataset,self).__init__()images = []labels = []txt_path = self.dataset2txt(dataset_path,num_class)with open(txt_path, 'r') as f:for line in f:if int(line.split('/')[-2]) > args.num_class:breakline = line.strip('\n')images.append(line)labels.append(int(line.split('/')[-2]))self.images = imagesself.labels = labelsself.transforms = transformsdef __getitem__(self, index):image = Image.open(self.images[index])label = self.labels[index]if self.transforms is not None:image = self.transforms(image)return image, labeldef __len__(self):return len(self.labels)def dataset2txt(self,dataset_path, class_num=None):'''transform dataset into a txt file which contain every Image:param In_path: path of dataset:param num_class: classes:return:path of txt file'''# 1.创建文件# 一下两行代码目的是与数据集同级目录下新建dataset-text.txt文件txt_path = os.path.abspath(os.path.dirname(dataset_path))txt_path = txt_path + '/dataset-text.txt'# 删除已经存在的文件,要保证每次操作的文件是一个空的txt文件if os.path.exists(txt_path):os.remove(txt_path)f = open(txt_path, 'w')f.close()# 2.写入文件# 打开数据集,将主目录下所有文件夹放入list中dirs = os.listdir(dataset_path)# 将文件夹按从小到大排序,文件夹的名字是按照数字命名的,01,02,03...dirs.sort()# 打开第二级每个文件夹,将并将每个文件的绝对路径写入到上面新建的txt文件for i, dir in enumerate(dirs):file = os.path.abspath(dataset_path) + '/' + dirs[i]DIRLIST = os.listdir(file)for j, d in enumerate(DIRLIST):content = file + '/' + d + '\n'# 每次执行前一定确保要写入的文件是空的with open(txt_path, 'a') as f:f.write(content)return txt_path
剩下的就是书写train_set,train_loader.
train_set = MyDataset(dataset_path=args.dataset, num_class=args.num_class, transforms=transform)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
卷积神经网络-加载数据集相关推荐
- [转载] 卷积神经网络做mnist数据集识别
参考链接: 卷积神经网络在mnist数据集上的应用 Python TensorFlow是一个非常强大的用来做大规模数值计算的库.其所擅长的任务之一就是实现以及训练深度神经网络. 在本教程中,我们将学到 ...
- 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)
目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...
- python_torch_加载数据集_构建模型_构建训练循环_保存和调用训练好的模型
以下代码均来自bilibili:[适用于初学者的Pytorch编程教学] 以下为完整代码,复制即可运行. import torch import time import json import tor ...
- R语言构建xgboost模型:使用xgb.DMatrix保存、加载数据集、使用getinfo函数抽取xgb.DMatrix结构中的数据
R语言构建xgboost模型:使用xgb.DMatrix保存.加载数据集.使用getinfo函数抽取xgb.DMatrix结构中的数据 目录
- pytorch 入门学习加载数据集-8
pytorch 入门学习加载数据集 import torch import numpy as np import torchvision import numpy as np from torch.u ...
- Pytorch深度学习(五):加载数据集以及mini-batch的使用
Pytorch深度学习(五):加载数据集以及mini-batch的使用 参考B站课程:<PyTorch深度学习实践>完结合集 传送门:<PyTorch深度学习实践>完结合集 一 ...
- pytorch创建自己的Dataset加载数据集
文章目录 创建一个类并继承torch.utils.data.dataset.Datase类 创建__getitem__方法 加载数据集 创建一个类并继承torch.utils.data.dataset ...
- python加载数据集卡住 dmesg报错Nvidia xid31
在一次运维中发现客户加载数据集会卡住,物理机总共是4块显卡.使用k8s独占显卡进行任务训练,其中有三块显卡在跑任务训练加载数据集时卡住,同时查看dmesg报错 (xid 31). [Tue Apr 1 ...
- URLError: <urlopen error [Errno 11004] getaddrinfo failed>关于使用seabron加载数据集报错的解决方案
在使用seaborn加载内置数据集时,出现以下错误: dataset = sns.load_dataset("iris") dataset.head() 解决方案: 一.原因需要连 ...
- FlexCell控件初始化以及加载数据集[原创]
'================================写在之前的话 抱歉,一直没有时间,所以FlexCell作者给我的几种加载数据集方法的代码一直没有发出来. 同时再次感谢FlexCell ...
最新文章
- BasicLSTMCell中num_units参数解释
- 【 Notes 】WLLS Algorithm of TOA - Based Positioning (include the two - step WLS estimator)
- python实现表格_零基础小白怎么用Python做表格?
- linux find d,Linux find命令傻瓜入门
- 6.29 Vue 第二天 学习笔记
- leetcode-最大子序和(动态规划讲解)
- 30销售是让用户开心的购买和消费
- 二叉搜索树(HDU3791)
- 小米电视4a刷鸿蒙,小米电视4A 删除内置应用及其去广告攻略
- zabbix登陆拒绝报没有权限
- 如何在地址栏显示图标
- 西瓜视频 iOS 播放器技术重构
- html 倒计时小工具
- matlab解常微分方程——符号解法
- 简历学习课程:1-9课
- VSCode插件,TODO标记
- c语言编程仓鼠吃豆子,动态规划之仓鼠吃豆子 - osc_8quu62cg的个人空间 - OSCHINA - 中文开源技术交流社区...
- Opencv学习——LSD直线检测
- 12.利用API抓取数据
- 安卓系统无线投屏到win10