detach的作用

Tensor.detach() 的作用是阻断反向梯度传播,当我们再训练网络的时候可能希望保持一部分的网络参数不变,只对其中一部分的参数进行调整;或者值训练部分分支网络,并不让其梯度对主网络的梯度造成影响,这时候我们就需要使用detach()函数来切断一些分支的反向传播,例如在生成对抗网络的训练当中,在训练判别器的时候不需要生成器进行反向梯度传播,这时候就会使用到 detach()。

detach文档说明

  • 返回一个新的 Tensor,与当前计算图形分离。
  • 结果永远不需要梯度,requires_grad为false
  • 这种方法也会影响前向模式 AD 梯度,结果永远不会有前向模式 AD 梯度。
  • 返回的 Tensor 与原始张量共享相同的存储。将看到对它们中的任何一个进行就地修改,另一个也会发生改变

如何使用detach

import torcha = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()out.sum().backward()
print(a.grad)
'''返回:
None
tensor([0.2139, 0.2217, 0.2445])
'''

当使用detach()分离tensor但是没有更改这个tensor时,并不会影响backward()

# %%
import torcha = torch.tensor([0.8, 0.7, 0.3], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)#添加detach(),c的requires_grad为False
c = out.detach()
print(c)#使用新生成的Variable进行反向传播
c.sum().backward()
print(a.grad)

打印结果

当使用detach()分离tensor,然后用这个分离出来的tensor去求导数,会影响backward(),会出现错误

# %%
import torcha = torch.tensor([0.8, 0.7, 0.3], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)#添加detach(),c的requires_grad为False
c = out.detach()
print(c)#使用新生成的Variable进行反向传播
c.sum().backward()
print(a.grad)''' 执行结果
None
tensor([0.6900, 0.6682, 0.5744], grad_fn=<SigmoidBackward>)
tensor([0.6900, 0.6682, 0.5744])
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn'''

当使用detach()分离tensor并且更改这个tensor时,即使再对原来的out求导数,会影响backward(),会出现错误


# %%
import torcha = torch.tensor([0.8, 0.7, 0.3], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)#添加detach(),c的requires_grad为False
c = out.detach()
print(c)
c.zero_() #使用in place函数对其进行修改#会发现c的修改同时会影响out的值
print(c)
print(out)#这时候对c进行更改,所以会影响backward(),这时候就不能进行backward(),会报错
out.sum().backward()
print(a.grad)'''执行结果
None
tensor([0.6900, 0.6682, 0.5744], grad_fn=<SigmoidBackward>)
tensor([0.6900, 0.6682, 0.5744])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor
[3]], which is output 0 of SigmoidBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the
operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).'''

detach_()


其实就相当于变量之间的关系本来是x -> m -> y,这里的叶子tensor是x,但是这个时候对m进行了m.detach_()操作,其实就是进行了两个操作:

  • 将m的grad_fn的值设置为None,这样m就不会再与前一个节点x关联,这里的关系就会变成x, m -> y,此时的m就变成了叶子结点
  • 然后会将m的requires_grad设置为False,这样对y进行backward()时就不会求m的梯度

其实detach()和detach_()很像,两个的区别就是detach_()是对本身的更改,detach()则是生成了一个新的tensor

比如x -> m -> y中如果对m进行detach(),如果还想对原来的计算图进行操作是可以的。

参考文档:https://blog.csdn.net/qq_27825451/article/details/95498211

pytorch 中的Tensor.detach介绍相关推荐

  1. 实践教程 | 浅谈 PyTorch 中的 tensor 及使用

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | xiaopl@知乎(已授权) 来源 | https://z ...

  2. 【深度学习理论】一文搞透pytorch中的tensor、autograd、反向传播和计算图

    转载:https://zhuanlan.zhihu.com/p/145353262 前言 本文的主要目标: 一遍搞懂反向传播的底层原理,以及其在深度学习框架pytorch中的实现机制.当然一遍搞不定两 ...

  3. pytorch中torch.optim的介绍

    pytorch中torch.optim的介绍 这是torch自带的一个优化器,里面自带了求导,更新等操作.开门见山直接讲怎么使用: 常用的引入: import torch.optim as optim ...

  4. Python数据类型、Numpy数据类型和Pytorch中的tensor类型间的相互转化

    数据类型包括Python数据类型.Numpy数据类型和Pytorch中的tensor,Pytorch中的tensor又包括CPU上的数据类型和GPU上的数据类型. 一.Python数据类型 Pytho ...

  5. PyTorch中使用Tensor作为索引

    一维Tensor作为索引 在Numpy中,我们可以传入数组作为索引,称为花式索引.这里只演示使用两个一维List的例子. In[42]: a=np.arange(18).reshape(6,3) In ...

  6. python/pytorch中的一些函数介绍

    看cvt代码,记录里面的一些不认识的函数或功能. 1 collections.OrderedDict 包含:from collections import OrderedDict 作用:建立有序的键值 ...

  7. pyTorch中tensor运算

    文章目录 PyTorch的简介 PyTorch中主要的包 PyTorch的安装 使用GPU的原因 使数据在GPU上运行 什么使Tensor(张量) 一些术语介绍 Tensor的属性介绍(Rank,ax ...

  8. pytorch中tensor、backward一些总结

    目录 说明 Tensor Tensor的创建 Tensor(张量)基本数据类型与常用属性 Tensor的自动微分 设置不可积分计算 pytorch 计算图 backward一些细节 该文章解决问题如下 ...

  9. pytorch中的卷积操作详解

    首先说下pytorch中的Tensor通道排列顺序是:[batch, channel, height, width] 我们常用的卷积(Conv2d)在pytorch中对应的函数是: torch.nn. ...

最新文章

  1. Sublime Text 2报“Decode error - output not utf-8”错误的解决办法
  2. 探秘 | 平安人寿人工智能研发团队北京研发中心
  3. c#中怎样取得某坐标点的颜色
  4. C# 线程的定义和使用
  5. 机器学习-算法背后的理论与优化(part1)--从线性回归到逻辑回归
  6. 著名开源项目_著名开源项目案例研究
  7. Missing artifact com.oracle:ojdbc6:jar:11.2.0.1.0问题解决 ojdbc包pom.xml出错
  8. BGP-13 配置BGP多路径发布
  9. Charades数据集
  10. 人事管理系统都有哪些功能和优势?
  11. 物联网学习之路——物联网通信技术简介
  12. python怎么下载panda包_pandas python下载
  13. centos修改ftp服务器密码是什么,centos ftp服务器密码忘记了
  14. python可能实现办公自动化吗,让工作化繁为简:用Python实现办公自动化
  15. 山东罕见姓氏百家姓都没有,翻家谱竟是皇室后裔,专家:是真的
  16. Python网页编程(CGI)
  17. Linux打开wim文件,linux笔记 wim编辑器
  18. 类选择器和ID选择器的区别
  19. 【第二届青训营-寒假前端场】- 「小游戏开发」笔记
  20. Win7系统如何用记事本打开文件?

热门文章

  1. [PAT乙级]1037. 在霍格沃茨找零钱(20)
  2. TableLayout 和 GridLayout 的区别
  3. 韶关生物实验室建设平面布局
  4. 对linux内核中GDT和LDT的理解
  5. onekey ghost怎么用
  6. IDEA中SpringBoot集成Swagger总结,思路清晰
  7. nexus3支持docker匿名拉取
  8. 【闲聊CQF的门槛,个人观点,不喜勿喷,欢迎交流指导】
  9. 社会责任审核-安全出口
  10. VMware Workstation Pro虚拟机黑屏无反应解决方案 [硬核版]