Pytorch中autograd.Variable.backward的grad_varables参数个人理解浅见
Pytorch中autograd.Variable.backward的grad_varables参数个人理解浅见
一、autograd.Variable简介
1、Pytorch在autograd模块中实现图计算的相关功能,autograd中的核心数据结构是Variable。Variable封装了tensor,并记录对tensor的操作记录用来构建计算图。
2、Variable结构包含data、grad和grad_fn(可以查看对应引用的Variable的反向传播函数,注意:用户创建的变量为计算图的叶子节点,对应的值为None,原因是叶子节点无法再往后进行函数算子操作比如加减乘除等)。
from torch.autograd import Variable as V
import torch as t
x = V(t.ones(1))
b = V(t.randn(1),requires_grad = True)
w = V(t.randn(1),requires_grad = True)
y = w * x
z = y + b # z的前向传播函数是加法
z.backward()
print(z.grad_fn) # 查看z的反向传播函数,见下图结果红色框
3、Variable的构造函数需要传入tensor,有两个可选参数分别是requires_grad(默认false表示不对该Variable求导)和volatile,并支持大部分tensor支持的函数,除inplace外。
from torch.autograd import Variable as V
import torch as t
a = V(t.ones(3,4),requires_grad=True) # 创建Variable
c = a.sum() # 支持大部分tensor支持的函数,除inplace外,c数据类型依旧是variable
4、如果想要计算各个Variable的梯度,只需要调用根节点Variable的backward方法,autograd会自动沿着计算图反向传播,计算根节点到每一个叶子节点的梯度,采用 Variable.backward(grad_variables=None,…) 方法。
二、grad_variables参数的理解
Variable.backward(grad_variables=None,…) 中的grad_variables参数形状与Variable一致,前向传播后得到 f 是目标函数的值一般是标量(scalar),grad_variables相当于链式法则偏导运算中的一个部分运算结果,一般是从 f 开始计算偏导数到叶子节点。如下面公式:f 表示目标函数(只当作理解参数 grad_variables 的使用,实际并不出现),x 表示叶子节点,y 表示根节点,grad_variables 等于 f 对 y 的偏导的结果。
2.1、当叶子节点都是一维的tensor(此一维是专指创建tensor时参数都是1,1的数目不限)构成的variable基本等同于标量,此时的情况表示目标函数 f 就是 y,说明 grad_variables =1 ,参数grad_variables在y.backward( )中可以直接省略,那么偏导式转换成:
import torch as t
from torch.autograd import Variable as V
x = V(t.ones(1),requires_grad=True) # 叶子节点x为一维的tensor构成的variable
y = x**2
# w = y.sum()
y.backward() # y等同于目标函数f,grad_variables=1可直接省略
print(x.grad) # 输出叶子节点x的梯度
2.2、当叶子节点为非一维的tensor构成的variable,此时的情况表示目标函数 f 与 y 初始不等同,那么偏导式整体结构初始不变。但又分两种情况分析:
(1)目标函数 f 是标量,y从非一维tensor构成的variable变成标量。对y采用sum()方法实现非一维tensor构成的variable变成标量,说明 grad_variables =1 ,参数grad_variables在f.backward( )中可以直接省略。
import torch as t
from torch.autograd import Variable as V
x = V(t.ones(2),requires_grad=True) # 叶子节点x为非一维的tensor构成的variable
y = x**2
f = y.sum() # 对y采用sum()方法实现非一维tensor构成的variable变成标量
f.backward() # 标量对标量不用求导,所以直接省略
print(x.grad) # 输出叶子节点x的梯度
(2)目标函数 f 是标量,y是非一维tensor构成的variable。那么 f 对 y 求偏导的结果也是非一维tensor构成的variable,并且维度大小和 y 一样,所以采用y.backward(t.ones(y.size())):这是把 y 当作被偏导数,t.ones(y.size())是f 对 y 求偏导的结果,y.size()是表示结果的维度和 y 保持相同设定。
import torch as t
from torch.autograd import Variable as V
x = V(t.ones(2),requires_grad=True) # 叶子节点x为非一维的tensor构成的variable
y = x**2
# f = y.sum() # 对y采用sum()方法实现非一维tensor变成标量
y.backward(t.ones(y.size())) # f 对 y 求偏导
print(x.grad) # 输出叶子节点x的梯度
Pytorch中autograd.Variable.backward的grad_varables参数个人理解浅见相关推荐
- Pytorch中的variable, tensor与numpy相互转化
来源:https://blog.csdn.net/m0_37592397/article/details/88327248 1.将numpy矩阵转换为Tensor张量 sub_ts = torch.f ...
- PyTorch中的Variable类型
1 前言 今天在学习PyTorch~ 之前在莫烦的教程中看到了Variable类型的变量,后来看PyTorch的<Deep Learning with PyTorch: A 60 Minute ...
- Pytorch中的Variable
Pytorch中的Variable pytorch两个基本对象:Tensor(张量)和Variable(变量) 其中,tensor不能反向传播,variable可以反向传播. Varibale包含三个 ...
- Pytorch中torch.nn.Softmax的dim参数含义
自己搞了一晚上终于搞明白了,下文说的很透彻,做个记录,方便以后翻阅 Pytorch中torch.nn.Softmax的dim参数含义
- pytorch中tensor、backward一些总结
目录 说明 Tensor Tensor的创建 Tensor(张量)基本数据类型与常用属性 Tensor的自动微分 设置不可积分计算 pytorch 计算图 backward一些细节 该文章解决问题如下 ...
- pytorch中的Variable()
参考链接:https://www.cnblogs.com/hellcat/p/8439055.html 函数简介 torch.autograd.Variable是Autograd的核心类,它封装了Te ...
- Pytorch中的variable, tensor与numpy相互转化的方法
来源:https://blog.csdn.net/pengge0433/article/details/79459679 在使用pytorch作为深度学习的框架时,经常会遇到变量variable.张量 ...
- pytorch中的Variable还有必要使用吗?
pytorch1.6文档 Variable 早在在pytorch0.4已经不需要了 tensor就支持autograd了 所以见到 data=Variable(data) 这样的用法请大胆删除Va ...
- pytorch 中 Autograd(四)
用Tensor训练网络很方便,但是反向传播过程需要手动实现.这对于像线性回归等较为简单的模型来说,还可以应付,但实际使用中经常出现非常复杂的网络结构,此时如果手动实现反向传播,不仅费时费力,而且容易出 ...
- pytorch中“_, pred = out.max(1)”语句的理解
本人小白在入门pytorch,看mnist数据集的代码时,看到这样一句代码:"_, pred = out.max(1)",顿时就有些懵了,查阅一番资料后,才明白其中的意思. tor ...
最新文章
- IPv6 — 地址配置方式
- Go实战--也许最快的Go语言Web框架kataras/iris初识三(Redis、leveldb、BoltDB)
- 【ACM】最长公共子序列 - 动态规划
- 浅谈:国内软件公司为何无法做大做强?
- 如何拷贝工程_如何将premiere的工程及素材文件打包?
- IDEA中Alt + Insert快捷键定制生成类方法
- Vue 电商PC后台管理(ElementUI)
- 禁止按钮在一定时间内连续点击
- Typescript tsconfig
- The Top 8 Security and Risk Trends We’re Watching
- 思岚A1M8激光雷达-ubuntu18.04-slam建图参考
- P5144 【蜈蚣】
- 裸金属虚拟化解决方案-工业一体机(1)
- 【魏先生搞定Python系列】一文搞定SQLAlchemy学习与使用
- 山东农村商业银行计算机笔试,2021年山东农村商业银行笔试备考:计算机科目高分复习方法...
- 从零开始实现mini-min网易云音乐(一)
- 目录没有.kaggle文件夹的解决方法
- 《520七夕情人节表白礼物》:虚幻浪漫的爱情故事——❤520表白星空漫漫3D相册❤(HTML+CSS+JavaScript)...
- Pass by reference和pass by value区别举例
- 虚拟货币盘点:微币,Q币,苹果平台,Facebook 的F币,Google会推G币么?