图像识别:CIFAR10图形识别

1.CIFAR10数据集共有60000张彩色图像,这些图像式32*32*3,分为10个类,每个类6000张

2.这里面有50000张用于训练,构成5个训练批,每一批10000张图;另外10000张用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。

3.一个训练批中的各类图像并不一定数量相同,总的来看训练集,每一类都有5000张图片

代码如下:与官网代码不一致

# 导入包
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets

 # 设置transforms
transform = transforms.Compose([transforms.ToTensor(), # numpy -> Tensortransforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))  # 归一化 ,范围[-1,1]
])

 # 下载训练数据集
# 训练集
trainset = datasets.CIFAR10(root='./CIFAR10',train=True,download=True,transform=transform)
# 测试集
testset = datasets.CIFAR10(root='./CIFAR10',train=False,download=True,transform=transform)

 出现如下图结果数据集下载成功

# 批量获取数据
from torch.utils.data.dataloader import DataLoaderBATCH_SIZE = 32train_loader = DataLoader(trainset,batch_size=BATCH_SIZE,shuffle=True,num_workers=8,pin_memory=True)test_loader = DataLoader(testset,batch_size=BATCH_SIZE,shuffle=True,num_workers=8,pin_memory=True)

 注意:其中BATCH_SIZE = 32 中的32 可以根据自己电脑配置来定,配置高可以定128 低可以定16

# 可视化显示
import matplotlib.pyplot as plt
import numpy as np# 十个类别
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')def imshow(img):img = img / 2 + 0.5 # 逆正则化np_img = img.numpy()  # tensor --> numpyplt.imshow(np.transpose(np_img,(1,2,0)))  # 改变通道顺序plt.show()# 随机获取一批数据
imgs,labs = next(iter(train_loader))print(imgs.shape)
print(labs.shape)#调用方法
imshow(torchvision.utils.make_grid(imgs))# 输出这批图片对应的标签
print(' '.join('%5s' % classes[labs[i]] for i in range(BATCH_SIZE)))    

 结果如下:

其中

torch.Size([32, 3, 32, 32])
torch.Size([32])

中32代表32张图片,3代表3个通道,32代表像素

# 定义网络模型
import torch.nn as nn
import torch.nn.functional as F'''
知识点:
1.特征图尺寸的计算公式为:[(原图片尺寸 = 卷积核尺寸) / 步长] + 1
'''
class Net(nn.Module):def __init__(self):super(Net,self).__init__()# 卷积层1.输入是32*32*3,计算(32-5)/ 1 + 1 = 28,那么通过conv1输出的结果是28*28*6self.conv1 = nn.Conv2d(3,6,5)  # imput:3 output:6, kernel:5# 池化层, 输入时28*28*6, 窗口2*2,计算28 / 2 = 14,那么通过max_poll层输出的结果是14*14*6self.pool = nn.MaxPool2d(2,2) # kernel:2 stride:2# 卷积层2, 输入是14*14*6,计算(14-5)/ 1 + 1 = 10,那么通过conv2输出的结果是10*10*16self.conv2 = nn.Conv2d(6,16,5) # imput:6 output:16, kernel:5# 全连接层1self.fc1 = nn.Linear(16*5*5, 120)  # input:16*5*5,output:120# 全连接层2self.fc2 = nn.Linear(120, 84)  # input:120,output:84# 全连接层3self.fc3 = nn.Linear(84, 10)  # input:84,output:10def forward(self,x):# 卷积1'''32x32x3 --> 28x28x6 -->14x14x6'''x = self.pool(F.relu(self.conv1(x)))# 卷积2'''14x14x6 --> 10x10x16 --> 5x5x16'''x = self.pool(F.relu(self.conv2(x)))# 改变shapex = x.view(-1,16*5*5)# 全连接层1x = F.relu(self.fc1(x))# 全连接层2x = F.relu(self.fc2(x))# 全连接层3x = self.fc3(x)return x 

 注意:__init__这一块下划线要注意,按理说只要将模型定义到__init__()里就ok了,但是大家容易少打一个下划线会报错,将下划线_改为__即可解决问题。

# 创建模型
net = Net().to('cuda')

电脑有GPU的话这一步是部署到CUDA上运行调用GPU, 这一步容易出现下图问题,

这时候多运行几次,代码是没有问题的,应为JUPYTER是在网页上运行,需要时间反应,多运行几次

如果出现以下问题:

注意Linear中的L要大写

# 定义优化器和损失函数
import torch.optim as optimcriterion = nn.CrossEntropyLoss()  # 交叉式损失函数optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)  # 优化器
# 定义函数
EPOCHS = 200for epoch in range(EPOCHS):train_loss = 0.0for i, (datas,labels) in enumerate(train_loader):datas,labels = datas.to('cuda'),labels.to('cuda')# 梯度置零optimizer.zero_grad()# 训练outputs = net(datas)# 计算损失loss = criterion(outputs,labels)# 反向传播loss.backward()# 参数更新optimizer.step()# 累计损失train_loss += loss.item()# 打印信息print(epoch+1, i+1,train_loss/len(train_loader.dataset))

循环次数可以自己设置,这里设置为200轮,for循环读取训练集

输出结果如下:

(可以参考网上其他输出格式)

# 测试
correct = 0
total = 0
# flag=True
with torch.no_grad():for i , (datas,labels) in enumerate(test_loader):# 输出outputs = model(datas) # outputs.data,shape --> torch.Size([128,10])_, predicted = torch.max(outputs.data, dim=1)  # 第一个是值得张量,第二个是序号张量# 累计数据值total += labels.size(0)  # labels.size() --> torch.Size([128]) , labels.size(0) --> 128# 比较有多少个预测正确correct += (predicted == labels).sum()  # 相同为1,不同为0,利用sum()总求和print("在1000张测试集图片上的准确率:{:.3f}%".format(torch.true_divide(correct,total))
# 显示每一类预测的概率
class_correct = list(0. for i in range(10))
total = list(0. for i in range(10))with torch.no_grad():for (images,labels) in test_loader:outputs = model(images)  # 输出_,predicted = torch.max(outputs,dim=1)  # 获取到每一行的最大索引c = (predicted ==labels).squeeze()  # squeeze() 去掉0维【默认】,unsqueeze() 增加一维if labels.shape[0]  == 128:for i in range(BATCH_SIZE):label = labels[i] # 获取每一个labelclass_correct[label] += c[i].item()  # 累计维True都个数,注意:1 + True = 2,1 + False = 1total[label]  += 1 # 该类总的个数# 输出正确率
for i in range(10):print("正确率 : %5s : $2d %%" % (classes[i],100 * class_correct[i] / total[i])

参考视频:​​​​​07-02 经典案例 CIFAR10 图像识别【个人实现】_哔哩哔哩_bilibili

深度学习之经典案例 CIFAR10 图形识别(jupyter)相关推荐

  1. 免费教材丨第56期:《深度学习导论及案例分析》、《谷歌黑板报-数学之美》

    小编说  离春节更近了!  本期教材        本期为大家发放的教材为:<深度学习导论及案例分析>.<谷歌黑板报-数学之美>两本书,大家可以根据自己的需要阅读哦! < ...

  2. 《深度学习导论及案例分析》一2.11概率图模型的推理

    本节书摘来自华章出版社<深度学习导论及案例分析>一书中的第2章,第2.11节,作者李玉鑑 张婷,更多章节内容可以访问云栖社区"华章计算机"公众号查看. 2.11概率图模 ...

  3. 学习=拟合?深度学习和经典统计学是一回事吗?

    来源:PaperWeekly.机器之心 本文大约8700字,建议阅读20分钟 本文介绍了理论计算机科学家.哈佛大学知名教授 Boaz Barak 详细比较了深度学习与经典统计学的差异. 深度学习和简单 ...

  4. 【深度学习】深度学习和经典统计学是一回事?

    器之心编译 编辑:rome rome 深度学习和简单的统计学是一回事吗?很多人可能都有这个疑问,毕竟二者连术语都有很多相似的地方.在这篇文章中,理论计算机科学家.哈佛大学知名教授 Boaz Barak ...

  5. 【深度学习】Pytorch实现CIFAR10图像分类任务测试集准确率达95%

    文章目录 前言 CIFAR10简介 Backbone选择 训练+测试 训练环境及超参设置 完整代码 部分测试结果 完整工程文件 Reference 前言 分享一下本人去年入门深度学习时,在CIFAR1 ...

  6. 干货丨深度学习和经典机器学习的全方位对比

    本文将对比深度学习和经典机器学习,分别介绍这两种技术的优缺点以及它们在哪些问题 如何得到最佳使用. 深度学习已成为大多数AI问题的首选技术,使得经典机器学习相形见绌.但是,尽管深度学习有很好的性能,经 ...

  7. 基于人工智能深度学习和经典算法的药物设计软件MolAICal

    使用MolAICal进行药物设计 MolAICal简介 MolAICal 教程 MolAICal开发版(Development version)教程 MolAICal简介 MolAICal可以通过人工 ...

  8. 深度学习之LSTM案例分析(三)

    #背景 来自GitHub上<tensorflow_cookbook>[https://github.com/nfmcclure/tensorflow_cookbook/tree/maste ...

  9. 《深度学习导论及案例分析》一导读

    PREFACE 前言 "深度学习"一词大家已经不陌生了,随着在不同领域取得了超越其他方法的成功,深度学习在学术界和工业界掀起了一次神经网络发展史上的新浪潮.运用深度学习解决实际问题 ...

最新文章

  1. 从 jQuery 到 VUE 技术栈
  2. [Codeforces] Round #320 (Div.2)
  3. java泛型bean copy list
  4. 将计算机退出域 脚本
  5. mysql班次和排班怎么设计表_java 员工轮询值班排班 开发设计(mysql+redis)
  6. PHP WEB程序设计信息表,PHP WEB程序设计
  7. 参数整定临界比例度实验_PID理解起来很难?系统讲解PID控制及参数调节,理论加实际才好!...
  8. 利用 Flask 动态展示 Pyecharts 图表数据的几种方法
  9. 第 45 届国际大学生程序设计竞赛(ICPC)亚洲区域赛(南京)签到题F Fireworks
  10. python语言开发环境搭建_Python开发环境搭建-Go语言中文社区
  11. 如何实现系统集约与管理运营集约相互促进而不是相互制约
  12. QNAP+Transmission
  13. 前后端api参考手册
  14. Python练习题——快乐数字
  15. SQL语句 —— 查询某天创建的数据(精确到日)
  16. 可一键生成数据分析报告的两个库
  17. 视觉心理物理学(2)matlab与ptb3
  18. 法拉科机器人接头_图解FANUC机器人I/O信号板接口定义与拆装
  19. 最近整理的一些常见的面试题,面试大全,黑马程序员面试宝典题库---最新技术--篇
  20. Springer的latex压缩包上传转不了pdf

热门文章

  1. 无盘服务器已缓存是什么意思,无盘网吧缓存是什么意思
  2. 男人必备!泡妞全攻略 1
  3. mac设置iterm2的Badge
  4. 使用函数实现两个数的交换(C语言)
  5. python电视剧口碑分析_Python分析最近大火的网剧《隐秘的角落》
  6. Dell 灵越5775 安装CentOS 7(已安装Windows系统)
  7. 【短信验证】手机登录短信验证
  8. 使用fvm管理多个flutter版本
  9. 自己制作第一个微信小程序
  10. 粤科软件:依托互联网优势创影院新生态