查看非叶节点梯度的两种方法

在反向传播过程中非叶子节点的导数计算完之后即被清空。若想查看这些变量的梯度,有两种方法:

  • 使用autograd.grad函数
  • 使用hook

autograd.gradhook方法都是很强大的工具,更详细的用法参考官方api文档,这里举例说明基础的使用。推荐使用hook方法,但是在实际使用中应尽量避免修改grad的值。

求z对y的导数

x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum()# hook
# hook没有返回值,参数是函数,函数的参数是梯度值
def variable_hook(grad):print("hook梯度输出:\r\n",grad)hook_handle = y.register_hook(variable_hook)         # 注册hook
z.backward(retain_graph=True)                        # 内置输出上面的hook
hook_handle.remove()                                 # 释放print("autograd.grad输出:\r\n",t.autograd.grad(z,y)) # t.autograd.grad方法

hook梯度输出:Variable containing:111
[torch.FloatTensor of size 3]autograd.grad输出:(Variable containing:111
[torch.FloatTensor of size 3]
,)

多次反向传播试验

实际就是使用retain_graph参数,

# 构件图
x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum()z.backward(retain_graph=True)
print(w.grad)
z.backward()
print(w.grad)

Variable containing:111
[torch.FloatTensor of size 3]Variable containing:222
[torch.FloatTensor of size 3]

如果不使用retain_graph参数,

实际上效果是一样的,AccumulateGrad object仍然会积累梯度

# 构件图
x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum()z.backward()
print(w.grad)
y = w.mul(x)  # <-----
z = y.sum()  # <-----
z.backward()
print(w.grad)

Variable containing:111
[torch.FloatTensor of size 3]Variable containing:222
[torch.FloatTensor of size 3]

分析:

这里的重新建立高级节点意义在这里:实际上高级节点在创建时,会缓存用于输入的低级节点的信息(值,用于梯度计算),但是这些buffer在backward之后会被清空(推测是节省内存),而这个buffer实际也体现了上面说的动态图的"动态"过程,之后的反向传播需要的数据被清空,则会报错,这样我们上面过程就分别从:保留数据不被删除&重建数据两个角度实现了多次backward过程。

实际上第二次的z.backward()已经不是第一次的z所在的图了,体现了动态图的技术,静态图初始化之后会留在内存中等待feed数据,但是动态图不会,动态图更类似我们自己实现的机器学习框架实践,相较于静态逻辑简单一点,只是PyTorch的静态图和我们的比会在反向传播后清空存下的数据:下次要么完全重建,要么反向传播之后指定不舍弃图z.backward(retain_graph=True)。

总之图上的节点是依赖buffer记录来完成反向传播,TensorFlow中会一直存留,PyTorch中就会backward后直接舍弃(默认时)。

『PyTorch』第五弹_深入理解autograd_中:Variable梯度探究相关推荐

  1. 『TensorFlow』第七弹_保存载入会话_霸王回马

    首更: 由于TensorFlow的奇怪形式,所以载入保存的是sess,把会话中当前激活的变量保存下来,所以必须保证(其他网络也要求这个)保存网络和载入网络的结构一致,且变量名称必须一致,这是caffe ...

  2. 『TensorFlow』第十一弹_队列多线程TFRecod文件_我辈当高歌

    TF数据读取队列机制详解 一.TFR文件多线程队列读写操作 TFRecod文件写入操作 import tensorflow as tf def _int64_feature(value):# valu ...

  3. 『PyTorch』第十一弹_torch.optim优化器 每层定制参数

    一.简化前馈网络LeNet 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 im ...

  4. 关于『HTML』:第三弹

    关于『HTML』:第三弹 建议缩放90%食用 盼望着, 盼望着, 第三弹来了, HTML基础系列完结了!! 一切都像刚睡醒的样子(包括我), 欣欣然张开了眼(我没有) 敬请期待Markdown语法系列 ...

  5. 『TensorFlow』函数查询列表_张量属性调整

    博客园 首页 新随笔 新文章 联系 订阅 管理 『TensorFlow』函数查询列表_张量属性调整 数据类型转换Casting 操作 描述 tf.string_to_number (string_te ...

  6. python iterable对象_如何理解Python中的iterable对象

    转载请注明出处:https://www.jianshu.com/u/5e6f798c903a [^*] 表示注脚,在文末可以查看对应连接,但简书不支持该语法. 首先,容器和 iterable 间没有必 ...

  7. python的上下文管理用哪个关键字_正确理解python中的关键字“with”与上下文管理器...

    正确理解python中的关键字"with"与上下文管理器 来源:中文源码网    浏览: 次    日期:2018年9月2日 [下载文档:  正确理解python中的关键字&quo ...

  8. 『PyTorch』第十五弹_torch.nn.Module的属性设置查询

    一.背景知识 python中两个属相相关方法 result = obj.name 会调用builtin函数getattr(obj,'name')查找对应属性,如果没有name属性则调用obj.__ge ...

  9. 『统计学』第五部分:方差分析和F检验

    第四部分的卡方检验是研究类别变量之间的关系,而这一部分的方差分析则是研究类别型自变量与数值型因变量之间的关系,它在形式上是比较多个总体的均值是否相等. 从形式上看,方差分析与之前的t检验或z检验区别不 ...

最新文章

  1. 服务器控件调用JS方法
  2. 使用shell定时自动备份mysql数据库
  3. mysql 时序 存储引擎_MySQL常见的三种存储引擎
  4. 神曲背后的故事:算法工程师带你理性解构“蚂蚁呀嘿”
  5. [Leetcode][第206题][JAVA][反转一个单链表][递归][迭代]
  6. 基于Tensorflow实现多层感知机网络MLPs
  7. web.config中配置数据库连接的两种方式
  8. 在 Laravel 5 中集成七牛云存储实现云存储功能
  9. wps如何自己制作流程图_WPS如何绘制流程图? WPS绘图流程图详细教程
  10. 修复MacOS X上QuickTime 7.2中的AVI播放错误
  11. office 论文 页码_word如何设置毕业论文页码
  12. 我对管理和领导的理解
  13. 我靠海外抖音搬运视频赚到了人生第一桶金:这个风口行业,真的很赚钱
  14. 【转】微信小程序日期时间选择器(年月日时分秒)
  15. 03.服务限流实现方案
  16. 基本共射极放大电路电路分析
  17. CSUOJ-1986: 玄学
  18. Saber吃苹果,保持每箱苹果数量递增
  19. [计算机毕业设计]模糊聚类算法
  20. 客似云来——习题精解

热门文章

  1. 联想G480类似没有小键盘开关的机器
  2. Chrome 浏览器跨域和安全访问问题 使用 chrome的命令行标记:disable-web-security 参数联调线上数据...
  3. cocos2d ccLayer响应触摸事件方法:CCStandardTouchDelegate 与 CCTargetedTouchDelegate
  4. Dataguard - 通过主库热备方式创建容灾库
  5. 旧手机的新玩法:postmarketOS 已适配上百款安卓手机
  6. (原)vs2013编译boost1.60库
  7. 使用外部表关联MySQL数据到Oracle
  8. Jquery zTree实例
  9. 任何时候不要把普通PC机接入到三层交换机
  10. [置顶] mmog游戏开发之业务篇