PyTorch的主要组成模块
数据读入
PyTorch数据读入是通过Dataset+Dataloader的方式完成的,Dataset定义好数据的格式和数据变换形式,Dataloader用iterative的方式不断读入批次数据。
我们可以定义自己的Dataset类来实现灵活的数据读取,定义的类需要继承PyTorch自身的Dataset类。主要包含三个函数:
__init__
: 用于向类中传入外部参数,同时定义样本集__getitem__
: 用于逐个读取样本集合中的元素,可以进行一定的变换,并将返回训练/验证所需的数据__len__
: 用于返回数据集的样本数
模型构建
神经网络构建
PyTorch中神经网络构造一般是基于 Module 类的模型来完成的,它让模型构造更加灵活。
Module 类是 nn 模块里提供的一个模型构造类,是所有神经⽹网络模块的基类,我们可以继承它来定义我们想要的模型。下面继承 Module 类构造多层感知机。这里定义的 MLP 类重载了 Module 类的 init 函数和 forward 函数。它们分别用于创建模型参数和定义前向计算。前向计算也即正向传播。
import torch
from torch import nnclass MLP(nn.Module):# 声明带有模型参数的层,这里声明了两个全连接层def __init__(self, **kwargs):# 调用MLP父类Block的构造函数来进行必要的初始化。这样在构造实例例时还可以指定其他函数super(MLP, self).__init__(**kwargs)self.hidden = nn.Linear(784, 256)self.act = nn.ReLU()self.output = nn.Linear(256,10)# 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出def forward(self, x):o = self.act(self.hidden(x))return self.output(o)
我们可以实例化 MLP 类得到模型变量 net 。下⾯的代码初始化 net 并传入输⼊数据 X 做一次前向计算。其中, net(X) 会调用 MLP 继承⾃自 Module 类的 call 函数,这个函数将调⽤用 MLP 类定义的forward 函数来完成前向计算。
神经网络中常见的层
- 不含模型参数的层
- 含模型参数的层
- 二维卷积层
- 池化层
模型示例
一个神经网络的典型训练过程如下:
- 定义包含一些可学习参数(或者叫权重)的神经网络
- 在输入数据集上迭代
- 通过网络处理输入
- 计算 loss (输出和正确答案的距离)
- 将梯度反向传播给网络的参数
- 更新网络的权重,一般使用一个简单的规则:`weight = weight - learning_rate * gradient
代码:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()# 输入图像channel:1;输出channel:6;5x5卷积核self.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)# an affine operation: y = Wx + bself.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):# 2x2 Max poolingx = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))# 如果是方阵,则可以只使用一个数字进行定义x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, self.num_flat_features(x))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):size = x.size()[1:] # 除去批处理维度的其他所有维度num_features = 1for s in size:num_features *= sreturn num_featuresnet = Net()
print(net)
损失函数
在深度学习广为盛行的今天,我们可以在脑海里清晰的知道,一个模型要可以达到很好的效果需要学习,也就是我们常说的训练。一个好的训练离不开优质的负反馈,这里的损失函数就是模型的负反馈。
所以在PyTorch中,损失函数是必不可少的。它是数据输入到模型当中,产生的结果与真实标签的评价指标,我们的模型可以按照损失函数的目标来做出改进。
下面我们将开始探索pytorch的所拥有的损失函数。这里将列出PyTorch中常用的损失函数(一般通过torch.nn调用),并详细介绍每个损失函数的功能介绍、数学公式和调用代码。当然,PyTorch的损失函数还远不止这些,在解决实际问题的过程中需要进一步探索、借鉴现有工作,或者设计自己的损失函数。
Pytorch优化器
什么是优化器
深度学习的目标是通过不断改变网络参数,使得参数能够对输入做各种非线性变换拟合输出,本质上就是一个函数去寻找最优解,只不过这个最优解使一个矩阵,而如何快速求得这个最优解是深度学习研究的一个重点,以经典的resnet-50为例,它大约有2000万个系数需要进行计算,那么我们如何计算出来这么多的系数,有以下两种方法:
- 第一种是最直接的暴力穷举一遍参数,这种方法的实施可能性基本为0,堪比愚公移山plus的难度。
- 为了使求解参数过程更加快,人们提出了第二种办法,即就是是BP+优化器逼近求解。
因此,优化器就是根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值,使得模型输出更加接近真实标签。。
Pytorch提供的优化器
Pytorch很人性化的给我们提供了一个优化器的库torch.optim,在这里面给我们提供了十种优化器。
- torch.optim.ASGD
- torch.optim.Adadelta
- torch.optim.Adagrad
- torch.optim.Adam
- torch.optim.AdamW
- torch.optim.Adamax
- torch.optim.LBFGS
- torch.optim.RMSprop
- torch.optim.Rprop
- torch.optim.SGD
- torch.optim.SparseAdam
而以上这些优化算法均继承于Optimizer
训练和评估
完成了上述设定后就可以加载数据开始训练模型了。首先应该设置模型的状态:如果是训练状态,那么模型的参数应该支持反向传播的修改;如果是验证/测试状态,则不应该修改模型参数。在PyTorch中,模型的状态设置非常简便,如下的两个操作二选一即可:
model.train() # 训练状态
model.eval() # 验证/测试状态
我们前面在DataLoader构建完成后介绍了如何从中读取数据,在训练过程中使用类似的操作即可,区别在于此时要用for循环读取DataLoader中的全部数据。
for data, label in train_loader:
之后将数据放到GPU上用于后续计算,此处以.cuda()为例
data, label = data.cuda(), label.cuda()
开始用当前批次数据做训练时,应当先将优化器的梯度置零:
optimizer.zero_grad()
之后将data送入模型中训练:
output = model(data)
根据预先定义的criterion计算损失函数:
loss = criterion(output, label)
将loss反向传播回网络:
loss.backward()
使用优化器更新模型参数:
optimizer.step()
这样一个训练过程就完成了,后续还可以计算模型准确率等指标,这部分会在下一节的图像分类实战中加以介绍。
验证/测试的流程基本与训练过程一致,不同点在于:
- 需要预先设置torch.no_grad,以及将model调至eval模式
- 不需要将优化器的梯度置零
- 不需要将loss反向回传到网络
- 不需要更新optimizer
一个完整的训练过程如下所示:
def train(epoch):model.train()train_loss = 0for data, label in train_loader:data, label = data.cuda(), label.cuda()optimizer.zero_grad()output = model(data)loss = criterion(label, output)loss.backward()optimizer.step()train_loss += loss.item()*data.size(0)train_loss = train_loss/len(train_loader.dataset)print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))
对应的,一个完整的验证过程如下所示:
def val(epoch): model.eval()val_loss = 0with torch.no_grad():for data, label in val_loader:data, label = data.cuda(), label.cuda()output = model(data)preds = torch.argmax(output, 1)loss = criterion(output, label)val_loss += loss.item()*data.size(0)running_accu += torch.sum(preds == label.data)val_loss = val_loss/len(val_loader.dataset)print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, val_loss))
PyTorch的主要组成模块相关推荐
- 【Pytorch学习笔记2】Pytorch的主要组成模块
个人笔记,仅用于个人学习与总结 感谢DataWhale开源组织提供的优秀的开源Pytorch学习文档:原文档链接 本文目录 1. Pytorch的主要组成模块 1.1 完成深度学习的必要部分 1.2 ...
- 【小白学习PyTorch教程】三、Pytorch中的NN模块并实现第一个神经网络模型
「@Author:Runsen」 在PyTorch建立模型,主要是NN模块. nn.Linear nn.Linear是创建一个线性层.这里需要将输入和输出维度作为参数传递. linear = nn.L ...
- 【语义分割系列】deeplabv3相关知识点以及pytorch实现(ASSP模块)
1.deeplabv3是在deeplabv2的基础上修改ASSP模块.deeplabv1加上个ASSP就是deeplabv2.deeplabv1相关介绍看我的博文:https://blog.csdn. ...
- pytorch中nn.moudle模块
nn.moudle是所有卷积神经网络的基类,相信一定非常困扰大家的学习,特此出一期详细讲解它,如果大家觉得通透了,那笔者的存在就会有了意义 1.__init(self)__: def __init__ ...
- 使用pytorch动手实现LSTM模块
原文 import torch import torch.nn as nn from torch.nn import Parameter from torch.nn import init from ...
- PyTorch基础之激活函数模块中Sigmoid、Tanh、ReLU、LeakyReLU函数讲解(附源码)
需要源码请点赞关注收藏后评论区留言私信~~~ 激活函数是神经网络中的重要组成部分.在多层神经网络中,上层节点的输出和下层节点的输入之间有一个函数关系.如果这个函数我们设置为非线性函数,深层网络的表达能 ...
- PyTorch 入坑七:模块与nn.Module学习
PyTorch 入坑七 模型创建概述 PyTorch中的模块 torch模块 torch.Tensor模块 torch.sparse模块 torch.cuda模块 torch.nn模块 torch.n ...
- Pytorch实现CT图像正投影(FP)与反投影(FBP)的模块
FP/FBP Modules 有关CT图像重建或图像处理的训练任务有时需要数据在投影域和图像域上进行变换,为了能使梯度在投影域和图像域之间进行传播,需要实现Forward Projection与Bac ...
- Pytorch学习(二)—— nn模块
torch.nn nn.Module 常用的神经网络相关层 损失函数 优化器 模型初始化策略 nn和autograd nn.functional nn和autograd的关系 hooks简介 模型保存 ...
- 【 线性回归 Linear-Regression torch模块实现与源码详解 深度学习 Pytorch笔记 B站刘二大人(4/10)】
torch模块实现与源码详解 深度学习 Pytorch笔记 B站刘二大人 深度学习 Pytorch笔记 B站刘二大人(4/10) 介绍 至此开始,深度学习模型构建的预备知识已经完全准备完毕. 从本章开 ...
最新文章
- Apache2.4部署python3.6+django2.0项目
- pip更换国内镜像源
- android 代码打开权限,android开发权限询问的示例代码
- Nuget Tips
- VTK:baking烘焙阴影贴图用法实战
- brew mysql 无法启动_MAC OSX brew 升级 mysql5.6到5.7无法启动的问题
- 为什么祖国没有农历生日? | 今日最佳
- JS循环执行函数setInterval
- python web 文件管理_我的第一个python web开发框架(23)——代码版本控制管理与接口文档...
- mysql server启动_mysql的启动方式
- 在UNITY中按钮的高亮用POINT灯实现,效果别具一番风味
- Atitit. .net c# web 跟客户端winform 的ui控件结构比较
- [UESTC SC T3] 蛋糕
- Windows 10 喇叭红叉 重装驱动无效 点击喇叭显示无插座信息
- matlab GUI 绘图 坐标轴控件
- 【高分一号影像数据命名规则】
- SM2 SM3 SM4加密java实现
- 绿联扩展坞拆解_绿联最新豪华版3A1C四口多功能扩展坞深度拆解,用料满满
- decodeString
- matlab2017b的破解激活
热门文章
- 2016百度编程题:裁减网格纸
- 蓝牙技术及其系统原理
- 2021年安全员-C证(山东省-2020版)考试报名及安全员-C证(山东省-2020版)考试APP
- Linux线程(3)——pthread_cancel()取消一个线程
- 推荐系统中的i2i,u2u2i,u2i2i和u2tag2i 是什么意思?
- 云原生-DevOps-环境搭建
- ImageButton带有文字(或者Button中带有文字与图片)
- 计算互相关注的SQL怎么写
- mysql tode_【20201007】Python操作MySQL数据库
- 用python实现 斐波那契数列。 3种方法