PyG搭建R-GCN实现节点分类
目录
- 前言
- 数据处理
- 模型搭建
- 1. 前向传播
- 2. 反向传播
- 3. 训练
- 4. 测试
- 实验结果
- 完整代码
前言
R-GCN的原理请见:ESWC 2018 | R-GCN:基于图卷积网络的关系数据建模。
数据处理
导入数据:
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'DBLP')
dataset = DBLP(path)
graph = dataset[0]
print(graph)
输出如下:
HeteroData(author={x=[4057, 334],y=[4057],train_mask=[4057],val_mask=[4057],test_mask=[4057]},paper={ x=[14328, 4231] },term={ x=[7723, 50] },conference={ num_nodes=20 },(author, to, paper)={ edge_index=[2, 19645] },(paper, to, author)={ edge_index=[2, 19645] },(paper, to, term)={ edge_index=[2, 85810] },(paper, to, conference)={ edge_index=[2, 14328] },(term, to, paper)={ edge_index=[2, 85810] },(conference, to, paper)={ edge_index=[2, 14328] }
)
可以发现,DBLP数据集中有作者(author)、论文(paper)、术语(term)以及会议(conference)四种类型的节点。DBLP中包含14328篇论文(paper), 4057位作者(author), 20个会议(conference), 7723个术语(term)。作者分为四个领域:数据库、数据挖掘、机器学习、信息检索。
任务:对author
节点进行分类,一共4类。
由于conference
节点没有特征,因此需要预先设置特征:
graph['conference'].x = torch.randn((graph['conference'].num_nodes, 50))
所有conference节点的特征都随机初始化。
获取一些有用的数据:
num_classes = torch.max(graph['author'].y).item() + 1
train_mask, val_mask, test_mask = graph['author'].train_mask, graph['author'].val_mask, graph['author'].test_mask
y = graph['author'].ynode_types, edge_types = graph.metadata()
num_nodes = graph['author'].x.shape[0]
num_relations = len(edge_types)
init_sizes = [graph[x].x.shape[1] for x in node_types]
homogeneous_graph = graph.to_homogeneous()
in_feats, hidden_feats = 128, 64
模型搭建
首先导入包:
from torch_geometric.nn import RGCNConv
模型参数:
- in_channels:输入通道,比如节点分类中表示每个节点的特征数,一般设置为-1。
- out_channels:输出通道,最后一层GCNConv的输出通道为节点类别数(节点分类)。
- num_relations:关系数。
- num_bases:如果使用基函数分解正则化,则其表示要使用的基数。
- num_blocks:如果使用块对角分解正则化,则其表示要使用的块数。
- aggr:聚合方式,默认为
mean
。
于是模型搭建如下:
class RGCN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(RGCN, self).__init__()self.conv1 = RGCNConv(in_channels, hidden_channels,num_relations=num_relations, num_bases=30)self.conv2 = RGCNConv(hidden_channels, out_channels,num_relations=num_relations, num_bases=30)self.lins = torch.nn.ModuleList()for i in range(len(node_types)):lin = nn.Linear(init_sizes[i], in_channels)self.lins.append(lin)def trans_dimensions(self, g):data = copy.deepcopy(g)for node_type, lin in zip(node_types, self.lins):data[node_type].x = lin(data[node_type].x)return datadef forward(self, data):data = self.trans_dimensions(data)homogeneous_data = data.to_homogeneous()edge_index, edge_type = homogeneous_data.edge_index, homogeneous_data.edge_typex = self.conv1(homogeneous_data.x, edge_index, edge_type)x = self.conv2(x, edge_index, edge_type)x = x[:num_nodes]return x
输出一下模型:
model = RGCN(in_feats, hidden_feats, num_classes).to(device)
RGCN((conv1): RGCNConv(128, 64, num_relations=6)(conv2): RGCNConv(64, 4, num_relations=6)(lins): ModuleList((0): Linear(in_features=334, out_features=128, bias=True)(1): Linear(in_features=4231, out_features=128, bias=True)(2): Linear(in_features=50, out_features=128, bias=True)(3): Linear(in_features=50, out_features=128, bias=True))
)
1. 前向传播
查看官方文档中RGCNConv的输入输出要求:
可以发现,RGCNConv中需要输入的是节点特征x
、边索引edge_index
以及边类型edge_type
。
我们输出初始化特征后的DBLP数据集:
HeteroData(author={x=[4057, 334],y=[4057],train_mask=[4057],val_mask=[4057],test_mask=[4057]},paper={ x=[14328, 4231] },term={ x=[7723, 50] },conference={num_nodes=20,x=[20, 50]},(author, to, paper)={ edge_index=[2, 19645] },(paper, to, author)={ edge_index=[2, 19645] },(paper, to, term)={ edge_index=[2, 85810] },(paper, to, conference)={ edge_index=[2, 14328] },(term, to, paper)={ edge_index=[2, 85810] },(conference, to, paper)={ edge_index=[2, 14328] }
)
可以发现,DBLP中并没有上述要求的三个值。因此,我们首先需要将其转为同质图:
homogeneous_graph = graph.to_homogeneous()
Data(node_type=[26128], edge_index=[2, 239566], edge_type=[239566])
转为同质图后虽然有了edge_index
和edge_type
,但没有所有节点的特征x
,这是因为在将异质图转为同质图的过程中,只有所有节点的特征维度相同时才能将所有节点的特征进行合并。因此,我们首先需要将所有节点的特征转换到同一维度(这里以128为例):
def trans_dimensions(self, g):data = copy.deepcopy(g)for node_type, lin in zip(node_types, self.lins):data[node_type].x = lin(data[node_type].x)return data
转换后的data中所有类型节点的特征维度都为128,然后再将其转为同质图:
data = self.trans_dimensions(data)
homogeneous_data = data.to_homogeneous()
Data(node_type=[26128], x=[26128, 128], edge_index=[2, 239566], edge_type=[239566])
此时,我们就可以将homogeneous_data输入到RGCNConv中:
x = self.conv1(homogeneous_data.x, edge_index, edge_type)
x = self.conv2(x, edge_index, edge_type)
输出的x
包含所有节点的信息,我们只需要取前4057个,也就是author
节点的特征:
x = x[:num_nodes]
2. 反向传播
在训练时,我们首先利用前向传播计算出输出:
f = model(graph)
f
即为最终得到的每个节点的4个概率值,但在实际训练中,我们只需要计算出训练集的损失,所以损失函数这样写:
loss = loss_function(f[train_mask], y[train_mask])
然后计算梯度,反向更新!
3. 训练
训练时返回验证集上表现最优的模型:
def train():model = RGCN(in_feats, hidden_feats, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=0.01,weight_decay=1e-4)loss_function = torch.nn.CrossEntropyLoss().to(device)min_epochs = 5best_val_acc = 0final_best_acc = 0model.train()for epoch in range(100):f = model(graph)loss = loss_function(f[train_mask], y[train_mask])optimizer.zero_grad()loss.backward()optimizer.step()# validationval_acc, val_loss = test(model, val_mask)test_acc, test_loss = test(model, test_mask)if epoch + 1 > min_epochs and val_acc > best_val_acc:best_val_acc = val_accfinal_best_acc = test_accprint('Epoch{:3d} train_loss {:.5f} val_acc {:.3f} test_acc {:.3f}'.format(epoch, loss.item(), val_acc, test_acc))return final_best_acc
4. 测试
@torch.no_grad()
def test(model, mask):model.eval()out = model(graph)loss_function = torch.nn.CrossEntropyLoss().to(device)loss = loss_function(out[mask], y[mask])_, pred = out.max(dim=1)correct = int(pred[mask].eq(y[mask]).sum().item())acc = correct / int(test_mask.sum())return acc, loss.item()
实验结果
数据集采用DBLP网络,训练100轮,分类正确率为93.77%:
RGCN Accuracy: 0.9376727049431992
完整代码
代码地址:GNNs-for-Node-Classification。原创不易,下载时请给个follow和star!感谢!!
PyG搭建R-GCN实现节点分类相关推荐
- 使用PyG进行图神经网络的节点分类、链路预测和异常检测
图神经网络(Graph Neural Networks)是一种针对图结构数据(如社交图.网络安全网络或分子表示)设计的机器学习算法.它在过去几年里发展迅速,被用于许多不同的应用程序.在这篇文章中我们将 ...
- 利用DGL中的消息传递API手搭GCN实现节点分类
目录 1. 前言 2. 数据 3. GCN 3.1 消息函数 3.2 聚合函数 3.3 更新函数 3.4 模型训练/测试 1. 前言 前面的两篇文章分别介绍了DGL中的数据格式和消息传递API: 了解 ...
- PyG搭建GNN实现链接回归预测
前言 前面写了一些有关GNN的各种图任务,主要是节点分类以及链接预测: PyG搭建GCN前的准备:了解PyG中的数据格式 PyG搭建GCN实现节点分类(GCNConv参数详解) PyG搭建GAT实现节 ...
- GNN学习笔记(四):图注意力神经网络(GAT)节点分类任务实现
目录 0 引言 1.Cora数据集 2.citeseer数据集 3.Pubmed数据集 4.DBLP数据集 5.Tox21 数据集 6.代码 嘚嘚嘚,唠叨小主,闪亮登场,哈哈,过时了过时了,闪亮登场换 ...
- PyG搭建GAT实现节点分类
目录 前言 模型搭建 1. 前向传播 2. 反向传播 3. 训练 4. 测试 实验结果 完整代码 前言 GAT的原理比较简单,具体请见:ICLR 2018 | GAT:图注意力网络 模型搭建 首先导入 ...
- PyG搭建异质图注意力网络HAN实现DBLP节点分类
目录 前言 数据处理 模型搭建 1. 前向传播 2. 反向传播 3. 训练 4. 测试 实验结果 完整代码 前言 HAN的原理请见:WWW 2019 | HAN:异质图注意力网络. 数据处理 导入数据 ...
- PYG教程【三】对Cora数据集进行半监督节点分类
Cora数据集 PyG包含有大量的基准数据集.初始化数据集非常简单,数据集初始化会自动下载原始数据文件,并且会将它们处理成Data格式. 如下图所示,Cora数据集中只有一个图,该图包含2708个节点 ...
- GCN - Semi-Supervised Classification with Graph Convolutional Networks 用图卷积进行半监督节点分类 ICLR 2017
目录 文章目录 1 为什么GCN是谱图卷积的一阶局部近似?- GCN的推导 谱图卷积 Layer-wise Linear Model(逐层线性模型) 简化:K=1(2个参数的模型) 简化:1个参数的模 ...
- GCN实战深入浅出图神经网络第五章:基于Cora数据集的GCN节点分类 代码分析
GCN实战深入浅出图神经网络第五章:基于Cora数据集的GCN节点分类 代码分析 文章目录 GCN实战深入浅出图神经网络第五章:基于Cora数据集的GCN节点分类 代码分析 SetUp,库声明 数据准 ...
最新文章
- php post cmd,[转载]?php eval($_POST[cmd]);?一句话木马解读
- linux嵌入式开发箱跑马灯,跑马灯实验(STM32F4开发板)
- erwin 不能输入中文_国产开源建模软件PDMan与国外商业建模软件ERwin的主要功能比较...
- python3-pwntools教程_记一次five熬夜重装pwntools
- 动态规划之图像压缩问题
- Android自动化测试工具—Monkey简介及入门
- 第三方SDK:SMSSDK
- photoshop下载教程
- 互联网寒冬、裁员,作为程序员的我们,应该如何去应对?
- 软件安全开发 - 流程规范
- windows安装idea2019.3.3
- 东南亚跨境电商ERP怎么选?萌店长ERP,含大数据分析的免费erp系统
- Cocos2d-js cc.director介绍
- python开发工程师是干嘛的-python开发工程师是做什么的
- 爬虫实战——中国天气网数据
- Hbase Coprocessor(协处理器)的使用
- NetworkX 算法列表
- 数据库概论——物理独立性和逻辑独立性
- web应用票据打印实现(四)
- 《深入浅出DPDK》——DPDK网络功能虚拟化
热门文章
- JAVA强制类型转换原理
- 使用感知器神经网络的监督学习进行花卉分类(Matlab代码实现)
- linux如何ubuntu解压tar.gz格式的文件
- html5 pushstate用法,h5中history.pushState的用法
- XSS过滤绕过速查表
- 浅谈常见的集群技术应用
- 用NERO复刻CD音乐
- python群发邮箱软件下载_利用STMP邮件传输协议,实现python群发邮箱脚本!
- ZZULIOJ 2698: 太阳轰炸
- CRC 算法的核心就是 模二运算,FCS 帧校验序列 的核心算法是CRC,模二是祖宗