PyTorch数据处理工具

概述

PyTorch主要数据处理工具:

  • Dataset:是一个抽象类,其他数据集需要继承这个类,并且覆写其中的两个方法(getitem_、len)。
  • DataLoader:定义一个新的迭代器,实现批量(batch)读取,打乱数据(shuffle)并提供并行加速等功能。
  • random_split:把数据集随机拆分为给定长度的非重叠的新数据集。
  • *sampler:多种采样函数。

视觉处理工具包torchvision包括四个类,功能如下:

  • datasets:提供常用的数据集加载,设计上都是继承自torch.utils.data.Dataset,主要包括MMIST、CIFAR10/100、ImageNet和COCO等。
  • models:提供深度学习中各种经典的网络结构以及训练好的模型(如果选择pretrained=True),包括AlexNet、VGG系列、ResNet系列、Inception系列等。
  • transforms:常用的数据预处理操作,主要包括对Tensor及PIL Image对象的操作。
  • utils:含两个函数,一个是make_grid,它能将多张图片拼接在一个网格中;另一个是save_img,它能将Tensor保存成图片。

utils.data

utils.data包括Dataset和DataLoader。torch.utils.data.Dataset为抽象类。自定义数据集需要继承这个类,并实现两个函数,一个是__len__,另一个是__getitem__,前者提供数据的大小(size),后者通过给定索引获取数据和标签。__getitem__一次只能获取一个数据,所以需要通过torch.utils.data.DataLoader来定义一个新的迭代器,实现batch读取。

import torch
from torch.utils import data
import numpy as np#定义获取数据集的类
#该类继承基类Dataset,自定义一个数据集及对应标签。
class TestDataset(data.Dataset):def __init__(self):self.Data=np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])#一些由二维向量表示的数据集self.Label=np.asarray([0,1,0,1,2])#这是数据集对应的标签def __getitem__(self,index):#把Numpy转换为Tensortxt=torch.from_numpy(self.Data[index])label=torch.tensor(self.Label[index])return txt,labeldef __len__(self):return len(self.Data)
#获取数据集中数据。
Test=TestDataset()
print(Test[2])#相当于调用__getitem__(2)
print(Test.__len__())

结果

(tensor([2, 1], dtype=torch.int32), tensor(0, dtype=torch.int32))
5

批量处理可使用DataLoader。其格式为:

data.DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=<function default_collate at 0x7f108ee01620>,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,
)

参数说明:

  • dataset:加载的数据集。
  • batch_size:批大小。
  • shuffle:是否将数据打乱。
  • sampler:样本抽样。
  • num_workers:使用多进程加载的进程数,0代表不使用多进程。
  • collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可。
  • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些。
  • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃。

实例:

test_loader = data.DataLoader(Test,batch_size=2,shuffle=False,num_workers=0)
for i,traindata in enumerate(test_loader):print('i:',i)Data,Label=traindataprint('data:',Data)print('Label:',Label)

结果:

i: 0
data: tensor([[1, 2],[3, 4]], dtype=torch.int32)
Label: tensor([0, 1], dtype=torch.int32)
i: 1
data: tensor([[2, 1],[3, 4]], dtype=torch.int32)
Label: tensor([0, 1], dtype=torch.int32)
i: 2
data: tensor([[4, 5]], dtype=torch.int32)
Label: tensor([2], dtype=torch.int32)

torchvision

torchvision有4个功能模块:model、datasets、transforms和utils,下面将重点介绍transforms及ImageFolder。

transforms

对PIL Image的常见操作:

  • Scale/Resize:调整尺寸,长宽比保持不变。
  • CenterCrop、RandomCrop、RandomSizedCrop:裁剪图片,CenterCrop和RandomCrop在crop时是固定size,RandomResizedCrop则是random size的crop。
  • Pad:填充。
  • ToTensor:把一个取值范围是[0,255]的PIL.Image转换成Tensor。形状为(H,W,C)的Numpy.ndarray转换成形状为[C,H,W],取值范围是[0,1.0]的torch.FloatTensor。
  • RandomHorizontalFlip:图像随机水平翻转,翻转概率为0.5。
  • RandomVerticalFlip:图像随机垂直翻转。
  • ColorJitter:修改亮度、对比度和饱和度。

对Tensor的操作:

  • Normalize:标准化,即,减均值,除以标准差。
  • ToPILImage:将Tensor转为PIL Image。
transforms.Compose([#将给定的 PIL.Image 进行中心切割,得到给定的 size,#size 可以是 tuple,(target_height, target_width)。#size 也可以是一个 Integer,在这种情况下,切出来的图片形状是正方形。transforms.CenterCrop(10),#切割中心点的位置随机选取transforms.RandomCrop(20, padding=0),#把一个取值范围是 [0, 255] 的 PIL.Image 或者 shape 为 (H, W, C) 的 numpy.ndarray,#转换为形状为 (C, H, W),取值范围是 [0, 1] 的 torch.FloatTensortransforms.ToTensor(),#规范化到[-1,1]transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
])

ImageFolder

当文件依据标签处于不同文件下时,如:

可以利用torchvision.datasets.ImageFolder来直接构造出dataset,代码如下:

loader = datasets.ImageFolder(path)
loader = data.DataLoader(dataset)

ImageFolder会将目录中的文件夹名自动转化成序列,当DataLoader载入时,标签自动就是整数序列了。

实例:

from torchvision import transforms,utils
from torchvision import datasets
import torch
import matplotlib.pyplot as plt
%matplotlib inlinemy_trans=transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor()
])
train_data = datasets.ImageFolder('./data/torchvision_data',transform=my_trans)
train_loader = data.DataLoader(train_data,batch_size=8,shuffle=True,)for i_batch,img in enumerate(train_loader):if i_batch == 0:print(img[1])fig = plt.figure()grid = utils.make_grid(img[0])plt.imshow(grid.numpy().transpose((1,2,0)))plt.show()utils.save_image(grid,'test01.png')break

结果:

#打开test01.png文件
from PIL import Image
Image.open('test01.png')

结果:

可视化工具

Tensorboard是Google TensorFlow的可视化工具,它可以记录训练数据、评估数据、网络结构、图像等,并且可以在web上展示,对于观察神经网络训练的过程非常有帮助。PyTorch可以采用tensorboard_logger、visdom等可视化工具,但这些方法比较复杂或不够友好。为解决这一问题,人们推出了可用于PyTorch可视化的新的更强大的工具——tensorboardX。

简介

步骤:

#导入tensorboardX,实例化SummaryWriter类,指明记录日志路径等信息。
from tensorboardX import SummaryWriter
#实例化SummaryWriter,并指明日志存放路径。在当前目录没有logs目录将自动创建。
writer = SummaryWriter(log_dir='logs')
#调用实例
writer.add_xxx()
#关闭writer
writer.close()
  • 如果是Windows环境,log_dir注意路径解析,如:
writer = SummaryWriter(log_dir=r'D:\myboard\test\logs')
  • SummaryWriter的格式为:
SummaryWriter(log_dir=None, comment='', **kwargs)
#其中comment在文件命名加上comment后缀
  • 如果不写log_dir,系统将在当前目录创建一个runs的目录。
  • 调用相应的API接口,接口一般格式为:
add_xxx(tag-name, object, iteration-number)
#即add_xxx(标签,记录的对象,迭代次数)
  • 启动tensorboard服务,cd到logs目录所在的同级目录,在命令行输入如下命令,logdir等式右边可以是相对路径或绝对路径。
tensorboard --logdir=logs --port 6006
#如果是Windows环境,要注意路径解析,如
#tensorboard --logdir=r'D:\myboard\test\logs' --port 6006

用tensorboardX可视化神经网络

import torch.nn.functional as F
import torchvision
from tensorboardX import SummaryWriter
#构建神经网络
class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.conv1 = nn.Conv2d(1,10,kernel_size=5)self.conv2 = nn.Conv2d(10,20,kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320,50)self.fc2 = nn.Linear(50,10)self.bn = nn.BatchNorm2d(20)def forward(self,x):x = F.max_pool2d(self.conv1(x),2)x = F.relu(x) + F.relu(-x)x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))x = self.bn(x)x = x.view(-1,320)x = F.relu(self.fc1(x))x = F.dropout(x,training=self.training)x = self.fc2(x)x = F.softmax(x,dim=1)return x
#把模型保存为graph
#定义输入
input = torch.rand(32,1,28,28)
#实例化神经网络
model = Net()
#将model保存为graph
with SummaryWriter(log_dir='logs',comment='Net' ) as w:w.add_graph(model,(input, ))

结果:

用tensorboardX可视化损失值

可视化损失值,需要使用add_scalar函数,这里利用一层全连接神经网络,训练一元二次函数的参数。

dtype = torch.FloatTensor
writer = SummaryWriter(log_dir='logs',comment='Linear')
np.random.seed(100)
x_train = np.linspace(-1,1,100).reshape(100,1)
y_train = 2*np.power(x_train,2) + 2 + 0.2*np.random.rand(x_train.size).reshape(100,1)
input_size = 1
output_size = 1
learning_rate = 0.01
num_epoches = 60
model = nn.Linear(input_size,output_size)criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)
for epoch in range(num_epoches):inputs = torch.from_numpy(x_train).type(dtype)targets = torch.from_numpy(y_train).type(dtype)output = model(inputs)loss = criterion(output,targets)optimizer.zero_grad()loss.backward()optimizer.step()#保存loss的数据与epoch数值writer.add_scalar('训练损失值',loss,epoch)
writer.flush()
writer.close()

训练损失值结果:

用tensorboardX可视化特征图

利用tensorboardX对特征图进行可视化,不同卷积层的特征图的抽取程度是不一样的。

import torchvision.utils as vutils
writer = SummaryWriter(log_dir='logs',comment='feature map')
img_grid = vutils.make_grid(x, normalize=True, scale_each=True, nrow=2)
net.eval()
for name, layer in net._modules.items():# 为fc层预处理xx = x.view(x.size(0), -1) if "fc" in name else xprint(x.size())x = layer(x)print(f'{name}')# 查看卷积层的特征图if 'layer' in name or 'conv' in name:x1 = x.transpose(0, 1) # C,B, H, W ---> B,C, H, Wimg_grid = vutils.make_grid(x1, normalize=True, scale_each=True, nrow=4) # normalize进行归一化处理writer.add_image(f'{name}_feature_maps', img_grid, global_step=0)

参考书籍:《Python深度学习:基于PyTorch》

print(f'{name}')
# 查看卷积层的特征图
if 'layer' in name or 'conv' in name:x1 = x.transpose(0, 1) # C,B, H, W ---> B,C, H, Wimg_grid = vutils.make_grid(x1, normalize=True, scale_each=True, nrow=4) # normalize进行归一化处理writer.add_image(f'{name}_feature_maps', img_grid, global_step=0)

参考书籍:《Python深度学习:基于PyTorch》

PyTorch数据处理工具相关推荐

  1. PyTorch数据处理工具箱

    PyTorch 数据处理工具箱 文章目录 PyTorch 数据处理工具箱 1.数据处理工具箱概述 2.utils.data 简介 2.1.自定义一个数据集 3.torchvision 简介 3.1.t ...

  2. 盘点数据处理工具,手把手教你做数据清洗和转换

    导读:原始数据本身没有用.为了使它实际有用,你需要准备它. 作者:Mars Geldard, Jonathon Manning, Paris Buttfield-Addison, Tim Nugent ...

  3. 探秘采云间:全链路数据处理工具直击传统DW/BI痛点

    采云间 近几年来,各行各业的数据增长趋势都非常明显,大数据不再是少数大企业的专属研究领域.如何在数据金矿中挖掘出宝藏.如何做好数字化运营,成为各类企业共同关注的话题.针对企业日益迫切的数据化运营需求, ...

  4. Python 数据处理工具 Pandas(上)

    序列与数据框的构造 外部数据的读取(文本文件读取.电子表格读取.数据库数据读取) 数据类型转换及描述统计 字符与日期数据的处理 数据清洗方法(重复观测处理.缺失值处理.异常值处理) 数据子集的获取 透 ...

  5. [转]开源大数据处理工具汇总

    查询引擎 一.Phoenix 贡献者::Salesforce 简介:这是一个Java中间层,可以让开发者在Apache HBase上执行SQL查询.Phoenix完全使用Java编写,代码位于GitH ...

  6. python数据处理工具-Pandas笔记

    序列与数据框的构造 Pandas模块的核心操作对象就是序列Series和数据框DataFrame序列可以理解为数据集中的一个字段数据框是指含有至少两个字段(或序列)的数据集. 构造序列 可以通过以下几 ...

  7. 1、大道至简的数据处理工具-(Microsoft Power Query入门)

    大道至简的数据处理工具-Microsoft Power Query 告别复杂的excel函数,excel VBA编程,让一切回归简单与职能. 什么样的人群适合这样的一个工具: 1.出纳.会计.统计.仓 ...

  8. 一共81个,开源大数据处理工具汇总(下)转

    作者:大数据女神-诺蓝(微信公号:dashujunvshen).本文是36大数据专稿,转载必须标明来源36大数据. 接上一部分:一共81个,开源大数据处理工具汇总(上),第二部分主要收集整理的内容主要 ...

  9. 数据分析---数据处理工具pandas(二)

    文章目录 数据分析---数据处理工具pandas(二) 一.Pandas数据结构Dataframe:基本概念及创建 1.DataFrame简介 2.创建Dataframe (1)方法一:由数组/lis ...

最新文章

  1. 不修改加密文件名的勒索软件TeslaCrypt 4.0
  2. 使用OpenVINO遇到No name 'IENetwork' in module 'openvino.inference_engine'解决
  3. 在单节点和多节点上的Hadoop设置
  4. vscode括号颜色插件_[VSCode插件推荐] Bracket Pair Colorizer: 为代码中的括号添上一抹亮色...
  5. 2019年在中国每个人都可能拥有百万元收入
  6. 【转载】MySQL innodb_table_stats表不存在的解决方法
  7. PCL中把txt文件转换成.pcd文件(很简单)
  8. python导入第三方库dlib报错解决
  9. paip.gui控件form窗体的原理实现以及easyui的新建以及编辑实现
  10. 绿坝-花季护航 官网论坛
  11. Android 修改屏幕尺寸
  12. 图音80系列车载导航/DVD分体机安装DSA
  13. Charles Error Report
  14. 菊花是哪个城市的市花1_2.html,菊花的季节作文
  15. python教你如何把自己的微信变成机器人
  16. 运行多个mysql service_同时运行多个MySQL服务器的方法
  17. c++语言程序设计教程与实验实验报告,C++程序设计课程设计实验报告—网络五子棋...
  18. 【中电十所】秋招提前批一面、二面面经
  19. 光储并网simulink仿真模型,直流微电网。 光伏系统采用扰动观察法是实现mppt控制,储能可由单独蓄电池构成
  20. Android 3.1更新后的警告

热门文章

  1. Windows更新之后,无法进入系统,重置此电脑,释放空间,再试一次
  2. 离线人脸识别SeetaFace2
  3. byte数组转string
  4. C++蠕虫病毒免疫器 (antiAutoRun)
  5. el-radio单选回显 打印预览不显示问题
  6. MATLAB 面向对象编程 APP Designer基础
  7. 分布式锁 Java常用技术方案
  8. 红帽子认证复习课程-视频分享
  9. 我的世界 苹果 android,我的世界安卓和苹果能联机吗 我的世界安卓和苹果怎么联机...
  10. Ubuntu双系统的安装(有U盘就行)