深入浅出PyTorch

模型部署及定义


使用seqtoseq模型预测时序数据

Pytorch模型定义

  • 深入浅出PyTorch
  • 1.数据集
    • 1.1数据读入
    • 1.2数据集预处理
  • 2模块化搭建模型

1.数据集

时间序列就是以时间为自变量的一系列数据。本文使用seabon包中的数据。

1.1数据读入

import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScalersns.get_dataset_names()
flight_data = sns.load_dataset("flights")

1.2数据集预处理

主要有两个部分:1)将数据转为浮点数以方便梯度计算;2)将数据进行标准化/归一化以防止过拟合。

# 归一化
scaler = MinMaxScaler(feature_range=(-1, 1))
train_data_normalized = scaler.fit_transform(train_data .reshape(-1, 1))
# 转为浮点形式的tensor
train_data_normalized = torch.FloatTensor(train_data_normalized).view(-1)
train_data_normalized

2模块化搭建模型

对于大部分模型结构(比如ResNet、DenseNet等),我们仔细观察就会发现,虽然模型有很多层, 但是其中有很多重复出现的结构。考虑到每一层有其输入和输出,若干层串联成的”模块“也有其输入和输出,如果我们能将这些重复出现的层定义为一个”模块“,每次只需要向网络中添加对应的模块来构建模型,这样将会极大便利模型构建的过程。在自然语言处理及时序数据处理中,具有重要作用的seqtoseq模型就是典型的例子。
seqtoseq模型主要包含编码器与解码器两个部分,其工作原理图如下:

关于该模型的理论详述,可参考https://zhuanlan.zhihu.com/p/57623148

第一部分编码器:

class Encoder(nn.Module):def __init__(self,input_size = 2,embedding_size = 128,hidden_size = 256,n_layers = 4,dropout = 0.5):super().__init__()self.hidden_size = hidden_sizeself.n_layers = n_layersself.linear = nn.Linear(input_size, embedding_size)self.rnn = nn.LSTM(embedding_size, hidden_size, n_layers,dropout = dropout)self.dropout = nn.Dropout(dropout)def forward(self, x):embedded = self.dropout(F.relu(self.linear(x)))output, (hidden, cell) = self.rnn(embedded)return hidden, cell

第二部分解码器:

class Decoder(nn.Module):def __init__(self,output_size = 2,embedding_size = 128,hidden_size = 256,n_layers = 4,dropout = 0.5):super().__init__()self.output_size = output_sizeself.hidden_size = hidden_sizeself.n_layers = n_layersself.embedding = nn.Linear(output_size, embedding_size)self.rnn = nn.LSTM(embedding_size, hidden_size, n_layers, dropout = dropout)self.linear = nn.Linear(hidden_size, output_size)self.dropout = nn.Dropout(dropout)def forward(self, x, hidden, cell):x = x.unsqueeze(0)embedded = self.dropout(F.relu(self.embedding(x)))prediction = self.linear(output.squeeze(0))return prediction, hidden, cell

使用写好的模型块,可以非常方便地组装seqtoseq模型。可以看到,通过模型块的方式实现了代码复用,整个模型结构定义所需的代码总行数明显减少,代码可读性也得到了提升。
第三部分seqtoseq:

class Seq2Seq(nn.Module):def __init__(self, encoder, decoder, device):super().__init__()self.encoder = encoderself.decoder = decoderself.device = deviceassert encoder.hidden_size == decoder.hidden_size, \"Hidden dimensions of encoder and decoder must be equal!"assert encoder.n_layers == decoder.n_layers, \"Encoder and decoder must have equal number of layers!"def forward(self, x, y, teacher_forcing_ratio = 0.5):batch_size = x.shape[1]target_len = y.shape[0]outputs = torch.zeros(y.shape).to(self.device)hidden, cell = self.encoder(x)decoder_input = x[-1, :, :]for i in range(target_len):output, hidden, cell = self.decoder(decoder_input, hidden, cell)outputs[i] = outputteacher_forcing = random.random() < teacher_forcing_ratiodecoder_input = y[i] if teacher_forcing else outputreturn outputs

该部分详述可参考https://curow.github.io/blog/LSTM-Encoder-Decoder/
模型训练完成后,可将训练好的参数进行保存,方便下次直接使用,储存方式主要有两种::存储整个模型(包括结构和权重),和只存储模型权重。更推荐使用第二种,因为若训练保存模型的pytorch版本与迁移后直接使用的pytorch版本相差较大时,直接加载第一种方法保存的模型可能会报错,且保存整个模型所需的空间明显更大。

# 保存整个模型
torch.save(model, save_dir)
# 保存模型权重
torch.save(model.state_dict, save_dir)

深入浅出PyTorch - Pytorch模型定义相关推荐

  1. PyTorch:模型定义

    PyTorch模型定义的方式 模型在深度学习中扮演着重要的角色,好的模型极大地促进了深度学习的发展进步,比如CNN的提出解决了图像.视频处理中的诸多问题,RNN/LSTM模型解决了序列数据处理的问题, ...

  2. pytorch卷积模型定义

    1.代码 class Net(nn.Module):def __init__(self):#对所有的层初始化super(Net, self).__init__()#父类的所有的属性self.conv1

  3. PyTorch学习笔记(五):模型定义、修改、保存

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  4. task03 Pytorch模型定义

    task03 Pytorch模型定义 2022/6/19 雾切凉宫 先引入必要的包 import os import numpy as np import collections import tor ...

  5. c++list遍历_小白学PyTorch | 6 模型的构建访问遍历存储(附代码)

    关注一下不迷路哦~喜欢的点个星标吧~<> 小白学PyTorch | 5 torchvision预训练模型与数据集全览 小白学PyTorch | 4 构建模型三要素与权重初始化 小白学PyT ...

  6. datawhale深入浅出Pytorch02——Pytorch各个模块组件

    Task02 Pytorch各个模块组件 本文主要参考DataWhale开源学习--深入浅出Pytorch,GitHub地址:https://github.com/datawhalechina/tho ...

  7. pytorch保存模型pth_Day159:模型的保存与加载

    网络结构和参数可以分开的保存和加载,因此,pytorch保存模型有两种方法: 保存 整个模型 (结构+参数) 只保存模型参数(官方推荐) # 保存整个网络torch.save(model, check ...

  8. 《Pytorch - CNN模型》

    2020年10月5号,依然在家学习. 今天是我写的第四个 Pytorch程序, 这一次我想把之前基于PyTorch实现的简易的传统的BP全连接神经网络改写成CNN网络,想看看对比和效果差异. 这一次我 ...

  9. 《Pytorch - 线性回归模型》

    2020年10月4号,依然在家学习. 今天是我写的第一个 Pytorch程序,从今天起也算是入门了. 就从简单的线性回归开始吧. 话不多说,我就直接上代码实例,代码的注释我都是用中文直接写的. imp ...

最新文章

  1. 知乎热议:周志华弟子 旷视南京负责人跳槽高校
  2. php 二维数组排序详解: array_multisort
  3. 异常org.xml.sax.SAXParseException; lineNumber: 5; columnNumber: 11; 注释中不允许出现字符串 --。的原因...
  4. ubuntu oracle 10g 安装,Ubuntu 12.04 安装Oracle 10g 全过程(完美)及问题解决办法
  5. C#数学计算包 Math.NET
  6. Understanding .NET Code Access Security
  7. 煤矿行业设备管理系统
  8. 使用GDB进行调试 -- 1 应用场景
  9. spring的依赖注入的方式(待更新)
  10. C++排列组合及应用
  11. java的多态是什么意思_【Java】基础18:什么叫多态?
  12. 5G系统——连接管理CM
  13. linux无法添加网络连接到服务器地址,ubuntu9.1服务器版局域网IP设置 网络无法连接(急)...
  14. 在线医疗系统(毕设)
  15. 用MVC写的查询,添加,删除,修改,登录。
  16. 常见app抓包软件对比
  17. c++ 箭头符号怎么打_C++语言中的标识符只能由字母、数字、下划线三种字符组成,且第一个字符_____。...
  18. ARM系列之ARM 平台安全架构PSA和Trustzone区别 浅析
  19. 手机摄像头的相关知识
  20. axis2 axiom_深入了解Axis2:AXIOM

热门文章

  1. 哈佛大学发表光子颜料技术,具有永不褪色、完全无毒等特点
  2. Python篇:面向对象练习及基础练习(制作简单双人游戏,自定规则排序练习等)
  3. 电气装备计算机控制实验,电气装备计算机控制系统设计实验报告(DOC)
  4. 感谢飞书放过幕布!GPT-4平替Poe;100个GPT-4实战案例;AI绘画新手指南之SD篇;new Bing靠谱教程;AI生成视频摘要神器 | ShowMeAI日报
  5. 广东全国计算机等级2018,2018年广东计算机等级考试报考简章
  6. Linux中case的用法
  7. 西门子SITRANS FM MAG 8000电磁水表基于IRDA红外收发器通讯无线数据采集对接云平台方案 ​​​​
  8. 如何在局域网使用手机访问电脑文件?
  9. int最大值java_java中int型最大值是多少?
  10. java 实现根据学号搜索学生信息