使用Apex进行混合精度训练

转自:https://fyubang.com/2019/08/26/fp16/

你想获得双倍训练速度的快感吗?
你想让你的显存空间瞬间翻倍吗?
如果我告诉你只需要三行代码即可实现,你信不?

在这篇博客里,瓦砾会详解一下混合精度计算(Mixed Precision),并介绍一款Nvidia开发的基于PyTorch的混合精度训练加速神器—Apex,最近Apex更新了API,可以用短短三行代码就能实现不同程度的混合精度加速,训练时间直接缩小一半。

话不多说,直接先教你怎么用。

PyTorch实现

from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”,不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:scaled_loss.backward()

对,就是这么简单,如果你不愿意花时间深入了解,读到这基本就可以直接使用起来了。

但是如果你希望对FP16和Apex有更深入的了解,或是在使用中遇到了各种不明所以的“Nan”的同学,可以接着读下去,后面会有一些有趣的理论知识和瓦砾最近一个月使用Apex遇到的各种bug,不过当你深入理解并解决掉这些bug后,你就可以彻底摆脱“慢吞吞”的FP32啦。

理论部分

为了充分理解混合精度的原理,以及API的使用,先补充一点基础的理论知识。

1. 什么是FP16?

半精度浮点数是一种计算机使用的二进制浮点数数据类型,使用2字节(16位)存储。下图是FP16和FP32表示的范围和精度对比。

其中,sign位表示正负,exponent位表示指数(2n−15+1(n=0)2n−15+1(n=0)),fraction位表示的是分数(m1024m1024)。其中当指数为零的时候,下图加号左边为0,其他情况为1。下图是FP16表示范例。

2. 为什么需要FP16?

在使用FP16之前,我想再赘述一下为什么我们使用FP16。

  1. 减少显存占用
    现在模型越来越大,当你使用Bert这一类的预训练模型时,往往显存就被模型及模型计算占去大半,当想要使用更大的Batch Size的时候会显得捉襟见肘。由于FP16的内存占用只有FP32的一半,自然地就可以帮助训练过程节省一半的显存空间。
  2. 加快训练和推断的计算
    与普通的空间时间Trade-off的加速方法不同,FP16除了能节约内存,还能同时节省模型的训练时间。在大部分的测试中,基于FP16的加速方法能够给模型训练带来多一倍的加速体验(爽感类似于两倍速看肥皂剧)。
  3. 张量核心的普及
    硬件的发展同样也推动着模型计算的加速,随着Nvidia张量核心(Tensor Core)的普及,16bit计算也一步步走向成熟,低精度计算也是未来深度学习的一个重要趋势,再不学习就out啦。

3. FP16带来的问题:量化误差

这个部分是整个博客最重要的理论核心
讲了这么多FP16的好处,那么使用FP16的时候有没有什么问题呢?当然有。FP16带来的问题主要有两个:1. 溢出错误;2. 舍入误差。

  1. 溢出错误(Grad Overflow / Underflow)
    由于FP16的动态范围(6×10−8∼655046×10−8∼65504)比FP32的动态范围(1.4×10−45∼1.7×10381.4×10−45∼1.7×1038)要狭窄很多,因此在计算过程中很容易出现上溢出(Overflow,g>65504g>65504)和下溢出(Underflow,g<6×10−8g<6×10−8)的错误,溢出之后就会出现“Nan”的问题。

    在深度学习中,由于激活函数的的梯度往往要比权重梯度小,更易出现下溢出的情况。

  1. 舍入误差(Rounding Error)
    舍入误差指的是当梯度过小,小于当前区间内的最小间隔时,该次梯度更新可能会失败,用一张图清晰地表示:

4. 解决问题的办法:混合精度训练+动态损失放大

  1. 混合精度训练(Mixed Precision)
    混合精度训练的精髓在于“在内存中用FP16做储存和乘法从而加速计算,用FP32做累加避免舍入误差”。混合精度训练的策略有效地缓解了舍入误差的问题。
  2. 损失放大(Loss Scaling)
    即使用了混合精度训练,还是会存在无法收敛的情况,原因是激活梯度的值太小,造成了下溢出(Underflow)。损失放大的思路是:

    • 反向传播前,将损失变化(dLoss)手动增大2k2k倍,因此反向传播时得到的中间变量(激活函数梯度)则不会溢出;
    • 反向传播后,将权重梯度缩2k2k倍,恢复正常值。

Apex的新API:Automatic Mixed Precision (AMP)

曾经的Apex混合精度训练的api仍然需要手动half模型已经输入的数据,比较麻烦,现在新的api只需要三行代码即可无痛使用:

from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”,不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:scaled_loss.backward()
  1. opt_level

    其中只有一个opt_level需要用户自行配置:

    • O0:纯FP32训练,可以作为accuracy的baseline;
    • O1:混合精度训练(推荐使用),根据黑白名单自动决定使用FP16(GEMM, 卷积)还是FP32(Softmax)进行计算。
    • O2:“几乎FP16”混合精度训练,不存在黑白名单,除了Batch norm,几乎都是用FP16计算。
    • O3:纯FP16训练,很不稳定,但是可以作为speed的baseline;
  2. 动态损失放大(Dynamic Loss Scaling)

    AMP默认使用动态损失放大,为了充分利用FP16的范围,缓解舍入误差,尽量使用最高的放大倍数(224224),如果产生了上溢出(Overflow),则跳过参数更新,缩小放大倍数使其不溢出,在一定步数后(比如2000步)会再尝试使用大的scale来充分利用FP16的范围:

干货:踩过的那些坑

这一部分是整篇博客最干货的部分,是瓦砾在最近在apex使用中的踩过的所有的坑,由于apex报错并不明显,常常debug得让人很沮丧,但只要注意到以下的点,95%的情况都可以畅通无阻了:

  1. 判断你的GPU是否支持FP16:构拥有Tensor Core的GPU(2080Ti、Titan、Tesla等),不支持的(Pascal系列)就不建议折腾了。
  2. 常数的范围:为了保证计算不溢出,首先要保证人为设定的常数(包括调用的源码中的)不溢出,如各种epsilon,INF等。
  3. Dimension最好是8的倍数:Nvidia官方的文档的2.2条表示,维度都是8的倍数的时候,性能最好。要求维度是8的整数倍最重要的目的是为了能使用 Tensor Core,Tensor Core 的算力是 CUDA Core 的好多倍。
  4. 涉及到sum的操作要小心,很容易溢出,类似Softmax的操作建议用官方API,并定义成layer写在模型初始化里。
  5. 模型书写要规范:自定义的Layer写在模型初始化函数里,graph计算写在forward里。
  6. 某些不常用的函数,在使用前需要注册:amp.register_float_function(torch, 'sigmoid')
  7. 某些函数(如einsum)暂不支持FP16加速,建议不要用的太heavy,xlnet的实现改FP16困扰了我很久。
  8. 需要操作模型参数的模块(类似EMA),要使用AMP封装后的model。
  9. 需要操作梯度的模块必须在optimizer的step里,不然AMP不能判断grad是否为Nan。
  10. 欢迎补充。。。

总结

这篇从理论到实践地介绍了混合精度计算以及Apex新API(AMP)的使用方法。瓦砾现在在做深度学习模型的时候,几乎都会第一时间把代码改成混合精度训练的了,速度快,精度还不减,确实是调参炼丹必备神器。目前网上还并没有看到关于AMP以及使用时会遇到的坑的中文博客,所以这一篇也是希望大家在使用的时候可以少花一点时间debug。当然,如果读者们有发现新的坑欢迎交流,我会补充在博客中。

Reference

  1. Intel的低精度表示用于深度学习训练与推断
  2. Nvidia官方的混合精度训练文档
  3. Apex官方使用文档
  4. [Nvidia-Training Neural Networks with Mixed Precision](http://on-demand.gputechconf.com/gtc-taiwan/2018/pdf/5-1_Internal Speaker_Michael Carilli_PDF For Sharing.pdf)

使用Apex进行混合精度训练相关推荐

  1. 混合精度训练、分布式训练等训练加速方法

    以Pytorch为例 混合精度训练 Pytorch自动混合精度(AMP)训练 Pytorch自动混合精度(AMP)介绍与使用 1. 理论基础 pytorch从1.6版本开始,已经内置了torch.cu ...

  2. ResNet实战:单机多卡DDP方式、混合精度训练

    文章目录 摘要 apex DP和DDP Parameter Server架构(PS模式) ring-all-reduce模式 DDP的基本用法 (代码编写流程) Mixup 项目结构 计算mean和s ...

  3. pytorch apex 混合精度训练和horovod分布式训练

    转载请注明出处: https://mp.csdn.net/postedit/103600124 如果你基于pytorch训练模型,然后,你想加快训练速度,增大batch_size,或者,你有一台配置多 ...

  4. 混合精度训练原理总结

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨ZOMI酱@知乎(已授权) 来源丨https://zhuanl ...

  5. 浅谈深度学习混合精度训练

    ↑ 点击蓝字 关注视学算法 作者丨Dreaming.O@知乎 来源丨https://zhuanlan.zhihu.com/p/103685761 编辑丨极市平台 本文主要记录下在学习和实际试用混合精度 ...

  6. 实战 PK!RTX2080Ti 对比 GTX1080Ti 的 CIFAR100 混合精度训练

    雷锋网 AI 科技评论按:本文作者 Sanyam Bhutani 是一名机器学习和计算机视觉领域的自由职业者兼 Fast.ai 研究员.在文章中,他将 2080Ti 与 1080Ti 就训练时长进行了 ...

  7. PyTorch基于Apex的混合精度加速

    安装:pip install apex 参考: https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/100135729 在这篇文章里,笔者会详 ...

  8. PyTorch必备神器 | 唯快不破:基于Apex的混合精度加速

    作者丨Nicolas 单位丨追一科技AI Lab研究员 研究方向丨信息抽取.机器阅读理解 你想获得双倍训练速度的快感吗? 你想让你的显卡内存瞬间翻倍吗? 如果告诉你只需要三行代码即可实现,你信不? 在 ...

  9. 混合精度训练-Pytorch

    目录 1.需求解读 2.F16和FP32的区别与联系 3.F16优点简介 4.F16缺点简介 5.混合精度训练代码实战 5.1 代码实现 5.2 代码解析 6.F16训练效果展示 7.个人总结 参考资 ...

最新文章

  1. 26期20180601目录管理
  2. 天啊!我的xbox360突然不读盘了。。。
  3. LSTM:《Long Short-Term Memory》的翻译并解读
  4. 数据库压力变大,读写分离吧
  5. 微软高级经理:Google Chrome内有部分微软的代码
  6. 浅谈微博营销如何吸引流量
  7. android监听动画完成,android判断动画已结束示例代码
  8. 数字未来,NFT未来,Game Farmer创始人胡烜峰在IGS上讲述FoxNFT和他的故事
  9. vue样式初始化_前端Vue项目——初始化及导航栏
  10. 汽车车牌自动识别技术
  11. 飞浆领航团AI达人创造营第01课|让人拍案叫绝的创意都是如何诞生的?
  12. winpe修复改linux工具,利用WinPE修改原系统注册表来修复系统
  13. 调查计算机对运算能力的影响,计算器对运算能力的影响的报告.docx
  14. Excel 拆分 分割 数据 (对数据进行分列)
  15. 开发一个完整的iOS直播app必须技能
  16. linux zend studio 10,Zend Studio 10发布,可编写mobile apps
  17. JAVA 3DES加密 ECB模式 ZeroPadding填充
  18. 天威诚信荣获「金融科技领域最具品牌影响力奖」
  19. 又创奇迹 揭秘极米H1为何成高端家庭娱乐新爆品
  20. 【离散数学】求一个n阶群的全部子群(代码实现)

热门文章

  1. 第九篇:Spring Boot整合Spring Data JPA_入门试炼04
  2. 对于AES和RSA算法的结合使用以及MD5加盐注册登录时的密码加密
  3. 求一个数的阶乘值c语言代码,求10000的阶乘(c语言代码实现)
  4. 计算机软考网络管理员题,2020年计算机软考网络管理员考前测试题及答案
  5. highcharts ajax 数据格式,Highcharts ajax获取json对象动态生成报表生成 .
  6. linux硬盘保护卡,保护卡下安装Linux
  7. Python ord 函数 - Python零基础入门教程
  8. 从尾到头打印单链表(C语言)
  9. 东莞 小学计算机编程大赛,关于举办第二十一届东莞市中小学电脑制作活动的通知...
  10. redhat配置oracle yum源,Redhat5和6 YUM源配置的区别