参考链接: torch.Tensor.retain_grad()


原文及翻译:

retain_grad()
方法: retain_grad()Enables .grad attribute for non-leaf Tensors.对非叶节点(即中间节点张量)张量启用用于保存梯度的属性(.grad).(译者注: 默认情况下对于非叶节点张量是禁用该属性grad,计算完梯度之后就被释放回收内存,不会保存中间结果的梯度.)

实验代码展示:

Microsoft Windows [版本 10.0.18363.1316]
(c) 2019 Microsoft Corporation。保留所有权利。C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001CDB4A5D330>
>>> data_in = torch.randn(3,5,requires_grad=True)
>>> data_in
tensor([[ 0.2824, -0.3715,  0.9088, -1.7601, -0.1806],[ 2.0937,  1.0406, -1.7651,  1.1216,  0.8440],[ 0.1783,  0.6859, -1.5942, -0.2006, -0.4050]], requires_grad=True)
>>> data_mean = data_in.mean()
>>> data_mean
tensor(0.0585, grad_fn=<MeanBackward0>)
>>> data_in.requires_grad
True
>>> data_mean.requires_grad
True
>>> data_1 = data_mean * 20200910.0
>>> data_1
tensor(1182591., grad_fn=<MulBackward0>)
>>> data_2 = data_1 * 15.0
>>> data_2
tensor(17738864., grad_fn=<MulBackward0>)
>>> data_2.retain_grad()
>>> data_3 = 2 * (data_2 + 55.0)
>>> loss = data_3 / 2.0 +89.2
>>> loss
tensor(17739010., grad_fn=<AddBackward0>)
>>>
>>> data_in.grad
>>> data_mean.grad
>>> data_1.grad
>>> data_2.grad
>>> data_3.grad
>>> loss.grad
>>> print(data_in.grad, data_mean.grad, data_1.grad, data_2.grad, data_3.grad, loss.grad)
None None None None None None
>>>
>>> loss.backward()
>>> data_in.grad
tensor([[20200910., 20200910., 20200910., 20200910., 20200910.],[20200910., 20200910., 20200910., 20200910., 20200910.],[20200910., 20200910., 20200910., 20200910., 20200910.]])
>>> data_mean.grad
>>> data_mean.grad
>>> data_1.grad
>>> data_2.grad
tensor(1.)
>>> data_3.grad
>>> loss.grad
>>>
>>>
>>> print(data_in.grad, data_mean.grad, data_1.grad, data_2.grad, data_3.grad, loss.grad)
tensor([[20200910., 20200910., 20200910., 20200910., 20200910.],[20200910., 20200910., 20200910., 20200910., 20200910.],[20200910., 20200910., 20200910., 20200910., 20200910.]]) None None tensor(1.) None None
>>>
>>>
>>> print(data_in.is_leaf, data_mean.is_leaf, data_1.is_leaf, data_2.is_leaf, data_3.is_leaf, loss.is_leaf)
True False False False False False
>>>
>>>
>>>

torch.Tensor.retain_grad()的使用举例相关推荐

  1. 通俗讲解Pytorch梯度的相关问题:计算图、torch.no_grad、zero_grad、detach和backward;Variable、Parameter和torch.tensor

    文章目录 with torch.no_grad()和requires_grad backward() Variable,Parameter和torch.tensor() zero_grad() 计算图 ...

  2. torch.Tensor.requires_grad_(requires_grad=True)的使用说明

    参考链接: requires_grad_(requires_grad=True) → Tensor 原文及翻译: requires_grad_(requires_grad=True) → Tensor ...

  3. torch.Tensor.requires_grad属性的使用说明

    参考链接: torch.Tensor.requires_grad 原文及翻译: requires_grad() 方法: requires_grad()Is True if gradients need ...

  4. 【PyTorch系例】torch.Tensor详解和常用操作

    学习教材: 动手学深度学习 PYTORCH 版(DEMO) (https://github.com/ShusenTang/Dive-into-DL-PyTorch) PDF 制作by [Marcus ...

  5. PyTorch 笔记(02)— 常用创建 Tensor 方法(torch.Tensor、ones、zeros、eye、arange、linspace、rand、randn、new)

    1. Tensor 概念分类 PyTorch 中的张量(Tensor)类似 NumPy 中的 ndarrays,之所以称之为 Tensor 的另一个原因是它可以运行在 GPU 中,以加速运算. 1.1 ...

  6. pytorch中的torch.tensor.repeat以及torch.tensor.expand用法

    文章目录 torch.tensor.expand torch.tensor.repeat torch.tensor.expand 先看招 import torch x = torch.tensor([ ...

  7. Unable to get repr for<class‘torch.Tensor‘>

    Unable to get repr for <class 'torch.Tensor'> tensor越界访问后就会变成这样. import torcha_data=torch.Tens ...

  8. unable to get repr for class ‘torch.tensor‘

    unable to get repr for class 'torch.tensor' 出错代码: batch_conf.gather(1, conf_t.view(-1,1)) 最近码代码使用pyt ...

  9. torch.Tensor和torch.tensor的区别

    torch.Tensor和torch.tensor的区别 2019-06-10 16:34:48 Vic_Hao 阅读数 4058更多 分类专栏: Pytorch 在Pytorch中,Tensor和t ...

最新文章

  1. 2021揭东一中今年高考成绩查询入口,2021年揭阳高考状元是谁分数多少分,历年揭阳高考状元名单...
  2. 【BZOJ3314】 [Usaco2013 Nov]Crowded Cows 单调队列
  3. Star PDF Watermark Ultimate中文版
  4. 为了OFFER,花了几个小时,刷下Leetcode链表算法题
  5. [转]C++类成员修饰const和mutable
  6. linux内核奇遇记之md源代码解读之六
  7. 能在沙漠飞行的翱翔机
  8. Live2D在Unity中的使用
  9. matlab工具箱下载
  10. C窗口程序——Shell_NotifyIcon()函数的使用
  11. JAVA打卡记录计算时间
  12. 别着急抢iPhone 13了!拍照有马赛克,苹果确认部分iPhone13存在bug
  13. 安装Win7系统,提示缺少所需的CD/DVD驱动器设备驱动程序
  14. 计算机网络本地连接,电脑本地连接受限制或无连接怎么办
  15. 【无标题】阿里滑块 通过 x82y接口、dll、源码 返回x5sec,可解决!
  16. (翻译)分块模式(Chunking)
  17. mybatis一个怪异的问题: Invalid bound statement not found 作者及来源: babyblue - 博客园 收藏到→_→: 摘要: mybatis一个怪异
  18. 查快递单号物流信息查询,支持多家快递
  19. angular的传值,子传父,父传子
  20. docker ss-pannel_如何构建Docker镜像

热门文章

  1. Playbook机密
  2. master上启动的容器一直pending
  3. 医疗器械电子硬件·安规与EMC设计
  4. IOM计算机组成原理,计算机组成原理实验__实验报告
  5. 简述C和C++程序员学习历程
  6. 电信dns错误怎么办?解决dns错误的的方法
  7. openwrt网络唤醒计算机,OpenWrt实现WOL(Wake-on-LAN)网络唤醒
  8. Android 内存修改与一键修改
  9. 常见的游戏引擎有哪几种?游戏开发
  10. 拍拍贷Q2季报图解:净利6亿 环比增39%同比降4%