转载

目录

引言

多进程方案

Graph Store

Distributed Sampler

后话


引言

本文为GNN教程的DGL框架之大规模分布式训练,前面的文章中我们介绍了图神经网络框架DGL如何利用采样的技术缩小计算图的规模来通过mini-batch的方式训练模型,当图特别大的时候,非常多的batches需要被计算,因此运算时间又成了问题,一个容易想到解决方案是采用并行计算的技术,很多worker同时采样,计算并且更新梯度。这篇博文重点介绍DGL的并行计算框架。

多进程方案

概括而言,目前DGL(version 0.3)采用的是多进程的并行方案,分布式的方案正在开发中。见下图,DGL的并行计算框架分为两个主要部分:Graph StoreSampler

  • Sampler被用来从大图中构建许多计算子图(NodeFlow),DGL能够自动得在多个设备上并行运行多个Sampler的实例。

  • Graph Store存储了大图的embedding信息和结构信息,到目前为止,DGL提供了内存共享式的Graph Store,以用来支持多进程,多GPU的并行训练。DGL未来还将提供分布式的Graph Store,以支持超大规模的图训练。

下面来分别介绍它们。

Graph Store

graph store 包含两个部分,server和client,其中server需要作为守护进程(daemon)在训练之前运行起来。比如如下脚本启动了一个graph store server 和 4个worker,并且载入了reddit数据集:

python3 run_store_server.py --dataset reddit --num-workers 4

在训练过程中,这4个worker将会和client交互以取得训练样本。用户需要做的仅仅是编写训练部分的代码。首先需要创建一个client对象连接到对应的server。下面的脚本中用shared_memory初始化store_type表明client连接的是一个内存共享式的server。

g = dgl.contrib.graph_store.create_graph_from_store("reddit", store_type="shared_mem")

g.update_all(fn.copy_src(src='features', out='m'),fn.sum(msg='m', out='preprocess'),lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})

初看这段代码和矩阵计算没有任何关系啊,其实这段代码要从语义上理解,在语义上表示邻接矩阵和特征矩阵的乘法,即对于每个节点的特征跟新为邻居特征的和。那么再看上面这段代码就容易了,copy_src将节点特征取出来,并发送出去, sum接受到来自邻居的特征并求和,求和结果再发给节点,最后节点自身进行一下renormalize。

update_all在graph store中是分布式进行的,每个trainer都会分派到一部分节点进行更新。

节点和边的数据现在全部存储在graph store中,因此访问他们不再像以前那样用 g.ndata/g.edata那样简单,因为这两个方法会读取整个节点和边的数据,而这些数据在graph store中并不存在(他们可能是分开存储的),因此用户只能通过g.nodes[node_ids].data[embed_name]来访问特定节点的Embedding数据。(注意:这种读数据的方式是通用的,并不是graph store特有的,g.ndata即是g.nodes[:].data的缩写)。

为了高效地初始化节点和边tensor,DGL提供了init_ndatainit_edata这两种方法。这两种方法都会讲初始化的命令发送到graph store server上,由server来代理初始化工作,下面展示了一个例子:

for i in range(n_layers):g.init_ndata('h_{}'.format(i), (features.shape[0], args.n_hidden), 'float32')g.init_ndata('agg_h_{}'.format(i), (features.shape[0], args.n_hidden), 'float32')

其中h_i存储i层节点Embedding,agg_h_i存储i节点邻居Embedding的聚集后的结果。

初始化节点数据之后,我们可以通过control-variate sampling的方法来训练GCN),这个方法在之前的博文中介绍过

for nf in NeighborSampler(g, batch_size, num_neighbors,neighbor_type='in', num_hops=L-1,seed_nodes=labeled_nodes):for i in range(nf.num_blocks):# aggregate history on the original graphg.pull(nf.layer_parent_nid(i+1),fn.copy_src(src='h_{}'.format(i), out='m'),lambda node: {'agg_h_{}'.format(i): node.data['m'].mean(axis=1)})# We need to copy data in the NodeFlow to the right context.nf.copy_from_parent(ctx=right_context)nf.apply_layer(0, lambda node : {'h' : layer(node.data['preprocess'])})h = nf.layers[0].data['h']for i in range(nf.num_blocks):prev_h = nf.layers[i].data['h_{}'.format(i)]# compute delta_h, the difference of the current activation and the historynf.layers[i].data['delta_h'] = h - prev_h# refresh the old historynf.layers[i].data['h_{}'.format(i)] = h.detach()# aggregate the delta_hnf.block_compute(i,fn.copy_src(src='delta_h', out='m'),lambda node: {'delta_h': node.data['m'].mean(axis=1)})delta_h = nf.layers[i + 1].data['delta_h']agg_h = nf.layers[i + 1].data['agg_h_{}'.format(i)]# control variate estimatornf.layers[i + 1].data['h'] = delta_h + agg_hnf.apply_layer(i + 1, lambda node : {'h' : layer(node.data['h'])})h = nf.layers[i + 1].data['h']# update historynf.copy_to_parent()

和原来代码稍有不同的是,这里right_context表示数据在哪个设备上,通过将数据调度到正确的设备上,我们就可以完成多设备的分布式训练。

Distributed Sampler

因为我们有多个设备可以进行并行计算(比如说多GPU,多CPU),那么需要不断地给每个设备提供nodeflow(计算子图实例)。DGL采用的做法是分出一部分设备专门负责采样,将采样作为服务提供给计算设备,计算设备只负责在采样后的子图上进行计算。DGL支持同时在多个设备上运行多个采样程序,每个采样程序都可以将采样结果发到计算设备上。

一个分布式采样的示例可以这样写,首先,在训练之前用户需要创建一个分布式SamplerReceiver对象:

sampler = dgl.contrib.sampling.SamplerReceiver(graph, ip_addr, num_sampler)
SamplerReceiver`类用来从其他设备上接收采样出来的子图,这个API的三个参数分别为`parent_graph`, `ip_address`, 和`number_of_samplers

然后,用户只需要在单机版的训练代码中改变一行:

for nf in sampler:for i in range(nf.num_blocks):# aggregate history on the original graphg.pull(nf.layer_parent_nid(i+1),fn.copy_src(src='h_{}'.format(i), out='m'),lambda node: {'agg_h_{}'.format(i): node.data['m'].mean(axis=1)})...

其中,代码for nf in sampler用来代替原单机采样代码:

for nf in NeighborSampler(g, batch_size, num_neighbors,neighbor_type='in', num_hops=L-1,seed_nodes=labeled_nodes):

其他所有的部分都可以保持不变。

因此,额外的开发工作主要是要编写运行在采样设备上的采样逻辑。对于邻居采样来说,开发者只需要拷贝单机采样的代码就可以了:

sender = dgl.contrib.sampling.SamplerSender(trainer_address)...for n in num_epoch:for nf in dgl.contrib.sampling.NeighborSampler(graph, batch_size, num_neighbors,neighbor_type='in',shuffle=shuffle,num_workers=num_workers,num_hops=num_hops,add_self_loop=add_self_loop,seed_nodes=seed_nodes):sender.send(nf, trainer_id)# tell trainer I have finished current epochsender.signal(trainer_id)

后话

本篇博文重点介绍了DGL的并行计算框架,其主要由采样层-计算层-存储层三层构建而来,采样和计算分布在不同的机器上,可以并行执行。通过这种方式,在存储充足的情况下,DGL可以处理数以亿计节点和边的大图。

GNN教程:大规模分布式训练相关推荐

  1. 如何像用MNIST一样来用ImageNet?这里有一份加速TensorFlow分布式训练的梯度压缩指南

    作者 | 王佐 今年的 NIPS 出现 "Imagenet is the new MNIST" 口号,宣告使用 MNIST 数据集检验网络模型性能已经成为过去式.算法工程师们早就意 ...

  2. 新手手册:Pytorch分布式训练

    文 | 花花@机器学习算法与自然语言处理 单位 | SenseTime 算法研究员 目录 0X01 分布式并行训练概述 0X02 Pytorch分布式数据并行 0X03 手把手渐进式实战 A. 单机单 ...

  3. MXNet结合kubeflow进行分布式训练

    GPU集群配置MXNet+CUDA 为方便控制集群,写了脚本cmd2all.sh #!/bin/bash if [ $# -lt 3 ]; thenecho "usage: $0 [type ...

  4. ztree局部刷新节点_神经网络训练的世界记录是怎样被刷新的 -- 总结分布式训练的计算场景...

    还是在今年(2018年)11月美国感恩节放假期间,我无意点开我的新论文搜索关注,假期的懈怠顿时被一扫而空.一篇谷歌的新论文跳入眼帘,声称打破了几天前刚建立的分布式训练速度的记录.各大公司训练速度记录上 ...

  5. VLDB 2023 | 北大河图发布分布式训练神器Galvatron,一键实现大模型高效自动并行...

    ©作者 | 北京大学河图团队 单位 | 北京大学数据与智能实验室 北大河图团队提出了一套面向大模型的自动并行分布式训练系统 Galvatron,相比于现有工作在多样性.复杂性.实用性方面均具有显著优势 ...

  6. VLDB 2023 | 北大河图发布分布式训练神器Galvatron, 一键实现大模型高效自动并行...

    关注公众号,发现CV技术之美 本文转自机器之心. 北大河图团队提出了一套面向大模型的自动并行分布式训练系统Galvatron,相比于现有工作在多样性.复杂性.实用性方面均具有显著优势,论文成果已经被  ...

  7. MindSpore Reinforcement新特性:分布式训练和蒙特卡洛树搜索

    MindSpore Reinforcement MindSpore Reinforcement v0.5 版本提供了基于Dataflow Fragment的分布式训练能力,通过扩展新的Fragment ...

  8. 张量模型并行详解 | 深度学习分布式训练专题

    随着模型规模的扩大,单卡显存容量无法满足大规模模型训练的需求.张量模型并行是解决该问题的一种有效手段.本文以Transformer结构为例,介绍张量模型并行的基本原理. 模型并行的动机和现状 我们在上 ...

  9. 从分布式训练到大模型训练

    要了解大模型训练难,我们得先看看从传统的分布式训练,到大模型的出现,需要大规模分布式训练的原因.接着第二点去了解下大规模训练的挑战. 从分布式训练到大规模训练 常见的训练方式是单机单卡,也就是一台服务 ...

最新文章

  1. Android String.xml 批量翻译工具 | Android string.xml 各国语言转换
  2. Linux防火墙配置—访问外网WEB
  3. hdu 4739 状压DP
  4. python第三天习题
  5. Razor视图引擎浅析
  6. python 属性描述符
  7. Fresh for Mac(文件管理软件)
  8. Spring Boot开发框架优点诠释
  9. (50)System Verilog类静态变量实例
  10. protobuf android ndk,直接在Android NDK端使用tensorflow(不使用JAVA api)
  11. 支付宝sdk java对接_java后台支付宝app支付调用sdk进行支付
  12. should, could, would, will, be going to, may, might到底有甚麼不同,又該怎麼用?
  13. ArrayList学习[常用方法|源码]
  14. 面向对象的系统分析(一)-系统分析方法
  15. 完美解决Python 发送邮件126,136,QQ等,都会报•554 DT:SPM 发送的邮件内容包含了未被许可的信息,或被系统识别为垃圾邮件。请检查是否有用户发送病毒或者垃圾邮件
  16. 旌扬机器人_“http://club.liangchanba.com/”搜索蜘蛛、机器人模拟抓取结果--站长工具...
  17. win7锁定计算机自动关机,windows7怎么设置电脑自动关机_win7如何自动关机
  18. 垃圾分类游戏HTML,垃圾分类宣传进村居,趣味游戏中学分类
  19. 厦门大学的【软件工程专业】被撤销!
  20. LabVIEW读海康网络摄像头问题

热门文章

  1. JS时间戳和时间之间转换
  2. 当同事用 Root 权限输入rm -rf 后,鬼知道我经历了什么
  3. 建筑八大员考试武汉施工员考试公路路面裂缝的养护施工技术
  4. termux使用教程python-神器Termux(二)——如何用安卓手机舒服地写Python
  5. CL210描述OPENSTACK控制平面--识别overclound控制平台服务+章节实验
  6. 正则掌握程度测试题——参考答案
  7. 超简易实现电脑微信多开
  8. 天大2021年秋学期考试《土力学与基础工程》离线作业考核试题
  9. thinkphp日志泄漏漏洞_ThinkPHP漏洞分析与利用
  10. 手动修改android系统模拟器dpi