文章目录

  • pytorch的两个函数:.detach()、.detach_()的作用和区别
    • 一、torch.detach()
    • 二、tensor.detach_()
  • 补充:requires_grad、grad_fn、grad的含义和作用
    • 参考

pytorch的两个函数:.detach()、.detach_()的作用和区别

当我们在训练神经网络的时候可能希望保持一部分的网络参数不变,只对其中一部分参数进行调整;或者只训练部分分支网络,并不让其梯度对主网络的梯度造成影响,这时候我们就需要使用detach()函数来切断一些分支的反向传播

一、torch.detach()

返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。

使用detach返回的tensor和原始的tensor共同一个内存,即一个修改另一个也会跟着改变。

import torcha = torch.tensor([1., 2., 3.], requires_grad=True)
print(a.grad)out = a.sigmoid()out.sum().backward()
print(a.grad)"""
None
tensor([0.1966, 0.1050, 0.0452])
"""

1.1、当使用detach()分离tensor但是没有更改这个tensor时,并不会影响backward()

import torcha = torch.tensor([1., 2., 3.], requires_grad=True)
print(a.grad)out = a.sigmoid()
print(out)# 添加detach(),c的requires_grad为False
c = out.detach()
print(c)# 这个时候没有对c进行更改,所以并不会影响backward()
out.sum().backward()
print(a.grad)

c、out之间的区别是c是没有梯度的,out是有梯度的

1.2、当使用detach()分离tensor,然后这个分离出来的tensor去求导数,会影响backward(),会出现错误

import torcha = torch.tensor([1., 2., 3.], requires_grad=True)
print(a.grad)out = a.sigmoid()
print(out)# 添加detach(),c的requires_grad为False
c = out.detach()
print(c)# 使用新生成的Variable进行反向传播
c.sum().backward()
print(a.grad)"""
None
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
"""

1.3、当使用detach()分离tensor并且更改这个tensor时,即使再对原来的out求导数,会影响backward(),会出现错误

如果此时对c进行了更改,这个更改会被autograd追踪,在对out.sum()进行backward()时也会报错,因为此时的值进行backward()得到的梯度是错误的

import torcha = torch.tensor([1., 2., 3.], requires_grad=True)
print(a.grad)out = a.sigmoid()
print(out)# 添加detach(),c的requires_grad为False
c = out.detach()
print(c)# 使用inplace函数对其进行修改
c.zero_()
print(c)
print(out)# 这个时候没有对c进行更改,所以并不会影响backward()
out.sum().backward()
print(a.grad)"""
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
"""

二、tensor.detach_()

import torcha = torch.tensor([1., 2., 3.], requires_grad=True)b = a + 2
print(b)c = b * b * 3
print(c)out = c.mean()
print(out)out.backward()
print(a.grad)"""
tensor([3., 4., 5.], grad_fn=<AddBackward0>)
tensor([27., 48., 75.], grad_fn=<MulBackward0>)
tensor(50., grad_fn=<MeanBackward0>)
tensor([ 6.,  8., 10.])
"""
import torcha = torch.tensor([1., 2., 3.], requires_grad=True)b = a + 2
print(b)b=b.detach_()
print(b)c = b * b * 3
print(c)out = c.mean()
print(out)out.backward()
print(c.grad)"""
tensor([3., 4., 5.], grad_fn=<AddBackward0>)
tensor([3., 4., 5.])
tensor([27., 48., 75.])
tensor(50.)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
"""

torch.detach_()将一个tensor从创建它的图中分离,并把它设置成叶子tensor

其实就相当于变量之间的关系本来是x -> m -> y,这里的叶子tensor是x,但是这时候对m进行了m.detach_()操作,其实就是进行了两个操作:

  • 将m的grad_fn的值设置为None,这样m就不会再与前一个节点x观点,这里的关系就会变成x, m -> y,此时的m就变成了叶子结点

  • 然后会将m的requires_grad设置为False,这样对y进行backward()时就不会求m的梯度

总结:其实detach()和detach_()很像,两个的区别就是detach_()是对本身的更改,detach()则是生成了一个新的tensor

比如x -> m -> y中如果对m进行detach(),后面如果反悔想还是对原来的计算图进行操作还是可以的

但是如果是进行了detach_(),那么原来的计算图也发生了变化,就不能反悔了

补充:requires_grad、grad_fn、grad的含义和作用

requires_grad:如果需要为张量计算梯度,则为True,否则为False。我们使用pytorch创建tensor时,可以指定requires_grad为True(默认为False)

grad_fn:grad_fn用来记录变量是怎么来的,方便计算梯度

grad:当执行完backward()之后,通过x.grad查看x的梯度

创建一个tensor并设置requires_grad=True,requires_grad=True说明该变量需要计算梯度

import torchx = torch.ones(2, 2, requires_grad=True)print(x)
print(x.grad_fn)"""
tensor([[1., 1.],[1., 1.]], requires_grad=True)
None
"""y = x + 2
print(y)
print(y.grad_fn)"""
tensor([[3., 3.],[3., 3.]], grad_fn=<AddBackward0>)
<AddBackward0 object at 0x00000143B2129700>
"""

由于x是直接创建的,所以它没有grad_fn,而y是通过一个加法操作创建的,所以y有grad_fn

像x这种直接创建的称为叶子节点,叶子节点对应的grad_fn是None

z = y * y * 3
out = z.mean()
print(out)"""
tensor(27., grad_fn=<MeanBackward0>)
"""out.backward()
print(x.grad)"""
tensor([[4.5000, 4.5000],[4.5000, 4.5000]])
"""

grad在反向传播过程中是累加的(accumulated),这意味着每一次运行反向传播,梯度都会累加之前的梯度,所以一般在反向传播之前把梯度清零

参考

1、pytorch的两个函数 .detach() .detach_() 的作用和区别

2、requires_grad,grad_fn,grad的含义及使用

pytorch:.detach()、.detach_()的作用和区别相关推荐

  1. pytorch .detach() .detach_() 和 .data用于切断反向传播

    参考:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-autograd/#detachsource 当我们再训 ...

  2. 【pytorch】.detach() .detach_() 和 .data==>用于切断反向传播

    当我们再训练网络的时候可能 希望保持一部分的网络参数不变,只对其中一部分的参数进行调整: 或者只训练部分分支网络,并不让其梯度对主网络的梯度造成影响, 这时候我们就需要使用detach()函数来切断一 ...

  3. Spring中SmartLifecycle和Lifecycle的作用和区别

    欢迎关注方志朋的博客,回复"666"获面试宝典 本文基于SpringBoot 2.5.0-M2讲解Spring中Lifecycle和SmartLifecycle的作用和区别,以及如 ...

  4. 浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式 pth中的路径加载使用

    首先xxx.pth文件里面会书写一些路径,一行一个. 将xxx.pth文件放在特定位置,则可以让python在加载模块时,读取xxx.pth中指定的路径. Python客栈送红包.纸质书 有时,在用i ...

  5. ANDROID 中UID与PID的作用与区别

    ANDROID 中UID与PID的作用与区别 PID:为Process Identifier, PID就是各进程的身份标识,程序一运行系统就会自动分配给进程一个独一无二的PID.进程中止后PID被系统 ...

  6. (转)从一道面试题彻底搞懂hashCode与equals的作用与区别及应当注意的细节

    背景:学习java的基础知识,每次回顾,总会有不同的认识.该文系转载 最近去面试了几家公司,被问到hashCode的作用,虽然回答出来了,但是自己还是对hashCode和equals的作用一知半解的, ...

  7. python类方法中使用:修饰符@staticmethod和@classmethod的作用与区别,还有装饰器@property的使用

    python类方法中使用:修饰符@staticmethod和@classmethod的作用与区别,还有装饰器@property的使用(3-20181205) 文章目录: 一. @staticmetho ...

  8. Jar/War/Ear等包的作用与区别详解

    Jar/War/Ear等包的作用与区别详解 以客户角度来看,jar文件就是一种封装格式,用户不需要知道jar包中有多少个.class格式的文件及每个文件中的功能与作用,也可以得到相应的访问的结果.ja ...

  9. java ear war_[转] 基于Java的打包jar、war、ear包的作用与区别详解

    以最终客户的角度来看,JAR文件就是一种封装,他们不需要知道jar文件中有多少个.class文件,每个文件中的功能与作用,同样可以得到他们希望的结果.除jar以外对于J2EE来说还有war和ear.区 ...

最新文章

  1. Jvm面试题及答案 100道(持续更新)
  2. linux vi 命令大全
  3. Java复习第三天-静态方法
  4. android上传文件用哪个布局,每周总结20130821——android控件的尺寸、http文件上传...
  5. Java经典设计模式(1):五大创建型模式(附实例和详解)
  6. Linux下NTP服务器配置
  7. 按拼音首字母排列的地区选择代码 中文和拼音已配好链接
  8. html中绝对定位的父级,【CSS学习笔记】绝对定位的父类参照物的确定
  9. execl2010数据有效性验证,保存后丢失问题
  10. PayPal接口开发
  11. 电脑端微信可以打开微信小程序了
  12. vue网页打印针式打印机内容显示不全
  13. 编一行代码,飞向星辰的大海
  14. OpenSIPS实战(八):修改sip消息-使用lumps system
  15. Android Studio学习笔记
  16. LabVIEW AI视觉工具包(非NI Vision)下载与安装教程
  17. 固实压缩文件容易损坏_你不知道的压缩软件小技巧1
  18. 使用xmind绘制思维导向
  19. 串口硬盘和并口硬盘的区别
  20. Shell小干货学到就不亏

热门文章

  1. WPF 实现3维图片墙相关展示效果(凹面墙,凸面墙)
  2. python基础:动态方法、私有属性、property、继承、重写、super、多态、符号重载、拷贝、组合、工厂模式,单例
  3. 微波射频学习笔记12-------单节/多节定向耦合器的设计
  4. MySQl 计算本年的天数
  5. 敏感度、特异性:TP TN FP FN sensitivity Accuracy
  6. 谷歌(Google): reCaptcha(3.0版本)做网站验证
  7. oldboy day 4
  8. 【UI】 element-ui 表格标题加背景 斑马线
  9. 语义分割数据集之RGB与索引图的转换
  10. 5g上行速率怎么提升_小兴君课堂 | 5G上行的痛,有解药啦!