简介

使用猫狗分类数据集中的训练集,共25000张图片。将原始训练集进行拆分,其中20000张用于训练,其余5000张用于测试。分类网络使用ResNet-18,使用了交叉熵损失函数和SGD优化方法。

环境配置

建立Conda虚拟环境,python3.7,几个重要的库:
(1)pytorch 1.7.0

(2)torchvision 0.8.0

(3)opencv-python 4.5.2.52

(4)tqdm 4.61.0

目录结构

运行方法

必须下载数据集。数据集下载完成后,存放在[工程主目录]/data路径下,首先运行如下命令完成数据集划分。

python prepare_data.py

运行完成后,[工程主目录]/data路径下会生成newtrain和newtest这2个路径,分别存放训练集和测试集。
训练过程

python train.py

训练完成后,在工程主目录下会生成名为resnet18_Cat_Dog.pth的权重文件,推理时会读取该权重文件。

测试过程

python test.py

推理完成后会打印出推理的正确率。

以下是prepare_data.py文件内容,该模块的主要功能是用来划分数据,将全部的数据一部分划分为分类训练集,一部分划分为测试集。

import os
import shutildef main():
# Step 1:创建训练集路径和测试集路径new_train_data_path = os.path.join(os.getcwd(), 'data/newtrain')new_test_data_path = os.path.join(os.getcwd(), 'data/newtest')if os.path.exists(new_train_data_path) is False:os.makedirs(new_train_data_path)if os.path.exists(new_test_data_path) is False:os.makedirs(new_test_data_path)#Step 2:将Cat和Dog类别中id>=10000的图片存到测试集路径中,其他图片存到训练集路径中origin_dataset_path = os.path.join(os.getcwd(), 'data/train')img_list = os.listdir(origin_dataset_path)for img_name in img_list:img_name_split = img_name.split('.')src_img = os.path.join(origin_dataset_path, img_name)if int(img_name_split[1])>=10000:shutil.copy(src_img, new_test_data_path)else:shutil.copy(src_img, new_train_data_path)if __name__ == '__main__':main()

以下是DogCatDataset.py。

import os
import cv2
from torch.utils.data import Datasetclass DogCatDataset(Dataset):def __init__(self, root_path, transform=None):self.label_name = {"Cat": 0, "Dog": 1}self.root_path = root_pathself.transform = transformself.get_train_img_info()def __getitem__(self, index):self.img = cv2.imread(os.path.join(self.root_path, self.train_img_name[index]))if self.transform is not None:self.img = self.transform(self.img)self.label = self.train_img_label[index]return self.img, self.labeldef __len__(self):return len(self.train_img_name)def get_train_img_info(self):self.train_img_name = os.listdir(self.root_path)self.train_img_label = [0 if 'cat' in imgname else 1 for imgname in self.train_img_name]

以下是train.py。

import os
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import torchvision.models as models
import DogCatDatasetdef main():#Step 0:查看torch版本、设置deviceprint(torch.__version__)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#Step 1:准备数据集train_transform = transforms.Compose([transforms.ToPILImage(),transforms.Resize((224, 224)),transforms.ToTensor()])train_data = DogCatDataset.DogCatDataset(root_path=os.path.join(os.getcwd(), 'data/newtrain'),transform=train_transform)train_dataloader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)#Step 2: 初始化模型model = models.resnet18()#修改网络结构,将fc层1000个输出改为2个输出fc_input_feature = model.fc.in_featuresmodel.fc = nn.Linear(fc_input_feature, 2)#load除最后一层的预训练权重pretrained_weight = torch.hub.load_state_dict_from_url(url='https://download.pytorch.org/models/resnet18-5c106cde.pth', progress=True)del pretrained_weight['fc.weight']del pretrained_weight['fc.bias']model.load_state_dict(pretrained_weight, strict=False)model.to(device)#Step 3:设置损失函数criterion = nn.CrossEntropyLoss()     #交叉熵损失函数#Step 4:选择优化器LR = 0.01optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)    #Step 5:设置学习率下降策略scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  #Step 6:训练网络model.train() MAX_EPOCH = 20    #设置epoch=20for epoch in range(MAX_EPOCH):loss_log = 0total_sample = 0train_correct_sample = 0    for data in tqdm(train_dataloader):img, label = dataimg, label = img.to(device), label.to(device)output = model(img)optimizer.zero_grad()loss = criterion(output, label)loss.backward()optimizer.step()_, predicted_label = torch.max(output, 1)total_sample += label.size(0)train_correct_sample += (predicted_label == label).cpu().sum().numpy()loss_log += loss.item()# if total_sample == 2400:#     print('mark!')#打印信息print('epoch: ', epoch)print("accuracy:", train_correct_sample/total_sample)print('loss:', loss_log/total_sample)scheduler.step()   #更新学习率print('train finish!')#Step 7: 存储权重torch.save(model.state_dict(), './resnet18_Cat_Dog.pth')if __name__ == '__main__':main()

以下是test.py部分。

import os
import tqdm
import torch
from torch.utils.data import DataLoader
from torch.nn import functional as F
import torch.nn as nn
from torchvision import transforms
import torchvision.models as models
import DogCatDatasetdef main():#Step 0:查看torch版本、设置deviceprint(torch.__version__)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#Step 1:准备数据集test_transform = transforms.Compose([transforms.ToPILImage(),transforms.Resize((224, 224)),transforms.ToTensor()])test_data = DogCatDataset.DogCatDataset(root_path=os.path.join(os.getcwd(), 'data/newtest'),transform=test_transform)test_dataloader = DataLoader(dataset=test_data, batch_size=1, shuffle=False)#Step 2: 初始化网络model = models.resnet18()#修改网络结构,将fc层1000个输出改为2个输出fc_input_feature = model.fc.in_featuresmodel.fc = nn.Linear(fc_input_feature, 2)#Step 3:加载训练好的权重trained_weight = torch.load('./resnet18_Cat_Dog.pth')model.load_state_dict(trained_weight)model.to(device)#Steo 4:网络推理model.eval()correct_sample = 0total_sample = 0with torch.no_grad():for data in test_dataloader:img, label = data  #这里的label啥意思img = img.to(device)label = label.to(device)output = model(img)_, predicted_label = torch.max(output, 1)correct_sample += (predicted_label==label).cpu().numpy()total_sample += 1#print('Image Name:{},predict:{}'.format(, predicted_label))#这里想提取出文件名#Step 5:打印分类准确率print(correct_sample/total_sample)if __name__ == '__main__':main()

resnet18实现猫狗图片的分类相关推荐

  1. cnn卷积实现猫狗图片识别分类

    数据集图片https://download.csdn.net/download/qq_42363032/12737988 import tensorflow as tf import random i ...

  2. 用卷积神经网络实现猫狗图片分类

    该例程使用数据集来源于 kaggle cat_VS _dog 数据集中的一部分, 用卷积神经网络实现猫狗图片二分类,例程序比较简单,就不多解释了,代码中会有相应的注释,直接上代码: import nu ...

  3. PyTorch搭建预训练AlexNet、DenseNet、ResNet、VGG实现猫狗图片分类

    目录 前言 AlexNet DensNet ResNet VGG 前言 在之前的文章中,利用一个简单的三层CNN猫狗图片分类,正确率不高,详见: CNN简单实战:PyTorch搭建CNN对猫狗图片进行 ...

  4. 使用预训练的卷积神经网络(猫狗图片分类)

    本次所用数据来自ImageNet,使用预训练好的数据来预测一个新的数据集:猫狗图片分类.这里,使用VGG模型,这个模型内置在Keras中,直接导入就可以了. from keras.applicatio ...

  5. 使用深度学习分类猫狗图片

    使用深度学习分类猫狗图片 前言 一.下载数据 二.构建网络 三.数据预处理 四.使用数据增强 总结 前言 本文将介绍如何使用较少的数据从头开始训练一个新的深度学习模型.首先在一个2000个训练样本上训 ...

  6. Top2:CNN 卷积神经网络实现猫狗图片识别二分类

    Top2:CNN 卷积神经网络实现猫狗图片识别二分类 系统:Windows10 Professional 环境:python=3.6 tensorflow-gpu=1.14 ```python &qu ...

  7. 11.CNN实现真实猫狗图片分类

    CNN实现真实猫狗图片分类 个人认为,和上一节的mnist数据集里面的手写数字图片不同之处就是,真实的图片更加复杂,像素点更多.因此在对应的图片预处理方面会稍微麻烦一些.但是这个例子能让我们可以处理自 ...

  8. 体验AI乐趣:基于AI Gallery的二分类猫狗图片分类小数据集自动学习

    摘要:直接使用AI Gallery里面现有的数据集进行自动学习训练,很简单和方便,节约时间,不用自己去训练了,AI Gallery 里面有很多类似的有趣数据集,也非常好玩,大家一起试试吧. 本文分享自 ...

  9. CNN之从头训练一个猫狗图片分类模型

    猫狗图片下载地址: 链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw 提取码:2xq4 说明:大概有816M大小,分为train和test,trai ...

最新文章

  1. 个人做asp.net时犯过的错或是一点心得什么的(我就经常的更新一下吧)
  2. 这是我的2018年终总结,你的呢?
  3. IOS用CGContextRef画各种图形(文字、圆、直线、弧线、矩形、扇形、椭圆、三角形、圆角矩形、贝塞尔曲线、图片)...
  4. C# 10 新特性 —— Lambda 优化
  5. 计算机c语言二级题型,计算机二级C语言题型和评分标准
  6. ROS笔记(19) 摄像头仿真
  7. pymysql使用变化的变量,构造SQL语句
  8. 大数据之-Hadoop3.x_MapReduce_ETL数据清洗案例---大数据之hadoop3.x工作笔记0136
  9. librdkafka 安装
  10. String.Format用法
  11. 应用安全-浏览器安全-攻防
  12. 全网最详细的hive-site.xml配置文件里如何添加达到Hive与HBase的集成,即Hive通过这些参数去连接HBase(图文详解)...
  13. 利用 /dev/zero 创建虚拟硬盘
  14. 注册岩土工程师计算机专业的能考吗,岩土工程师报考条件
  15. 在IE/Chrome/Firefox等浏览器在线打开Word等Office文档完全解决方案
  16. 提高个人竞争力的三件法宝
  17. 【学习】Congestion Control
  18. 启用mysql系统找不到指定的文件类型_net start mysql 发生系统错误2 系统找不到指定的文件...
  19. 字节跳动CVPR 2023论文精选来啦(内含一批图像生成新研究)
  20. win10 suse linux,Windows 10现已支持安装SUSE Linux子系统 附安装教程|蓝点网

热门文章

  1. 写一篇关于招标代理机构的项目进度计划及进度保证措施
  2. Python登录qq邮箱发送邮件(附件)
  3. 计算机视觉-OpenCV(文档扫描OCR识别)
  4. gitee码云完整使用教程(部署与克隆)
  5. 语音模块之学习硬件模块改代码
  6. AD如何使用向导快速画DIP系列
  7. Java编程入门笔记(一)
  8. java 时间段重叠_Java判断多个时间段是否重叠(重叠区间个数)
  9. 大数据工程师应聘要求高么?好找工作么
  10. Java 生产神器 BTrace