(本文首发于公众号,没事来逛逛)

有读者让我讲一下 LSQ (Learned Step Size Quantization) 这篇论文,刚好我自己在实践中有用到,是一个挺实用的算法,因此这篇文章简单介绍一下。阅读这篇文章需要了解量化训练的基本过程,可以参考我之前的系列教程。

LSQ 是 IBM 在 2020 年发表的一篇文章,从题目意思也可以看出,文章是把量化参数 step size (也叫 scale) 也当作参数进行训练。这种把量化参数也进行求导训练的技巧也叫作可微量化参数。在这之后,高通也发表了增强版的 LSQ+,把另一个量化参数 zero point 也进行训练,从而把 LSQ 推广到非对称量化中。

这篇文章就把 LSQ 和 LSQ+ 放在一起介绍了。由于两篇文章的公式符号不统一,为了防止符号错乱,统一使用 LSQ 论文中的符号进行表述。

普通量化训练

在量化训练中需要加入伪量化节点 (Fake Quantize),这些节点做的事情就是把输入的 float 数据量化一遍后,再反量化回 float,以此来模拟量化误差,同时在反向传播的时候,发挥 STE 的功能,把导数回传到前面的层。


Fake Quantize 的过程可以总结成以下公式 (为了方便讲解 LSQ,这里采用 LSQ 中的对称量化的方式):
v ‾ = r o u n d ( c l i p ( v / s , − Q N , Q P ) ) v ^ = v ‾ × s (1) \begin{aligned} \overline v&=round(clip(v/s, -Q_N,Q_P)) \\ \hat v&=\overline v \times s \tag{1} \end{aligned} vv^​=round(clip(v/s,−QN​,QP​))=v×s​(1)
其中, v v v 是 float 的输入, v ‾ \overline v v 是量化后的数据 (仍然使用 float 来存储,但数值由于做了 round 操作,因此是整数), v ^ \hat v v^ 是反量化的结果。 − Q N -Q_N −QN​ 和 Q P Q_P QP​ 分别是量化数值的最小值和最大值 (在对称量化中, Q N Q_N QN​、 Q P Q_P QP​ 通常是相等的), s s s 是量化参数。

由于 round 操作会带来误差,因此 v ^ \hat v v^ 和 v v v 之间存在量化误差,这些误差反应到 loss 上会产生梯度,这样就可以反向传播进行学习。每次更新 weight 后,我们会得到新的 float 的数值范围,然后重新估计量化参数 s s s:
s = ∣ v ∣ m a x Q P (3) s=\frac{|v|_{max}}{Q_P} \tag{3} s=QP​∣v∣max​​(3)
之后,开始新一次迭代训练。

LSQ

可以看到,上面这个过程的量化参数都是根据每一轮的权重计算出来的,而整个网络在训练的过程中只会更新权重的数值。

LSQ 想做的,就是把这里的 s s s 也放到网络的训练当中,而不是通过权重来计算。

也就是说,每次反向传播的时候,需要对 s s s 求导进行更新。

这个导数可以这样计算:把 (1) 式统一一下得到:
v ^ = r o u n d ( c l i p ( v / s , − Q N , Q P ) ) × s = { − Q N × s v / s < = − Q N r o u n d ( v / s ) × s − Q N < v / s < Q P Q P × s v / s > = Q P (5) \begin{aligned} \hat v&=round(clip(v/s, -Q_N, Q_P))\times s \\ &=\begin{cases}-Q_N \times s & v/s <= -Q_N \\ round(v/s)\times s & -Q_N < v/s < Q_P \tag{5} \\ Q_P \times s & v/s >= Q_P \end{cases} \end{aligned} v^​=round(clip(v/s,−QN​,QP​))×s=⎩⎪⎨⎪⎧​−QN​×sround(v/s)×sQP​×s​v/s<=−QN​−QN​<v/s<QP​v/s>=QP​​​(5)
然后对 s s s 求导得到:
∂ v ^ ∂ s = { − Q N v / s < = − Q N r o u n d ( v / s ) + ∂ r o u n d ( v / s ) ∂ s × s − Q N < v / s < Q P Q P v / s > = Q P (6) \frac{\partial \hat v}{\partial s}= \begin{cases} -Q_N & v/s <= -Q_N \\ round(v/s)+\frac{\partial round(v/s)}{\partial s}\times s & -Q_N< v/s<Q_P \tag{6} \\ Q_P & v/s >= Q_P \\ \end{cases} ∂s∂v^​=⎩⎪⎨⎪⎧​−QN​round(v/s)+∂s∂round(v/s)​×sQP​​v/s<=−QN​−QN​<v/s<QP​v/s>=QP​​(6)
r o u n d ( v / s ) round(v/s) round(v/s) 这一步的导数可以通过 STE 得到:
∂ r o u n d ( v / s ) ∂ s = ∂ ( v / s ) ∂ s = − v s 2 (7) \begin{aligned} \frac{\partial round(v/s)}{\partial s}&=\frac{\partial (v/s)}{\partial s} \tag{7} \\ &=-\frac{v}{s^2} \end{aligned} ∂s∂round(v/s)​​=∂s∂(v/s)​=−s2v​​(7)
最终得到论文中的求导公式:
∂ v ^ ∂ s = { − Q N v / s < = − Q N − v s + r o u n d ( v / s ) − Q N < v / s < Q P Q P v / s > = Q P (8) \frac{\partial \hat v}{\partial s}= \begin{cases} -Q_N & v/s <= -Q_N \\ -\frac{v}{s}+round(v/s) & -Q_N< v/s <Q_P \tag{8} \\ Q_P & v/s >= Q_P \\ \end{cases} ∂s∂v^​=⎩⎪⎨⎪⎧​−QN​−sv​+round(v/s)QP​​v/s<=−QN​−QN​<v/s<QP​v/s>=QP​​(8)
(上面这堆公式敲得非常辛苦,给个赞不过分吧)

作者在实验中发现,这种简单粗暴的训练方式有一个好处。

假设把量化范围固定在 [0, 3] 区间,(即 Q N = 0 Q_N=0 QN​=0, Q P = 3 Q_P=3 QP​=3)。下面 A 图表示量化前的 v v v 和反量化后的 v ^ \hat{v} v^ 之间的映射关系(假设 s = 1 s=1 s=1),这里面 round 采用四舍五入的原则,也就是说,在 0.5 这个地方 (图中第一道虚线), v ^ \hat{v} v^ 是会从 0 突变到 1 的,从而带来巨大的量化误差。


因此,从 0.5 的左侧走到右侧,梯度应该是要陡然增大的。

在 B 图中,作者就对比了 QIL、PACT 和 LSQ (前面两个是另外两种可微量化参数的方法) 在这些突变处的梯度变化,结果发现,QIL 和 PACT 在突变处的梯度没有明显变化,还是按照原来的趋势走,而 LSQ 则出现了一个明显突变 (注意每条虚线右侧)。因此,LSQ 在梯度计算方面是更加合理的。

此外,作者还认为,在计算 s s s 梯度的时候,还需要兼顾模型权重的梯度,二者差异不能过大,因此,作者设计了一个比例系数来约束 s s s 的梯度大小:
R = ∂ s L s / ∣ ∣ ∂ w L ∣ ∣ ∣ ∣ w ∣ ∣ ≈ 1 (9) R=\frac{\partial_s L}{s}/\frac{||\partial_w L||}{||w||} \approx 1 \tag{9} R=s∂s​L​/∣∣w∣∣∣∣∂w​L∣∣​≈1(9)
同时,为了保持训练稳定,作者在 s s s 的梯度上还乘了一个缩放系数 g g g,对于 weight 来说, g = 1 / N W Q P g=1/\sqrt{N_W Q_P} g=1/NW​QP​ ​,对于 feature 来说, g = 1 / N F Q P g=1/\sqrt{N_F Q_P} g=1/NF​QP​ ​, N W N_W NW​ 和 N F N_F NF​ 分别表示 weight 和 feature 的大小。

而在初始化方面,作者采用 2 ∣ v ∣ Q P \frac{2|v|}{\sqrt{Q_P}} QP​ ​2∣v∣​ 的方式初始化 s s s。

到这里,LSQ 的要点基本讲完了,其实,精华的部分就是把 s s s 作为量化参数进行训练,至于后面的梯度约束、初始化等,在不同网络结构、不同任务中可能需要灵活调整,没必要完全照论文来。

LSQ+

LSQ+ 的思路和 LSQ 基本一致,就是把零点 (zero point,也叫 offset) 也变成可微参数进行训练。

加入零点后,(1) 式就变成了:
v ‾ = r o u n d ( c l i p ( ( v − β ) / s , − Q N , Q P ) ) v ^ = v ‾ × s + β (11) \begin{aligned} \overline v&=round(clip((v-\beta)/s, -Q_N,Q_P)) \\ \hat v&=\overline v \times s + \beta \tag{11} \end{aligned} vv^​=round(clip((v−β)/s,−QN​,QP​))=v×s+β​(11)
(高通这个零点计算方式和我之前使用的差得比较多,我自己使用的时候是遵照我之前文章的风格 v / s + β v/s+\beta v/s+β 来计算的,因此大家也可以灵活调整)

之后就是按照 LSQ 的方式分别计算导数 ∂ v ^ ∂ s \frac{\partial \hat{v}}{\partial s} ∂s∂v^​ 和 ∂ v ^ ∂ β \frac{\partial \hat{v}}{\partial \beta} ∂β∂v^​,再做量化训练。

论文还给出了一些初始化 s s s 和 β \beta β 的方式,但还是那句话,视具体任务、具体网络结构而定,可以自己调整 (比如我通常就按照 v v v 取 90% 左右的区间来估计 s s s 和 β \beta β 的初始值),甚至你可以用 weight equalize 先预处理一遍网络的权重再来跑 LSQ+ 的算法。

实验

这两篇文章都只给出了分类任务的实验,我觉得应该增加一点别的任务来体现算法的通用性。这里就不列举实验结果了,感兴趣的同学可以看看论文。值得注意的一点是在低比特 (4bit 以下) 的情况下,精度也可以保持得比较好。

一点思考

我自己在一个 GAN 类型的网络上尝试过 LSQ+ 算法,当时被它的效果惊艳到。

这个问题的背景是这样的:最开始的时候,我用普通的量化训练 (8bit) 加上一些蒸馏的技巧来量化这个网络,结果和全精度模型差不多。后来,团队的小伙伴对这个 GAN 网络做了巨量的压缩,同时用了一些技巧大大增强了这个网络的生成能力。然后,我的量化算法在这个网络上就失效了,精度损失非常明显。期间尝试了很多种方案,但都没法拯救。

我自己在分析这个网络权重的时候,发现一个现象,随着网络被压缩得越来越小,权重的数值范围是在逐渐增大的,换句话说,这个网络本身的信息量在逐渐增大。对量化来说,这是件很可怕的事情,因为留给我量化的信息容量是固定的,就只有 8 比特。随着网络信息量增大,每次做量化训练时,round 带来的误差也会更大,这可能使得网络的梯度变得非常不稳定。甚至我会想,是不是 8 比特的信息量就不可能承载得了新网络的容量?

后来,在万念俱灰之下,尝试了 LSQ+ 算法,结果一下子把精度提高了一个档次,我感觉我又活过来了!事后分析的时候,我觉得一个很重要的原因就是:LSQ+ 在前向传播的时候, s s s 本身也在控制调整权重的数值分布,而且这种调整是可微的,可以用损失函数进行学习,是一种动态的调整。相比仅仅更新 weight 来调整数值分布的做法,LSQ 多了一条路径来学习。

最后,给需要做量化部署的同学提个醒,在导出量化模型进行部署时,需要根据训练好的 s s s 来确定权重的 minmax 大小,因为在 LSQ 的前向传播中,模型权重的数值范围是受 s s s 影响的,最终也是根据 s s s 反应到损失函数上的。

参考

  • Learned Step Size Quantization
  • LSQ+: Improving low-bit quantization through learnable offsets and better initialization
  • https://www.yuque.com/yahei/hey-yahei/quantization-retrain_differentiable

欢迎关注我的公众号:大白话AI,立志用大白话讲懂AI。

量化训练之可微量化参数—LSQ相关推荐

  1. AI模型工业部署:综述【常用的部署框架:TensorRT、Libtorch】【常见提速方法:模型结构、剪枝、蒸馏、量化训练、稀疏化】【常见部署流程:onnx2trt】【常见服务部署搭配】

    作为深度学习算法工程师,训练模型和部署模型是最基本的要求,每天都在重复着这个工作,但偶尔静下心来想一想,还是有很多事情需要做的: 模型的结构,因为上线业务需要,更趋向于稳定有经验的,而不是探索一些新的 ...

  2. tensorflow sigmoid 如何计算训练数据的正确率_量化训练:Quantization Aware Training in Tensorflow(一)...

    本文的内容包括对神经网络模型量化的基本介绍.对Tensorflow量化训练的理解与上手实操. 此外,后续系列还对量化训练中的by pass和batch norm两种情况进行补充解释,欢迎点击浏览,量化 ...

  3. 深度学习——训练时碰到的超参数

    深度学习--训练时碰到的超参数 文章目录 深度学习--训练时碰到的超参数 一.前言​ 二.一些常见的超参数 学习率(Learning rate) 迭代次数(iteration) batchsize e ...

  4. R语言构建xgboost模型:控制训练信息输出级别verbose参数

    R语言构建xgboost模型:控制训练信息输出级别verbose参数 目录 R语言构建xgboost模型:控制训练信息输出级别verbose参数

  5. 开源 java CMS - FreeCMS2.8 栏目页静态化参数

    项目地址:http://www.freeteam.cn/ 栏目页静态化参数 在栏目页静态化时,系统会使用此栏目指定的模板文件(如未指定,默认为站点所选模板中的"channel.html&qu ...

  6. Keras保存和载入训练好的模型和参数

    1.保存模型 my_model = create_model_function( ...... )my_model.compile( ...... )my_model.fit( ...... )mod ...

  7. 验证 Boost.Optional 复制构造函数不会尝试调用从模板化参数初始化构造函数的元素类型

    验证 Boost.Optional 复制构造函数不会尝试调用从模板化参数初始化构造函数的元素类型 实现功能 C++实现代码 实现功能 验证 Boost.Optional 复制构造函数不会尝试调用从模板 ...

  8. 开源 免费 java CMS - FreeCMS-信息页静态化参数 .

    下载地址:http://code.google.com/p/freecms/ 信息页静态化参数 在信息页静态化时,系统会使用此信息指定的模板文件(如未指定,默认为站点所选模板中的"信息页面. ...

  9. 南大周志华团队开源深度森林软件包DF21:训练效率高、超参数少,普通设备就能跑 | AI日报...

    中国学者研发新型电子纹身,实现8倍延展,有望用于医疗.VR和可穿戴机器人等领域 可穿戴设备,已经成为我们生活中极为常见的一种设备,它们体积轻巧.佩戴方便.检测数据齐全,但也存在一个很明显的缺点--无法 ...

最新文章

  1. Spring 加载、解析applicationContext.xml 流程
  2. 在nginx.conf中配置https
  3. kafka版本_Apache Kafka 版本演进及特性介绍
  4. ObjectDataSource控件的使用...
  5. Wondershare Recoverit for Mac(数据恢复套件)
  6. IEEE会议文章接收后提交流程
  7. wsimport生成wsdl代码
  8. android灰度发布平台,安卓版微信灰度发布购物直播功能 小程序直播上线公域流量入口...
  9. win10如何调整计算机时间同步,Win10如何修改时间同步服务器?Windows时间同步出错解决方法...
  10. Python:正则表达式 re.sub()替换功能
  11. 手机上如何使用Termux当终端,以及开启SSH服务的步骤
  12. js模糊匹配(like)
  13. 金蝶标准单据扩展类开发
  14. 如何快速而准确地进行 IP 和端口信息扫描:渗透测试必备技能
  15. amd锐龙笔记本cpu怎么样_如果你要购买笔记本的话!千万不要现在购买AMD锐龙笔记本!因为...
  16. 23个带给你灵感的英文字体Logo设计欣赏
  17. vue3 :deep() :slotted() :global() css动态绑定变量
  18. 并行计算程序设计(CUDA C)
  19. 新构造运动名词解释_构造运动与地质构造(教材第八章)_普通地质学矿物
  20. 盘点企业生产管理软件的四大优势

热门文章

  1. 华为p40安装包管理一直在扫描_华为Mate40系列屏幕排列出炉,只有保时捷版全系三星...
  2. MTU和Fragment详解
  3. 关于对Match-Sea,第一次做完的游戏进行反思
  4. Springboot+ssm课堂教学效果实时评价系统
  5. linux按修改时间排序
  6. 第五章:IO流-字节流不能读中文,可以写中文
  7. LiveQing视频点播RTMP推流直播服务支持H5无插件WebRTC超低延时视频直播
  8. 学会这招,再也不怕女朋友生气啦。
  9. rknn-toolkit 国内源链接
  10. Leetcode 两数相除