目录

  • 前言
  • 数据处理
  • 模型搭建
    • 1. 前向传播
    • 2. 反向传播
    • 3. 训练
    • 4. 测试
  • 实验结果
  • 完整代码

前言

R-GCN的原理请见:ESWC 2018 | R-GCN:基于图卷积网络的关系数据建模。

数据处理

导入数据:

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'DBLP')
dataset = DBLP(path)
graph = dataset[0]
print(graph)

输出如下:

HeteroData(author={x=[4057, 334],y=[4057],train_mask=[4057],val_mask=[4057],test_mask=[4057]},paper={ x=[14328, 4231] },term={ x=[7723, 50] },conference={ num_nodes=20 },(author, to, paper)={ edge_index=[2, 19645] },(paper, to, author)={ edge_index=[2, 19645] },(paper, to, term)={ edge_index=[2, 85810] },(paper, to, conference)={ edge_index=[2, 14328] },(term, to, paper)={ edge_index=[2, 85810] },(conference, to, paper)={ edge_index=[2, 14328] }
)

可以发现,DBLP数据集中有作者(author)、论文(paper)、术语(term)以及会议(conference)四种类型的节点。DBLP中包含14328篇论文(paper), 4057位作者(author), 20个会议(conference), 7723个术语(term)。作者分为四个领域:数据库、数据挖掘、机器学习、信息检索。

任务:对author节点进行分类,一共4类。

由于conference节点没有特征,因此需要预先设置特征:

graph['conference'].x = torch.randn((graph['conference'].num_nodes, 50))

所有conference节点的特征都随机初始化。

获取一些有用的数据:

num_classes = torch.max(graph['author'].y).item() + 1
train_mask, val_mask, test_mask = graph['author'].train_mask, graph['author'].val_mask, graph['author'].test_mask
y = graph['author'].ynode_types, edge_types = graph.metadata()
num_nodes = graph['author'].x.shape[0]
num_relations = len(edge_types)
init_sizes = [graph[x].x.shape[1] for x in node_types]
homogeneous_graph = graph.to_homogeneous()
in_feats, hidden_feats = 128, 64

模型搭建

首先导入包:

from torch_geometric.nn import RGCNConv

模型参数:

  1. in_channels:输入通道,比如节点分类中表示每个节点的特征数,一般设置为-1。
  2. out_channels:输出通道,最后一层GCNConv的输出通道为节点类别数(节点分类)。
  3. num_relations:关系数。
  4. num_bases:如果使用基函数分解正则化,则其表示要使用的基数。
  5. num_blocks:如果使用块对角分解正则化,则其表示要使用的块数。
  6. aggr:聚合方式,默认为mean

于是模型搭建如下:

class RGCN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(RGCN, self).__init__()self.conv1 = RGCNConv(in_channels, hidden_channels,num_relations=num_relations, num_bases=30)self.conv2 = RGCNConv(hidden_channels, out_channels,num_relations=num_relations, num_bases=30)self.lins = torch.nn.ModuleList()for i in range(len(node_types)):lin = nn.Linear(init_sizes[i], in_channels)self.lins.append(lin)def trans_dimensions(self, g):data = copy.deepcopy(g)for node_type, lin in zip(node_types, self.lins):data[node_type].x = lin(data[node_type].x)return datadef forward(self, data):data = self.trans_dimensions(data)homogeneous_data = data.to_homogeneous()edge_index, edge_type = homogeneous_data.edge_index, homogeneous_data.edge_typex = self.conv1(homogeneous_data.x, edge_index, edge_type)x = self.conv2(x, edge_index, edge_type)x = x[:num_nodes]return x

输出一下模型:

model = RGCN(in_feats, hidden_feats, num_classes).to(device)
RGCN((conv1): RGCNConv(128, 64, num_relations=6)(conv2): RGCNConv(64, 4, num_relations=6)(lins): ModuleList((0): Linear(in_features=334, out_features=128, bias=True)(1): Linear(in_features=4231, out_features=128, bias=True)(2): Linear(in_features=50, out_features=128, bias=True)(3): Linear(in_features=50, out_features=128, bias=True))
)

1. 前向传播

查看官方文档中RGCNConv的输入输出要求:

可以发现,RGCNConv中需要输入的是节点特征x、边索引edge_index以及边类型edge_type

我们输出初始化特征后的DBLP数据集:

HeteroData(author={x=[4057, 334],y=[4057],train_mask=[4057],val_mask=[4057],test_mask=[4057]},paper={ x=[14328, 4231] },term={ x=[7723, 50] },conference={num_nodes=20,x=[20, 50]},(author, to, paper)={ edge_index=[2, 19645] },(paper, to, author)={ edge_index=[2, 19645] },(paper, to, term)={ edge_index=[2, 85810] },(paper, to, conference)={ edge_index=[2, 14328] },(term, to, paper)={ edge_index=[2, 85810] },(conference, to, paper)={ edge_index=[2, 14328] }
)

可以发现,DBLP中并没有上述要求的三个值。因此,我们首先需要将其转为同质图:

homogeneous_graph = graph.to_homogeneous()
Data(node_type=[26128], edge_index=[2, 239566], edge_type=[239566])

转为同质图后虽然有了edge_indexedge_type,但没有所有节点的特征x,这是因为在将异质图转为同质图的过程中,只有所有节点的特征维度相同时才能将所有节点的特征进行合并。因此,我们首先需要将所有节点的特征转换到同一维度(这里以128为例):

def trans_dimensions(self, g):data = copy.deepcopy(g)for node_type, lin in zip(node_types, self.lins):data[node_type].x = lin(data[node_type].x)return data

转换后的data中所有类型节点的特征维度都为128,然后再将其转为同质图:

data = self.trans_dimensions(data)
homogeneous_data = data.to_homogeneous()
Data(node_type=[26128], x=[26128, 128], edge_index=[2, 239566], edge_type=[239566])

此时,我们就可以将homogeneous_data输入到RGCNConv中:

x = self.conv1(homogeneous_data.x, edge_index, edge_type)
x = self.conv2(x, edge_index, edge_type)

输出的x包含所有节点的信息,我们只需要取前4057个,也就是author节点的特征:

x = x[:num_nodes]

2. 反向传播

在训练时,我们首先利用前向传播计算出输出:

f = model(graph)

f即为最终得到的每个节点的4个概率值,但在实际训练中,我们只需要计算出训练集的损失,所以损失函数这样写:

loss = loss_function(f[train_mask], y[train_mask])

然后计算梯度,反向更新!

3. 训练

训练时返回验证集上表现最优的模型:

def train():model = RGCN(in_feats, hidden_feats, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=0.01,weight_decay=1e-4)loss_function = torch.nn.CrossEntropyLoss().to(device)min_epochs = 5best_val_acc = 0final_best_acc = 0model.train()for epoch in range(100):f = model(graph)loss = loss_function(f[train_mask], y[train_mask])optimizer.zero_grad()loss.backward()optimizer.step()# validationval_acc, val_loss = test(model, val_mask)test_acc, test_loss = test(model, test_mask)if epoch + 1 > min_epochs and val_acc > best_val_acc:best_val_acc = val_accfinal_best_acc = test_accprint('Epoch{:3d} train_loss {:.5f} val_acc {:.3f} test_acc {:.3f}'.format(epoch, loss.item(), val_acc, test_acc))return final_best_acc

4. 测试

@torch.no_grad()
def test(model, mask):model.eval()out = model(graph)loss_function = torch.nn.CrossEntropyLoss().to(device)loss = loss_function(out[mask], y[mask])_, pred = out.max(dim=1)correct = int(pred[mask].eq(y[mask]).sum().item())acc = correct / int(test_mask.sum())return acc, loss.item()

实验结果

数据集采用DBLP网络,训练100轮,分类正确率为93.77%:

RGCN Accuracy: 0.9376727049431992

完整代码

代码地址:GNNs-for-Node-Classification。原创不易,下载时请给个follow和star!感谢!!

PyG搭建R-GCN实现节点分类相关推荐

  1. 使用PyG进行图神经网络的节点分类、链路预测和异常检测

    图神经网络(Graph Neural Networks)是一种针对图结构数据(如社交图.网络安全网络或分子表示)设计的机器学习算法.它在过去几年里发展迅速,被用于许多不同的应用程序.在这篇文章中我们将 ...

  2. 利用DGL中的消息传递API手搭GCN实现节点分类

    目录 1. 前言 2. 数据 3. GCN 3.1 消息函数 3.2 聚合函数 3.3 更新函数 3.4 模型训练/测试 1. 前言 前面的两篇文章分别介绍了DGL中的数据格式和消息传递API: 了解 ...

  3. PyG搭建GNN实现链接回归预测

    前言 前面写了一些有关GNN的各种图任务,主要是节点分类以及链接预测: PyG搭建GCN前的准备:了解PyG中的数据格式 PyG搭建GCN实现节点分类(GCNConv参数详解) PyG搭建GAT实现节 ...

  4. GNN学习笔记(四):图注意力神经网络(GAT)节点分类任务实现

    目录 0 引言 1.Cora数据集 2.citeseer数据集 3.Pubmed数据集 4.DBLP数据集 5.Tox21 数据集 6.代码 嘚嘚嘚,唠叨小主,闪亮登场,哈哈,过时了过时了,闪亮登场换 ...

  5. PyG搭建GAT实现节点分类

    目录 前言 模型搭建 1. 前向传播 2. 反向传播 3. 训练 4. 测试 实验结果 完整代码 前言 GAT的原理比较简单,具体请见:ICLR 2018 | GAT:图注意力网络 模型搭建 首先导入 ...

  6. PyG搭建异质图注意力网络HAN实现DBLP节点分类

    目录 前言 数据处理 模型搭建 1. 前向传播 2. 反向传播 3. 训练 4. 测试 实验结果 完整代码 前言 HAN的原理请见:WWW 2019 | HAN:异质图注意力网络. 数据处理 导入数据 ...

  7. PYG教程【三】对Cora数据集进行半监督节点分类

    Cora数据集 PyG包含有大量的基准数据集.初始化数据集非常简单,数据集初始化会自动下载原始数据文件,并且会将它们处理成Data格式. 如下图所示,Cora数据集中只有一个图,该图包含2708个节点 ...

  8. GCN - Semi-Supervised Classification with Graph Convolutional Networks 用图卷积进行半监督节点分类 ICLR 2017

    目录 文章目录 1 为什么GCN是谱图卷积的一阶局部近似?- GCN的推导 谱图卷积 Layer-wise Linear Model(逐层线性模型) 简化:K=1(2个参数的模型) 简化:1个参数的模 ...

  9. GCN实战深入浅出图神经网络第五章:基于Cora数据集的GCN节点分类 代码分析

    GCN实战深入浅出图神经网络第五章:基于Cora数据集的GCN节点分类 代码分析 文章目录 GCN实战深入浅出图神经网络第五章:基于Cora数据集的GCN节点分类 代码分析 SetUp,库声明 数据准 ...

最新文章

  1. php post cmd,[转载]?php eval($_POST[cmd]);?一句话木马解读
  2. linux嵌入式开发箱跑马灯,跑马灯实验(STM32F4开发板)
  3. erwin 不能输入中文_国产开源建模软件PDMan与国外商业建模软件ERwin的主要功能比较...
  4. python3-pwntools教程_记一次five熬夜重装pwntools
  5. 动态规划之图像压缩问题
  6. Android自动化测试工具—Monkey简介及入门
  7. 第三方SDK:SMSSDK
  8. photoshop下载教程
  9. 互联网寒冬、裁员,作为程序员的我们,应该如何去应对?
  10. 软件安全开发 - 流程规范
  11. windows安装idea2019.3.3
  12. 东南亚跨境电商ERP怎么选?萌店长ERP,含大数据分析的免费erp系统
  13. Cocos2d-js cc.director介绍
  14. python开发工程师是干嘛的-python开发工程师是做什么的
  15. 爬虫实战——中国天气网数据
  16. Hbase Coprocessor(协处理器)的使用
  17. NetworkX 算法列表
  18. 数据库概论——物理独立性和逻辑独立性
  19. web应用票据打印实现(四)
  20. 《深入浅出DPDK》——DPDK网络功能虚拟化

热门文章

  1. JAVA强制类型转换原理
  2. 使用感知器神经网络的监督学习进行花卉分类(Matlab代码实现)
  3. linux如何ubuntu解压tar.gz格式的文件
  4. html5 pushstate用法,h5中history.pushState的用法
  5. XSS过滤绕过速查表
  6. 浅谈常见的集群技术应用
  7. 用NERO复刻CD音乐
  8. python群发邮箱软件下载_利用STMP邮件传输协议,实现python群发邮箱脚本!
  9. ZZULIOJ 2698: 太阳轰炸
  10. CRC 算法的核心就是 模二运算,FCS 帧校验序列 的核心算法是CRC,模二是祖宗