【深度学习】pytorch自动求导机制的理解 | tensor.backward() 反向传播 | tensor.detach()梯度截断函数 | with torch.no_grad()函数
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
文章目录
- 前言
- 一、pytorch里自动求导的基础概念
- 1.1、自动求导 requires_grad=True
- 1.2、求导 requires_grad=True是可以传递的
- 1.3、tensor.backward() 反向计算导数
- 1.4、tensor的梯度是可以累加
- 二、tensor.detach()梯度截断函数
- 三、with torch.no_grad()函数
- 总结
前言
本来在写GAN生成手写数字这篇博客的时候,遇到了一些和梯度有关的代码没看懂,憋得自己很难受,赶紧把pytorch最基础的知识赶紧补了一下
在数学上,梯度就是由于偏导数组成的一个向量,其方向为多维曲面某点的方向导数最大值所在的方向。
一、pytorch里自动求导的基础概念
1.1、自动求导 requires_grad=True
一般来说,在tensor里需要设置requires_grad=True,这样tensor就能自动求导了。默认情况下requires_grad=False:
import torch
x = torch.tensor([[1.0,2.0],[3.0,4.0]])
print(x)
结果为:
我们将requires_grad设置为True:
import torch
x = torch.tensor([[1.0,2.0],[3.0,4.0]], requires_grad=True)
print(x)
结果为:
在这里可以看到求导开关被打开了。我们指定矩阵x可以求导
1.2、求导 requires_grad=True是可以传递的
我们设置一个函数y=x**2+2x+1,因为x是可以自动求导的,那么y也是
import torch
x = torch.tensor([[1.0,2.0],[3.0,4.0]], requires_grad=True)
y=torch.sum(x**2+2*x+1)
print(y)
print(y.requires_grad)
1.3、tensor.backward() 反向计算导数
使用backward() 函数,以本题为例,就能算出y在x上每个元素的导数,使用来查看x.grad梯度信息。梯度就是由tensor.backward()产生的
import torch
x = torch.tensor([[1.0,2.0],[3.0,4.0]], requires_grad=True)
print(x.grad)
y=torch.sum(x**2+2*x+1)
y.backward()
print(x.grad)
从这张结果图能看出,最开始直接打印x.grad的梯度信息是没有的,而是在backward()后,再使用x.grad才会看到梯度信息。
1.4、tensor的梯度是可以累加
张量的梯度是可以一直叠加的,一般都会在用之前把梯度清零(optim.zero_grad())
x = torch.tensor([[1.0,2.0],[3.0,4.0]], requires_grad=True)
print(x.grad)
y1=torch.sum(x**2+2*x+1)y1.backward()
print(x.grad)
#进行梯度叠加
y2=torch.sum(x)
y2.backward()
print(x.grad)
y2对于x的梯度是1(x求导为1),所以后续x矩阵的值都加上了1。
二、tensor.detach()梯度截断函数
张量截断的应用,我第一次是在生成对抗网络中见到的,当时是为了截断梯度,防止判别器的梯度传入生成器:
fake_image = g_net(noises.detach()).detach()
tensor.detach()梯度截断函数的解释如下:会返回一个新张量,阻断梯度传播
我们来看一个梯度截断的简单例子。
正常情况下,代码的结果应该是:
x = torch.tensor([[1.0,2.0],[3.0,4.0]], requires_grad=True)
y=torch.sum(x**2+2*x+1)
print(y)y.backward()
print(x.grad)
进行梯度截断之后:
import torch
x = torch.tensor([[1.0,2.0],[3.0,4.0]], requires_grad=True)
y=torch.sum(x**2+2*x+1)
print(y)y = y.detach()
print(y)y.backward()
print(x.grad)
代码会直接报错:
同时再次打印y,张量里的grad_fn=SumBackward0直接不见了:
三、with torch.no_grad()函数
这部分简要阐述一一下就行。
在代码里面,神经网络求梯度和求导是需要吃内存的,但是有些操作是不需要求梯度的(比如统计每一轮的损失,损失求平均这些)。为了节约内存,人们总是喜欢在这些代码前面加上with torch.no_grad()函数。下面就是个很好的例子:
# 得到生成器的损失g_optim.zero_grad()fake_output = dis(gen_img)g_loss = loss_function(fake_output,torch.ones_like(fake_output))g_loss.backward()g_optim.step()d_epoch_loss += d_lossg_epoch_loss += g_lossd_epoch_loss /= batch_countg_epoch_loss /= batch_countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('Epoch:', epoch)gen_img_plot(gen, test_input)
你可以很明显看出后面的代码是不需要求梯度的,为了节约内存所以会改成:
# 得到生成器的损失g_optim.zero_grad()fake_output = dis(gen_img)g_loss = loss_function(fake_output,torch.ones_like(fake_output))g_loss.backward()g_optim.step()with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_losswith torch.no_grad():d_epoch_loss /= batch_countg_epoch_loss /= batch_countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('Epoch:', epoch)gen_img_plot(gen, test_input)
总结
提示:这里对文章进行总结:
例如:以上就是今天要讲的内容,本文仅仅简单介绍了pandas的使用,而pandas提供了大量能使我们快速便捷地处理数据的函数和方法。
【深度学习】pytorch自动求导机制的理解 | tensor.backward() 反向传播 | tensor.detach()梯度截断函数 | with torch.no_grad()函数相关推荐
- PyTorch 笔记Ⅱ——PyTorch 自动求导机制
文章目录 Autograd: 自动求导机制 张量(Tensor) 梯度 使用PyTorch计算梯度数值 Autograd 简单的自动求导 复杂的自动求导 Autograd 过程解析 扩展Autogra ...
- pytorch自动求导机制
Torch.autograd 在训练神经网络时,我们最常用的算法就是反向传播(BP). 参数的更新依靠的就是loss function针对给定参数的梯度.为了计算梯度,pytorch提供了内置的求导机 ...
- BP神经网络分类实战项目(深度学习笔记)原创!基础篇||PCA降维、反向传播公式、梯度下降、标准化、倾斜样本处理、独热编码、Adam优化算法、权值初始化、F1-Score、ROC、模型可视化
结果展示
- Pytorch Autograd (自动求导机制)
Introduce Pytorch Autograd库 (自动求导机制) 是训练神经网络时,反向误差传播(BP)算法的核心. 本文通过logistic回归模型来介绍Pytorch的自动求导机制.首先, ...
- 深度学习修炼(三)——自动求导机制
文章目录 致谢 3 自动求导机制 3.1 传播机制与计算图 3.1.1 前向传播 3.1.2 反向传播 3.2 自动求导 3.3 再来做一次 3.4 线性回归 3.4.1 回归 3.4.2 线性回归的 ...
- 深度学习框架 TensorFlow:张量、自动求导机制、tf.keras模块(Model、layers、losses、optimizer、metrics)、多层感知机(即多层全连接神经网络 MLP)
日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 安装 TensorFlow2.CUDA10.cuDNN7.6. ...
- Pytorch学习(一)—— 自动求导机制
现在对 CNN 有了一定的了解,同时在 GitHub 上找了几个 examples 来学习,对网络的搭建有了笼统地认识,但是发现有好多基础 pytorch 的知识需要补习,所以慢慢从官网 API 进行 ...
- 【PyTorch学习(三)】Aurograd自动求导机制总结
Aurograd自动求导机制总结 PyTorch中,所有神经网络的核心是 autograd 包.autograd 包为tensor上的所有操作提供了自动求导机制.它是一个在运行时定义(define- ...
- PyTorch的计算图和自动求导机制
文章目录 PyTorch的计算图和自动求导机制 自动求导机制简介 自动求导机制实例 梯度函数的使用 计算图构建的启用和禁用 总结 PyTorch的计算图和自动求导机制 自动求导机制简介 PyTorch ...
最新文章
- Linux期末复习题库(1)
- DeepI2P:基于深度分类的图像对点云配准
- python新建文件夹口令_3分钟学会一段Python代码脚本,轻松实现破解FTP密码口令...
- 能写出HTML语言框架结构,HTML语言—框架最新.ppt
- Debian下配置SSH服务器的方法
- IO模式设置,阻塞与非阻塞的比较,recv参数对性能的影响—O_NONBLOCK(open使用)、IPC_NOWAIT(msgrcv)、MSG_DONTWAIT
- 信息学奥赛C++语言:判断奇偶
- Web MVC模式实现
- 中国银行外币汇率查询
- 服务器挂机自动签到京东,解放双手,什么值得买自动签到京东自动签到给你更多时间享受生活...
- 虚拟机安装MAC-OS系统开发非常卡。使用beamoff.zip优化教程
- arcmap 坡降工具_ArcHydro_Toolbar_In_Arcmap Arcgis水文分析工具具体的操作 - 下载 - 搜珍网...
- 小米系列手机开源代码
- 计算机网络中NTFS概念及功能,什么是NTFS
- CTF.show:misc入门24-49
- 2015级计算机科学与技术2班班级博客大全
- 医院云PACS管理系统源码
- 【正点原子Linux连载】第二十五章 语音识别项目 摘自【正点原子】I.MX6U嵌入式Qt开发指南V1.0.2
- Web全栈工程师技能树梳理
- 此nvidia驱动程序与此windows版本不兼容,此图形的驱动程序无法找到兼容的驱动硬件