官网中文文档 神经网络

文章目录

  • 核心代码
  • 卷积
  • 卷积 + 分类
  • 网络架构

核心代码

首先介绍一下 torch.nn.Conv2d(),传入参数的含义如下:

in_channels # 输入通道数
out_channels # 输出通道数
kernel_size # 卷积核尺寸,常见有 1,3,5,7
stride # 步长,默认为1
padding # 填充,默认零填充
dilation # 空洞卷积,默认为 1
groups # 组卷积,默认为 1
bias # 是否需要偏置,默认为 True

和原代码在形式上稍微有点不同,这里使用了 nn.Sequential() 模块快速进行搭建。上一层的输出直接作为下一层的输入。输入要求为 1 * 1 * 32 * 32 的四维张量。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.features = nn.Sequential(# 输入通道数为 1,输出通道数为 6,有 6 个 1*5*5 卷积核nn.Conv2d(1, 6, 5),nn.MaxPool2d(2, 2),# 输入通道数为 6,输出通道数为 16,有 16 个 6*5*5 卷积核nn.Conv2d(6, 16, 5),nn.MaxPool2d(2, 2),)self.classifier = nn.Sequential(nn.Linear(16*5*5, 120),nn.ReLU(True),nn.Linear(120, 84),nn.ReLU(True),nn.Linear(84, 10),)def forward(self, x):# 卷积x = self.features(x)# x 为 4 维张量。把 x 的尺寸调整为 [1, 16*5*5]x = x.view(x.size(0), -1)# 分类x = self.classifier(x)# x 的尺寸为 [1, 10]return x

卷积

以 1 * 1 * 32 * 32 的输入为例。第一个 1 表示 batchsize,第二个 1 表示通道数(channel),后面三个参数可视为一个立方体。conv 表示卷积,pooling 表示池化。

卷积 + 分类

网络架构

print(Net())

输出

Net((features): Sequential((0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(classifier): Sequential((0): Linear(in_features=400, out_features=120, bias=True)(1): ReLU(inplace=True)(2): Linear(in_features=120, out_features=84, bias=True)(3): ReLU(inplace=True)(4): Linear(in_features=84, out_features=10, bias=True))
)

查看参数

params = list(net.parameters())
len(params)
# conv1's .weight
params[0].size()
# conv2's .weight
params[2].size()

输出

10
torch.Size([6, 1, 5, 5])  # 表示 6 个 1 * 5 * 5 的卷积核
torch.Size([16, 6, 5, 5])  # 表示 16 个 6 * 5 * 5 的卷积核

这 10 个参数分别是

conv1.weight
conv1.bias
conv2.weight
conv2.bias
fc1.weight
fc1.bias
fc2.weight
fc2.bias
fc3.weight
fc3.bias

如果想要详细查看个参数的具体数值,这样

# 查看某个参数数值
params[0]
# 或查看所有参数数值
for param in net.parameters():print(param)

[PyTorch] 官网教程之神经网络相关推荐

  1. [pytorch] 官网教程+注释

    pytorch官网教程+注释 Classifier import torch import torchvision import torchvision.transforms as transform ...

  2. 关于pytorch官网教程中的What is torch.nn really?(三)

    文章目录 Switch to CNN `nn.Sequential` Wrapping `DataLoader` Using your GPU Closing thoughts 原文在这里. 因为MN ...

  3. pytorch官网教程:autograd代码理解

    # Autograd: 自动求导机制#PyTorch 中所有神经网络的核心是 autograd 包,torch.Tensor是这个包的核心类. #如果设置 .requires_grad 为 True, ...

  4. pytorch官网教程:cifar10代码理解

    import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot ...

  5. pytorch官网教程:tensor代码理解

    #tensor from __future__ import print_function import torch #创建一个 5x3 矩阵, 但是未初始化 x = torch.empty(5,3) ...

  6. 02/03_Pytorch安装、Conda安装Pythorch,换源、pytorch官网、验证、安装jupyter、卸载、安装、启动jupyter、配置Jupyter notebook、使用

    1.2.Pytorch安装 1.2.1.Conda安装Pythorch,换源 1 conda添加清华镜像源 查看源 conda config --show-sources 由于从官方的conda源中下 ...

  7. MNE溯源fieldtrip官网教程

    MNE溯源fieldtrip官网教程 Introduction 在本教程中,您可以找到有关如何使用最小范数估计进行源重构的信息,以重构单个主题的事件相关字段(MEG).我们将使用预处理教程中描述的数据 ...

  8. 解决pytorch官网下载慢ubuntu16.04+anaconda3(python3.6)+pytorch0.4.1+cuda9.0+cudnn7.1安装指南

    一.准备工作 1.系统环境是ubuntu 2.去anaconda官网下载anaconda3,python版本是3.5以上的就行,官网最新的是python3.7,没有关系,反正自己可以创建新环境 来选择 ...

  9. Spring Cloud学习笔记—网关Spring Cloud Gateway官网教程实操练习

    Spring Cloud学习笔记-网关Spring Cloud Gateway官网教程实操练习 1.Spring Cloud Gateway介绍 2.在Spring Tool Suite4或者IDEA ...

最新文章

  1. 用javascript实现仿163的js广告向下挤压页面的效果
  2. 英语写作中常见语法总结(二)
  3. 3D 音频技术产品介绍(1):Iosono the future of spatial audio
  4. 极详细的ECC讲解 -OOB与ECC
  5. html5+实现图片自动切换,js图片自动切换效果处理代码
  6. net 自定义表单的设计
  7. ECCV2018--点云匹配
  8. python批量读取用例的方法
  9. 系统没有安装vc9.注意是x86 32位_Windows 软件默认安装位置之谜
  10. Bitwise聘请前联邦检察官Katherine Dowling担任总法律顾问
  11. 基于Maven的S2SH(Struts2+Spring+Hibernate)框架搭建
  12. matlab直方图均衡化函数
  13. 微信小程序——推箱子小游戏
  14. vue 上传音视频文件获取时长
  15. The-Swift-2.0-Programming-Language-playground
  16. Luckysheet 导入导出 - Java后台处理和js前端实现
  17. 【论文翻译】Learning from Few Samples: A Survey 小样本学习综述
  18. 八股文之linux常用指令
  19. Jetpack Compose 深入探索系列一:Composable 函数
  20. 谷粒商城 -->「P01-P44」

热门文章

  1. 验证方式二 html标签验证码,Django标签、转义及验证码生成
  2. postman json 中写注释_Swagger界面丑、功能弱怎么破?用Postman增强下就给力了!
  3. excel中怎么把超链接的结果(图片)直接显示出来_把500张产品图片导入Excel里?用这个方法可超速完成,码住...
  4. Hihocoder 1142 三分
  5. 洛谷 P5194 [USACO05DEC]Scales S(DFS)
  6. Java 实验5 T4 检验字符串是否合法
  7. 深度学习——最优化的学习笔记
  8. Python装饰器几个有用又好玩的例子
  9. EC600 QuecPython下载脚本代码到开发板、设置开机自运行
  10. c/c++教程 - 2.4.3 this指针作用,链式编程思想,空指针访问成员函数,const修饰成员函数,常函数,常对象