一、Lenet5网络结构

二、model.py

import torch.nn as nn
import torch.nn.functional as F# 定义类(类名LeNet),nn.Moudle是父类
class LeNet(nn.Module):# 初始化函数【实现在搭建网络中所需要的一些网络层结构】def __init__(self):super(LeNet, self).__init__()  # 【Super继承父类的构造函数】Super函数解决在多层继承中,调用父类方法可能会出现的问题(涉及到多继承都会使用Super)self.conv1 = nn.Conv2d(3, 16, 5) #卷积:(输入特征的深度【输入是RGB彩色图片】,输出特征的深度,即卷积核的个数【16个卷积核(使用几个卷积核就会生成多少维的特征矩阵)】,卷积核尺寸【5*5】)self.pool1 = nn.MaxPool2d(2, 2) #下采样:Maxpool2d方法self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32 * 5 * 5, 120) #全连接层的输入是一维向量,需要将得到的特征矩阵展平为一维向量。(120是LeNet的结点个数)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)# forward函数定义正向传播的过程def forward(self, x): #x是输入数据(待处理的数据,没处理前是图片)#经过卷积层1得到的输出经过relu激活函数x = F.relu(self.conv1(x))  # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x)  # output(16, 14, 14)x = F.relu(self.conv2(x))  # output(32, 10, 10)x = self.pool2(x)  # output(32, 5, 5)#view函数将特征矩阵展出一维向量(这里的-1,表示不确定展开几行)x = x.view(-1, 32 * 5 * 5)  # output(32*5*5)x = F.relu(self.fc1(x))  # output(120)x = F.relu(self.fc2(x))  # output(84)x = self.fc3(x)  # output(10)return x

(1)Conv2d方法的官方说明文档链接:

https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html?highlight=conv2d#torch.nn.Conv2d

其中【输入图片大小为W*W】,【Filter大小F*F】,【步长S】,【padding的像素数为P】。

(根据上述公式,Lenet网络conv1层的输入:input(3,32,32),则一层卷积后的尺寸大小为:(32-5+2*0)/1 +1)=28,output(16,28,28))

(3)Pytorch Tensor的通道排序:[batch,channel,height,width]

(4)下采样层Pool池化层,MaxPool2d方法:

池化层只改变特征矩阵的高与宽,不会影响矩阵的深度。

(根据(2)中公式,pool1下采样层输入:input(16,28,28),则池化后的尺寸大小为:(28-2+2*0)/2 +1)=14,output(16,14,14))

(5)View()函数:

view中一个参数定为-1,代表动态调整这个维度上的元素个数,以保证元素的总数不变。

(6)调试,查看终端网络的结构:

# 测试
import torchintput1 = torch.rand([32, 3, 32, 32])  # batch数量:32,channel深度:3,heigh高:32,width宽:32
model = LeNet()  # 实例化模型
print(model)
output = model(intput1)

三、train.py

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transformsdef main():# Compose将所使用的预处理的方法打包成一个整体transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 50000张训练图片# 第一次使用时要将download设置为True才会自动去下载数据集# transform是对图像进行预处理的函数train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)# 将训练集导入进来,分成批次# 【batch_size】:每次随机拿出36张图片进行训练# 【shuffle】:数据集是否打乱,True为打乱# 【num_workers】:载入数据的线程数train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,shuffle=True, num_workers=0)# 10000张验证图片# 第一次使用时要将download设置为True才会自动去下载数据集val_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,shuffle=False, num_workers=0)val_data_iter = iter(val_loader)val_image, val_label = val_data_iter.next()# classes = ('plane', 'car', 'bird', 'cat',#            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = LeNet()loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)for epoch in range(5):  # loop over the dataset multiple timesrunning_loss = 0.0for step, data in enumerate(train_loader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = loss_function(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if step % 500 == 499:    # print every 500 mini-batcheswith torch.no_grad():outputs = net(val_image)  # [batch, 10]predict_y = torch.max(outputs, dim=1)[1]accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')save_path = './Lenet.pth'torch.save(net.state_dict(), save_path)if __name__ == '__main__':main()

(1)transforms.ToTensor:

(2) transforms.Normalize标准化:

最后,仅小白学习深度学习的笔记,如有错误,请多多指教。该笔记源于观看B站UP主,霹雳吧啦Wz。相关网址:https://www.bilibili.com/video/BV187411T7Ye/?spm_id_from=333.788&vd_source=a26986127371efd81983577687e46aad

目录

一、Lenet5网络结构

二、model.py

(1)Conv2d方法的官方说明文档链接:

(3)Pytorch Tensor的通道排序:[batch,channel,height,width]

(4)下采样层Pool池化层,MaxPool2d方法:

(5)View()函数:

(6)调试,查看终端网络的结构:

三、train.py


pytorch官方代码demo(Lenet)解析笔记【B站UP主“霹雳吧啦Wz”视频观看】相关推荐

  1. pytorch常用代码

    20211228 https://mp.weixin.qq.com/s/4breleAhCh6_9tvMK3WDaw 常用代码段 本文代码基于 PyTorch 1.x 版本,需要用到以下包: impo ...

  2. PyTorch常用代码段整理合集,建议收藏!

    点击上方,选择星标或置顶,每天给你送干! 阅读大概需要12分钟 跟随小博主,每天进步一丢丢 张皓:南京大学计算机系机器学习与数据挖掘所(LAMDA)硕士生,研究方向为计算机视觉和机器学习,特别是视觉识 ...

  3. PyTorch 常用代码段示例整理

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 众所周知,程序猿在写代码时通常会在网上搜索大量资料,其中大部分是代 ...

  4. PyTorch常用代码段整理合集

    本文代码基于PyTorch 1.0版本,需要用到以下包 import collections import os import shutil import tqdmimport numpy as np ...

  5. pytorch list转tensor_点赞收藏:PyTorch常用代码段整理合集

    机器之心转载 来源:知乎 作者:张皓 众所周知,程序猿在写代码时通常会在网上搜索大量资料,其中大部分是代码段.然而,这项工作常常令人心累身疲,耗费大量时间.所以,今天小编转载了知乎上的一篇文章,介绍了 ...

  6. YOLOV5dataset.py代码注释与解析

    YOLOv5代码注释版更新啦,注释的是最近的2021.07.14的版本,且注释更全 github: https://github.com/Laughing-q/yolov5_annotations Y ...

  7. PyTorch 常用代码段整理合集

    PyTorch 常用代码段整理合集 来源:知乎 作者:张皓 众所周知,程序猿在写代码时通常会在网上搜索大量资料,其中大部分是代码段.然而,这项工作常常令人心累身疲,耗费大量时间.所以,今天小编转载了知 ...

  8. Pytorch 常用代码

    Pytorch 常用代码 本文代码基于PyTorch 1.0版本,需要用到以下包 import collections import os import shutil import tqdmimpor ...

  9. 语义分割|学习记录(5)Pytorch官方实现的FCN网络结构

    文章目录 前言 FCN网络结构 参考资料 前言 Pytorch官方实现的FCN和当年论文的结构图稍有不同,因为现在有了更多的backbone的选择,而且应用了膨胀卷积技术. FCN网络结构 下图是Py ...

最新文章

  1. 简述机器指令与微指令之间的关系_自考《计算机组成原理》模拟试题(一)
  2. python的try exception捕获异常
  3. 泛函编程(19)-泛函库设计-Parallelism In Action
  4. 用云服务器实现janus之web端与web通话!
  5. Centos 开机无法输入密码的问题
  6. STL 之随机访问迭代器
  7. 《城市建筑美学》读书笔记
  8. (原創) 為什麼VB有Dim obj As Foo = New Foo()這種語法? (初級) (Visual BASIC)
  9. .NET文档生成工具ADB[更新至2.3]
  10. 架构师必备最全SQL优化方案
  11. 如何提高页面性能并充分利用主机
  12. ​最适合女生的10个副业(上篇),只要你有执行力,实现财富自由很简单!
  13. 985、211外,你还应该清楚这些高校联盟!
  14. 又一次移植最新lvgl8到esp32的踩坑记录
  15. 产品经理——java学习之路
  16. CV10 图像模糊(均值、高斯、中值、双边滤波)
  17. 优秀好用的Mac平台上的DRM音频转换辅助工具
  18. 人生/活着有什么意义?人的一生到底该追求什么
  19. html金山打字源码,c#实现简单金山打字小游戏(源码)
  20. W32Dasm反汇编工具使用教程

热门文章

  1. Python发送消息到手机(基于IFTTT)
  2. 读书笔记——《educated 》和《atomic habits》
  3. iso镜像+kickstart实现linux系统半自动化安装
  4. YARN源码分析—AM-RM通信协议,获得资源
  5. 利用ADS中的Design-Guide进行微带线单枝节匹配
  6. 2023年温州医科大学眼科学考研考情与难度、参考书及上岸前辈经验
  7. MobData荣登创业邦企业服务创新成长50强
  8. A. Giga Tower
  9. 各种一维条形码介绍(锐浪报表)
  10. Credit Card