现在深度学习中一般我们学习的参数都是连续的,因为这样在反向传播的时候才可以对梯度进行更新。但是有的时候我们也会遇到参数是离>散的情况,这样就没有办法进行反向传播了,比如二值神经网络。本文中讲解了如何用pytorch对二值化的参数进行梯度更新的straight-through estimator算法。
Question:
STE核心的思想就是我们的参数初始化的时候就是float这样的连续值,当我们forward的时候就将原来的连续的参数映射到{-1, 1}带入到网络进行计算,这样就可以计算网络的输出。然后backward的时候直接对原来float的参数进行更新,而不是对二值化的参数更新。这样可以完成对整个网络的更新了。
首先我们对上面问题进行一下数学的讲解。

Example:
首先我们验证一下使用torch.sign会是参数的梯度基本上都是0:

>>> input = torch.randn(4, requires_grad = True)
>>> output = torch.sign(input)
>>> loss = output.mean()
>>> loss.backward()
>>> input
tensor([-0.8673, -0.0299, -1.1434, -0.6172], requires_grad=True)
>>> input.grad
tensor([0., 0., 0., 0.])

我们需要重写sign这个函数,就好像写一个激活函数一样。

import torchclass LBSign(torch.autograd.Function):@staticmethoddef forward(ctx, input):return torch.sign(input)@staticmethoddef backward(ctx, grad_output):return grad_output.clamp_(-1, 1)
import torch
from LBSign import LBSignif __name__ == '__main__':sign = LBSign.applyparams = torch.randn(4, requires_grad = True)                                                                           output = sign(params)loss = output.mean()loss.backward()

测试梯度:

>>> params
tensor([-0.9143,  0.8993, -1.1235, -0.7928], requires_grad=True)
>>> params.grad
tensor([0.2500, 0.2500, 0.2500, 0.2500])

文章转载:https://segmentfault.com/a/1190000020993594?utm_source=tag-newest仅供参考学习,如有侵权则请联系博主。

参考文献:

  • https://segmentfault.com/a/1190000020993594?utm_source=tag-newest

pytorch实现straight-through estimator(STE)相关推荐

  1. 开源项目|基于darknet实现量化感知训练,已实现yolov3-tiny所有算子

    ◎本文为极市开发者「ArtyZe」原创投稿,转载请注明来源. ◎极市「项目推荐」专栏,帮助开发者们推广分享自己的最新工作,欢迎大家投稿.联系极市小编(fengcall19)即可投稿~ 量化简介 在实际 ...

  2. QAT(Quantization Aware Training)量化感知训练(二)【详解】

    文章目录 1.QAT(Quantization Aware Training)的建议 1.QAT(Quantization Aware Training)的建议 Quantization Aware ...

  3. 性能不打折,内存占用减少90%,Facebook提出极致模型压缩方法Quant-Noise

    对于动辄上百 M 大小的神经网络来说,模型压缩能够减少它们的内存占用.通信带宽和计算复杂度等,以便更好地进行应用部署.最近,来自 Facebook AI 的研究者提出了一种新的模型量化压缩技术 Qua ...

  4. java list 占用内存不释放_性能不打折,内存占用减少90%,Facebook提出极致模型压缩方法Quant-Noise...

    对于动辄上百 M 大小的神经网络来说,模型压缩能够减少它们的内存占用.通信带宽和计算复杂度等,以便更好地进行应用部署.最近,来自 Facebook AI 的研究者提出了一种新的模型量化压缩技术 Qua ...

  5. 闲话模型压缩之量化(Quantization)篇

    1. 前言 这些年来,深度学习在众多领域亮眼的表现使其成为了如今机器学习的主流方向,但其巨大的计算量仍为人诟病.尤其是近几年,随着端设备算力增强,业界涌现出越来越多基于深度神经网络的智能应用.为了弥补 ...

  6. 收藏 | 一文总结70篇论文,帮你透彻理解神经网络的剪枝算法

    来源:DeepHub IMBA本文约9500字,建议阅读10+分钟 本文为你详细介绍神经网络剪枝结构.剪枝标准和剪枝方法. 无论是在计算机视觉.自然语言处理还是图像生成方面,深度神经网络目前表现出来的 ...

  7. 深度学习量化总结(PTQ、QAT)

    背景  目前神经网络在许多前沿领域的应用取得了较大进展,但经常会带来很高的计算成本,对内存带宽和算力要求高.另外降低神经网络的功率和时延在现代网络集成到边缘设备时也极其关键,在这些场景中模型推理具有严 ...

  8. 我总结了70篇论文的方法,帮你透彻理解神经网络的剪枝算法

    无论是在计算机视觉.自然语言处理还是图像生成方面,深度神经网络目前表现出来的性能都是最先进的.然而,它们在计算能力.内存或能源消耗方面的成本可能令人望而却步,这使得大部份公司的因为有限的硬件资源而完全 ...

  9. 初入神经网络剪枝量化4(大白话)

    二. 量化 简单介绍目前比较SOTA的量化方法,也是最近看的. 2.1 DSQ   Differentiable Soft Quantization:Bridging Full-Precision a ...

  10. 高糊图片可以做什么?

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者:David Berthelot.Peyman Milanfar ...

最新文章

  1. vrep小车避障算法_V-REP 小车建模
  2. React是什么及特点
  3. oracle或mysql分组查询并且获取前3条排序后的数据
  4. apache 配置虚拟目录
  5. m3u8手机批量转码_手机怎么把m3u8格式转换成mp4格式?
  6. 打造最舒适的webview调试环境 1
  7. 户外lisp导向牌如何安装_他山之石可攻玉,赴成都、重庆学习考察户外广告和门头牌匾规划管理工作...
  8. 文本处理及正则表达式
  9. 藏在耳机里的小东西——蓝牙天线
  10. .net framework 3.5win10无法安装,一招解决win10无法安装.NET Framework 3.5
  11. 蒸汽管道图纸符号_供热循环系统“30问”(附管网图常见符号图例)
  12. 监控流媒体服务器的搭建和使用
  13. 新浪视频播放器站外调用代码
  14. java根据年份计算生肖
  15. vue组件深度传值provide、inject,值类型响应式的方法
  16. UVA 10306 e-Coins(二维完全背包)
  17. 2019/4/2更新 重制3617-6.17 增加918+6.21 二合一引导启动系统盘
  18. 电脑版适合什么插件HTML,推荐一些好用的Chrome插件
  19. docker进入容器时报错 Error response from daemon: Container xxx is restarting, wait until the container is
  20. 街道大动土,断网一周,学习计划照旧

热门文章

  1. 原生js与css3实现简单翻页动画
  2. Only fullscreen activities can request orientation异常解决
  3. 记事本改字体的代码java_记事本编程切换字体颜色 用java编写一个记事本程序
  4. ICM TSCC视频格式的播放
  5. C语言if的所有用法,关于if的用法
  6. OSError: [WinError 1455] 页面文件太小,无法完成操作。 Error loading “D:\Anaconda\envs\pytorch-1.4\lib\site-package
  7. python自动化plc_PYTHON – 让“Monty 语言”进入自动化行业:第 4 部分
  8. html5弹幕制作(探索ing)
  9. vmware workstation与WIFI共享大师
  10. 学计算机的可以考哪种证书,自学比较容易考的证书 哪些证书有用