pytorch官网教程:autograd代码理解
# Autograd: 自动求导机制#PyTorch 中所有神经网络的核心是 autograd 包,torch.Tensor是这个包的核心类。
#如果设置 .requires_grad 为 True,那么将会追踪所有对于该张量的操作
import torch
x = torch.ones(2,2,requires_grad=True) #创建一个张量并设置 requires_grad=True 用来追踪他的计算历史
print(x) #tensor([[1., 1.], [1., 1.]], requires_grad=True)#对张量进行操作
y = x + 2
print(y) #tensor([[3., 3.], [3., 3.]], grad_fn=<AddBackward0>)print(y.grad_fn) #<AddBackward0 object at 0x000001D6F5100AC8>, grad_fn已经被自动生成了#对y进行一个操作
z = y * y * 3
out = z.mean()
print(z,out)# tensor([[27., 27.], [27., 27.]], grad_fn=<MulBackward0>) tensor(27., grad_fn=<MeanBackward0>)# .requires_grad_( ... ) 可以改变现有张量的 requires_grad属性。
# 如果没有指定的话,默认输入的flag是 False。
a = torch.randn(2,2)
a = ((a * 3) / (a - 1))
print(a.requires_grad) #False
a.requires_grad_(True) #主义这里有“_”
print(a.requires_grad) #True
b = (a * a).sum()
print(b.grad_fn) #<SumBackward0 object at 0x000002004F7D5608>#梯度
# 反向传播 因为 out是一个纯量(scalar),
# out.backward() 等于out.backward(torch.tensor(1))。
out.backward()
print(x.grad)#下面是一个雅可比向量积的例子
x = torch.randn(3, requires_grad=True)
y = x * 2
while y.data.norm() < 1000: #y.data.norm()将张量y中的每个元素平方,然后对它们求和,最后得到结果和的平方根。这些运算计算所谓的L2范数y = y * 2
print(y) #,而是tensor([ -561.3829, -1019.7476, 191.2780], grad_fn=<MulBackward0>)# y不再是个标量,torch.autograd无法直接计算出完整的雅可比行列,
# 但是如果我们只想要雅可比向量积,只需将向量作为参数传入backward
gradients = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
y.backward(gradients)
print(x.grad)# 如果.requires_grad=True但是你又不希望进行autograd的计算,
# 那么可以将变量包裹在 with torch.no_grad()中print(x.requires_grad) # True
print((x ** 2).requires_grad) # Truewith torch.no_grad():print((x ** 2).requires_grad) #False
pytorch官网教程:autograd代码理解相关推荐
- [pytorch] 官网教程+注释
pytorch官网教程+注释 Classifier import torch import torchvision import torchvision.transforms as transform ...
- [PyTorch] 官网教程之神经网络
官网中文文档 神经网络 文章目录 核心代码 卷积 卷积 + 分类 网络架构 核心代码 首先介绍一下 torch.nn.Conv2d(),传入参数的含义如下: in_channels # 输入通道数 o ...
- pytorch官网教程:cifar10代码理解
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot ...
- pytorch官网教程:tensor代码理解
#tensor from __future__ import print_function import torch #创建一个 5x3 矩阵, 但是未初始化 x = torch.empty(5,3) ...
- 关于pytorch官网教程中的What is torch.nn really?(三)
文章目录 Switch to CNN `nn.Sequential` Wrapping `DataLoader` Using your GPU Closing thoughts 原文在这里. 因为MN ...
- pytorch实现:Resnet模型识别花朵数据集(参考pytorch官网代码)
pytorch实现:Resnet模型识别花朵数据集 一.pytorch实现:Resnet模型识别花朵数据集 1.1 训练模型 1.2 图像预测和可视化 1.3 对新来的数据进行处理和展示 一.pyto ...
- Spring Cloud学习笔记—网关Spring Cloud Gateway官网教程实操练习
Spring Cloud学习笔记-网关Spring Cloud Gateway官网教程实操练习 1.Spring Cloud Gateway介绍 2.在Spring Tool Suite4或者IDEA ...
- Gem5模拟器,详解官网教程Event-driven programming(五)
目录 一.解释一下gem5中的event-driven? 二.Creating a simple event callback (1)定义一个新的 C++ 类,并继承自 SimObject 抽象基类 ...
- java官网教程(基础篇)—— 基础的Java类 —— 基础 I / O
目录 基本 Java 类 基础 I/O I/O流 字节流 字符流 缓冲流 扫描和格式化 扫描 格式化 从命令行中进行IO操作 数据流 对象流 文件 I/O(采用 NIO.2) 什么是路径? Path类 ...
最新文章
- 2017全球中国锂电池市场趋势概述
- Android中BaseAdapter使用总结(imooc笔记)
- c++字符加密_linux安全Linux下RAR加密解密
- windows server 2012 FTP 服务器 / 访问网络共享盘
- 极域电子教室软件怎么脱离控制_全自动点胶机的控制系统都有哪些?
- 域传送漏洞(vulhub)
- 自动切换电脑或手机版(php aspx),ASP程序自动判断是电脑或手机访问网站。
- Java当中TreeMap用法
- 写作有困扰?不知道用什么词?不知道怎么解释不一致的结果?这个网站来帮你。
- 内存类型范围寄存器 (MTRR)
- 删除用户账号的命令 mysql_【Mysql】常用指令之——用户操作(创建,授权,修改,删除)...
- SQLAlchemy create table
- 程序员的自我修养笔记3 内存管理
- 苹果手机上网速度慢_手机信号明明满格却上不去网?4招帮你搞定它!
- 计算机上硬盘驱动器,什么是计算机硬盘驱动器?它有什么作用?如何维护?
- VIN码识别技术,扫一扫自动获取车架号
- oracle含有特殊字符查询,Oracle特殊字符查询
- 数显之家快讯:「SHIO世硕心语」shio是什么牌子?
- 如何实现字体沟边与发光特效?
- 【Prometheus】Alertmanager告警全方位讲解