代码:

import torch.nn as nn
import torch
import torch.nn.functional as F
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import argparse
import os# 训练
def train(args, model, device, train_loader, optimizer):for epoch in range(1, args.epochs + 1):model.train()for batch_index, data in enumerate(train_loader):images, labels = dataimages = images.to(device)labels = labels.to(device)# forwardoutput = model(images)loss = F.cross_entropy(output, labels)# backwardoptimizer.zero_grad()  # 梯度清空loss.backward()  # 梯度回传,更新参数optimizer.step()# 打印lossprint(f'Epoch:{epoch},Batch ID:{batch_index}/{len(train_loader)}, loss:{loss}')# 保存模型if epoch % args.checkpoint_interval == 0:torch.save(model.state_dict(), f'checkpoints/cifar10_%d.pth' % epoch)def test(args, model, device, test_loader):model.eval()total_loss = 0num_correect = 0with torch.no_grad():for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)# 总的losstotal_loss += F.cross_entropy(outputs, labels).item()# 预测值_, predected = torch.max(outputs, dim=1)# 预测对的总个数num_correect += (predected==labels).sum().item()# 计算平均lossaverage_loss = total_loss / len(test_loader.dataset)# 计算准确率accuracy = num_correect / len(test_loader.dataset)# 打印平均loss和准确率print(f'Average loss:{average_loss}\nTest Accuracy:{accuracy*100}%')if __name__ == '__main__':parser = argparse.ArgumentParser(description = 'Pytorch-cifar10_classification')parser.add_argument('--epochs', type=int, default=10, help='number of epochs')parser.add_argument('--batch_size', type=int, default=32, help='size of each image batch')parser.add_argument('--num_classes', type=int, default=10, help='number of classes')parser.add_argument('--lr', type=float, default=0.001, help='learning rate')parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')parser.add_argument('--pretrained_weights', type=str, default='checkpoints/cifar10_17.pth',help='if specified starts from checkpoint model')parser.add_argument("--img_size", type=int, default=224, help="size of each image dimension")parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")parser.add_argument("--train", default=True, help="train or test")args = parser.parse_args()print(args)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# os.makedirs() 方法用于递归创建目录os.makedirs("output", exist_ok=True)os.makedirs("checkpoints", exist_ok=True)# transformdata_transform = transforms.Compose([transforms.ToTensor(),transforms.RandomResizedCrop(args.img_size)])# 下载训练数据集trian_data = datasets.CIFAR10(root = 'data',train = True,download = False,transform = data_transform,target_transform = None,)# 下载测试数据集test_data = datasets.CIFAR10(root = "data",train = False,download = False,transform = data_transform,target_transform = None)# 加载数据train_loader = DataLoader(dataset = trian_data,batch_size = args.batch_size,shuffle = True)test_loader = DataLoader(dataset = test_data,batch_size = args.batch_size)# 创建模型,使用预训练好的权重model = models.vgg16(pretrained = True)# # 冻结模型,参数不更新# for para in model.parameters():#     para.requires_grad = False# # 只训练全连接层# model.classifier[3].requires_grad = True# model.classifier[6].requires_grad = True# 修改vgg16的输出维度model.classifier[6] = nn.Linear(in_features=4096, out_features=args.num_classes, bias=True)model = model.to(device)# 打印网络结构print(model)# 定义优化器(也可以选择其他优化器)optimizer = torch.optim.SGD(model.parameters(), lr = args.lr, momentum = args.momentum)# optimizer = torch.optim.Adam(model.parameters())if train == True:if args.pretrained_weights.endswith(".pth"):model.load_state_dict(torch.load(args.pretrained_weights))for epoch in range(1, epochs+1):train(args, model, device, train_loader, optimizer)else:if args.pretrained_weights.endswith(".pth"):model.load_state_dict(torch.load(args.pretrained_weights))test(args, model, device, test_loader)

说明:
        cifar10数据集可以通过trochvision中的datasets.CIFAR10下载,也可以自己下载(注意存放路径);我模型使用的是torchvision中的models中预训练好的vgg16网络,也可以自己搭建网络。

【Pytorch】CIFAR1010数据集的训练和测试相关推荐

  1. FCN制作自己的数据集、训练和测试 caffe

    原文:http://blog.csdn.net/zoro_lov3/article/details/74550735 FCN制作自己的数据集.训练和测试全流程 花了两三周的时间,在导师的催促下,把FC ...

  2. python划分数据集用pandas_用pandas划分数据集实现训练集和测试集

    1.使用model_select子模块中的train_test_split函数进行划分 数据:使用kaggle上Titanic数据集 划分方法:随机划分 # 导入pandas模块,sklearn中mo ...

  3. [Python+sklearn] 拆分数据集为训练和测试子集 sklearn.model_selection.train_test_split()

    Python - sklearn 拆分数据集为训练和测试子集 sklearn.model_selection.train_test_split() 功能: 将数组或矩阵拆分为随机的训练子集和测试子集 ...

  4. 将数据集分为训练集和测试集(python脚本)

    文章目录 程序: 下面简单介绍一下程序流程 1.引入库 os库 shutil random 2.mk_file函数 3.主函数 程序: 我们在训练卷积神经网络之前,要搭建好数据集,分成训练集和测试集两 ...

  5. 【Pytorch】MNIST数据集的训练和测试

    训练和测试的完整代码: import torch import torch.nn as nn import torch.nn.functional as F from torchvision impo ...

  6. 自定义ava数据集及训练与测试 完整版 时空动作/行为 视频数据集制作 yolov5, deep sort, VIA MMAction, SlowFast

    前言 这一篇博客应该是我花时间最多的一次了,从2022年1月底至2022年4月底. 我已经将这篇博客的内容写为论文,上传至arxiv:https://arxiv.org/pdf/2204.10160. ...

  7. 7个Bert变种模型baseline在7个文本分类数据集上训练和测试

    引入和代码项目简介 https://github.com/songyingxin/Bert-TextClassification 模型有哪些? 使用的模型有下面七个 BertOrigin, BertC ...

  8. 【caffe】mnist数据集lenet训练与测试

    在上一篇中,费了九牛二虎之力总算是把Caffe编译通过了,现在我们可以借助mnist数据集,测试下Caffe的训练和检测效果. 准备工作:在自己的工作目录下,新建一个文件夹,命名为mnist_test ...

  9. 机器学习之数据集划分——训练集测试集划分,划分函数,估计器的使用

    训练集测试集划分,划分函数,估计器的使用 参考文章 训练集.验证集和测试集的划分及交叉验证的讲解 划分训练集和测试集的函数学习 sklearn数据集,数据集划分,估计器详细讲解 参考文章 训练集.验证 ...

最新文章

  1. pandas使用isna函数和any函数计算返回dataframe中包含缺失值的数据行(rows with missing values in dataframe)
  2. 人体肺活量测试软件,人体肺活量怎么测试
  3. C语言学习之输入一行字符,分别统计出其中英文字母、空格、数字和其他字符的个数。
  4. 使用 TypeScript 自定义装饰器给类的属性增添监听器 Listener
  5. c语言指针++_C ++此指针| 查找输出程序| 套装3
  6. 计算机游戏高少手电影,支持switch,还有电影特技!上手简评骨伽IMMERSA Ti游戏耳机...
  7. 恒大汽车:引入腾讯、滴滴等投资者 筹集约40亿港元
  8. 基于element ui的收起展开检索条件效果
  9. L1-008 求整数段和 (10 分)—团体程序设计天梯赛
  10. 项目中一些零碎化总结的
  11. 易语言代码转php,易语言代码转PHP代码有没大佬
  12. linux内核nvme驱动程序,Linux中nvme驱动详解
  13. 高级电工实验室成套设备(带功率表、功率因数表)
  14. IP抓包精准定位教程
  15. 【渝粤题库】陕西师范大学165102管理心理学 作业(高起专)
  16. 修改win10更新服务器,修改win10更新服务器地址
  17. STM32初始化产生低电平引起的问题
  18. 机械师f117-7p安装linux禁用触摸板问题
  19. uniapp获取微信头像和昵称
  20. CocosCreator + JavaScript游戏开发

热门文章

  1. python预测药_python 最麻烦的时间有药了
  2. np python_python小白之np功能快速查
  3. CubeMX的代码生成设置
  4. 【飞控理论】从零开始学习Kalman Filters之三:非线性状态估算器
  5. 【C语言】1161: 字符串长度(指针专题)(空格和\0)
  6. STC51-A/D和D/A
  7. nginx的模块化体系结构
  8. varnish关于Grace mode和Saint mode这两中模式配置
  9. 第一次写CSDN的博客
  10. lisp xy轴不等比缩放_解决高缩放等级下的抖动问题