Pytorch中autograd.Variable.backward的grad_varables参数个人理解浅见

一、autograd.Variable简介

1、Pytorch在autograd模块中实现图计算的相关功能,autograd中的核心数据结构是Variable。Variable封装了tensor,并记录对tensor的操作记录用来构建计算图。
2、Variable结构包含data、grad和grad_fn(可以查看对应引用的Variable的反向传播函数,注意:用户创建的变量为计算图的叶子节点,对应的值为None,原因是叶子节点无法再往后进行函数算子操作比如加减乘除等)。

from torch.autograd import Variable as V
import torch as t
x = V(t.ones(1))
b = V(t.randn(1),requires_grad = True)
w = V(t.randn(1),requires_grad = True)
y = w * x
z = y + b # z的前向传播函数是加法
z.backward()
print(z.grad_fn) # 查看z的反向传播函数,见下图结果红色框


3、Variable的构造函数需要传入tensor,有两个可选参数分别是requires_grad(默认false表示不对该Variable求导)和volatile,并支持大部分tensor支持的函数,除inplace外。

from torch.autograd import Variable as V
import torch as t
a = V(t.ones(3,4),requires_grad=True) # 创建Variable
c = a.sum() # 支持大部分tensor支持的函数,除inplace外,c数据类型依旧是variable

4、如果想要计算各个Variable的梯度,只需要调用根节点Variable的backward方法,autograd会自动沿着计算图反向传播,计算根节点到每一个叶子节点的梯度,采用 Variable.backward(grad_variables=None,…) 方法。

二、grad_variables参数的理解

Variable.backward(grad_variables=None,…) 中的grad_variables参数形状与Variable一致,前向传播后得到 f目标函数的值一般是标量(scalar),grad_variables相当于链式法则偏导运算中的一个部分运算结果,一般是从 f 开始计算偏导数到叶子节点。如下面公式:f 表示目标函数(只当作理解参数 grad_variables 的使用,实际并不出现),x 表示叶子节点,y 表示根节点,grad_variables 等于 f 对 y 的偏导的结果

2.1、当叶子节点都是一维的tensor(此一维是专指创建tensor时参数都是1,1的数目不限)构成的variable基本等同于标量,此时的情况表示目标函数 f 就是 y,说明 grad_variables =1 ,参数grad_variables在y.backward( )中可以直接省略,那么偏导式转换成:

import torch as t
from torch.autograd import Variable as V
x = V(t.ones(1),requires_grad=True) # 叶子节点x为一维的tensor构成的variable
y = x**2
# w = y.sum()
y.backward() # y等同于目标函数f,grad_variables=1可直接省略
print(x.grad) # 输出叶子节点x的梯度

2.2、当叶子节点为非一维的tensor构成的variable,此时的情况表示目标函数 fy 初始不等同,那么偏导式整体结构初始不变。但又分两种情况分析:

(1)目标函数 f 是标量,y从非一维tensor构成的variable变成标量。对y采用sum()方法实现非一维tensor构成的variable变成标量,说明 grad_variables =1 ,参数grad_variables在f.backward( )中可以直接省略。

import torch as t
from torch.autograd import Variable as V
x = V(t.ones(2),requires_grad=True) # 叶子节点x为非一维的tensor构成的variable
y = x**2
f = y.sum() # 对y采用sum()方法实现非一维tensor构成的variable变成标量
f.backward() # 标量对标量不用求导,所以直接省略
print(x.grad) # 输出叶子节点x的梯度

(2)目标函数 f 是标量,y是非一维tensor构成的variable。那么 f 对 y 求偏导的结果也是非一维tensor构成的variable,并且维度大小和 y 一样,所以采用y.backward(t.ones(y.size())):这是把 y 当作被偏导数,t.ones(y.size())是f 对 y 求偏导的结果,y.size()是表示结果的维度和 y 保持相同设定。

import torch as t
from torch.autograd import Variable as V
x = V(t.ones(2),requires_grad=True) # 叶子节点x为非一维的tensor构成的variable
y = x**2
# f = y.sum()  # 对y采用sum()方法实现非一维tensor变成标量
y.backward(t.ones(y.size())) # f 对 y 求偏导
print(x.grad) # 输出叶子节点x的梯度

Pytorch中autograd.Variable.backward的grad_varables参数个人理解浅见相关推荐

  1. Pytorch中的variable, tensor与numpy相互转化

    来源:https://blog.csdn.net/m0_37592397/article/details/88327248 1.将numpy矩阵转换为Tensor张量 sub_ts = torch.f ...

  2. PyTorch中的Variable类型

    1 前言 今天在学习PyTorch~ 之前在莫烦的教程中看到了Variable类型的变量,后来看PyTorch的<Deep Learning with PyTorch: A 60 Minute ...

  3. Pytorch中的Variable

    Pytorch中的Variable pytorch两个基本对象:Tensor(张量)和Variable(变量) 其中,tensor不能反向传播,variable可以反向传播. Varibale包含三个 ...

  4. Pytorch中torch.nn.Softmax的dim参数含义

    自己搞了一晚上终于搞明白了,下文说的很透彻,做个记录,方便以后翻阅 Pytorch中torch.nn.Softmax的dim参数含义

  5. pytorch中tensor、backward一些总结

    目录 说明 Tensor Tensor的创建 Tensor(张量)基本数据类型与常用属性 Tensor的自动微分 设置不可积分计算 pytorch 计算图 backward一些细节 该文章解决问题如下 ...

  6. pytorch中的Variable()

    参考链接:https://www.cnblogs.com/hellcat/p/8439055.html 函数简介 torch.autograd.Variable是Autograd的核心类,它封装了Te ...

  7. Pytorch中的variable, tensor与numpy相互转化的方法

    来源:https://blog.csdn.net/pengge0433/article/details/79459679 在使用pytorch作为深度学习的框架时,经常会遇到变量variable.张量 ...

  8. pytorch中的Variable还有必要使用吗?

    pytorch1.6文档 Variable 早在在pytorch0.4已经不需要了 tensor就支持autograd了 所以见到  data=Variable(data)  这样的用法请大胆删除Va ...

  9. pytorch 中 Autograd(四)

    用Tensor训练网络很方便,但是反向传播过程需要手动实现.这对于像线性回归等较为简单的模型来说,还可以应付,但实际使用中经常出现非常复杂的网络结构,此时如果手动实现反向传播,不仅费时费力,而且容易出 ...

  10. pytorch中“_, pred = out.max(1)”语句的理解

    本人小白在入门pytorch,看mnist数据集的代码时,看到这样一句代码:"_, pred = out.max(1)",顿时就有些懵了,查阅一番资料后,才明白其中的意思. tor ...

最新文章

  1. IPv6 — 地址配置方式
  2. Go实战--也许最快的Go语言Web框架kataras/iris初识三(Redis、leveldb、BoltDB)
  3. 【ACM】最长公共子序列 - 动态规划
  4. 浅谈:国内软件公司为何无法做大做强?
  5. 如何拷贝工程_如何将premiere的工程及素材文件打包?
  6. IDEA中Alt + Insert快捷键定制生成类方法
  7. Vue 电商PC后台管理(ElementUI)
  8. 禁止按钮在一定时间内连续点击
  9. Typescript tsconfig
  10. The Top 8 Security and Risk Trends We’re Watching
  11. 思岚A1M8激光雷达-ubuntu18.04-slam建图参考
  12. P5144 【蜈蚣】
  13. 裸金属虚拟化解决方案-工业一体机(1)
  14. 【魏先生搞定Python系列】一文搞定SQLAlchemy学习与使用
  15. 山东农村商业银行计算机笔试,2021年山东农村商业银行笔试备考:计算机科目高分复习方法...
  16. 从零开始实现mini-min网易云音乐(一)
  17. 目录没有.kaggle文件夹的解决方法
  18. 《520七夕情人节表白礼物》:虚幻浪漫的爱情故事——❤520表白星空漫漫3D相册❤(HTML+CSS+JavaScript)...
  19. Pass by reference和pass by value区别举例
  20. 虚拟货币盘点:微币,Q币,苹果平台,Facebook 的F币,Google会推G币么?

热门文章

  1. 预约挂号医院管理项目----service_OSS模块—对象存储
  2. 3G产业链对我国运营商竞争关系的影响
  3. Android适配器之-----SimpleExpandableListAdapter
  4. web笔记day04
  5. 苏州大学教育学专业考研上岸经验分享
  6. Appium App UI自动化测试
  7. C++联系人管理系统2.0(用类实现)
  8. 配置apache + tomcat 并设置apache 二级域名重定向试验
  9. 1074 宇宙无敌加法器
  10. 第48章 IServiceProvider、IUrlHelper、HttpClient深入理解