目录

  • 前言
  • 数据集
  • 模型实现
    • PyTorch实现
    • PyG实现
  • 实验结果
  • 完整代码

前言

SGC的原理比较简单,具体请见:ICML 2019 | SGC:简单图卷积网络

数据集

数据集采用节点分类常用的三大引文网络:Citeseer、Cora和PubMed,数据集不再详细介绍。

模型实现

PyTorch实现

由于SGC的原理比较简单,因此用PyTorch手写也十分轻松。

观察SGC的表达式:

首先我们需要计算对称归一化的邻接矩阵SSS
S=D~−12A~D~−12S=\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}S=D~21A~D~21
其中A~=A+I\tilde{A}=A+IA~=A+ID~=D+I\tilde{D}=D+ID~=D+I

首先获取数据:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
names = ['CiteSeer', 'Cora', 'PubMed']dataset = Planetoid(root='data', name=names[0])
dataset = dataset[0]
edge_index, _ = add_self_loops(dataset.edge_index)

然后提取邻接矩阵:

# get adj
adj = to_scipy_sparse_matrix(edge_index).todense()
adj = torch.tensor(adj).to(device)

提取度矩阵:

deg = degree(edge_index[0], dataset.num_nodes)
deg = torch.diag_embed(deg)
deg_inv_sqrt = torch.pow(deg, -0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
deg_inv_sqrt = deg_inv_sqrt.to(device)

对邻接矩阵进行对称归一化:

s = torch.mm(torch.mm(deg_inv_sqrt, adj), deg_inv_sqrt)

对特征进行预处理:

k = 2
norm_x = torch.mm(torch.matrix_power(s, k), feature)

最后,搭建模型:

class SGC(nn.Module):def __init__(self, in_feats, out_feats):super(SGC, self).__init__()self.softmax = nn.Softmax(dim=1)self.w = nn.Linear(in_feats, out_feats)def forward(self, x):out = self.w(x)return self.softmax(out)

其中xxx为上面预处理过的特征norm_x。

模型训练:

def train(model):optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-6)loss_function = torch.nn.CrossEntropyLoss().to(device)scheduler = StepLR(optimizer, step_size=10, gamma=0.4)min_epochs = 10min_val_loss = 5final_best_acc = 0model.train()t = perf_counter()for epoch in tqdm(range(100)):out = model(norm_x)optimizer.zero_grad()loss = loss_function(out[train_mask], y[train_mask])loss.backward()optimizer.step()scheduler.step()# validationval_loss, test_acc = test(model)if val_loss < min_val_loss and epoch + 1 > min_epochs:min_val_loss = val_lossfinal_best_acc = test_accmodel.train()print('Epoch{:3d} train_loss {:.5f} val_loss {:.3f} test_acc {:.3f}'.format(epoch, loss.item(), val_loss, test_acc))train_time = perf_counter() - treturn final_best_acc, train_time

PyG实现

首先导入包:

from torch_geometric.nn import SGConv

模型参数:

  1. in_channels:输入通道,比如节点分类中表示每个节点的特征数。
  2. out_channels:输出通道,输出通道为节点类别数(节点分类)。
  3. K:跳数,最远提取到K阶邻居的特征,也就是前面公式中的K。
  4. cached:如果为True,则只是在第一次执行时才计算预处理后的特征,否则每一次都计算。默认为True。
  5. add_self_loops:如果为False不再强制添加自环,默认为True。
  6. bias:默认添加偏置。

于是模型搭建如下:

class PyG_SGC(nn.Module):def __init__(self, in_feats, out_feats):super(PyG_SGC, self).__init__()self.conv = SGConv(in_feats, out_feats, K=k, cached=True)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv(x, edge_index)x = F.softmax(x, dim=1)return x

训练时返回验证集上表现最优的模型:

def pyg_train(model, data):optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-6)loss_function = torch.nn.CrossEntropyLoss().to(device)scheduler = StepLR(optimizer, step_size=10, gamma=0.4)min_epochs = 10min_val_loss = 5final_best_acc = 0model.train()t = perf_counter()for epoch in tqdm(range(100)):out = model(data)optimizer.zero_grad()loss = loss_function(out[train_mask], y[train_mask])loss.backward()optimizer.step()scheduler.step()# validationval_loss, test_acc = pyg_test(model, data)if val_loss < min_val_loss and epoch + 1 > min_epochs:min_val_loss = val_lossfinal_best_acc = test_accmodel.train()print('Epoch{:3d} train_loss {:.5f} val_loss {:.3f} test_acc {:.3f}'.format(epoch, loss.item(), val_loss, test_acc))train_time = perf_counter() - treturn final_best_acc, train_time

实验结果

这里给出Citeseer网络的实验结果:

pytorch train_time: 0.2071399000000005
pytorch best test acc: 0.681pyg train_time: 0.21978220000000004
pyg best test acc: 0.676

可以看出PyTorch手写和PyG的效果类似,耗时也类似。

完整代码

后面统一整理。

搭建SGC实现引文网络节点预测(PyTorch+PyG)相关推荐

  1. 疫情之下春运人口回流“硬核”预测:往返这些超级网络节点城市有更大感染风险

    大数据文摘授权转载自nCoV疫情地图 作者:张海平 修宇璇 经历了各种"硬核"隔离之后,相信大家都明白了"人口的流动性是疫情防控的关键"这个道理. 春节前的人口 ...

  2. 复杂网络链路预测的研究现状及展望(2010)

    前言:做链路预测这个方向有一年多的时间了,有一些收获和体会.一直想写一个综述进行总结,总是希望这个综述尽可能的包括更多更全面的信息,但是新的思想和结果源源不断的涌现,所谓的综述也就无限期的搁置了下来. ...

  3. 节点表征学习与节点预测和边预测

    基于图神经网络的节点表征学习 引言 在图节点预测或边预测任务中,需要先构造节点表征(representation),节点表征是图节点预测和边预测任务成功的关键.在此篇文章中,我们将学习如何基于图神经网 ...

  4. 节点预测与边预测任务实践

    一.使用InMemoryDataset数据集类 import os.path as ospimport torch from torch_geometric.data import (InMemory ...

  5. 百度论文引用网络节点分类比赛

    论文引用网络节点分类比赛 Baseline 赛题介绍 图神经网络(Graph Neural Network)是一种专门处理图结构数据的神经网络,目前被广泛应用于推荐系统.金融风控.生物计算中.图神经网 ...

  6. 【论文翻译】基于分层关注和时间RNN的动态异构网络链路预测建模

    基于分层关注和时间RNN的动态异构网络链路预测建模 摘要 网络嵌入的目的是在获取网络结构信息的同时学习节点的低维表示.它在链路预测.节点分类等网络分析任务中取得了巨大的成功.现有的网络嵌入算法大多集中 ...

  7. 时空同步图卷积网络:时空网络数据预测的新框架

    1.文章信息 <Spatial-Temporal Synchronous Graph Convolutional Networks: A New Framework for Spatial-Te ...

  8. 【Ryo】SPSS Modeler:贝叶斯网络在预测银行信贷风险中的应用

    对银行信贷来说,如何量化客户违约的可能性,对潜在的风险进行预测是管理决策层关注的重中之重.面对复杂的信息结构和庞大的人群数据,运用贝叶斯网络能够理清相关影响因素的关联关系,是现在提高信贷违约风险预测正 ...

  9. 以太坊(Ethereum) - 网络节点

    章节 以太坊(Ethereum) – 是什么 以太坊(Ethereum) – 什么是智能合约 以太坊(Ethereum) – 以太币 以太坊(Ethereum) – 虚拟机(E.V.M.) 以太坊(E ...

最新文章

  1. frame和bounds
  2. jsonp react 获取返回值_Django+React全栈开发:文章列表
  3. mysql rpm 安装6_CentOS 7.6 MySQL 8.0 RPM包方式安装及新特性介绍
  4. linux 6.4 nfs配置,RHEL6.4 NFS文件共享服务器搭建
  5. Git 打补丁-- patch 和 diff 的使用(详细)
  6. Python实现日程表
  7. 【LWIP】(补充)STM32H743(M7内核)CubeMX配置LWIP并ping通
  8. Kong Dashboard系列【三】添加插件----rate-limiting
  9. 商贸零售行业2021年投资策略:市场下沉、渠道效率升级,新品牌新业态乘风而起
  10. 神舟战神笔记本怎么U盘装Win10系统教学
  11. 【Linux杂篇】Cron是什么?利用Cron Job自动执行定时任务
  12. 预编译及预处理的理解
  13. 厦门考计算机竞赛保送北大清华名单,优秀!厦门这些学生将保送清华北大等名校...
  14. 悬崖勒马回头是岸——关于玩王者荣耀游戏的一些想法
  15. 2020年8月-北京-百度度小满面试题(已offer)
  16. VMware Workstation 无法恢复错误: (vcpu-0) Exception 0xc0000005 (access violation) has occurred.
  17. VMware虚拟机忘记Linux用户登陆密码,重置密码解决办法
  18. MCDBA 微软官方考试内容
  19. 利用installshiled打包.inf和.sys驱动文件到setup.exe的方法
  20. 安卓逆向——某宝APP抓包之环境对比 (一)

热门文章

  1. 推荐系统[八]算法实践总结V0:淘宝逛逛and阿里飞猪个性化推荐:召回算法实践总结【冷启动召回、复购召回、用户行为召回等算法实战】
  2. 《通往财富自由之路》阅读笔记(一)
  3. 程序员应聘阿里P7岗,面试都过了,结果栽在背景调查!(你请注意了!)
  4. Linux系统安装JDK1.8 详细流程
  5. Ubuntu 安装Flash
  6. 技术变现,到底怎么变?本文或能成为你的“点金石”
  7. 西雅图的朋友如何知道自己是否可以打疫苗?
  8. ftp地址,ftp地址的2大作用
  9. 给 element-ui 表格的表头添加icon图标
  10. Fortran语言初探及Win7 64位下Fortran开发环境配置