提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
  • 一、pytorch里自动求导的基础概念
    • 1.1、自动求导 requires_grad=True
    • 1.2、求导 requires_grad=True是可以传递的
    • 1.3、tensor.backward() 反向计算导数
    • 1.4、tensor的梯度是可以累加
  • 二、tensor.detach()梯度截断函数
  • 三、with torch.no_grad()函数
  • 总结

前言

本来在写GAN生成手写数字这篇博客的时候,遇到了一些和梯度有关的代码没看懂,憋得自己很难受,赶紧把pytorch最基础的知识赶紧补了一下

在数学上,梯度就是由于偏导数组成的一个向量,其方向为多维曲面某点的方向导数最大值所在的方向。

一、pytorch里自动求导的基础概念

1.1、自动求导 requires_grad=True

一般来说,在tensor里需要设置requires_grad=True,这样tensor就能自动求导了。默认情况下requires_grad=False:

import torch
x = torch.tensor([[1.0,2.0],[3.0,4.0]])
print(x)

结果为:

我们将requires_grad设置为True:

import torch
x = torch.tensor([[1.0,2.0],[3.0,4.0]], requires_grad=True)
print(x)

结果为:

在这里可以看到求导开关被打开了。我们指定矩阵x可以求导

1.2、求导 requires_grad=True是可以传递的

我们设置一个函数y=x**2+2x+1,因为x是可以自动求导的,那么y也是

import torch
x = torch.tensor([[1.0,2.0],[3.0,4.0]], requires_grad=True)
y=torch.sum(x**2+2*x+1)
print(y)
print(y.requires_grad)

1.3、tensor.backward() 反向计算导数

使用backward() 函数,以本题为例,就能算出y在x上每个元素的导数,使用来查看x.grad梯度信息。梯度就是由tensor.backward()产生的

import torch
x = torch.tensor([[1.0,2.0],[3.0,4.0]], requires_grad=True)
print(x.grad)
y=torch.sum(x**2+2*x+1)
y.backward()
print(x.grad)


从这张结果图能看出,最开始直接打印x.grad的梯度信息是没有的,而是在backward()后,再使用x.grad才会看到梯度信息。

1.4、tensor的梯度是可以累加

张量的梯度是可以一直叠加的,一般都会在用之前把梯度清零(optim.zero_grad())

x = torch.tensor([[1.0,2.0],[3.0,4.0]], requires_grad=True)
print(x.grad)
y1=torch.sum(x**2+2*x+1)y1.backward()
print(x.grad)
#进行梯度叠加
y2=torch.sum(x)
y2.backward()
print(x.grad)


y2对于x的梯度是1(x求导为1),所以后续x矩阵的值都加上了1。

二、tensor.detach()梯度截断函数

张量截断的应用,我第一次是在生成对抗网络中见到的,当时是为了截断梯度,防止判别器的梯度传入生成器:

fake_image = g_net(noises.detach()).detach()

tensor.detach()梯度截断函数的解释如下:会返回一个新张量,阻断梯度传播

我们来看一个梯度截断的简单例子。
正常情况下,代码的结果应该是:

x = torch.tensor([[1.0,2.0],[3.0,4.0]], requires_grad=True)
y=torch.sum(x**2+2*x+1)
print(y)y.backward()
print(x.grad)


进行梯度截断之后:

import torch
x = torch.tensor([[1.0,2.0],[3.0,4.0]], requires_grad=True)
y=torch.sum(x**2+2*x+1)
print(y)y = y.detach()
print(y)y.backward()
print(x.grad)

代码会直接报错:

同时再次打印y,张量里的grad_fn=SumBackward0直接不见了:


三、with torch.no_grad()函数

这部分简要阐述一一下就行。
在代码里面,神经网络求梯度和求导是需要吃内存的,但是有些操作是不需要求梯度的(比如统计每一轮的损失,损失求平均这些)。为了节约内存,人们总是喜欢在这些代码前面加上with torch.no_grad()函数。下面就是个很好的例子:

# 得到生成器的损失g_optim.zero_grad()fake_output = dis(gen_img)g_loss = loss_function(fake_output,torch.ones_like(fake_output))g_loss.backward()g_optim.step()d_epoch_loss += d_lossg_epoch_loss += g_lossd_epoch_loss /= batch_countg_epoch_loss /= batch_countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('Epoch:', epoch)gen_img_plot(gen, test_input)

你可以很明显看出后面的代码是不需要求梯度的,为了节约内存所以会改成:

# 得到生成器的损失g_optim.zero_grad()fake_output = dis(gen_img)g_loss = loss_function(fake_output,torch.ones_like(fake_output))g_loss.backward()g_optim.step()with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_losswith torch.no_grad():d_epoch_loss /= batch_countg_epoch_loss /= batch_countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('Epoch:', epoch)gen_img_plot(gen, test_input)

总结

提示:这里对文章进行总结:
例如:以上就是今天要讲的内容,本文仅仅简单介绍了pandas的使用,而pandas提供了大量能使我们快速便捷地处理数据的函数和方法。

【深度学习】pytorch自动求导机制的理解 | tensor.backward() 反向传播 | tensor.detach()梯度截断函数 | with torch.no_grad()函数相关推荐

  1. PyTorch 笔记Ⅱ——PyTorch 自动求导机制

    文章目录 Autograd: 自动求导机制 张量(Tensor) 梯度 使用PyTorch计算梯度数值 Autograd 简单的自动求导 复杂的自动求导 Autograd 过程解析 扩展Autogra ...

  2. pytorch自动求导机制

    Torch.autograd 在训练神经网络时,我们最常用的算法就是反向传播(BP). 参数的更新依靠的就是loss function针对给定参数的梯度.为了计算梯度,pytorch提供了内置的求导机 ...

  3. BP神经网络分类实战项目(深度学习笔记)原创!基础篇||PCA降维、反向传播公式、梯度下降、标准化、倾斜样本处理、独热编码、Adam优化算法、权值初始化、F1-Score、ROC、模型可视化

    结果展示

  4. Pytorch Autograd (自动求导机制)

    Introduce Pytorch Autograd库 (自动求导机制) 是训练神经网络时,反向误差传播(BP)算法的核心. 本文通过logistic回归模型来介绍Pytorch的自动求导机制.首先, ...

  5. 深度学习修炼(三)——自动求导机制

    文章目录 致谢 3 自动求导机制 3.1 传播机制与计算图 3.1.1 前向传播 3.1.2 反向传播 3.2 自动求导 3.3 再来做一次 3.4 线性回归 3.4.1 回归 3.4.2 线性回归的 ...

  6. 深度学习框架 TensorFlow:张量、自动求导机制、tf.keras模块(Model、layers、losses、optimizer、metrics)、多层感知机(即多层全连接神经网络 MLP)

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 安装 TensorFlow2.CUDA10.cuDNN7.6. ...

  7. Pytorch学习(一)—— 自动求导机制

    现在对 CNN 有了一定的了解,同时在 GitHub 上找了几个 examples 来学习,对网络的搭建有了笼统地认识,但是发现有好多基础 pytorch 的知识需要补习,所以慢慢从官网 API 进行 ...

  8. 【PyTorch学习(三)】Aurograd自动求导机制总结

    ​Aurograd自动求导机制总结 PyTorch中,所有神经网络的核心是 autograd 包.autograd 包为tensor上的所有操作提供了自动求导机制.它是一个在运行时定义(define- ...

  9. PyTorch的计算图和自动求导机制

    文章目录 PyTorch的计算图和自动求导机制 自动求导机制简介 自动求导机制实例 梯度函数的使用 计算图构建的启用和禁用 总结 PyTorch的计算图和自动求导机制 自动求导机制简介 PyTorch ...

最新文章

  1. Linux期末复习题库(1)
  2. DeepI2P:基于深度分类的图像对点云配准
  3. python新建文件夹口令_3分钟学会一段Python代码脚本,轻松实现破解FTP密码口令...
  4. 能写出HTML语言框架结构,HTML语言—框架最新.ppt
  5. Debian下配置SSH服务器的方法
  6. IO模式设置,阻塞与非阻塞的比较,recv参数对性能的影响—O_NONBLOCK(open使用)、IPC_NOWAIT(msgrcv)、MSG_DONTWAIT
  7. 信息学奥赛C++语言:判断奇偶
  8. Web MVC模式实现
  9. 中国银行外币汇率查询
  10. 服务器挂机自动签到京东,解放双手,什么值得买自动签到京东自动签到给你更多时间享受生活...
  11. 虚拟机安装MAC-OS系统开发非常卡。使用beamoff.zip优化教程
  12. arcmap 坡降工具_ArcHydro_Toolbar_In_Arcmap Arcgis水文分析工具具体的操作 - 下载 - 搜珍网...
  13. 小米系列手机开源代码
  14. 计算机网络中NTFS概念及功能,什么是NTFS
  15. CTF.show:misc入门24-49
  16. 2015级计算机科学与技术2班班级博客大全
  17. 医院云PACS管理系统源码
  18. 【正点原子Linux连载】第二十五章 语音识别项目 摘自【正点原子】I.MX6U嵌入式Qt开发指南V1.0.2
  19. Web全栈工程师技能树梳理
  20. 此nvidia驱动程序与此windows版本不兼容,此图形的驱动程序无法找到兼容的驱动硬件

热门文章

  1. Python numpy.corrcoef函数方法的使用
  2. Linux·VFS虚拟文件系统
  3. 分布式是大数据处理的万能药?
  4. [flow] 1.Spyglass CDC
  5. python 二维码制作
  6. elastic官网文档翻译来.1
  7. Ubuntu 电脑下插入移动硬盘,显示不能挂载该硬盘
  8. WaveDrom的使用
  9. 程序人生 - 数字化人民币的无网络支付是如何实现的?
  10. 【华为中央硬件部】最新社会招聘公告!