尽管有各种深度学习加速器,神经网络的大小依然受限于计算平台的能力。百度硅谷人工智能实验室高级研究员Greg Diamos在最近的ICML 2016上发表了一篇PRNN(Persistent RNNs)的论文(相关英文访谈),介绍了他在深度学习平台GPU可扩展性方面的最新工作。但在此之前,Greg Diamos已经在Github上发布一篇博客文章简要解释了PRNN的工作和效果,本文为这篇文章的译文。PRNN已经在Github上开源,感兴趣的读者可以自行下载。

1. 简介

在SVAIL(百度硅谷人工智能实验室),我们的使命是创造能够对数以亿计的人们产生深远影响的AI技术。我们相信,达到这个目标的一种良好方式是提高语音识别的准确性,这将通过在更大数据集上使用深度学习算法实现。这些算法需要大量的运算,所以系统的内存大小和计算吞吐量会限制数据量以及我们可以训练的神经网络大小。所以搞清楚如何更有效地运行深度学习是一大挑战。这么做可以让我们在更大数据集上训练更大的模型,目前已经提高了语音识别的准确性。在这里,我们将要讨论一项新技术,它能加快深度递归神经网络(Recurrent Neural Networks)的训练。

2. 将递归层映射到硬件上

我们两个语音识别模型的密集计算集中于递归层(上图中蓝色部分),所以这种优化直接针对这部分网络。

2.1 用矩阵乘法实现RNN

通常实现递归神经网络的方式是进行一系列的矩阵乘法操作,参见前一篇博文以及上图。这包括从片外存储器中为每一个时间步长加载递归权重矩阵U和活化向量H。

在高性能存储器上,比如GPU,片外存储器要比片上存储器(如寄存器文件和高速缓存)慢的多,效率更低。所以当小批量数据相对较大(每GPU大约64或更高)时,矩阵乘法最高效,因为递归权重可以从片外存储器上一次性加载,并在小批量的每个样例上重复使用。

然而,使用较大的小批量有几个缺点:

  • 在训练网络的时候会增加内存占用量。
  • 在单GPU中会并行地耗尽可用的数据。
  • 它会使模型部署复杂化。

应该清楚,在每块GPU上使用较大的小批量会占用更多的内存。在许多时间步长上训练RNN时,存储活化向量需要的内存要比网络权重多许多。

例如,存储一个有1200个单元,每个单元是32位浮点数的简单RNN的权重,大约需要5.7MB的内存,但是存储一个小批量大小为64和700个时间步长的活化向量,则需要215MB的内存。所以,增加小批量的大小会直接导致训练模型所需内存的增加。

还需要明确的是,每个GPU上使用较大的小批量会并行地耗尽可用的数据,这些数据可能已经用于多GPU上传播计算。使用512小批量大小的算法,每GPU小批量大小为4,可以使用128块GPU,但是每GPU小批量大小为64的话,只能使用16块GPU。使用128块GPU以最高效率来训练单个模型对许多读者来说似乎有些苛刻,但这对我们很重要,因为这可以让我们测试语音识别精度是否会随着网络大小和数据容量的增长而持续提高。

最后,在每块GPU上使用大批量,会复杂化模型部署,因为在同一GPU上,多用户流需要被同时调度处理。这在嵌入式应用(比如运行在手机上的语音识别)中非常困难,因为通常只有单个用户。在云服务上部署也同样困难,因为多用户流必须在单个服务器上同时调度以达到良好的能源效率,但是不能有太多的流,因为会增加时延。

2.2 使用持久内核

所以,我们想找到一种方法来一次性加载递归权重并且多次使用它们,而不增加小批量大小。

这将允许我们:

  • 使用同样的硬件训练更大或更深的网络。
  • 在更多的GPU上大规模进行网络训练。
  • 使少并发用户的模型部署更有效。

对于GPU,片上存储器上最大的资源分布于数千个线程的各个寄存器文件中。例如,英伟达TitanX GPU的寄存器文件内存有6.3MB,足够存储约有1200个活化向量的递归层。持久内核利用这个寄存器文件内存来缓存递归权重并且在多个时间步长上重复使用它们。

然而,如果单个线程在网络权重的各个子集上工作,那么它们必须进行通信,对于当前的时间步长,需要将每个部分的计算结果结合在一起,同时还要为其他线程创建的下一个时间步长读取更新的活化向量。这意味着,数千个GPU线程需要在每个时间步长内保持通信并且相互同步。这种类型的同步不受CUDA或OpenCL的支持,因为相比单线程块而言,GPU不能保证更多的线程会同时运行。然而,它们通常是许多线程同时运行,特别是在像TitanX这样的更大GPU上。我们可以通过在GPU上实现一种形式的抢先多任务处理来解决这个限制,各个线程使用全局障碍直接进行同步,但最终会超时并退出。CPU上的一个运行时系统监控着过早退出的线程,然后重启内核(重新从内存中加载权重)直到所有任务都成功。在实践过程中,如果只有这个程序使用GPU,并且我们不启动太多的线程,那么全局障碍几乎不会超时。


比起基于矩阵乘法的实现,这种方法在小批量上显著提高了性能。批量大小为4的情况下,性能从90GFLOPs变成了2.8TFLOPs,大约提升了30倍。

2.3 可能的替代方法

为解决性能问题,通常会有多种可能的解决方案。本节会介绍除了持久RNN以外的技术来减少训练RNN的内存使用量以及每块GPU小批量的大小。我们已经发现,对于我们的语音识别模型,它们没有持久RNN有效。然而这种技术可能在其他情况下效果不错。

  • 随时间截短反向传播
  • 使用CPU内存缓存活化向量
  • 并行化RNN模型

2.3.1 随时间截短反向传播

随时间截短反向传播能够减少反向传播过程中存储活化向量所需的内存,在处理语义剩余部分之前,它会以固定数量的时间步长进行前向和反向传播。这种方法会显著减少训练网络所需的内存,因为只有固定数量的时间步长的活化向量需要被存储,但这样做会在长时间依赖上丢失梯度信息。在我们的系统中,我们发现,相比于在整个语义上使用反向传播,使用随时间截短反向传播会丢失大约20%的语音识别精度。所以对于我们的语音识别模型,我们不采用随时间截短反向传播这种方式。

2.3.2 使用CPU内存缓存活化向量

另外一种规避GPU内存限制的方式是在CPU的片外存储器上缓存用于反向传播的活化数据,CPU片外存储器通常比GPU片外存储器大很多。这是以系统分级存储体系(例如CPU DRAM,SSD缓存,磁盘等)的更高级别,存储用于反向传播的活化向量一般策略的一种特殊情况。如果从CPU中向后、向前拷贝数据所需的时间比网络中前向传播和反向传播中算术运算所需的时间少,那么这种方法有效。然而,对于我们的情况,单个节点上,网络在8块GPU的速度已经够快了,从CPU DRAM中向后、向前传播数据会导致整个系统2-4倍的性能下降。此外,它还会占用我们用来实现节点之间并行化数据规约操作的处理器内连接带宽。所以对我们的系统来说这么做没有意义。一般而言,这种方法更适合于每块CPU上有更少GPU的系统,或者更慢的RNN实现的系统。

2.3.3 并行化RNN模型

最后一种减少每块GPU所需内存的方式是使用并行化模型对多GPU上的各个递归层进行分区。这种方式仍然使用矩阵乘法来进行RNN的前向和反向传播操作,但在多GPU上对各个矩阵乘法进行分布式计算。仍然需要使用较大的小批量来使矩阵乘法更高效,但它分配于多个GPU上,所以每块GPU上小批量的效率会降低。这种方法在每个时间步长内进行高昂的GPU之间的同步操作,用来结合分布式乘法的结果,所以当递归层非常大时,这么做才有意义,比如在4块TitanX GPU上每个递归层有5000个活化向量。这种方法可以结合使用持久RNN,持久RNN在中等大小的递归层上效率更佳,模型并行RNN在超大递归层上效果更好。

3. 结论

这项工作表明,RNN权重可以高效地存储在GPU寄存器中,并可以通过这种方法来进行高吞吐量计算。这大大提高了在低小批量大小上的性能,从而能够在同样硬件上训练更深的模型,并在更多的GPU上大规模训练模型。

3.1 训练更深的模型

将小批量大小从64降低到4,活化内存的占用节省了16倍。训练语音识别所需的大部分内存是用来为反向传播存储活化向量的,而不是用来存储网络权重,所以这种节省直接增加了我们可以训练的模型大小。我们现在可以在GPU内存中使用110层的深度递归网络,这比我们之前的7层模型深了一个数量级。

3.2 并行化扩展数据到多个GPU

减少每块GPU的小批量大小同样可以将数据并行化扩展到多个GPU上,而不需要改变算法批量大小。我们通常使用512或1024的算法小批量大小来训练网络,因为更大的值会减慢模型训练速度。在每块GPU上使用小批量大小为64可以在8到16块GPU上训练一个模型,而使用大小为4的小批量可以在128到256块GPU上训练同样的模型。

3.3 展望未来

随着未来GPU有更多的硬件线程,更大的片上存储容量,以及更低精度的浮点运算操作,使用这种方式训练出模型的层的大小会随之增加。这项工作重点在于深度语音网络中的RNN,但这种方法同样可以应用于GRU和LSTM,这将通过分布式GPU线程中权重矩阵实现。

原则上,这种方法也可以应用于其他类型的处理器上,例如AMD或Intel的大型GPU可以在线程寄存器文件中缓存递归权重。多核处理器,比如Intel Xeon和Xeon PHI可以将递归权重缓存在L1和L2缓存中。FPGA可以将权重分布在芯片RAM块上。

我们希望社区中的其他人可以将持久RNN运用于在更大数据集上训练更大、更深的递归神经网络。

原文:Persistent RNNs - 30 times faster RNN layers at small mini-batch sizes
译者:刘翔宇 审校:刘帝伟
责编:周建丁(zhoujd@csdn.net)


百度PRNN:增强GPU伸缩性,RNN训练最高提速30倍(源码下载)相关推荐

  1. asp写的百度ocr识别文字-通用文字识别(高精度版)源码下载

    今天接到一个客户需求,需要用ASP写一个百度OCR文字识别代码,他的程序都是用ASP写的,所以我们也需要用ASP给他实现百度OCR文字识别,我们在百度AI网站上启用了通用文字识别高精度版,利用高精度板 ...

  2. 基于TensorFlow训练花朵识别模型的源码和Demo

    基于TensorFlow训练花朵识别模型的源码和Demo 转发来源: https://blog.csdn.net/Anymake_ren/article/details/80550684 下面就通过对 ...

  3. TensorFlow (RNN)深度学习 双向LSTM(BiLSTM)+CRF 实现 sequence labeling 序列标注问题 源码下载...

    http://blog.csdn.net/scotfield_msn/article/details/60339415 在TensorFlow (RNN)深度学习下 双向LSTM(BiLSTM)+CR ...

  4. cesium加载百度地图_Cesium专栏-百度地图加载(附源码下载)

    Cesium 是一款面向三维地球和地图的,世界级的JavaScript开源产品.它提供了基于JavaScript语言的开发包,方便用户快速搭建一款零插件的虚拟地球Web应用,并在性能,精度,渲染质量以 ...

  5. 【从线性回归到 卷积神经网络CNN 循环神经网络RNN Pytorch 学习笔记 目录整合 源码解读 B站刘二大人 绪论(0/10)】

    深度学习 Pytorch 学习笔记 目录整合 数学推导与源码详解 B站刘二大人 目录传送门: 线性模型 Linear-Model 数学原理分析以及源码详解 深度学习 Pytorch笔记 B站刘二大人( ...

  6. 深度增强学习PPO(Proximal Policy Optimization)算法源码走读

    原文地址:https://blog.csdn.net/jinzhuojun/article/details/80417179 OpenAI出品的baselines项目提供了一系列deep reinfo ...

  7. 深度学习大模型训练--分布式 deepspeed PipeLine Parallelism 源码解析

    deepspeed PipeLine Parallelism 源码解析 basic concept PipeDream abstract 1F1B 4 steps Code comprehension ...

  8. 干货|Pytorch弹性训练极简实现( 附源码)

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨颜挺帅@知乎(已授权) 来源丨https://zhuanlan ...

  9. YOLO-V3-SPP 训练时正样本筛选源码解析之build_targets

    前言 理论详解:YOLO-V3-SPP详细解析 build_targets 讲解形式主要以流程图形式,逐流程详细解读每一行代码 代码以pytorch框架为基础 targets处理整体流程 这里主要介绍 ...

最新文章

  1. java 学到什么实习_我如何获得外展实习机会以及到目前为止所学到的知识
  2. Apache Spark学习:利用Eclipse构建Spark集成开发环境
  3. 关于html5的7个传说
  4. 【论文解读】OneNet:一阶段的端到端物体检测器,无需NMS
  5. 高手经验:一个新手的verilog学习经验
  6. 新一代人工智能发展规划_助力人工智能创新发展,新华三任合肥市新一代人工智能产业发展联盟理事单位...
  7. Sphinx全文检索引擎测试
  8. html选中列表整列变色,excel选中行变色完整代码和动画效果
  9. c语言mergesort 参数,归并排序C语言兑现MergeSort
  10. 【C语言进阶深度学习记录】十一 C语言中enum,sizeof,typedef分析
  11. 6种不同画法画平行线_6种电视背景墙,不同材质做法,价格是多少,你都了解嘛?...
  12. mysql5.715 安装在d盘_mysql5.7.15在windows环境下的安装设置图文详细教程
  13. Head First设计模式读书笔记十 第十一章 代理模式
  14. windows上cppcheck检查不出任何错误解决
  15. [转]用了docker是否还有必要使用openstack?
  16. Gym 100963B
  17. 基于C++和AStar算法求解八数码问题的方案
  18. DIY一个正弦表计算器,用于单片机查表生成正弦波
  19. 什么情况下选用mysql_在MySQL中,‘%’可以用在什么情况下?
  20. 禅道 mysql 错误

热门文章

  1. java 抽象类继承抽象类_Java之继承、抽象类、接口篇
  2. java kafka api_kafka java API的使用
  3. 局域网抓包分析工具_[源码和文档分享]基于Libpcap实现的局域网嗅探抓包发包解析工具...
  4. leetcode算法题--“气球” 的最大数量
  5. leetcode算法题--树的子结构
  6. tab-pane 怎么家点击事件_有好转?辛巴燕窝事件新进展曝光。二子爷老婆首次回应银行行长送奥迪!二子爷分析小样你家老铁太精...
  7. Linux 用户被差别对待?无法通过 apple.com 管理 Apple ID
  8. 函数(复习),闭包,DOM
  9. MySQL Workbench 怎么创建数据库
  10. tomcat配置与应用(2)