选自arXiv

机器之心编译

常见的最优化器,如 Adam、AdaGrad、SGD+Momentum 等,都是一阶的。但是二阶梯度的收敛速度相比它们就快了太多。近日,谷歌研究者联合普林斯顿大学等,提出了真正应用的二阶梯度最优化器 Shampoo,让这个理论上颇有前景的设想变为现实。

目前,无论是从理论还是应用层面来说,机器学习中的优化都是以随机梯度下降等一阶梯度方法为主。囊括二阶梯度和/或二阶数据统计的二阶优化方法虽然理论基础更强,但受限于计算量、内存和通信花销等因素,二阶梯度优化方法的普及度不高。

可能你用各种框架搭建过各种神经网络,同时也尝试过调用 Adam、AdaGrad、SGD+Momentum 等形形色色的最优化器。但是你会发现,它们都采用一阶梯度,二阶梯度似乎仍然遥遥无期。

然而在谷歌大脑与普林斯顿大学等研究者的努力下,二阶梯度优化终于在实战大模型上展现出独特的优势。

研究者表示,为了缩短理论和实际优化效果之间的差距,该论文提出了一种二阶优化的概念性验证,并通过一系列重要的算法与数值计算提升,证明它在实际深度模型中能有非常大的提升。

论文地址:https://arxiv.org/abs/2002.09018

具体而言,在训练深度模型过程中,二阶梯度优化 Shampoo 能高效利用由多核 CPU 和多加速器单元组成的异构硬件架构。并且在大规模机器翻译、图像识别等领域实现了非常优越的性能,要比现有的顶尖一阶梯度下降方法还要好。

我们可以先看看它在 WMT 英-法翻译数据集上的效果,它采用的是标准的 Transformer。除了这一个实验,研究者还测试了 Big Transformer 以及 ImageNet 上的 ResNet,这些实验结果都展示在后文中。

WMT 14 英法翻译数据集上的 Transformer,二阶梯度算法 Shampoo 在迭代数上快了 1.95 倍,且就算要计算二阶梯度,每一次迭代也只慢了 16%,总体上来说节省了 40% 的执行时间。

从上图可以清楚地看到,如果 Adam 和 Shampoo 要训练到相同的准确度,Adam 需要迭代 30 万次,而 Shampoo 大概只需要迭代 11 万到 12 万次就差不多了。二阶梯度,果然收敛速度快了太多。

二阶梯度是什么

想象一下,如果我们希望找到「谷底」,那么沿着坡度一步一步往下走就行了。对于机器学习来说,「谷底」就是最优解,一步一步就是迭代过程。之前,我们采用一阶梯度,也就是坡度的陡和缓来确定步子要迈多大。而当坡度是有变化的,即逐渐变陡或变缓,根据当前坡度来确定步子大小就有一些问题。

之前我们可以慢慢多走几步,就能根据坡度的变化直接调整。现在如果能用二阶梯度,相当于梯度的梯度,那么也就知道坡度变化的趋势,因此一步就能走到位。所以二阶梯度本质上比一阶梯度多出一些信息,模型收敛也就会更快。

对于凸优化,二阶梯度一次就能找到最优解,而一阶梯度找到的方向必定垂直于当前点的等高线,因此出现这种「锯齿」现象。

二阶梯度策略无疑是数学优化中性能最有保障的算法之一。在这种算法中,我们使用预条件算子(preconditioner)矩阵转换梯度,然后应用到每个迭代步骤上。一般来说,这包括计算/近似估计二阶导矩阵,如 Hessian。另一方面,AdaGrad 和其他相关的算法主要针对随机优化,使用二阶梯度的方差矩阵来构建预条件算子。

虽然二阶方法在收敛性上比一阶好很多,但是其计算量限制了实际的应用。因为在每次梯度更新的时候,这种算法需要平方级别的存储和立方级别的计算时间。因此,这些方法在现在的机器学习优化方法中并不常见。

现代优化策略面对的最大的一个挑战是在理论和实际优化方法中搭建一个桥梁,使得二阶优化方法能够更合理地被应用和部署。

在这篇论文中,研究者真正提出了一种二阶梯度改进方法,它与 Adam 等算法一样是适应性梯度,但它能利用二阶梯度信息加速收敛,甚至在大型机器翻译模型中收敛快了一半。

二阶梯度,超越 Adam 的最优化

为了完成这项工作,研究者仔细思考了二阶优化存在的问题与困难,并改进了名为 Shampoo 的二阶梯度方法。研究者表示,二阶优化最无解的是目前深度学习库对一阶梯度已经有了大量优化,它们对计算量与内存的要求都不高。然而,对于 Shampoo,每一次迭代的密集计算,都对应用大模型产生了不可逾越的阻拦。

和一阶梯度优化方法相比,Shampoo 要走向实践,还有如下三大挑战。

1. 算法上的挑战

现代机器学习架构通常使用很大的嵌入层,维度可能多达百万级别。Shampoo 需要对每个维度计算一个预条件子,但是不管是计算还是存储,百万次的百万维度矩阵都几乎是不可处理的。因此从算法上,我们首先就要设计一种新方法来解决这类问题。

2. 计算上的挑战

Shampoo 的权重更新式大概是如下这样的,其中 L 和 R 都是矩阵,它们需要求逆与求根,在计算上会显得非常复杂,这也会拖慢整个迭代的速度。

之前矩阵求逆与根是可以使用 SVD 计算的,但是它们太慢了。因此可以考虑如 Schur-Newton 等一些算法,可以将逆 P 次根问题转换为一系列矩阵-向量和矩阵-矩阵的积,所以对于优化很有帮助。

图 1:在不同维度上,计算某矩阵逆 P 次根的基准对比。Schur-Newton 迭代方法能够在 CPU 上高效运行,而且相比 SVD 有很大的冗余用于提升。

3. 基础设施挑战

神经网络加速器通常是定制的,用来让机器学习程序运行地更快,开销更少。加速器设计倾向于低精度(8bit/16bit),能够满足现有的基准。研究者的方法需要双精度运算,因此已有的加速器甚至都不会启动。

此外,TensorFlow 等深度学习库提供的最优化器 API 适应于随机、一阶梯度下降那种模式。而二阶优化器需要与训练循环做交互,因此从实现上需要对框架底层做出修正。

虽然难,但还是能攻破

三大挑战使实现二阶梯度优化器异常复杂,研究者针对不同的问题提出了一系列优化组件、优化算法。最终搞定的分布式 Shampoo 在 CPU、TPU、GPU 等硬件上高效运行,这也是我们第一次看到二阶优化器在大模型上能 Work 的新研究。

研究者们首先分析了标准的数据并行方法。每个加速器的核都会在一个分批数据上进行前向和反向传播。然后算法会对数据批进行梯度聚合,并使用 all-reduction 的方法获得平均梯度。聚合的梯度被用来进行权重更新。前向和反向传播在所有核中并行计算。

为什么使用 All-reduction?这增加了一个屏障,使得所有核同步聚合批的梯度,并进行权重更新。图 3 中,研究者评价了 Transformer 模型每一步的计算开销。

图 3:使用 Diagonal AdaGrad 优化器的 Transformer 模型的每迭代步延迟时间为 134 毫秒,其中(1)前向传播 57 毫秒;后向传播 71 毫秒;all reduction:4 毫秒;权重更新:2 毫秒。

本文分布式系统实现的整体设计如下图 4 的时间轴所示,具体分布式架构可查阅原论文。

图 4:本文优化算法设计的时间轴。在每一步上计算所有张量的预调节器统计数据。预调节器只在每 N 步上计算,并且计算会分配给训练系统上可用的所有 CPU 核心。运算进行流水线处理,这样就实现了开销均摊。

可实战的二阶梯度优化

研究者在包含 3630 万个句对的 WMT 14 英法标准机器翻译数据集上验证了分布式系统实现方法的有效性。此外,他们在实验中使用了当前 SOTA Transformer 架构,该架构包含 9330 万个参数和 6 层的编码器-解码器结构。

实验在 32 核谷歌 Cloud TPU v3 Pod 上运行,结果如下图 6 所示,本文提出的 Shampoo 算法只需一半迭代数就能实现与 AdaGrad 和 Adam 相同的准确度。

图 6:WMT 14 英法翻译数据集上的 Transformer 模型,Shampoo 二阶梯度算法的收敛速度在迭代数上快了 1.95 倍,且就算要计算二阶梯度,每一次迭代也只慢了 16%,总体上来说节省了 40% 的执行时间。

研究者还利用一个大型 Transformer 模型进行实验,该模型包含 3.754 亿个参数和 6 层的编码器-解码器结构。实验结果如下图 12 所示,端到端的执行时间实现了提升。

图 12:WMT 14 英法翻译数据集上的 Transformer-Big 模型,Shampoo 二阶梯度算法的收敛速度在迭代数上快了 2 倍,且就算要计算二阶梯度,每一次迭代也只慢了 40%,总体上来说节省了 30% 的执行时间。

最后,研究者在 ImageNet-2012 数据集上训练了 ResNet-50 模型,并对使用 SGD+Momentum 的 SOTA 基准方法、本文提出的 Shampoo 二阶梯度算法以及 Adagrad 方法的测试结果进行对比,结果如下图 14 和表 2 所示。

研究者发现,Shampoo 二阶梯度算法虽然未能在测试损失或准确度方面实现任何改进,但与调整好的 SGD+Momentum 基准方法相比,该方法能够更快地减少训练损失。

图 14:在 ImageNet-2012 数据集上训练 ResNet-50 模型时的训练(图左)和测试(图右)交叉熵变化图。

表 2:当训练 ResNet-50 模型的 batch size=4096 时,三种方法在 ImageNet-2012 数据集上的准确度测试结果对比。

相信很多读者都学过最优化方法这类理论课程,我们会发现梯度下降,或者称之为最速下降法是最简单的方法,「锯齿」现象令它在很多领域上都存在问题。我们也会发现各种二阶优化、拟二阶优化在理论上性质远远超过它。然而在深度学习领域,由于数据与模型的规模,我们采用的都是「最速下降法」这个大家庭。

很多时候,我们会想,之前累积的那么多优秀方法,完全应用不到深度学习吗?而这篇论文至少告诉我们,计算、内存等各种困难,我们都是有机会克服的,相信那些具有强硬理论支持的最优化方法,最终会在深度学习展现它们的魅力。

推荐阅读
中文版开源! 一份来自亚马逊工程师写的 Google 面试指南,太火了10个必会的 PyCharm 技巧
为了追到小姐姐,我用 Python 制作了一个机器人
青出于蓝而胜于蓝,这是一款脱胎于Jupyter Notebook的新型编程环境
【中文教程】简单粗暴入门TensorFlow 2.0 | 北大学霸出品

二阶梯度优化新崛起,超越 Adam,Transformer 只需一半迭代量相关推荐

  1. 地理新教材降难度-小学生只需明白地球是圆的-人教社-新课改-教材

    地理新教材降难度:小学生只需明白地球是圆的|人教社|新课改|教材 本报长沙讯 人教社建社62年以来编写出版的第11套中小学通用教科书今年秋季与学生见面了.众多教研专家和优秀教师在研读新修订的人教版教科 ...

  2. 【深度学习】——梯度下降优化算法(批量梯度下降、随机梯度下降、小批量梯度下降、Momentum、Adam)

    目录 梯度 梯度下降 常用的梯度下降算法(BGD,SGD,MBGD) 梯度下降的详细算法 算法过程 批量梯度下降法(Batch Gradient Descent) 随机梯度下降法(Stochastic ...

  3. 深度研究自然梯度优化,从入门到放弃 | Deep Reading

    参加 2019 Python开发者日,请扫码咨询 ↑↑↑ 作者 | Cold Marie Wild 译者 | 刘畅 责编 | Jane 出品 | AI科技大本营(公众号id:rgznai100) [导 ...

  4. 超越Swin Transformer!谷歌提出了收敛更快、鲁棒性更强、性能更强的NesT

    [导读]谷歌&罗格斯大学的研究员对ViT领域的分层结构设计进行了反思与探索,提出了一种简单的结构NesT,方法凭借68M参数取得了超越Swin Transformer的性能. 文章链接:htt ...

  5. 训练过程--梯度下降算法(SGD、adam等)

    SGD系列 1)Batch gradient descent(批量梯度下降)   在整个数据集上   每更新一次权重,要遍历所有的样本,由于样本集过大,无法保存在内存中,无法线上更新模型.对于损失函数 ...

  6. 菜鸟de深度学习之路——(2)张量运算和梯度优化

    一,引言 上一节https://blog.csdn.net/zzl1060549268/article/details/88675915通过一个具体的例子,从整体上鸟瞰了一下一个三层神经网络,但是对于 ...

  7. 卡地亚搜索引擎_「AF厂卡地亚猎豹」网站SEO优化新方向

    网站SEO优化新方向,因为百度算法的不断更新,网站SEO优化也越来越有难度,作为SEOer咱们需要的是不断试错,不断的加强咱们的优化技巧,不断的去了解百度的算法,那么网站SEO优化新方向有哪些? 高质 ...

  8. GNN手绘草图识别新架构:Multi-Graph Transformer 网络

    点击我爱计算机视觉标星,更快获取CVML新技术 本文介绍一篇比较小众但非常有意思的手绘草图识别的新文章<Multi-Graph Transformer for Free-Hand Sketch ...

  9. ResNet被全面超越了,是Transformer干的:轻量版优于MobileNet

    作者丨Happy    审稿|邓富城    编辑丨极市平台 极市导读 又一篇Transformer来了!本文在ViT方面进行了一次突破性探索,提出了首次全面超越ResNet,甚至轻量化版本优于Mobi ...

  10. 拯救万千学子于水深火热之中!Facebook开源无梯度优化工具

    乾明 发自 凹非寺 量子位 出品 | 公众号 QbitAI 机器学习啥最苦?十有八九找参数! 不少研究生,都被卡在这个环节上,久久不能毕业. 现在,圣诞节前,有了一个好消息! Facebook宣布,开 ...

最新文章

  1. wcf ria中主从表绑定treeview
  2. 详解音视频直播中的低延时
  3. 权限管理,pymysql模块
  4. kali2020提高权限到root
  5. C语言源代码展示:常用转换函数实现原理
  6. .jar文件如何打开_ofd发票文件如何打开
  7. python矩阵对角化_numpy创建单位矩阵和对角矩阵的实例
  8. Python科学计算库numpy中的add运算
  9. php 字符串hash比较,分析两个 url 查询字符串和 hash 的区别
  10. thinkphp验证要插入数据库
  11. Datawhale编程学习之算法思想(7)
  12. 爬虫如何爬取微信公众号文章
  13. 51单片机最小系统原理分析
  14. 从键盘输入n个数 求其中的最大数
  15. 8个最好用的H5页面制作工具
  16. 使用云效应用交付平台 AppStack进行应用管理
  17. python检测键盘输入termios、等待按键超时检测
  18. 管理中第一可怕之事(2) .
  19. jsp与servlet数据交互出现null或???解决方案
  20. Python(私有变量)类中的特殊方法

热门文章

  1. 肇庆学院计算机论文选题,肇庆学院本科毕业论文(设计)写作与印制规范
  2. 提交到dockerHub
  3. 什么是服务器CC攻击,被CC攻击了服务器怎么防护?
  4. router-view显示不出来的原因
  5. Hola Stduio导入RUBE配置的过程
  6. Linux安装mysql 开启bingo日志
  7. 关于C++中的随机数生成器
  8. Redis系列-生产应用篇-分布式锁(5)-单进程Redis分布式锁的Java实现(Redisson使用与底层实现)-原子锁类
  9. Android自定义Transition动画
  10. 形如in (‘111,222,333‘) 的 ,Oracle的in函数(报错:无效数字)