本文主要分析Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Network论文中提出的新的图神经网络模型以及新的训练图神经网络的训练方法。
一些经典模型如 GCN 采用了 full-batch 的 SGD 优化算法,要计算整个梯度则需要存储所有中间的 Embedding,因此,其是不可扩展的。此外,虽然每个 epoch 也只能更新一次参数。
GraphSAGE 中提出 mini-batch 的 SGD 优化方法,由于每次更新只基于一个 mini-batch,所以内存的需求降低,并在每个 epoch 中可以进行多次更新,从而收敛速度更快。然而,随着层数加深,每个节点的感受野越来越大,其计算单个节点的计算开销也会越来越大。针对这个问题,GraphSAGE 通过使用固定大小的邻居采样,同时 FastGCN 的重要性采样可以一定程度上解决计算开销,但是随着 GCN 的深度加深,计算开销问题依然没法解决。VR-GCN 提出利用方差来控制邻居的采样节点,尽管减少了采样的大小,但是它需要将所有节点的中间 Embedding 存储于内存中,导致其可扩展性较差。
下表是不同模型的时间复杂度和空间复杂度:

作者在实验中发现mini-batch的算法效率与batch内节点与batch外节点间的连接数量成正比,针对这一现象,作者构建了节点的分区,使同一分区中的节点之间的图连接于不同分区中的节点之间的图连接更多。

为了解决普通方法无法在超大图上做节点表征学习的问题,Cluster-GCN论文提出:

  • 利用图节点聚类算法将一个图的节点划分为c个簇,每一次选择几个组的节点和这些节点对应的边构成一个子图,然后对子图做训练。
  • 由于是利用图节点聚类算法将节点划分为多个簇,所以簇内边的数量要比簇间边的数量多得多。
  • 基于小图进行训练,不会消耗很多内存空间,于是我们可以训练更深的神经网络,从而得到更高的精度。
    我们知道,基于mini-batch的SGD可以在单个epoch中更新多次,从而使得其比full batch具有更快的收敛速度,但是前者每个epoch所花的时间都更长。

原理

对于一个图G而言,将其分为c组,其中只包含组内节点之间的边,对节点进行重组后,邻接矩阵被划分为c的平方个子矩阵,即

其中

对角线每个块都是大小为|Vt|x|Vt|的邻接矩阵,它由Gt内的边构成。 ∆是由A的所有非对角线块组成的矩阵。Xt和Yt分别由Vt中节点的特征和标签组成。
划分簇的意义:

  • 对于每个batch而言,Embedding utilization相当于簇内的连接,每个节点及其相邻节点通常位于同一簇内,因此经过几次后跳跃后,邻接节点大概率还是在簇内;
  • 我们使用它的块对角线近似值来替换了原来的矩阵,并且误差与簇间的的连接成正比,所以需要使得簇间的连接数量尽可能少。
    下图为图的随机分区与聚类分区的对比:

随机多分类

尽管 vanilla Cluster-GCN 能够减少计算开销和内存开销,但仍然存在两个问题:

  • 图被分割后,原来图中的一些连接会被删除,影响性能。
  • 聚类后的分布与原始数据集有区别,从而导致 SGD 更新时有偏差。
    下图为 Reddit 数据集中标签分布不平衡的案例,通过每个簇的标签分布计算其熵值,与随机分割相比,可以清楚的看到聚类分区的簇的熵较小,这表明簇的标签分布偏向于某些特征的标签,所以这会增加不同 batch 的梯度更新的差异,并影响 SGD 的收敛性。

    为了解决这个问题,作者提出随机多聚类方法对簇进行合并,从而减少batch间的差异。
    作者首先将图分为多个小簇,然后随机选择q个簇并到batch中,这样可以减少batch之间的差异。
    下图展示了每个epoch随机组合的batch,相同颜色的块在同一数据batch中:

    两种方式对比,随机多聚类方法收敛速度更快:

    作者提出了一个简单的技术来改进深度 GCN 的训练,核心思想在于放大每个 GCN 层中使用的邻接矩阵 A 的对角部分,并通过这种方式在每个 GCN 层的聚合中对上一层的 Embedding 添加更多的权重:但这种方法有些问题,比如这种方法无视相邻节点的数量,而对所有节点使用相同的权重。此外,当层数增加时,其数值可能会呈现指数型爆炸。所以作者先对邻接矩阵进行标准化。

实现

1.数据集分析

dataset = Reddit('../dataset/Reddit')
data = dataset[0]
# print(dataset.num_classes)
# print(data.num_nodes)
# print(data.num_edges)
# print(data.num_features)

可以看到该数据集包含41个分类任务,232965个节点,114615873条边,节点维度为602维。
2.图节点聚类和数据加载器生成

cluster_data = ClusterData(data, num_parts=1500, recursive=False, save_dir=dataset.processed_dir)
train_loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True, num_workers=0)
subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=1024, shuffle=False, num_workers=0)

train_loader:图节点首先被聚类,返回的一个batch由多个簇组成
subgraph_loader:使用此数据加载器不对图节点聚类,计算一个batch中的节点的嵌入需要计算该batch中所有节点的距离从0到L的邻居节点。
3.构建神经网络

class Net(torch.nn.Module):def __init__(self, in_channels, out_channels):super(Net, self).__init__()self.convs = ModuleList([SAGEConv(in_channels, 128), SAGEConv(128, out_channels)])def forward(self, x, edge_index):for i,conv in enumerate(self.convs):x = conv(x, edge_index)if i != len(self.convs) - 1:x = F.relu(x)x = F.dropout(x, p=0.5, training=self.training)return F.log_softmax(x, dim=-1)def inference(self, x_all):pbar = tqdm(total=x_all.size(0)*len(self.convs))pbar.set_description('Evaluating')for i,conv in enumerate(self.convs):xs = []for batch_size, n_id, adj in subgraph_loader:edge_index, _, size = adj.to(device)x = x_all[n_id].to(device)x_target = x[:size[1]]x = conv((x, x_target), edge_index)if i != len(self.convs) - 1:x = F.relu(x)xs.append(x.cpu())pbar.update(batch_size)x_all = torch.cat(xs, dim=0)pbar.close()return x_all

4.最后我们对网络进行训练,5个epoch验证一次,结果如下:


loss从0.3358降到了0.2313,测试集的准确率大概到了94.56%左右。

作业

将数据集切分成不同数量的簇进行实验,观察结果并进行比较。
1.将数据集分成1000簇,最后的正确率为94.56%。

2.将数据集分成2000簇

参考资料

1.https://cloud.tencent.com/developer/article/1665684
2.超大图上的表征学习

组队学习-图神经网络(fifth)相关推荐

  1. Datawhale组队学习-图神经网络(五)

    Datawhale组队学习-图神经网络(五) 此内容出自Cluster-GCN的论文:Cluster-GCN: An Efficient Algorithm for Training Deep and ...

  2. Datawhale组队学习-图神经网络(四)

    Datawhale组队学习-图神经网络(四) 数据完全存于内存的数据集类 + 节点预测与边预测任务实践 对于占用内存有限的数据集,我们可以将整个数据集的数据都存储到内存里.PyG为我们提供了方便的方式 ...

  3. Datawhale 6月学习——图神经网络:超大图上的节点表征学习

    前情回顾 图神经网络:图数据表示及应用 图神经网络:消息传递图神经网络 图神经网络:基于GNN的节点表征学习 图神经网络:基于GNN的节点预测任务及边预测任务 1 超大图上的节点表征学习 1.1 简述 ...

  4. 深度学习-强化学习-图神经网络-自然语言处理等AI课程超级大列表-最新版

    本篇文章内容整理自网络,汇集了大量关于深度学习.强化学习.机器学习.计算机视觉.语音识别.强化学习.图神经网络和自然语言处理相关的各种课程.之前分享过一次,经过一年的更新,又补充了很多2019.202 ...

  5. 2020人工智能课程超级大列表:深度学习-强化学习-图神经网络-自然语言处理等...

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 本篇博文主要为大家介绍一个课程网站,汇集了机器学习,深度学习.强化学习的各个方面, ...

  6. 学习图神经网络相关内容

    本周学习情况 本周学习任务: 学习图神经网络相关内容 图基本知识(连通分量.度中心性.特征向量中心性.中介中心性.接近中心性.PageRank.HITS)并使用networkx包简单实践. 学习了相关 ...

  7. 图表示学习+图神经网络:破解AI黑盒,揭示万物奥秘的钥匙!

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 从电信网络到社交网络,从经济网络到生物医学网络--图结构的数据无处 ...

  8. 【赠书】图表示学习+图神经网络:破解AI黑盒,揭示万物奥秘的钥匙!

    ‍‍ 从电信网络到社交网络,从经济网络到生物医学网络--图结构的数据无处不在. 如何提取图的特征,表示或编码图的结构,基于图数据进行学习.推理和归纳变得越来越重要. 因为无论是进行数据挖掘.分析社交网 ...

  9. 【图神经网络实战】深入浅出地学习图神经网络GNN(上)

    文章目录 一.图神经网络应用领域 1.1 芯片设计 1.2 场景分析与问题推理 1.3 推荐系统 1.4 欺诈检测与风控相关 1.5 知识图谱 1.6 道路交通的流量预测 1.7 自动驾驶(无人机等场 ...

最新文章

  1. Softmax 函数及其作用(含推导)
  2. ecshop根目录调用_ECSHOP各文件夹功能说明
  3. 7 orm 有批量更新_2020.12.24更新公告
  4. 浏览器数据库 IndexedDB(一) 概述
  5. 输出矩阵的左下半三角
  6. 【iOS】获取应用程序本地路径
  7. Ubuntu server解决不能访问外网问题
  8. 软考 信息安全工程师(第二版)笔记-第1章 网络信息安全概述
  9. CSS - 盒子模型(下)
  10. 全新2021款 Jlink隔离器,ARM仿真器隔离,Jlink,Nu-link,ULINK的隔离,Cortex-M系列隔离仿真
  11. mysql基础学习--day7
  12. 金融工程python报告期权_金融工程专业详细解析
  13. 第十五讲:达索系统锂电池行业解决方案在线直播 | 达索系统百世慧
  14. 计算机大学生三好学生申请书,大学生三好学生申请书
  15. Flutter开发之——Card
  16. Android Studio安装与SDK配置
  17. Android平台移植FFmpeg和x264
  18. edx 4G Network Essentials 4 - Nodes of the control plane (HSS, MME)
  19. windows11任务栏全透明
  20. User-Agent大全

热门文章

  1. 决战Python之巅(十二)
  2. xman的思维导图快捷键_思维导图与xmind快捷键
  3. 边缘智能:边缘计算和人工智能的深度融合
  4. wps打印预览显示不全怎么解决?
  5. 【教程】如何使用Java生成PDF文档?
  6. 小心 laravel 模型的 Soft Delete
  7. 载荷是什么意思?底层原理是什么?
  8. CVE-2022-30190 MSDT远程代码执行漏洞复现
  9. 什么是php面向对象及面向对象的三大特性
  10. SQL Server时间粒度系列----第7节日历数据表详解