【Pytorch】CIFAR1010数据集的训练和测试
代码:
import torch.nn as nn
import torch
import torch.nn.functional as F
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import argparse
import os# 训练
def train(args, model, device, train_loader, optimizer):for epoch in range(1, args.epochs + 1):model.train()for batch_index, data in enumerate(train_loader):images, labels = dataimages = images.to(device)labels = labels.to(device)# forwardoutput = model(images)loss = F.cross_entropy(output, labels)# backwardoptimizer.zero_grad() # 梯度清空loss.backward() # 梯度回传,更新参数optimizer.step()# 打印lossprint(f'Epoch:{epoch},Batch ID:{batch_index}/{len(train_loader)}, loss:{loss}')# 保存模型if epoch % args.checkpoint_interval == 0:torch.save(model.state_dict(), f'checkpoints/cifar10_%d.pth' % epoch)def test(args, model, device, test_loader):model.eval()total_loss = 0num_correect = 0with torch.no_grad():for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)# 总的losstotal_loss += F.cross_entropy(outputs, labels).item()# 预测值_, predected = torch.max(outputs, dim=1)# 预测对的总个数num_correect += (predected==labels).sum().item()# 计算平均lossaverage_loss = total_loss / len(test_loader.dataset)# 计算准确率accuracy = num_correect / len(test_loader.dataset)# 打印平均loss和准确率print(f'Average loss:{average_loss}\nTest Accuracy:{accuracy*100}%')if __name__ == '__main__':parser = argparse.ArgumentParser(description = 'Pytorch-cifar10_classification')parser.add_argument('--epochs', type=int, default=10, help='number of epochs')parser.add_argument('--batch_size', type=int, default=32, help='size of each image batch')parser.add_argument('--num_classes', type=int, default=10, help='number of classes')parser.add_argument('--lr', type=float, default=0.001, help='learning rate')parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')parser.add_argument('--pretrained_weights', type=str, default='checkpoints/cifar10_17.pth',help='if specified starts from checkpoint model')parser.add_argument("--img_size", type=int, default=224, help="size of each image dimension")parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")parser.add_argument("--train", default=True, help="train or test")args = parser.parse_args()print(args)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# os.makedirs() 方法用于递归创建目录os.makedirs("output", exist_ok=True)os.makedirs("checkpoints", exist_ok=True)# transformdata_transform = transforms.Compose([transforms.ToTensor(),transforms.RandomResizedCrop(args.img_size)])# 下载训练数据集trian_data = datasets.CIFAR10(root = 'data',train = True,download = False,transform = data_transform,target_transform = None,)# 下载测试数据集test_data = datasets.CIFAR10(root = "data",train = False,download = False,transform = data_transform,target_transform = None)# 加载数据train_loader = DataLoader(dataset = trian_data,batch_size = args.batch_size,shuffle = True)test_loader = DataLoader(dataset = test_data,batch_size = args.batch_size)# 创建模型,使用预训练好的权重model = models.vgg16(pretrained = True)# # 冻结模型,参数不更新# for para in model.parameters():# para.requires_grad = False# # 只训练全连接层# model.classifier[3].requires_grad = True# model.classifier[6].requires_grad = True# 修改vgg16的输出维度model.classifier[6] = nn.Linear(in_features=4096, out_features=args.num_classes, bias=True)model = model.to(device)# 打印网络结构print(model)# 定义优化器(也可以选择其他优化器)optimizer = torch.optim.SGD(model.parameters(), lr = args.lr, momentum = args.momentum)# optimizer = torch.optim.Adam(model.parameters())if train == True:if args.pretrained_weights.endswith(".pth"):model.load_state_dict(torch.load(args.pretrained_weights))for epoch in range(1, epochs+1):train(args, model, device, train_loader, optimizer)else:if args.pretrained_weights.endswith(".pth"):model.load_state_dict(torch.load(args.pretrained_weights))test(args, model, device, test_loader)
说明:
cifar10数据集可以通过trochvision中的datasets.CIFAR10下载,也可以自己下载(注意存放路径);我模型使用的是torchvision中的models中预训练好的vgg16网络,也可以自己搭建网络。
【Pytorch】CIFAR1010数据集的训练和测试相关推荐
- FCN制作自己的数据集、训练和测试 caffe
原文:http://blog.csdn.net/zoro_lov3/article/details/74550735 FCN制作自己的数据集.训练和测试全流程 花了两三周的时间,在导师的催促下,把FC ...
- python划分数据集用pandas_用pandas划分数据集实现训练集和测试集
1.使用model_select子模块中的train_test_split函数进行划分 数据:使用kaggle上Titanic数据集 划分方法:随机划分 # 导入pandas模块,sklearn中mo ...
- [Python+sklearn] 拆分数据集为训练和测试子集 sklearn.model_selection.train_test_split()
Python - sklearn 拆分数据集为训练和测试子集 sklearn.model_selection.train_test_split() 功能: 将数组或矩阵拆分为随机的训练子集和测试子集 ...
- 将数据集分为训练集和测试集(python脚本)
文章目录 程序: 下面简单介绍一下程序流程 1.引入库 os库 shutil random 2.mk_file函数 3.主函数 程序: 我们在训练卷积神经网络之前,要搭建好数据集,分成训练集和测试集两 ...
- 【Pytorch】MNIST数据集的训练和测试
训练和测试的完整代码: import torch import torch.nn as nn import torch.nn.functional as F from torchvision impo ...
- 自定义ava数据集及训练与测试 完整版 时空动作/行为 视频数据集制作 yolov5, deep sort, VIA MMAction, SlowFast
前言 这一篇博客应该是我花时间最多的一次了,从2022年1月底至2022年4月底. 我已经将这篇博客的内容写为论文,上传至arxiv:https://arxiv.org/pdf/2204.10160. ...
- 7个Bert变种模型baseline在7个文本分类数据集上训练和测试
引入和代码项目简介 https://github.com/songyingxin/Bert-TextClassification 模型有哪些? 使用的模型有下面七个 BertOrigin, BertC ...
- 【caffe】mnist数据集lenet训练与测试
在上一篇中,费了九牛二虎之力总算是把Caffe编译通过了,现在我们可以借助mnist数据集,测试下Caffe的训练和检测效果. 准备工作:在自己的工作目录下,新建一个文件夹,命名为mnist_test ...
- 机器学习之数据集划分——训练集测试集划分,划分函数,估计器的使用
训练集测试集划分,划分函数,估计器的使用 参考文章 训练集.验证集和测试集的划分及交叉验证的讲解 划分训练集和测试集的函数学习 sklearn数据集,数据集划分,估计器详细讲解 参考文章 训练集.验证 ...
最新文章
- pandas使用isna函数和any函数计算返回dataframe中包含缺失值的数据行(rows with missing values in dataframe)
- 人体肺活量测试软件,人体肺活量怎么测试
- C语言学习之输入一行字符,分别统计出其中英文字母、空格、数字和其他字符的个数。
- 使用 TypeScript 自定义装饰器给类的属性增添监听器 Listener
- c语言指针++_C ++此指针| 查找输出程序| 套装3
- 计算机游戏高少手电影,支持switch,还有电影特技!上手简评骨伽IMMERSA Ti游戏耳机...
- 恒大汽车:引入腾讯、滴滴等投资者 筹集约40亿港元
- 基于element ui的收起展开检索条件效果
- L1-008 求整数段和 (10 分)—团体程序设计天梯赛
- 项目中一些零碎化总结的
- 易语言代码转php,易语言代码转PHP代码有没大佬
- linux内核nvme驱动程序,Linux中nvme驱动详解
- 高级电工实验室成套设备(带功率表、功率因数表)
- IP抓包精准定位教程
- 【渝粤题库】陕西师范大学165102管理心理学 作业(高起专)
- 修改win10更新服务器,修改win10更新服务器地址
- STM32初始化产生低电平引起的问题
- 机械师f117-7p安装linux禁用触摸板问题
- uniapp获取微信头像和昵称
- CocosCreator + JavaScript游戏开发