GNN Tricks《Bag of Tricks of Semi-Supervised Classification with Graph Neural Networks》
Wang Y. Bag of Tricks of Semi-Supervised Classification with Graph Neural Networks[J]. arXiv preprint arXiv:2103.13355, 2021.
我在浏览OGB排行榜代码的时候偶然发现了一篇关于GNN的Tricks的文章,作者是DGL Team的大佬,这篇貌似还没有被会议接受,不过已经在Arxiv上preprint出来了。本文改进后的几个模型在几个OGB数据集上的表现都不错。所以就赶快拿过来看看,学习一下,还是受到了一些启发。代码的话是用DGL框架和规范写的,现在还看不太懂,等我先把DGL框架学一下再好好拜读一下code。
本文配套的ogbn-arxiv代码:https://github.com/Espylapiza/dgl/tree/master/examples/pytorch/ogb/ogbn-arxiv
目前我按照作者的思路试了一下label use(PyG代码),并没有对我的模型(GCN内核)起到什么效果。ogbn-arxiv排行榜我的是13名,用了label use的模型是16名。。
PS:我记得当年CNN也有一篇《Bag of Tricks…》,很经典,可以参考《深度学习 cnn trick合集》。
文章目录
- Abstract
- 1. Introduction
- 2. Preliminary
- 2.1. Problem Formulation
- 2.2. Existing Tricks
- 3. Methodology
- 3.1. Label Usage
- 3.2. Loss Function
- 3.3. Architecture Design
- 4. Experiments
- 我的总结与思考
Abstract
作者表示,关于GNN的模型结构改进方法现在有不少,但是这些paper里常常会忽略模型实现的一些Tricks(可能觉得太low所以闭口不提),只有当你去看code的时候才会发现一些细节和Tricks。
但是,这些被忽略的技术/trick在GNN的实践中起到了十分重要的作用,并且往往伴随着模型精度的提升。
本文就提出了一些GNN中的新技巧,包括模型设计+标签使用等等。
1. Introduction
从OGB排行榜上可以十分直观的发现,模型精度的提升不仅仅依赖于模型(体系)结构的改变,也就是说并不一定要提出新的GNN模型,从技巧或Tricks上进行改进也能提升模型性能、发paper(比如本文,哈哈哈)。
作者发现,目前的GNN模型缺乏对节点标签信息的使用。虽然说最近出现了LPA相关的算法,比如LP、C&S,但是基于LP的算法的理论动机是要求相邻节点具有相似的label,但是在异构图中貌似并不是这样的,并且LPA算法也不能直接处理加权图。
本文提出的新技术主要涉及标签使用和架构设计,后面会详细说。
2. Preliminary
2.1. Problem Formulation
Label Propagation Algorithm.
LPA算法就是通过边/邻接关系,把节点的标签信息传播给邻居节点。(这里和文章中的不太一样,是加了残差传播的版本)
Y′=α⋅D−1/2AD−1/2Y+(1−α)Y\mathbf{Y}^{\prime} = \alpha \cdot \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2} \mathbf{Y} + (1 - \alpha) \mathbf{Y} Y′=α⋅D−1/2AD−1/2Y+(1−α)Y
Combination of Label and Feature Usages.
LP、C&S等方法可能会导致产生次优解,我在实践中也有体会,有时候用C&S,模型的性能反而会下降。
2.2. Existing Tricks
Data Augmentation.
数据增强方法。在之前的笔记中我也对这方面有所关注,主要是Dropout和FLAG。
Sampling.
采样方法。采样除了被应用于缩小每次训练的图规模(minibatch),还可以被当做是一种训练或正则化技巧。代表性方法有FastGCN(分层采用)、LADIES(层重要性采样)以及NLP(word2vec)当中的负采样技术。
Renormalization.
GCN中提出的重归一化,用于缓解数值不稳定和梯度爆炸。
3. Methodology
3.1. Label Usage
主要是将节点的标签也作为输入。由于节点的标签信息能够提供更多的信息,所以如果能利用好label的话理论上会对模型性能提升有很大帮助。
本文在label use方面提出的Trick主要是通过mask技术,将初始节点特征和label(经过mask后的)拼接后作为输入,以从标签信息中学习到更多的标签信息。
label use大体思路就是:(PyG代码,可自己体会,比较难描述)
def add_labels(feat, labels, idx):onehot = torch.zeros([feat.shape[0], dataset.num_classes]).to(device)onehot[idx, labels[idx, 0]] = 1return torch.cat([feat, onehot], dim=-1)
# 定义训练函数
def train():model.train()mask_rate = 0.5mask = torch.rand(train_idx.shape) < mask_ratetrain_labels_idx = train_idx[mask]train_pred_idx = train_idx[~mask]feat = add_labels(x, y, train_labels_idx)out = model(feat, data.adj_t)loss = criterion(out[train_pred_idx], data.y.squeeze(1)[train_pred_idx])# loss = cross_entropy(out[train_idx], data.y[train_idx])optimizer.zero_grad()loss.backward()optimizer.step()return loss.item()
Augmentation with Label Reuse.
label reuse等什么时候看了作者的code之后再说吧。主要思想是用上一次迭代的预测值来代替常数0。
3.2. Loss Function
主要是对CrossEntropy损失函数进行了小改动(CE–>LCE),并证明了LCE鲁棒性更好。
改进方面是在CE的基础上增加了超参数ϵ\epsilonϵ=1e-2。就不上公式了,pytorch代码更易懂一些。
epsilon = 1 - math.log(2)def cross_entropy(x, labels):y = F.cross_entropy(x, labels[:, 0], reduction="none")y = torch.log(epsilon + y) - math.log(epsilon)return torch.mean(y)
3.3. Architecture Design
Architecture Variant for GCN.
对于GCN内核的改进主要是仿照skip-connection,增加了一个linear层(公式里后面那一项),以保证每个节点的输出都不同,以缓解过平滑。
X(k+1)=σ((D~−12A~D~−12)X(k)W0(k)+X(k)W1(k))X^{(k+1)}=\sigma \left(\left(\tilde D^{-\frac12}\tilde A\tilde D^{-\frac12}\right)X^{(k)}W_0^{(k)}+X^{(k)}W_1^{(k)}\right) X(k+1)=σ((D~−21A~D~−21)X(k)W0(k)+X(k)W1(k))
def forward(self, graph, feat):h = feath = self.input_drop(h)for i in range(self.n_layers):conv = self.convs[i](graph, h)if self.use_linear:linear = self.linear[i](h)h = conv + linearelse:h = convif i < self.n_layers - 1:h = self.norms[i](h)h = self.activation(h)h = self.dropout(h)return h
Architecture Variant for GAT.
GAT的改动就比较大了,并且从实验结果来看,改进策略相当成功!不仅成功优化了训练时的内存占用,还能够和其他策略一起,提升模型的精度,值得我好好学习GAT的代码以及策略。
(我用PyG复现的GAT跑,16GB显卡会出现GPU内存溢出的问题,并且效果贼差,看来是时候好好学习一下DGL了)
从公式上看,改进后的GAT和GCN形式貌似差不多。
X(k+1)=σ((D~−12A~attD~−12)X(k)W0(k)+X(k)W1(k))X^{(k+1)}=\sigma \left(\left(\tilde D^{-\frac12}\tilde A_{att}\tilde D^{-\frac12}\right)X^{(k)}W_0^{(k)}+X^{(k)}W_1^{(k)}\right) X(k+1)=σ((D~−21A~attD~−21)X(k)W0(k)+X(k)W1(k))
当A~att=A\tilde A_{att}=AA~att=A时,GAT就退化成了GCN。A~att\tilde A_{att}A~att是归一化后的注意力矩阵,具体的实现还需要去仔细研究一下代码。
4. Experiments
作者主要在ogbn-arxiv、ogbn-products和ogbn-proteins这三个数据集上进行了实验,结果就不赘述了。
我的总结与思考
近期会研究一下Tricks以及OGB冲榜。看来这次DGL是不得不学了。
我的博客:我向OGB排行榜提交代码的经历。
GNN Tricks《Bag of Tricks of Semi-Supervised Classification with Graph Neural Networks》相关推荐
- 《Bag of Tricks for Node Classification with Graph Neural Networks》阅读笔记
论文地址:Bag of Tricks for Node Classification with Graph Neural Networks 一.概述 本文作者总结了前人关于图上半监督节点分类任务的常用 ...
- 复杂网络论文解析——《Finding Patient Zero:Learning Contagion Source with Graph Neural Networks》
本文为原创,转载需声明出处. 介绍最近看的一篇复杂网络研究流行病传染源的文章,<Finding Patient Zero: LearningContagion Source with Graph ...
- CV:翻译并解读2019《A Survey of the Recent Architectures of Deep Convolutional Neural Networks》第一章~第三章
CV:翻译并解读2019<A Survey of the Recent Architectures of Deep Convolutional Neural Networks>第一章~第三 ...
- 论文精翻《Progressive Tandem Learning for Pattern Recognition With Deep Spiking Neural Networks》
目录 0 摘要/Abstract 1 简介/Introduction 2 相关工作/Related Work 3 重新思考ANN-to-SNN的转换/Rethinking ANN-to-SNN Con ...
- 《How powerful are graph neural networks》论文翻译
作者:Keyulu Xu (MIT),Weihua Hu(Stanford Universtity),Jure Leskovec(Stanford Universtity),Stefanie Jege ...
- 《Poluparity Prediction on Social Platforms with Coupled Graph Neural Networks》阅读笔记
论文地址:Popularity Prediction on Social Platforms with Coupled Graph Neural Networks 文章概览 作者提出了一种耦合图神经网 ...
- Node Classification with Graph Neural Networks(使用GNN进行节点分类)
文章目录 Setup 准备数据集 处理和可视化数据集 拆分数据集为分层训练集和测试集 训练和评估的实现 Feedforward Network(FFN) 构建一个Baseline神经网络模型 为bas ...
- 《A Gentle Introduction to Graph Neural Networks》要点
A Gentle Introduction to Graph Neural Networks 1.架构 该篇文章总共有4块信息:什么数据可以表示成一张图.图和别的数据有什么不一样的地方 为什么要用图神 ...
- 论文阅读《SuperGlue: Learning Feature Matching with Graph Neural Networks》
论文地址:https://arxiv.org/abs/1911.11763 代码地址:https://github.com/magicleap/SuperGluePretrainedNetwork 背 ...
最新文章
- 自建ELK vs 日志服务(SLS)全方位对比
- ERROR: Could not install packages due to an EnvironmentError: [Errno 2] No such file or directory: ‘
- php web 目录遍历,php的目录遍历操作
- Fixjs——显示容器基类DisplayObjectContainer
- SpannableString 给TextView添加不同的显示样式
- 文件资源管理软件EagleFiler for Mac
- [高频电子线路]-避免从第一章开始懵逼
- 面向价值实现的数据资产管理体系构建
- WORD图、表标号——题注
- 比特大陆60天 :夺权、立威下的疯狂裁员
- win7(win10)更改“文件类型显示图标“的终极修改方法
- 天天快充滚动图片android750x379
- win10,ubuntu18.04系统下图像识别YOLOv5菠萝_附菠萝数据集图片标签
- Spring-IoC注解
- linux6 64位,CentOS 6.0 X64官方正式版系统(64位)
- iPhone 开发常用工具
- emoji表情 mysql转移,mysql中emoji表情存储
- 企业微信自动添加手机好友工具
- 《野蛮生长》--冯仑
- Flak模型和应用(一对一,一对多,多对多)
热门文章
- 如何用Stata完成(shui)一篇经济学论文(一):软件安装与语法规范
- Linux服务器宝塔面板7.7破解方法(其他一些常见问题)
- Dell venue 8 pro 打造全功能机
- 单片微机原理与接口技术——8051汇编指令系统与编程基础(3)算术运算与逻辑运算指令
- 后盾网-CI框架实例教程-马振宇 - 学习笔记(4)
- 研究发现,无创连续Masimo PVi®监测对指导术中体液处治的价值优于中央静脉压测量
- numpy创建一个8x8的国际象棋
- 阿里云短信验证-PHP
- java中的adt安装配置,Android SDK 2.3与Eclipse最新版开发环境搭建
- 【技巧】Pandas使用drop后使用reset_index重置索性