一、Cluster-GCN

论文 Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Network 针对普通训练方法无法训练超大图的问题,提出了解决方法:

  • 利用图节点聚类算法将一个图的节点划分为ccc个簇,每一次选择几个簇的节点和这些节点对应的边构成一个子图,然后对子图做训练。
  • 由于是利用图节点聚类算法将节点划分为多个簇,所以簇内边的数量要比簇间边的数量多得多,所以可以提高表征利用率,并提高图神经网络的训练效率。
  • 每一次随机选择多个簇来组成一个batch,这样不会丢失簇间的边,同时也不会有batch内类别分布偏差过大的问题。
  • 基于小图进行训练,不会消耗很多内存空间,于是我们可以训练更深的神经网络,进而可以达到更高的精度。

二、Cluster-GCN实践

加载数据集

from torch_geometric.datasets import Reddit
from torch_geometric.data import ClusterData, ClusterLoader, NeighborSamplerdataset = Reddit('../dataset/Reddit')
data = dataset[0]
print(dataset.num_classes)
print(data.num_nodes)
print(data.num_edges)
print(data.num_features)

图节点聚类与数据加载器生成

cluster_data = ClusterData(data, num_parts=1500, recursive=False, save_dir=dataset.processed_dir)
# 此数据加载器返回的一个batch由多个簇组成
train_loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True, num_workers=12)
# 使用此数据加载器对图节点聚类
subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=1024, shuffle=False, num_workers=12)

图神经网络的构建

class Net(torch.nn.Module):def __init__(self, in_channels, out_channels):super(Net, self).__init__()self.convs = ModuleList([SAGEConv(in_channels, 128),SAGEConv(128, out_channels)])def forward(self, x, edge_index):for i, conv in enumerate(self.convs):x = conv(x, edge_index)if i != len(self.convs) - 1:x = F.relu(x)x = F.dropout(x, p=0.5, training=self.training)return F.log_softmax(x, dim=-1)# inference方法用于推理阶段,获取更高的预测精度def inference(self, x_all):pbar = tqdm(total=x_all.size(0) * len(self.convs))pbar.set_description('Evaluating')# Compute representations of nodes layer by layer, using *all*# available edges. This leads to faster computation in contrast to# immediately computing the final representations of each batch.for i, conv in enumerate(self.convs):xs = []for batch_size, n_id, adj in subgraph_loader:edge_index, _, size = adj.to(device)x = x_all[n_id].to(device)x_target = x[:size[1]]x = conv((x, x_target), edge_index)if i != len(self.convs) - 1:x = F.relu(x)xs.append(x.cpu())pbar.update(batch_size)x_all = torch.cat(xs, dim=0)pbar.close()return x_all

训练、验证与测试

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)def train():model.train()total_loss = total_nodes = 0for batch in train_loader:batch = batch.to(device)optimizer.zero_grad()out = model(batch.x, batch.edge_index)loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])loss.backward()optimizer.step()nodes = batch.train_mask.sum().item()total_loss += loss.item() * nodestotal_nodes += nodesreturn total_loss / total_nodes@torch.no_grad()
def test():  # Inference should be performed on the full graph.model.eval()out = model.inference(data.x)y_pred = out.argmax(dim=-1)accs = []for mask in [data.train_mask, data.val_mask, data.test_mask]:correct = y_pred[mask].eq(data.y[mask]).sum().item()accs.append(correct / mask.sum().item())return accsfor epoch in range(1, 31):loss = train()if epoch % 5 == 0:train_acc, val_acc, test_acc = test()print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, 'f'Val: {val_acc:.4f}, test: {test_acc:.4f}')else:print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')

原文地址

超大图上的节点表征学习相关推荐

  1. Datawhale 6月学习——图神经网络:超大图上的节点表征学习

    前情回顾 图神经网络:图数据表示及应用 图神经网络:消息传递图神经网络 图神经网络:基于GNN的节点表征学习 图神经网络:基于GNN的节点预测任务及边预测任务 1 超大图上的节点表征学习 1.1 简述 ...

  2. Datawhale 图神经网络 Task05 超大图上的节点表征学习

    学习课程:gitee_Datawhale_GNN 学习论坛:Datawhale CLUB 公众号:Datawhale 本次学习的内容是有关于超大图的,具体的论文是Cluster-GCN: An Eff ...

  3. 图神经网络基础--基于图神经网络的节点表征学习

    图神经网络基础–基于图神经网络的节点表征学习 引言 在图节点预测或边预测任务中,首先需要生成节点表征(Node Representation).我们使用图神经网络来生成节点表征,并通过基于监督学习的对 ...

  4. 图神经网络/GNN(三)-基于图神经网络的节点表征学习

    Task3概览: 在图任务当中,首要任务就是要生成节点特征,同时高质量的节点表征也是用于下游机器学习任务的前提所在.本次任务通过GNN来生成节点表征,并通过基于监督学习对GNN的训练,使得GNN学会产 ...

  5. 基于图神经网络的节点表征学习

    我们使用图神经网络来生成节点表征,并通过基于监督学习的对图神经网络的训练,使得图神经网络学会产生高质量的节点表征.高质量的节点表征能够用于衡量节点的相似性,同时高质量的节点表征也是准确分类节点的前提. ...

  6. 图神经网络GNN(三):基于图神经网络的节点表征学习

    1. 写在前面 这个系列整理的关于GNN的相关基础知识, 图深度学习是一个新兴的研究领域,将深度学习与图数据连接了起来,推动现实中图预测应用的发展. 之前一直想接触这一块内容,但总找不到能入门的好方法 ...

  7. 节点表征学习与节点预测和边预测

    基于图神经网络的节点表征学习 引言 在图节点预测或边预测任务中,需要先构造节点表征(representation),节点表征是图节点预测和边预测任务成功的关键.在此篇文章中,我们将学习如何基于图神经网 ...

  8. NeurIPS 2021 | 图上不均衡表示学习新视野:基于拓扑结构的不均衡学习

    论文标题: Topology-Imbalance Learning for Semi-Supervised Node Classification 论文链接: https://arxiv.org/ab ...

  9. 腾讯游戏自研学术成果:基于图分割的网络表征学习初始化技术

    图是一种通用的数据表现形式,图算法逐渐在大数据处理中展现其价值.网络表征学习算法作为目前比较主流的一种图数据处理算法,引起学术界和工业界的极大兴趣. 本文介绍了 IEG 在网络表征学习方面的一个自研学 ...

最新文章

  1. 解决hal.dll丢失问题 调试方法启动XP
  2. rest-framework 响应器(渲染器)
  3. C++:录入班级数学成绩,计算最大值、平均值、不及格人数
  4. 引用参考文献_引用参考文献时应注意些什么
  5. Winform 自定义窗体皮肤组件
  6. 使用 Apache Lucene 搜索文本——轻松为应用程序构建搜索和索引功能
  7. C语言宏使用常见问题
  8. 在Java 8中使用Stream API解析文件
  9. 别再Prompt了!谷歌提出tuning新方法,强力释放GPT-3潜力!
  10. cad二次开发 java_应用Java语言进行AutoCAD2000二次开发.PDF
  11. vue设置多选框默认勾选_vue中复选框怎么默认全选,至少选择4个才可以点击下一步...
  12. Intel保护机制:特权级别:Protection Rings
  13. win7桌面图标计算机打不开,win7系统下双击电脑桌面图标打不开的解决方法
  14. 如何使用iMazing备份、恢复《暴力飞车》游戏存档
  15. 数学建模2-美国人口增长模型的确定
  16. 通过命名空间快速定位SAP UI5工程名称
  17. MFC选择文件和文件夹对话框
  18. 电子科技大学信通2018级学生上早自习缺勤率情况分析
  19. 勇敢的心——感动内德
  20. vim显示空格和tab符号

热门文章

  1. request技巧-utils的功能-cookie对象与字典的转换-URL编码与解码-关掉SSL验证
  2. 草稿 爬虫-访问登陆可见的页面-请求时带上cookie数据
  3. python-虚拟环境的创建与使用-针对linu系统
  4. (原創) 如何将字符串前后的空白去除? (使用string.find_first_not_of, string.find_last_not_of) (C/C++)...
  5. 阿里云服务器ECS挑选什么样的网站环境
  6. AC日记——Count on a tree bzoj 2588
  7. android模拟按键问题总结[使用IWindowManager.injectKeyEvent方法](转)
  8. oracle的redo与undio
  9. 【转】Unix环境高级程序设计入门----文件系统的相关编程(上)
  10. 面试:Java分派机制