torch.Tensor.retain_grad()的使用举例
参考链接: torch.Tensor.retain_grad()
原文及翻译:
retain_grad()
方法: retain_grad()Enables .grad attribute for non-leaf Tensors.对非叶节点(即中间节点张量)张量启用用于保存梯度的属性(.grad).(译者注: 默认情况下对于非叶节点张量是禁用该属性grad,计算完梯度之后就被释放回收内存,不会保存中间结果的梯度.)
实验代码展示:
Microsoft Windows [版本 10.0.18363.1316]
(c) 2019 Microsoft Corporation。保留所有权利。C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May 6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001CDB4A5D330>
>>> data_in = torch.randn(3,5,requires_grad=True)
>>> data_in
tensor([[ 0.2824, -0.3715, 0.9088, -1.7601, -0.1806],[ 2.0937, 1.0406, -1.7651, 1.1216, 0.8440],[ 0.1783, 0.6859, -1.5942, -0.2006, -0.4050]], requires_grad=True)
>>> data_mean = data_in.mean()
>>> data_mean
tensor(0.0585, grad_fn=<MeanBackward0>)
>>> data_in.requires_grad
True
>>> data_mean.requires_grad
True
>>> data_1 = data_mean * 20200910.0
>>> data_1
tensor(1182591., grad_fn=<MulBackward0>)
>>> data_2 = data_1 * 15.0
>>> data_2
tensor(17738864., grad_fn=<MulBackward0>)
>>> data_2.retain_grad()
>>> data_3 = 2 * (data_2 + 55.0)
>>> loss = data_3 / 2.0 +89.2
>>> loss
tensor(17739010., grad_fn=<AddBackward0>)
>>>
>>> data_in.grad
>>> data_mean.grad
>>> data_1.grad
>>> data_2.grad
>>> data_3.grad
>>> loss.grad
>>> print(data_in.grad, data_mean.grad, data_1.grad, data_2.grad, data_3.grad, loss.grad)
None None None None None None
>>>
>>> loss.backward()
>>> data_in.grad
tensor([[20200910., 20200910., 20200910., 20200910., 20200910.],[20200910., 20200910., 20200910., 20200910., 20200910.],[20200910., 20200910., 20200910., 20200910., 20200910.]])
>>> data_mean.grad
>>> data_mean.grad
>>> data_1.grad
>>> data_2.grad
tensor(1.)
>>> data_3.grad
>>> loss.grad
>>>
>>>
>>> print(data_in.grad, data_mean.grad, data_1.grad, data_2.grad, data_3.grad, loss.grad)
tensor([[20200910., 20200910., 20200910., 20200910., 20200910.],[20200910., 20200910., 20200910., 20200910., 20200910.],[20200910., 20200910., 20200910., 20200910., 20200910.]]) None None tensor(1.) None None
>>>
>>>
>>> print(data_in.is_leaf, data_mean.is_leaf, data_1.is_leaf, data_2.is_leaf, data_3.is_leaf, loss.is_leaf)
True False False False False False
>>>
>>>
>>>
torch.Tensor.retain_grad()的使用举例相关推荐
- 通俗讲解Pytorch梯度的相关问题:计算图、torch.no_grad、zero_grad、detach和backward;Variable、Parameter和torch.tensor
文章目录 with torch.no_grad()和requires_grad backward() Variable,Parameter和torch.tensor() zero_grad() 计算图 ...
- torch.Tensor.requires_grad_(requires_grad=True)的使用说明
参考链接: requires_grad_(requires_grad=True) → Tensor 原文及翻译: requires_grad_(requires_grad=True) → Tensor ...
- torch.Tensor.requires_grad属性的使用说明
参考链接: torch.Tensor.requires_grad 原文及翻译: requires_grad() 方法: requires_grad()Is True if gradients need ...
- 【PyTorch系例】torch.Tensor详解和常用操作
学习教材: 动手学深度学习 PYTORCH 版(DEMO) (https://github.com/ShusenTang/Dive-into-DL-PyTorch) PDF 制作by [Marcus ...
- PyTorch 笔记(02)— 常用创建 Tensor 方法(torch.Tensor、ones、zeros、eye、arange、linspace、rand、randn、new)
1. Tensor 概念分类 PyTorch 中的张量(Tensor)类似 NumPy 中的 ndarrays,之所以称之为 Tensor 的另一个原因是它可以运行在 GPU 中,以加速运算. 1.1 ...
- pytorch中的torch.tensor.repeat以及torch.tensor.expand用法
文章目录 torch.tensor.expand torch.tensor.repeat torch.tensor.expand 先看招 import torch x = torch.tensor([ ...
- Unable to get repr for<class‘torch.Tensor‘>
Unable to get repr for <class 'torch.Tensor'> tensor越界访问后就会变成这样. import torcha_data=torch.Tens ...
- unable to get repr for class ‘torch.tensor‘
unable to get repr for class 'torch.tensor' 出错代码: batch_conf.gather(1, conf_t.view(-1,1)) 最近码代码使用pyt ...
- torch.Tensor和torch.tensor的区别
torch.Tensor和torch.tensor的区别 2019-06-10 16:34:48 Vic_Hao 阅读数 4058更多 分类专栏: Pytorch 在Pytorch中,Tensor和t ...
最新文章
- 2021揭东一中今年高考成绩查询入口,2021年揭阳高考状元是谁分数多少分,历年揭阳高考状元名单...
- 【BZOJ3314】 [Usaco2013 Nov]Crowded Cows 单调队列
- Star PDF Watermark Ultimate中文版
- 为了OFFER,花了几个小时,刷下Leetcode链表算法题
- [转]C++类成员修饰const和mutable
- linux内核奇遇记之md源代码解读之六
- 能在沙漠飞行的翱翔机
- Live2D在Unity中的使用
- matlab工具箱下载
- C窗口程序——Shell_NotifyIcon()函数的使用
- JAVA打卡记录计算时间
- 别着急抢iPhone 13了!拍照有马赛克,苹果确认部分iPhone13存在bug
- 安装Win7系统,提示缺少所需的CD/DVD驱动器设备驱动程序
- 计算机网络本地连接,电脑本地连接受限制或无连接怎么办
- 【无标题】阿里滑块 通过 x82y接口、dll、源码 返回x5sec,可解决!
- (翻译)分块模式(Chunking)
- mybatis一个怪异的问题: Invalid bound statement not found 作者及来源: babyblue - 博客园 收藏到→_→: 摘要: mybatis一个怪异
- 查快递单号物流信息查询,支持多家快递
- angular的传值,子传父,父传子
- docker ss-pannel_如何构建Docker镜像