PyG Temporal搭建STGCN实现多变量输入多变量输出时间序列预测
目录
- I. 前言
- II. STGCN
- III. PyG Temporal
- IV. 模型训练/测试
- V. 代码
I. 前言
前面已经写过不少时间序列预测的文章:
- 深入理解PyTorch中LSTM的输入和输出(从input输入到Linear输出)
- PyTorch搭建LSTM实现时间序列预测(负荷预测)
- PyTorch中利用LSTMCell搭建多层LSTM实现时间序列预测
- PyTorch搭建LSTM实现多变量时间序列预测(负荷预测)
- PyTorch搭建双向LSTM实现时间序列预测(负荷预测)
- PyTorch搭建LSTM实现多变量多步长时间序列预测(一):直接多输出
- PyTorch搭建LSTM实现多变量多步长时间序列预测(二):单步滚动预测
- PyTorch搭建LSTM实现多变量多步长时间序列预测(三):多模型单步预测
- PyTorch搭建LSTM实现多变量多步长时间序列预测(四):多模型滚动预测
- PyTorch搭建LSTM实现多变量多步长时间序列预测(五):seq2seq
- PyTorch中实现LSTM多步长时间序列预测的几种方法总结(负荷预测)
- PyTorch-LSTM时间序列预测中如何预测真正的未来值
- PyTorch搭建LSTM实现多变量输入多变量输出时间序列预测(多任务学习)
- PyTorch搭建ANN实现时间序列预测(风速预测)
- PyTorch搭建CNN实现时间序列预测(风速预测)
- PyTorch搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
- PyTorch搭建Transformer实现多变量多步长时间序列预测(负荷预测)
- PyTorch时间序列预测系列文章总结(代码使用方法)
- TensorFlow搭建LSTM实现时间序列预测(负荷预测)
- TensorFlow搭建LSTM实现多变量时间序列预测(负荷预测)
- TensorFlow搭建双向LSTM实现时间序列预测(负荷预测)
- TensorFlow搭建LSTM实现多变量多步长时间序列预测(一):直接多输出
- TensorFlow搭建LSTM实现多变量多步长时间序列预测(二):单步滚动预测
- TensorFlow搭建LSTM实现多变量多步长时间序列预测(三):多模型单步预测
- TensorFlow搭建LSTM实现多变量多步长时间序列预测(四):多模型滚动预测
- TensorFlow搭建LSTM实现多变量多步长时间序列预测(五):seq2seq
- TensorFlow搭建LSTM实现多变量输入多变量输出时间序列预测(多任务学习)
- TensorFlow搭建ANN实现时间序列预测(风速预测)
- TensorFlow搭建CNN实现时间序列预测(风速预测)
- TensorFlow搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
- PyG搭建图神经网络实现多变量输入多变量输出时间序列预测
- PyTorch搭建GNN-LSTM和LSTM-GNN模型实现多变量输入多变量输出时间序列预测
- PyG Temporal搭建STGCN实现多变量输入多变量输出时间序列预测
- 时序预测中Attention机制是否真的有效?盘点LSTM/RNN中24种Attention机制+效果对比
从第31篇文章起,本系列开始更新时空预测模型,其中前两篇文章都不是属于论文中的模型,今天介绍一个使用较为广泛的用于时序预测的时空图卷积网络STGCN。
II. STGCN
STGCN是北大发表在IJCAI 2018上的论文Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting中提出来的,其目的是用于实时的交通预测。
在该论文中使用的数据集为美国加州PeMSD7数据集,里面包含了分布在不同地方的228个传感器观测到的车流量,文章中使用这228个节点构成了一个无向图,然后根据历史的车流量信息预测未来某个时间段的所有传感器所在地的车流量信息。
可以看出,STGCN要解决的问题与前两篇文章要解决的问题基本一致。前两篇问题中,我们给出了13个变量前24小时的数据,目的是预测13个变量未来某几个小时的数据。在这里13个变量类比于228个传感器。
STGCN的原理也较为简单,STGCN由两个时空图卷积块(ST-Conv Block)和一个输出全连接层(Output Layer组成。其中ST-Conv Block又由两个时间门控卷积和中间的一个空间图卷积组成:
从图右边可知,两个Temporal Gated-Conv使用的是1-D卷积,和CNN处理一维时序信号类似,即进行seq_len
维度上的卷积。Spatial Graph-Conv进行的是空域上的卷积,模型为GCN。
关于STGCN详细的原理可以阅读原论文,原理也比较简单。本篇文章不做太多详细的推导过程,主要讲解如何利用STGCN进行多变量输入多变量输出的时间序列预测。
III. PyG Temporal
PyG Temporal是PyG的一个扩展库,其主要用于处理时空信号数据,里面实现了许多使用较为广泛的时空图卷积模型如STGCN、DCRNN、T-GCN、LRGCN等。
PyG Temporal的安装也比较简单:
pip install torch-geometric-temporal
PyG Temporal中STGCN的实现如下:
参数解释如下:
in_channels
:节点输入特征的维度大小,这里为1,即每个节点都只有一个特征,我们需要预测的也是该特征。hidden_channels
:字面意思。out_channels
:字面意思。kernel_size
:时域卷积时的卷积核大小,类比CNN即可。K
:将切比雪夫多项式作为图卷积核时的卷积核大小,具体可以参考我之前写的一篇文章:ICML 2019 | SGC:简单图卷积网络。normalization
:拉普拉斯矩阵的归一化选项,前面也讲过了。bias
:无需多述。
一个STConv所能接受的输入格式为:
可以看出,一个STConv需要接受三个输入:
X
:维度大小为(batch_size, seq_len, num_nodes, in_channels)
,在本文中即X=(256, 24, 13, 1)
。edge_index
:图的邻接矩阵。edge_weight
:边权重矩阵(可选)。
为此,我们可以首先搭建一个STGCN:
class STGCN(nn.Module):def __init__(self, num_nodes, size, K):super(STGCN, self).__init__()self.conv1 = STConv(num_nodes=num_nodes, in_channels=1, hidden_channels=16,out_channels=32, kernel_size=size, K=K)self.conv2 = STConv(num_nodes=num_nodes, in_channels=32, hidden_channels=16,out_channels=32, kernel_size=size, K=K)def forward(self, x, edge_index):# x(batch_size, seq_len, num_nodes, in_channels)x, edge_index = x.to(device), edge_index.to(device)x = F.elu(self.conv1(x, edge_index))x = self.conv2(x, edge_index)return x
然后一个用于多变量输入多变量输出的STGCN模型搭建如下:
class STGCN_MLP(nn.Module):def __init__(self, args):super(STGCN_MLP, self).__init__()self.args = argsself.out_feats = 128self.stgcn = STGCN(num_nodes=args.input_size, size=3, K=1)self.fcs = nn.ModuleList()for k in range(args.input_size):self.fcs.append(nn.Sequential(nn.Linear(16 * 32, 64),nn.ReLU(),nn.Linear(64, args.output_size)))def forward(self, x, edge_index):# x(batch_size, seq_len, input_size)# x(512, 24, 13)--->(512, 24, 13, 1)x = x.unsqueeze(3)x = self.stgcn(x, edge_index)preds = []for k in range(x.shape[2]):preds.append(self.fcs[k](torch.flatten(x[:, :, k, :], start_dim=1)))pred = torch.stack(preds, dim=0)return pred
照例简单分析一下模型的处理过程:
首先我们有x=(batch_size=256, seq_len=24, input_size=13)
,为了满足STGCN的输入要求(batch_size, seq_len, num_nodes, in_channels=1)
,我们需要将x
扩展一个维度:
x = x.unsqueeze(3)
然后经过STGCN:
x = self.stgcn(x, edge_index)
得到x=(256, 16, 13, 32)
。操作过程与CNN类似,一维卷积作用在seq_len=24
维度,最终变成16。随后,为了得到每个变量的输出,我们简单地将13个变量各自的(16, 32)
经过13个不同的全连接层。
IV. 模型训练/测试
这点与前面一致,不再赘述。
预测效果相当不错:
预测效果示意图(只给出前6个变量):
V. 代码
后续考虑整理公开。
PyG Temporal搭建STGCN实现多变量输入多变量输出时间序列预测相关推荐
- PyG搭建图神经网络实现多变量输入多变量输出时间序列预测
目录 I. 前言 II. 图的建立 III. 数据集构建 IV. 模型搭建(1) V. 模型搭建(2) VI. 模型训练/测试 I. 前言 前面已经写过不少时间序列预测的文章: 深入理解PyTorch ...
- TensorFlow搭建LSTM实现多变量时间序列预测(负荷预测)
目录 I. 前言 II. 数据处理 III. LSTM模型 IV. 训练/测试 V. 源码及数据 I. 前言 在前面的一篇文章TensorFlow搭建LSTM实现时间序列预测(负荷预测)中,我们利用L ...
- TensorFlow搭建CNN实现时间序列预测(风速预测)
目录 I. 数据集 II. 特征构造 III. 一维卷积 IV. 数据处理 1. 数据预处理 2. 数据集构造 V. CNN模型 1. 模型搭建 2. 模型训练及表现 VI. 源码及数据 时间序列预测 ...
- TensorFlow搭建LSTM实现时间序列预测(负荷预测)
目录 I. 前言 II. 数据处理 III. 模型 IV. 训练/测试 V. 源码及数据 I. 前言 前面已经写过不少时间序列预测的文章: 深入理解PyTorch中LSTM的输入和输出(从input输 ...
- TensorFlow搭建双向LSTM实现时间序列预测(负荷预测)
目录 I. 前言 II. 原理 III. 模型定义 IV. 训练和预测 V. 源码及数据 I. 前言 前面几篇文章中介绍的都是单向LSTM,这篇文章讲一下双向LSTM. 系列文章: 深入理解PyTor ...
- 回归预测 | MATLAB实现GWO-LSTM灰狼算法优化长短期记忆神经网络多输入单输出回归预测
回归预测 | MATLAB实现GWO-LSTM灰狼算法优化长短期记忆神经网络多输入单输出回归预测 目录 回归预测 | MATLAB实现GWO-LSTM灰狼算法优化长短期记忆神经网络多输入单输出回归预测 ...
- 回归预测 | MATLAB实现GWO-BiLSTM灰狼算法优化双向长短期记忆神经网络多输入单输出回归预测
回归预测 | MATLAB实现GWO-BiLSTM灰狼算法优化双向长短期记忆神经网络多输入单输出回归预测 目录 回归预测 | MATLAB实现GWO-BiLSTM灰狼算法优化双向长短期记忆神经网络多输 ...
- 回归预测 | MATLAB实现DBN-BP深度置信网络结合BP神经网络多输入单输出回归预测
回归预测 | MATLAB实现DBN-BP深度置信网络结合BP神经网络多输入单输出回归预测 目录 回归预测 | MATLAB实现DBN-BP深度置信网络结合BP神经网络多输入单输出回归预测 预测效果 ...
- 回归预测 | MATLAB实现DBN多层深度置信网络多输入单输出回归预测
回归预测 | MATLAB实现DBN多层深度置信网络多输入单输出回归预测 目录 回归预测 | MATLAB实现DBN多层深度置信网络多输入单输出回归预测 预测效果 基本介绍 模型描述 程序设计 参考资 ...
最新文章
- 吴恩达:机器学习毕业后,如何规划职业生涯?
- 矮个男生不好找对象?某大厂程序员自称太高也难找对象!身高196cm,有房有车,却被嫌太高!...
- QLibrary 动态加载外部库文件
- 负数如何归一化处理_机器学习之数据预处理
- RocketMQ开发指导之一——RocketMQ简介
- 不确定性原理的前世今生 · 数学篇(一)
- 恢复Debian下root用户bash高亮显示
- 手机连接蓝牙扫码枪_扫码枪蓝牙连接电脑 蓝牙扫码枪
- PMP模拟题200道,中英双对照,附答案解析
- ssh远程连接服务器常用命令
- html 英文发音,一些英文字母的发音
- 风变编程python26_风变编程学习Python的切身体会
- 五子棋游戏程序记录和复盘功能设置
- 将其他人物模型动画导入Carla使用
- 生物统计学教材中的统计推断方法
- 服装行业如何利用长尾关键词挖掘推广?
- mysql lag和lead_Oracle的LAG和LEAD分析函数
- (环境搭建+复现)74CMS模版注入+文件包含getshell
- LearnOpenGL学习笔记——OpenGL颜色
- freemarker 模板使用记录
热门文章
- LCD1602的引脚定义
- 排列组合问题Java实现
- 标准c语言程序设计,C语言程序设计标准.doc
- 学平面设计少走弯路,选择平面设计专业培训!
- 视频教程-Redis进阶教程—基础篇-NoSQL
- 面试中 项目遇见的难点答案_盘点产品经理求职面试中“可能”会遇到的十大项目管理问题...
- 计算机网络 一、概述
- 使用 Trace32 对 FLASH 编程摘要及Trace32-ICD和Trace32-ICE的区别
- 用Python代码画一只杰瑞
- fseek()函数的用法及其理解