GNN教程:大规模分布式训练
转载
目录
引言
多进程方案
Graph Store
Distributed Sampler
后话
引言
本文为GNN教程的DGL框架之大规模分布式训练,前面的文章中我们介绍了图神经网络框架DGL如何利用采样的技术缩小计算图的规模来通过mini-batch的方式训练模型,当图特别大的时候,非常多的batches需要被计算,因此运算时间又成了问题,一个容易想到解决方案是采用并行计算的技术,很多worker同时采样,计算并且更新梯度。这篇博文重点介绍DGL的并行计算框架。
多进程方案
概括而言,目前DGL(version 0.3)采用的是多进程的并行方案,分布式的方案正在开发中。见下图,DGL的并行计算框架分为两个主要部分:Graph Store
和Sampler
Sampler
被用来从大图中构建许多计算子图(NodeFlow
),DGL能够自动得在多个设备上并行运行多个Sampler
的实例。Graph Store
存储了大图的embedding信息和结构信息,到目前为止,DGL提供了内存共享式的Graph Store,以用来支持多进程,多GPU的并行训练。DGL未来还将提供分布式的Graph Store,以支持超大规模的图训练。
下面来分别介绍它们。
Graph Store
graph store 包含两个部分,server和client,其中server需要作为守护进程(daemon)在训练之前运行起来。比如如下脚本启动了一个graph store server 和 4个worker,并且载入了reddit数据集:
python3 run_store_server.py --dataset reddit --num-workers 4
在训练过程中,这4个worker将会和client交互以取得训练样本。用户需要做的仅仅是编写训练部分的代码。首先需要创建一个client对象连接到对应的server。下面的脚本中用shared_memory
初始化store_type
表明client连接的是一个内存共享式的server。
g = dgl.contrib.graph_store.create_graph_from_store("reddit", store_type="shared_mem")
g.update_all(fn.copy_src(src='features', out='m'),fn.sum(msg='m', out='preprocess'),lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})
初看这段代码和矩阵计算没有任何关系啊,其实这段代码要从语义上理解,在语义上表示邻接矩阵和特征矩阵的乘法,即对于每个节点的特征跟新为邻居特征的和。那么再看上面这段代码就容易了,
copy_src
将节点特征取出来,并发送出去, sum
接受到来自邻居的特征并求和,求和结果再发给节点,最后节点自身进行一下renormalize。
update_all
在graph store中是分布式进行的,每个trainer都会分派到一部分节点进行更新。
节点和边的数据现在全部存储在graph store中,因此访问他们不再像以前那样用 g.ndata/g.edata
那样简单,因为这两个方法会读取整个节点和边的数据,而这些数据在graph store中并不存在(他们可能是分开存储的),因此用户只能通过g.nodes[node_ids].data[embed_name]
来访问特定节点的Embedding数据。(注意:这种读数据的方式是通用的,并不是graph store特有的,g.ndata
即是g.nodes[:].data
的缩写)。
为了高效地初始化节点和边tensor,DGL提供了init_ndata
和init_edata
这两种方法。这两种方法都会讲初始化的命令发送到graph store server上,由server来代理初始化工作,下面展示了一个例子:
for i in range(n_layers):g.init_ndata('h_{}'.format(i), (features.shape[0], args.n_hidden), 'float32')g.init_ndata('agg_h_{}'.format(i), (features.shape[0], args.n_hidden), 'float32')
其中h_i
存储i
层节点Embedding,agg_h_i
存储i
节点邻居Embedding的聚集后的结果。
初始化节点数据之后,我们可以通过control-variate sampling的方法来训练GCN),这个方法在之前的博文中介绍过
for nf in NeighborSampler(g, batch_size, num_neighbors,neighbor_type='in', num_hops=L-1,seed_nodes=labeled_nodes):for i in range(nf.num_blocks):# aggregate history on the original graphg.pull(nf.layer_parent_nid(i+1),fn.copy_src(src='h_{}'.format(i), out='m'),lambda node: {'agg_h_{}'.format(i): node.data['m'].mean(axis=1)})# We need to copy data in the NodeFlow to the right context.nf.copy_from_parent(ctx=right_context)nf.apply_layer(0, lambda node : {'h' : layer(node.data['preprocess'])})h = nf.layers[0].data['h']for i in range(nf.num_blocks):prev_h = nf.layers[i].data['h_{}'.format(i)]# compute delta_h, the difference of the current activation and the historynf.layers[i].data['delta_h'] = h - prev_h# refresh the old historynf.layers[i].data['h_{}'.format(i)] = h.detach()# aggregate the delta_hnf.block_compute(i,fn.copy_src(src='delta_h', out='m'),lambda node: {'delta_h': node.data['m'].mean(axis=1)})delta_h = nf.layers[i + 1].data['delta_h']agg_h = nf.layers[i + 1].data['agg_h_{}'.format(i)]# control variate estimatornf.layers[i + 1].data['h'] = delta_h + agg_hnf.apply_layer(i + 1, lambda node : {'h' : layer(node.data['h'])})h = nf.layers[i + 1].data['h']# update historynf.copy_to_parent()
和原来代码稍有不同的是,这里right_context
表示数据在哪个设备上,通过将数据调度到正确的设备上,我们就可以完成多设备的分布式训练。
Distributed Sampler
因为我们有多个设备可以进行并行计算(比如说多GPU,多CPU),那么需要不断地给每个设备提供nodeflow
(计算子图实例)。DGL采用的做法是分出一部分设备专门负责采样,将采样作为服务提供给计算设备,计算设备只负责在采样后的子图上进行计算。DGL支持同时在多个设备上运行多个采样程序,每个采样程序都可以将采样结果发到计算设备上。
一个分布式采样的示例可以这样写,首先,在训练之前用户需要创建一个分布式SamplerReceiver
对象:
sampler = dgl.contrib.sampling.SamplerReceiver(graph, ip_addr, num_sampler)
SamplerReceiver`类用来从其他设备上接收采样出来的子图,这个API的三个参数分别为`parent_graph`, `ip_address`, 和`number_of_samplers
然后,用户只需要在单机版的训练代码中改变一行:
for nf in sampler:for i in range(nf.num_blocks):# aggregate history on the original graphg.pull(nf.layer_parent_nid(i+1),fn.copy_src(src='h_{}'.format(i), out='m'),lambda node: {'agg_h_{}'.format(i): node.data['m'].mean(axis=1)})...
其中,代码for nf in sampler
用来代替原单机采样代码:
for nf in NeighborSampler(g, batch_size, num_neighbors,neighbor_type='in', num_hops=L-1,seed_nodes=labeled_nodes):
其他所有的部分都可以保持不变。
因此,额外的开发工作主要是要编写运行在采样设备上的采样逻辑。对于邻居采样来说,开发者只需要拷贝单机采样的代码就可以了:
sender = dgl.contrib.sampling.SamplerSender(trainer_address)...for n in num_epoch:for nf in dgl.contrib.sampling.NeighborSampler(graph, batch_size, num_neighbors,neighbor_type='in',shuffle=shuffle,num_workers=num_workers,num_hops=num_hops,add_self_loop=add_self_loop,seed_nodes=seed_nodes):sender.send(nf, trainer_id)# tell trainer I have finished current epochsender.signal(trainer_id)
后话
本篇博文重点介绍了DGL的并行计算框架,其主要由采样层-计算层-存储层三层构建而来,采样和计算分布在不同的机器上,可以并行执行。通过这种方式,在存储充足的情况下,DGL可以处理数以亿计节点和边的大图。
GNN教程:大规模分布式训练相关推荐
- 如何像用MNIST一样来用ImageNet?这里有一份加速TensorFlow分布式训练的梯度压缩指南
作者 | 王佐 今年的 NIPS 出现 "Imagenet is the new MNIST" 口号,宣告使用 MNIST 数据集检验网络模型性能已经成为过去式.算法工程师们早就意 ...
- 新手手册:Pytorch分布式训练
文 | 花花@机器学习算法与自然语言处理 单位 | SenseTime 算法研究员 目录 0X01 分布式并行训练概述 0X02 Pytorch分布式数据并行 0X03 手把手渐进式实战 A. 单机单 ...
- MXNet结合kubeflow进行分布式训练
GPU集群配置MXNet+CUDA 为方便控制集群,写了脚本cmd2all.sh #!/bin/bash if [ $# -lt 3 ]; thenecho "usage: $0 [type ...
- ztree局部刷新节点_神经网络训练的世界记录是怎样被刷新的 -- 总结分布式训练的计算场景...
还是在今年(2018年)11月美国感恩节放假期间,我无意点开我的新论文搜索关注,假期的懈怠顿时被一扫而空.一篇谷歌的新论文跳入眼帘,声称打破了几天前刚建立的分布式训练速度的记录.各大公司训练速度记录上 ...
- VLDB 2023 | 北大河图发布分布式训练神器Galvatron,一键实现大模型高效自动并行...
©作者 | 北京大学河图团队 单位 | 北京大学数据与智能实验室 北大河图团队提出了一套面向大模型的自动并行分布式训练系统 Galvatron,相比于现有工作在多样性.复杂性.实用性方面均具有显著优势 ...
- VLDB 2023 | 北大河图发布分布式训练神器Galvatron, 一键实现大模型高效自动并行...
关注公众号,发现CV技术之美 本文转自机器之心. 北大河图团队提出了一套面向大模型的自动并行分布式训练系统Galvatron,相比于现有工作在多样性.复杂性.实用性方面均具有显著优势,论文成果已经被 ...
- MindSpore Reinforcement新特性:分布式训练和蒙特卡洛树搜索
MindSpore Reinforcement MindSpore Reinforcement v0.5 版本提供了基于Dataflow Fragment的分布式训练能力,通过扩展新的Fragment ...
- 张量模型并行详解 | 深度学习分布式训练专题
随着模型规模的扩大,单卡显存容量无法满足大规模模型训练的需求.张量模型并行是解决该问题的一种有效手段.本文以Transformer结构为例,介绍张量模型并行的基本原理. 模型并行的动机和现状 我们在上 ...
- 从分布式训练到大模型训练
要了解大模型训练难,我们得先看看从传统的分布式训练,到大模型的出现,需要大规模分布式训练的原因.接着第二点去了解下大规模训练的挑战. 从分布式训练到大规模训练 常见的训练方式是单机单卡,也就是一台服务 ...
最新文章
- Android String.xml 批量翻译工具 | Android string.xml 各国语言转换
- Linux防火墙配置—访问外网WEB
- hdu 4739 状压DP
- python第三天习题
- Razor视图引擎浅析
- python 属性描述符
- Fresh for Mac(文件管理软件)
- Spring Boot开发框架优点诠释
- (50)System Verilog类静态变量实例
- protobuf android ndk,直接在Android NDK端使用tensorflow(不使用JAVA api)
- 支付宝sdk java对接_java后台支付宝app支付调用sdk进行支付
- should, could, would, will, be going to, may, might到底有甚麼不同,又該怎麼用?
- ArrayList学习[常用方法|源码]
- 面向对象的系统分析(一)-系统分析方法
- 完美解决Python 发送邮件126,136,QQ等,都会报•554 DT:SPM 发送的邮件内容包含了未被许可的信息,或被系统识别为垃圾邮件。请检查是否有用户发送病毒或者垃圾邮件
- 旌扬机器人_“http://club.liangchanba.com/”搜索蜘蛛、机器人模拟抓取结果--站长工具...
- win7锁定计算机自动关机,windows7怎么设置电脑自动关机_win7如何自动关机
- 垃圾分类游戏HTML,垃圾分类宣传进村居,趣味游戏中学分类
- 厦门大学的【软件工程专业】被撤销!
- LabVIEW读海康网络摄像头问题