搭建SGC实现引文网络节点预测(PyTorch+PyG)
目录
- 前言
- 数据集
- 模型实现
- PyTorch实现
- PyG实现
- 实验结果
- 完整代码
前言
SGC的原理比较简单,具体请见:ICML 2019 | SGC:简单图卷积网络
数据集
数据集采用节点分类常用的三大引文网络:Citeseer、Cora和PubMed,数据集不再详细介绍。
模型实现
PyTorch实现
由于SGC的原理比较简单,因此用PyTorch手写也十分轻松。
观察SGC的表达式:
首先我们需要计算对称归一化的邻接矩阵SSS:
S=D~−12A~D~−12S=\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}S=D~−21A~D~−21
其中A~=A+I\tilde{A}=A+IA~=A+I,D~=D+I\tilde{D}=D+ID~=D+I。
首先获取数据:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
names = ['CiteSeer', 'Cora', 'PubMed']dataset = Planetoid(root='data', name=names[0])
dataset = dataset[0]
edge_index, _ = add_self_loops(dataset.edge_index)
然后提取邻接矩阵:
# get adj
adj = to_scipy_sparse_matrix(edge_index).todense()
adj = torch.tensor(adj).to(device)
提取度矩阵:
deg = degree(edge_index[0], dataset.num_nodes)
deg = torch.diag_embed(deg)
deg_inv_sqrt = torch.pow(deg, -0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
deg_inv_sqrt = deg_inv_sqrt.to(device)
对邻接矩阵进行对称归一化:
s = torch.mm(torch.mm(deg_inv_sqrt, adj), deg_inv_sqrt)
对特征进行预处理:
k = 2
norm_x = torch.mm(torch.matrix_power(s, k), feature)
最后,搭建模型:
class SGC(nn.Module):def __init__(self, in_feats, out_feats):super(SGC, self).__init__()self.softmax = nn.Softmax(dim=1)self.w = nn.Linear(in_feats, out_feats)def forward(self, x):out = self.w(x)return self.softmax(out)
其中xxx为上面预处理过的特征norm_x。
模型训练:
def train(model):optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-6)loss_function = torch.nn.CrossEntropyLoss().to(device)scheduler = StepLR(optimizer, step_size=10, gamma=0.4)min_epochs = 10min_val_loss = 5final_best_acc = 0model.train()t = perf_counter()for epoch in tqdm(range(100)):out = model(norm_x)optimizer.zero_grad()loss = loss_function(out[train_mask], y[train_mask])loss.backward()optimizer.step()scheduler.step()# validationval_loss, test_acc = test(model)if val_loss < min_val_loss and epoch + 1 > min_epochs:min_val_loss = val_lossfinal_best_acc = test_accmodel.train()print('Epoch{:3d} train_loss {:.5f} val_loss {:.3f} test_acc {:.3f}'.format(epoch, loss.item(), val_loss, test_acc))train_time = perf_counter() - treturn final_best_acc, train_time
PyG实现
首先导入包:
from torch_geometric.nn import SGConv
模型参数:
- in_channels:输入通道,比如节点分类中表示每个节点的特征数。
- out_channels:输出通道,输出通道为节点类别数(节点分类)。
- K:跳数,最远提取到K阶邻居的特征,也就是前面公式中的K。
- cached:如果为True,则只是在第一次执行时才计算预处理后的特征,否则每一次都计算。默认为True。
- add_self_loops:如果为False不再强制添加自环,默认为True。
- bias:默认添加偏置。
于是模型搭建如下:
class PyG_SGC(nn.Module):def __init__(self, in_feats, out_feats):super(PyG_SGC, self).__init__()self.conv = SGConv(in_feats, out_feats, K=k, cached=True)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv(x, edge_index)x = F.softmax(x, dim=1)return x
训练时返回验证集上表现最优的模型:
def pyg_train(model, data):optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-6)loss_function = torch.nn.CrossEntropyLoss().to(device)scheduler = StepLR(optimizer, step_size=10, gamma=0.4)min_epochs = 10min_val_loss = 5final_best_acc = 0model.train()t = perf_counter()for epoch in tqdm(range(100)):out = model(data)optimizer.zero_grad()loss = loss_function(out[train_mask], y[train_mask])loss.backward()optimizer.step()scheduler.step()# validationval_loss, test_acc = pyg_test(model, data)if val_loss < min_val_loss and epoch + 1 > min_epochs:min_val_loss = val_lossfinal_best_acc = test_accmodel.train()print('Epoch{:3d} train_loss {:.5f} val_loss {:.3f} test_acc {:.3f}'.format(epoch, loss.item(), val_loss, test_acc))train_time = perf_counter() - treturn final_best_acc, train_time
实验结果
这里给出Citeseer网络的实验结果:
pytorch train_time: 0.2071399000000005
pytorch best test acc: 0.681pyg train_time: 0.21978220000000004
pyg best test acc: 0.676
可以看出PyTorch手写和PyG的效果类似,耗时也类似。
完整代码
后面统一整理。
搭建SGC实现引文网络节点预测(PyTorch+PyG)相关推荐
- 疫情之下春运人口回流“硬核”预测:往返这些超级网络节点城市有更大感染风险
大数据文摘授权转载自nCoV疫情地图 作者:张海平 修宇璇 经历了各种"硬核"隔离之后,相信大家都明白了"人口的流动性是疫情防控的关键"这个道理. 春节前的人口 ...
- 复杂网络链路预测的研究现状及展望(2010)
前言:做链路预测这个方向有一年多的时间了,有一些收获和体会.一直想写一个综述进行总结,总是希望这个综述尽可能的包括更多更全面的信息,但是新的思想和结果源源不断的涌现,所谓的综述也就无限期的搁置了下来. ...
- 节点表征学习与节点预测和边预测
基于图神经网络的节点表征学习 引言 在图节点预测或边预测任务中,需要先构造节点表征(representation),节点表征是图节点预测和边预测任务成功的关键.在此篇文章中,我们将学习如何基于图神经网 ...
- 节点预测与边预测任务实践
一.使用InMemoryDataset数据集类 import os.path as ospimport torch from torch_geometric.data import (InMemory ...
- 百度论文引用网络节点分类比赛
论文引用网络节点分类比赛 Baseline 赛题介绍 图神经网络(Graph Neural Network)是一种专门处理图结构数据的神经网络,目前被广泛应用于推荐系统.金融风控.生物计算中.图神经网 ...
- 【论文翻译】基于分层关注和时间RNN的动态异构网络链路预测建模
基于分层关注和时间RNN的动态异构网络链路预测建模 摘要 网络嵌入的目的是在获取网络结构信息的同时学习节点的低维表示.它在链路预测.节点分类等网络分析任务中取得了巨大的成功.现有的网络嵌入算法大多集中 ...
- 时空同步图卷积网络:时空网络数据预测的新框架
1.文章信息 <Spatial-Temporal Synchronous Graph Convolutional Networks: A New Framework for Spatial-Te ...
- 【Ryo】SPSS Modeler:贝叶斯网络在预测银行信贷风险中的应用
对银行信贷来说,如何量化客户违约的可能性,对潜在的风险进行预测是管理决策层关注的重中之重.面对复杂的信息结构和庞大的人群数据,运用贝叶斯网络能够理清相关影响因素的关联关系,是现在提高信贷违约风险预测正 ...
- 以太坊(Ethereum) - 网络节点
章节 以太坊(Ethereum) – 是什么 以太坊(Ethereum) – 什么是智能合约 以太坊(Ethereum) – 以太币 以太坊(Ethereum) – 虚拟机(E.V.M.) 以太坊(E ...
最新文章
- frame和bounds
- jsonp react 获取返回值_Django+React全栈开发:文章列表
- mysql rpm 安装6_CentOS 7.6 MySQL 8.0 RPM包方式安装及新特性介绍
- linux 6.4 nfs配置,RHEL6.4 NFS文件共享服务器搭建
- Git 打补丁-- patch 和 diff 的使用(详细)
- Python实现日程表
- 【LWIP】(补充)STM32H743(M7内核)CubeMX配置LWIP并ping通
- Kong Dashboard系列【三】添加插件----rate-limiting
- 商贸零售行业2021年投资策略:市场下沉、渠道效率升级,新品牌新业态乘风而起
- 神舟战神笔记本怎么U盘装Win10系统教学
- 【Linux杂篇】Cron是什么?利用Cron Job自动执行定时任务
- 预编译及预处理的理解
- 厦门考计算机竞赛保送北大清华名单,优秀!厦门这些学生将保送清华北大等名校...
- 悬崖勒马回头是岸——关于玩王者荣耀游戏的一些想法
- 2020年8月-北京-百度度小满面试题(已offer)
- VMware Workstation 无法恢复错误: (vcpu-0) Exception 0xc0000005 (access violation) has occurred.
- VMware虚拟机忘记Linux用户登陆密码,重置密码解决办法
- MCDBA 微软官方考试内容
- 利用installshiled打包.inf和.sys驱动文件到setup.exe的方法
- 安卓逆向——某宝APP抓包之环境对比 (一)
热门文章
- 推荐系统[八]算法实践总结V0:淘宝逛逛and阿里飞猪个性化推荐:召回算法实践总结【冷启动召回、复购召回、用户行为召回等算法实战】
- 《通往财富自由之路》阅读笔记(一)
- 程序员应聘阿里P7岗,面试都过了,结果栽在背景调查!(你请注意了!)
- Linux系统安装JDK1.8 详细流程
- Ubuntu 安装Flash
- 技术变现,到底怎么变?本文或能成为你的“点金石”
- 西雅图的朋友如何知道自己是否可以打疫苗?
- ftp地址,ftp地址的2大作用
- 给 element-ui 表格的表头添加icon图标
- Fortran语言初探及Win7 64位下Fortran开发环境配置