文件结构:

data是放数据集,img是放模型训练完成之后测试的,module放模型,model是模型,test是测试img中的,train是训练模型

模型

# -*- coding: utf-8 -*-from torch import nnclass Model(nn.Module):def __init__(self):super(Model, self).__init__()self.model = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2), # 卷积操作nn.MaxPool2d(kernel_size=2), # 最大池化nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(32, 64, 5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Flatten(), # 展开nn.Linear(1024, 64), # 线形变换nn.Linear(64, 10) # 线形变换)def forward(self, input):output = self.model(input)return output

加载数据集

该数据集Pytorch可以直接下载,直接download=True,若是现在太慢,可以在迅雷下载

import torchvision
train_dataset = torchvision.datasets.CIFAR10(root='./data',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data',train=False,transform=torchvision.transforms.ToTensor(),download=True)

训练

其中可以调用gpu进行训练,若是没有gpu则会选用cpu训练,训练速度lr =0.001,用的SGD做优化器,平方差计算loss,训练次数为500次,并且如果验证集的准确率大于0.68

# -*- coding: utf-8 -*-from torch.utils.data import DataLoader
from model import Model
import torch
import torchvision
import torch.nn as nndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')train_dataset = torchvision.datasets.CIFAR10(root='./data',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data',train=False,transform=torchvision.transforms.ToTensor(),download=True)train_len = len(train_dataset)
test_len = len(test_dataset)
print(f"训练数据量有{train_len}个")
print(f"测试数据量有{test_len}个")# 加载数据
train_loader = DataLoader(train_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)# 加载网络
model = Model().to(device)
# 损失函数
losser = nn.CrossEntropyLoss()
losser.to(device)
# 优化器
speed = 0.001
optimzer = torch.optim.SGD(model.parameters(), lr=speed)# 训练轮数
total_train_step = 500# 记录训练次数和测试次数model.train()
for i in range(1,total_train_step+1):print(f"------第{i}轮训练开始------")train_loss = 0for data in train_loader:img, target = dataimg, target = img.to(device), target.to(device)output = model(img)loss = losser(output, target)# 模型优化optimzer.zero_grad()loss.backward()optimzer.step()train_loss += lossprint(f"\t\t训练Loss:{loss}")# 测试数据model.eval()total_test_Loss = 0total_accuracy = 0with torch.no_grad():print(f"\t测试开始")for data in test_loader:img, target = dataimg, target = img.to(device), target.to(device)output = model(img)loss = losser(output, target)total_test_Loss += lossaccuracy = (output.argmax(1) == target).sum()total_accuracy += accuracytotal_accuracy = int(total_accuracy)print(f"\t\t测试总损失:{total_test_Loss}")print(f"\t\t测试正确率:{round((total_accuracy/test_len), 4)*100}%")if (total_accuracy/test_len) > 0.68:torch.save(model.state_dict(), f"./module/CIFAR_{i}_{round((total_accuracy/test_len), 2)}.pth")print(f"------第{i}轮模型已保存------")

测试模型

注意:png图片要转化成RGB图片,CIFAR总类顺序是'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'。并且在测试是不需要梯度。

# -*- coding: utf-8 -*-import torch
from model import Model
from PIL import Image
import torchvisionmodel = Model()
model.load_state_dict(torch.load('./module/CIFAR_313_0.68.pth'))kind = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']img = Image.open('./img/img.png')transf = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])
img = transf(img)
img = torch.reshape(img, (1,3,32,32))
model.eval()
with torch.no_grad():output = model(img)print(output.argmax(1))

总结

这个模型训练最高的准确率在68.9%左右,这也是这个模型能训练最高的准确率。如果你发现后面的准确率在降低,但是训练集loss也在减少,说明模型过拟合了,就可以停止。

如果想要再度提高准确率,你可以选择增加数据集的数据量,或者改变模型(用残差等)

机器学习CIFAR10训练(卷积神经网络)相关推荐

  1. 【深度学习】基于Torch的Python开源机器学习库PyTorch卷积神经网络

    [深度学习]基于Torch的Python开源机器学习库PyTorch卷积神经网络 文章目录 1 CNN概述 2 PyTorch实现步骤2.1 加载数据2.2 CNN模型2.3 训练2.4 可视化训练 ...

  2. 机器学习笔记三—卷积神经网络与循环神经网络

    系列文章目录 机器学习笔记一-机器学习基本知识 机器学习笔记二-梯度下降和反向传播 机器学习笔记三-卷积神经网络与循环神经网络 机器学习笔记四-机器学习可解释性 机器学习笔记五-机器学习攻击与防御 机 ...

  3. pytorch1.7教程实验——迁移学习训练卷积神经网络进行图像分类

    只是贴上跑通的代码以供参考学习 参考网址:迁移学习训练卷积神经网络进行图像分类 需要用到的数据集下载网址: https://download.pytorch.org/tutorial/hymenopt ...

  4. matlab训练参数,设置参数并训练卷积神经网络

    设置参数并训练卷积神经网络 按照指定卷积神经网络的层中所述定义神经网络的层后,下一步是为网络设置训练选项.使用 trainingOptions 函数定义全局训练参数.要训练网络,请使用 trainin ...

  5. PyTorch 从零训练卷积神经网络(Convent)

    本文主要介绍从从零训练卷积神经网络(Convent).使用PyTorch创建各自的convent或神经网络样本. 原文地址:PyTorch 从零训练卷积神经网络(Convent)

  6. 全球名校课程作业分享系列(7)--斯坦福计算机视觉与深度学习CS231n之基于cifar10的卷积神经网络实践

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/yaoqiang2011/article ...

  7. 李宏毅机器学习课程10~~~卷积神经网络

    卷积的意义 数字图像是一个二维的离散信号,对数字图像做卷积操作其实就是利用卷积核(卷积模板)在图像上滑动,将图像点上的像素灰度值与对应的卷积核上的数值相乘,然后将所有相乘后的值相加作为卷积核中间像素对 ...

  8. 机器学习:利用卷积神经网络实现图像风格迁移 (一)

    相信很多人都对之前大名鼎鼎的 Prisma 早有耳闻,Prisma 能够将一张普通的图像转换成各种艺术风格的图像,今天,我们将要介绍一下Prisma 这款软件背后的算法原理.就是发表于 2016 CV ...

  9. 在MNIST图像上训练卷积神经网络

    摘要:这是第一次接触卷积神经网络,非常顺利地运行了代码,基本了解了卷积神经网络是怎么训练的. 1.代码(有空格的是一个单元格) from keras import layers from keras ...

  10. cnn风格迁移_机器学习:利用卷积神经网络实现图像风格迁移 (一)

    相信很多人都对之前大名鼎鼎的 Prisma 早有耳闻,Prisma 能够将一张普通的图像转换成各种艺术风格的图像,今天,我们将要介绍一下Prisma 这款软件背后的算法原理.就是发表于 2016 CV ...

最新文章

  1. 重磅!亚马逊将在2019年全面弃用Oracle数据库
  2. 互联网秒杀设计--转载
  3. 计算机在气象上的应用浅论,简析计算机网络在气象服务中的应用原稿
  4. ubuntu默认防火墙
  5. 机房收费系统学生下机结账小结
  6. Spring系列(六) Spring Web MVC 应用构建分析
  7. CPU Cache Line:CPU缓存行/缓存块
  8. SAP License:移动类型541(委外业务)不产生会计凭证的原因
  9. 电容或电感的电压_纯电阻、纯电感和纯电容电路
  10. AutoCAD2014的安装
  11. Captcha Cracker
  12. jQuery 案例-图片抽奖
  13. 张继群,创青春-数字经济赛道,中国创翼临沂市决赛,创客中国-中小企业创客比赛-临沂市决赛
  14. java两张图片拼接
  15. 第八讲:工业网络——单环冗余(理论)
  16. 神经网络应用现状分析,神经网络应用现状调查
  17. 人工智能AI编程基础(五)
  18. 图片名字存在txt文件中,从另一个装有图片的文件夹中筛选对应的图片。python代码
  19. 围棋计算机运算,计算机围棋中的算法研究
  20. 安装JDK报错:Failed to extract file RegUtils from the binary table

热门文章

  1. 【1月英语—罗塞塔之爱】
  2. 小程序报错 Invalid regular expression: invalid group specifier name
  3. TCP/IP——从wireshark看TCP(一)
  4. vscode 侧边栏源代码管理不见了
  5. 国际清算银行称:央行数字货币可能导致银行挤兑
  6. 【excel】如何绘制斜线表头
  7. 为你的App瘦身,优化你的App
  8. 光猫路由器一体机安装和千兆网络
  9. [N1CTF 2022] solve_pow,baby_N1ES
  10. 第七天 位置参数 变量运算if case || find locate compress