训练一个分类器

关于数据?
一般情况下处理图像、文本、音频和视频数据时,可以使用标准的Python包来加载数据到一个numpy数组中。 然后把这个数组转换成 torch.*Tensor。

图像可以使用 Pillow, OpenCV
音频可以使用 scipy, librosa
文本可以使用原始Python和Cython来加载,或者使用 NLTK或 SpaCy 处理
特别的,对于图像任务,我们创建了一个包 torchvision,它包含了处理一些基本图像数据集的方法。这些数据集包括 Imagenet, CIFAR10, MNIST 等。除了数据加载以外,torchvision 还包含了图像转换器, torchvision.datasets 和 torch.utils.data.DataLoader。

torchvision包不仅提供了巨大的便利,也避免了代码的重复。

在这个教程中,我们使用CIFAR10数据集,它有如下10个类别 :‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’。CIFAR-10的图像都是 3x32x32大小的,即,3颜色通道,32x32像素。

训练一个图像分类器
依次按照下列顺序进行:

使用torchvision加载和归一化CIFAR10训练集和测试集
定义一个卷积神经网络
定义损失函数
在训练集上训练网络
在测试集上测试网络
读取和归一化 CIFAR10
使用torchvision可以非常容易地加载CIFAR10。

import torch
import torchvision
import torchvision.transforms as transforms

torchvision的输出是[0,1]的PILImage图像,我们把它转换为归一化范围为[-1, 1]的张量。

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

我们展示一些训练图像。

import matplotlib.pyplot as plt
import numpy as np# 展示图像的函数def imshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))# 获取随机数据
dataiter = iter(trainloader)
images, labels = dataiter.next()# 展示图像
imshow(torchvision.utils.make_grid(images))
# 显示图像标签
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

定义一个卷积神经网络
从之前的神经网络一节复制神经网络代码,并修改为输入3通道图像。

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()

定义损失函数和优化器

import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

训练网路
有趣的时刻开始了。 我们只需在数据迭代器上循环,将数据输入给网络,并优化。

for epoch in range(2):  # 多批次循环running_loss = 0.0for i, data in enumerate(trainloader, 0):# 获取输入inputs, labels = data# 梯度置0optimizer.zero_grad()# 正向传播,反向传播,优化outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 打印状态信息running_loss += loss.item()if i % 2000 == 1999:    # 每2000批次打印一次print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')

在测试集上测试网络
我们在整个训练集上进行了2次训练,但是我们需要检查网络是否从数据集中学习到有用的东西。 通过预测神经网络输出的类别标签与实际情况标签进行对比来进行检测。 如果预测正确,我们把该样本添加到正确预测列表。 第一步,显示测试集中的图片并熟悉图片内容。

dataiter = iter(testloader)
images, labels = dataiter.next()# 显示图片
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

让我们看看神经网络认为以上图片是什么。

outputs = net(images)

输出是10个标签的能量。 一个类别的能量越大,神经网络越认为它是这个类别。所以让我们得到最高能量的标签。

_, predicted = torch.max(outputs, 1)print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]for j in range(4)))

结果看来不错。

接下来让看看网络在整个测试集上的结果如何。

correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

结果看起来不错,至少比随机选择要好,随机选择的正确率为10%。 似乎网络学习到了一些东西。

在识别哪一个类的时候好,哪一个不好呢?

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

在GPU上训练

把一个神经网络移动到GPU上训练就像把一个Tensor转换GPU上一样简单。并且这个操作会递归遍历有所模块,并将其参数和缓冲区转换为CUDA张量。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 确认我们的电脑支持CUDA,然后显示CUDA信息:print(device)

本节的其余部分假定device是CUDA设备。

然后这些方法将递归遍历所有模块并将模块的参数和缓冲区 转换成CUDA张量:

net.to(device)

记住:inputs, targets 和 images 也要转换。

    inputs, labels = inputs.to(device), labels.to(device)

为什么我们没注意到GPU的速度提升很多?那是因为网络非常的小。

实践: 尝试增加你的网络的宽度(第一个nn.Conv2d的第2个参数,第二个nn.Conv2d的第一个参数,它们需要是相同的数字),看看你得到了什么样的加速。

pytorch学习(四)相关推荐

  1. 【深度学习】翻译:60分钟入门PyTorch(四)——训练一个分类器

    前言 原文翻译自:Deep Learning with PyTorch: A 60 Minute Blitz 翻译:林不清(https://www.zhihu.com/people/lu-guo-92 ...

  2. PyTorch框架学习四——计算图与动态图机制

    PyTorch框架学习四--计算图与动态图机制 一.计算图 二.动态图与静态图 三.torch.autograd 1.torch.autograd.backward() 2.torch.autogra ...

  3. PyTorch学习笔记(四):PyTorch基础实战

    PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...

  4. Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用

    Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用 官方参考链接 1. 损失函数 (1)BCELoss 二分类 计算公式 小例子: (2) BCEWithLogitsLoss ...

  5. pytorch 学习笔记目录

    1 部分内容 pytorch笔记 pytorch模型中的parameter与buffer_刘文巾的博客-CSDN博客 pytorch学习笔记 torchnn.ModuleList_刘文巾的博客-CSD ...

  6. Pytorch学习笔记总结

    往期Pytorch学习笔记总结: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 Pytorch系列目录: PyTorch学习笔记( ...

  7. PyTorch学习笔记(七):PyTorch可视化

    PyTorch可视化 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一) ...

  8. PyTorch学习笔记(六):PyTorch进阶训练技巧

    PyTorch实战:PyTorch进阶训练技巧 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: P ...

  9. PyTorch学习笔记(五):模型定义、修改、保存

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  10. PyTorch学习笔记(三):PyTorch主要组成模块

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

最新文章

  1. 2020-11-22(工作集与常驻集)
  2. MySQL数据库中导入导出方法以及工具介绍
  3. oc5480十六进制_oc 数据类型 | 学步园
  4. ARM 之九 Cortex-M/R 内核启动过程 / 程序启动流程(基于ARMCC、Keil)
  5. cad渐变线怎么画_怎么画压力线和支撑线
  6. 为什么强烈推荐你使用单表查询?(续篇)
  7. python xml.dom模块解析xml
  8. 人生永无止境的意思是什么_《永无止境》中艾迪真的成功改进了NZT吗?
  9. 重构:如何去掉代码中的S味
  10. 小D课堂 - 新版本微服务springcloud+Docker教程_5-01分布式核心知识之熔断、降级
  11. HTTP请求时POST参数到底应该怎么传?
  12. [词根词缀]cre/cred/crit/cult字根由来及词源C的故事
  13. html 怎么让他变成圆角,html让图片变圆角
  14. lzg_ad:下载资料必读
  15. Floyd-Warshall算法过程中矩阵计算方法—十字交叉法(转)
  16. 杂谈(20210405)
  17. Element.closest() 兼容IE
  18. 判定两个点是否在一条直线的同一侧_帮帮学堂丨高中物理的常用方法、题型特点及应用注意点!建议收藏!...
  19. xcode和macos对应版本参考
  20. 阿里最全面试116题整理

热门文章

  1. 初中能自学c语言么,我是初中生想自学计算机编程,需要什么
  2. 百度惊现区块链项目“莱茨狗”,俄罗斯财长称不会允许自由地交易数字货币 | 区块链日报
  3. 计算机电脑制作社团,电脑制作小社团活动计划
  4. autojs免root脚本引擎编写的本地音乐播放器源代码开源
  5. 硬核干货:葡萄城 SpreadJS 前端表格技术分享
  6. Android中MediaStore的介绍
  7. 车站计算机的在线运行模式,车站计算机联锁仿真设计(一).pdf
  8. 医院计算机,医院计算机五大应用系统
  9. Minute Commander:像百战天虫的大海战
  10. 简单垃圾邮件过滤系统