文章目录

  • model
  • predict
  • train

教程连接
数据集 cifa10

batch 一批数据集的个数
channel 深度,灰度图为1,rgb为3
height,width 宽高,32


在pytorch官网查文档

model

#///导包时用的是torch
import torch.nn as nn
import torch.nn.functional as F#///流程:建一个类,这个类要继承nn.Model这个父类,这个类中要实现两个方法,一个是初始化函数(实现在网络中要使用的网络层结构),另一个是forward函数中定义正向传播的过程,当我们实例化这个类之后,将参数传递到这个实例中,就会按照forward的顺序进行正向传播过程
class LeNet(nn.Module):def __init__(self)://初始化函数super(LeNet, self).__init__()//继承,解决多继承,调用基类的构造函数self.conv1 = nn.Conv2d(3, 16, 5)#//第一个卷积层(按住alt+鼠标左键查看函数定义:“采用2D卷积对输入进行处理Applies a 2D //convolution over an input signal composed of several input planes.”)参数:#     self,#     in_channels: int,  输入特征矩阵的输入(RGB为3)#     out_channels: int,(卷积核的个数=输出为深度维的特征矩阵)#     kernel_size: _size_2_t,(卷积核的大小)#     stride: _size_2_t = 1,(步距)#     padding: _size_2_t = 0,(补齐)#     dilation: _size_2_t = 1,(暂时用不到)#     groups: int = 1,(暂时用不到)#     bias: bool = True,(偏置)#     padding_mode: str = 'zeros'  # TODO: refine this type# ):self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)#//10是因为10分类任务,主观修改的def forward(self, x):#x代表输入的输入,形式是[batch,channel,height,width]x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x)            # output(16, 14, 14)x = F.relu(self.conv2(x))    # output(32, 10, 10)x = self.pool2(x)            # output(32, 5, 5)#与全连接层进行拼接,需要展开成为一维向量x = x.view(-1, 32*5*5)       # output(32*5*5)#-1代表第一个维度是自动推理的,(batch)x = F.relu(self.fc1(x))      # output(120)x = F.relu(self.fc2(x))      # output(84)x = self.fc3(x)              # output(10)return x#测试
# import torch
#input1=torch.rand([32,3,32,32])
#model=LeNet() #实例化模型
#print(model)
#output=莫得了(input1)   #输入

predict

#调用模型权重进行预测

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNettransform = transforms.Compose([transforms.Resize((32, 32)),#缩放transforms.ToTensor(),#转化成tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#标准化classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))#载入权重文件im = Image.open('1.jpg')#用Image模块载入图像
im = transform(im)  # [C, H, W]#图像的shape一般都是【】,就要转化成pytorch tensor的格式
im = torch.unsqueeze(im, dim=0)  # [N, C, H, W]#最前面增加一个batch的新维度with torch.no_grad():outputs = net(im)predict = torch.max(outputs, dim=1)[1].data.numpy()
print(classes[int(predict)])

train

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
# 导包transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#Compose将一些预处理的方法打包到一起1.toTensor,把PIL或numoy改成tensor2.Normalize是标准化,使用均值与标准差来标准化tensor# 50000张训练图片
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)#DOWNLOAD改成ture就下载,root表示数据集下载到什么地方,train维ture就会导入训练集的样本,transform就是预处理的函数,在上面train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,shuffle=False, num_workers=0)#把数据集导入进来并分成一批一批的,shuffle代表是否随机挑出来,num——workers代表线程数,windows下只能为0# 10000张验证图片,与上面相同方法
val_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,shuffle=False, num_workers=0)
val_data_iter = iter(val_loader)   #将val_loader转化为迭代器,之后通过next的方法就能获取一批数据,就包括测试的图像与图像对应的标签值
val_image, val_label = val_data_iter.next()# classes = ('plane', 'car', 'bird', 'cat',
#            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#将10个标签导入,是元组类型,值不能改index0就对应表签飞机
#用官方的imshow可以看图片(numpy与matplotlib包)29:23
net = LeNet()#实例化模型
loss_function = nn.CrossEntropyLoss()#定义损失函数(有softmax)
optimizer = optim.Adam(net.parameters(), lr=0.001)#优化器用的adam,传入的是要训练的参数,lr是学习率for epoch in range(5):  # loop over the dataset multiple times
#要将训练集训练多少轮running_loss = 0.0#用来累加训练中的损失for step, data in enumerate(train_loader, start=0):#遍历训练集样本,enumrate不仅返回每一批的数据data,还返回这一批data对应的步数,start=0从0开始# get the inputs; data is a list of [inputs, labels]inputs, labels = data#分成图像与标签# zero the parameter gradientsoptimizer.zero_grad()#将历史损失梯度清零,不清零可能对计算的历史梯度进行累加(多次进行小batch的训练)# forward + backward + optimizeoutputs = net(inputs)#输入网络得到输出loss = loss_function(outputs, labels)#根据之前定义的损失函数、输出、真实标签求得损失loss.backward()#反向穿播optimizer.step()#参数更新# print statisticsrunning_loss += loss.item()#将loss累加if step % 500 == 499:    # print every 500 mini-batcheswith torch.no_grad():#在接下来的计算过程中不要去计算每个节点的误差损失梯度,不用的话,测试过程中也会计算1.会占用更多算力2.占用更多内存资源outputs = net(val_image)  # [batch, 10]predict_y = torch.max(outputs, dim=1)[1]#寻找输出的最大index在什么位置:网络预测最可能是哪一类,dim是第几个维度,【1】是只需要知道是哪一类就行了accuracy = (predict_y == val_label).sum().item() / val_label.size(0)#//与类别进行比较(是个tensor(数值),sum求和是个数值),正确率print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))#epoch代表迭代到多少轮,step,某一轮到多少步,平均训练误差,准确率running_loss = 0.0#清零并进行下一轮print('Finished Training')save_path = './Lenet.pth'#模型进行保存
torch.save(net.state_dict(), save_path)#保存的是参数

lenet pytorch 官方demo学习笔记相关推荐

  1. Dynamic Quantization PyTorch官方教程学习笔记

    诸神缄默不语-个人CSDN博文目录 本文是PyTorch的教程Dynamic Quantization - PyTorch Tutorials 1.11.0+cu102 documentation的学 ...

  2. RPG游戏Demo学习笔记一

    导航 RPG游戏Demo学习笔记一 UE笔记 -- 一个简单的AI_weixin_52918492的博客-CSDN博客 目录 一.准备资源 二.基础功能 生命值与能量值 创建Widget Bluepr ...

  3. 【从零开始的大数据学习】Flink官方教程学习笔记(一)

    Flink官方教程学习笔记 学习资源 基础Scala语法 Scala数据结构专题 声明变量 代码块 函数(function) 方法(methods) Traits (接口) class(类) tupl ...

  4. CORE-ESP32C3|eink|日期格式化|IO11解锁|墨水屏操作库|SNTP自动同步|局部刷新|全局刷新|LuatOS-SOC接口|官方demo|学习(12):简单日期显示

    目录 基础资料 探讨重点 参考博文: 实现功能 硬件准备 软件版本 日志及soc下载工具 软件使用 接线示意图 IO11解锁教程可参考: 功能1:基于墨水屏的日期显示: 初始化: 日期显示: 功能2: ...

  5. CORE-ESP32C3|eink|墨水屏日历|天气API|LuatOS公共接口|气象要素数据V1|collectgarbage|LuatOS-SOC接口|官方demo|学习(13):墨水屏动态日历

    目录 参考博文 项目官方地址 显示效果: 硬件准备 软件版本 日志及soc下载工具 软件使用 接线示意图 硬件接线 一.Elink驱动管脚适配 二.天气信息获取 API使用方式: 接口格式(注意需不需 ...

  6. 合宙Air105|摄像头|capture|SPI|Serial 串口|TFTLCD|Micro SD卡|GC032A|USB转TTL|官方demo|学习(2-1):摄像头camera-capture

    目录 基础资料 探讨重点 实现功能 硬件准备 软件版本 软件使用 接线示意图 功能1:捕捉图片并存入SPI接口外置SD卡 lcd初始化 摄像头初始化 指定capture按钮 SD卡初始化 图片存储 功 ...

  7. CORE-ESP32C3|eink|墨水屏日历+时间日期+温度显示|I2C软件模拟| LuatOS-SOC接口|官方demo|学习(14):墨水屏动态日历+oled日期显示+ AHT10测温模组

    目录 参考博文 源于网友oled+eink+aht10项目 源代码修改及复现说明 主要修改 显示效果 ​编辑硬件准备 软件版本 日志及soc下载工具 软件使用 接线说明 天气显示屏 硬件接线 温度采集 ...

  8. 学习Pytorch官方Demo——Lenet,以及遇到的问题

    文章目录 1.官方Demo的项目目录 2.模型 3.训练 4.预测 5.遇到的问题 1.官方Demo的项目目录 2.模型 代码: import torch.nn as nn import torch. ...

  9. pytorch官方demo实现图像分类(LeNet)

    深度学习学习笔记 导师博客:https://blog.csdn.net/qq_37541097/article/details/103482003 导师github:https://github.co ...

最新文章

  1. R语言均匀分布函数uniform Distribution(dunif, punif, qunif runif)实战
  2. chrome取消安全模式
  3. 架构师之路 — 部署架构 — 集群部署
  4. WINCE6.0下开始菜单的“挂起(suspend)”是否可见及阻止系统进入睡眠模式
  5. CentOs7下lnmp环境安装
  6. oracle 中时间类型 date 与 long 互转
  7. 英语自动提取高频词_斑马英语提分营免费体验课
  8. mysql: union / union all / 自定义函数用法详解
  9. uni-app自定义tabBar;uni-app小程序自定义tabBar;uni-app小程序修改中间tabBar导航栏大小;uni-app中间导航栏凸起;uni-app修改底部导航栏
  10. 技术支持诈骗手段翻新:借勒索软件类锁屏界面恐吓用户
  11. python gephi可视化 金庸人物关系图
  12. Linux磁盘配额教程,磁盘配额设置及使用
  13. Blender建模练习:人物模型多边形建模流程图解(二形体调整篇)
  14. 游戏因为音效而变得触动人心
  15. ORACLE数据库日期更新到时分秒格式
  16. 基于深度学习的图标型验证码识别系统(包含完整代码、界面)
  17. 为云主机申请配置免费的域名和证书
  18. 快手极速版(目前稳定奔跑中~)别问能不能跑了~
  19. 亲爱的老狼-display的使用
  20. (第一章) UI---PS基础和选框工具

热门文章

  1. RTT设备与驱动之PIN设备
  2. 1、【设计模式】组合模式
  3. Arch Linux中安装Anaconda
  4. 【javascript 对日期的扩展 Format\addDays】
  5. BZOJ 1059 - 二分图匹配
  6. Android WebRTC视频旋转问题
  7. vmware设置centos虚拟机nat联网(转)
  8. 什么时候加上android.intent.category.DEFAULT和LAUNCHER
  9. php文件改写nodejs,node.js – 提供PHP文件的nodejs,expressjs
  10. java io 过滤数据,Java IO文件后缀名过滤总结