小技巧(5):将TT100K数据集转成VOC格式,并且用Python脚本选出45类超过100张的图片和XML

文章目录

  • 1 数据预处理
    • 1.1 下载数据集
    • 1.2 制作bs_dataset
    • 1.3 制作bs_dataloader
  • 2 搭建网络模型
  • 3 训练和测试模型
  • 4 总结

本文完整代码:https://github.com/cqfdch/BelgiumTSC-pytorch

图像分类是计算机视觉的基础,pytorch做图像分类都是模块化的操作,主要包括数据预处理,搭建网络模型,训练和测试模型。

1 数据预处理

1.1 下载数据集

能拿来训练的数据集=数据集+标签文件
方法一:可以去官网https://btsd.ethz.ch/shareddata/下载数据集,然后进行自行处理

方法二:(1积分,看着支持就行)通过csdn资源下载(含数据集+标签文件)
https://download.csdn.net/download/Hankerchen/13073778

方法三:我的百度网盘(含数据集+标签文件)链接: https://pan.baidu.com/s/1JYWEFYFJCSRsVPmBfkauPQ 密码: wqtv

1.2 制作bs_dataset

根据pytorch标准的制作数据集的三段式
def init(self):
def len(self):
def getitem(self, idx):

import torch
import os
import pandas as pd
from torch.utils.data import Dataset
import numpy as np
from PIL import Imageclass BelgiumTSC(Dataset):base_folder = 'BelgiumTSC'def __init__(self, root_dir, train=False, transform=None):"""Args:train (bool): Load trainingset or test set.root_dir (string): Directory containing GTSRB folder.transform (callable, optional): Optional transform to be appliedon a sample."""self.root_dir = root_dirself.sub_directory = 'Training' if train else 'Testing'self.csv_file_name = 'train_data.csv' if train else 'test_data.csv'csv_file_path = os.path.join(root_dir, self.base_folder, self.sub_directory, self.csv_file_name)self.csv_data = pd.read_csv(csv_file_path)self.transform = transformdef __len__(self):return len(self.csv_data)def __getitem__(self, idx):img_path = os.path.join(self.root_dir, self.base_folder, self.sub_directory,self.csv_data.iloc[idx, 0])img = Image.open(img_path)classId = self.csv_data.iloc[idx, 1]if self.transform is not None:img = self.transform(img)return img, classId

1.3 制作bs_dataloader

参考链接https://github.com/tomlawrenceuk/GTSRB-Dataloader
通过torch.utils.data.DataLoader()将bs_dataset导入bs_loader。

import bs_dataset as dataset
import torchvision.transforms as transforms
import torchdef get_train_valid_loader(data_dir,batch_size,num_workers=0,):# Create Transformstransform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.3403, 0.3121, 0.3214),(0.2724, 0.2608, 0.2669))])# Create Datasetstrainset = dataset.BelgiumTSC(root_dir=data_dir, train=True,  transform=transform)testset = dataset.BelgiumTSC(root_dir=data_dir, train=False,  transform=transform)# Load Datasetstrainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)return trainloader, testloader

2 搭建网络模型

网络模型有很多,AlexNet,VGGNet,ResNet等,注意要数据图片尺寸。
具体的代码可以看我的github
https://github.com/cqfdch/BelgiumTSC-pytorch

3 训练和测试模型

参考链接
https://blog.csdn.net/qq_37541097/article/details/104710784
将前面制作好的bs_loader进行训练,每个epoch测试一次。

from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import bs_loader
# from model import AlexNet
from model import Modelprint(torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")train_loader, validate_loader = bs_loader.get_train_valid_loader('D:\\Networks\\cnn-ga-master\\data', batch_size=32, num_workers=0)net = Model()
net.to(device)
loss_function = nn.CrossEntropyLoss()
# pata = list(net.parameters())
optimizer = optim.Adam(net.parameters(), lr=0.0002)
epoch = 30
save_path = './model.pth'
best_acc = 0.0
for epoch in range(epoch):# trainnet.train()running_loss = 0.0total = 0correct = 0show_step = 32for step, data in enumerate(tqdm(train_loader),0):images, labels = datadevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = net(images)loss = loss_function(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()_, predicted = torch.max(outputs.detach(), 1)total += labels.size(0)correct += predicted.eq(labels.data).sum().item()# print train process# rate = (step+1)/len(train_loader)# a = "*" * int(rate * 50)# b = "." * int((1 - rate) * 50)# print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")# print('Train-Epoch:%3d, %3d / %3d ,Loss: %.3f, Acc:%.3f'% (epoch+1, step+1, len(train_loader),running_loss/total, (correct/total)))if step % show_step == 0:print("Epoch [{}][{}/{}]:Loss:{:.3f},Acc:{:.3f}".format(epoch+1, step+1, len(train_loader),running_loss/total, (correct/total)))print()# validatenet.eval()val_loss = 0.0total = 0correct = 0acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():for _,val_data in enumerate(validate_loader,0):val_images, val_labels = val_datadevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")val_images, val_labels = val_images.to(device), val_labels.to(device)outputs = net(val_images)loss = loss_function(outputs, val_labels)val_loss += loss.item()_, predicted = torch.max(outputs.detach(), 1)total += val_labels.size(0)correct += predicted.eq(val_labels.data).sum().item()if correct / total > best_acc:best_acc = correct / total# print('*'*100, self.best_acc)torch.save(net.state_dict(), save_path)print('Validate-Loss:%.3f, Acc:%.3f' % (val_loss / total, correct / total))print('Finished Training')

4 总结

完整代码
https://github.com/cqfdch/BelgiumTSC-pytorch

觉得有帮助的可以给个赞,给个星!

小技巧(6):进行BelgiumTSC交通标志数据集识别(定义自己的数据集)相关推荐

  1. 基于深度学习的道路交通标志数字识别

    基于MATLAB深度学习的交通标志识别 课题介绍 交通标志识别技术的研究最早开始于奔驰等14家大型汽车公 司组成的民间组织所资助的Prometheus(Program for European Tra ...

  2. 车拍条件下交通标志实时识别

    一.项目设计与实现 1.1总体设计简介 交通标志有着显著的颜色和形状特征,主要功能以指示.提示和警示为主.标志的意义:警告标志用于警告车辆.行人注意危险地点:禁令标志用于禁止或限制车辆.行人的交通行为 ...

  3. matlab交通标志神经网络识别,基于神经网络的交通标志识别方法

    Municipal & Traffic Construction SCIENCE & TECHNOLOGY FOR DEVELOPMENT 149 基于神经网络的交通标志识别方法 赵丹 ...

  4. 基于Yolov5的交通标志检测识别设计

    项目介绍 上一篇文章介绍了基于卷积神经网络的交通标志分类识别Python交通标志识别基于卷积神经网络的保姆级教程(Tensorflow),并且最后实现了一个pyqt5的GUI界面,并且还制作了一个简单 ...

  5. 德国交通标志检测识别数据集

    http://benchmark.ini.rub.de/?section=gtsdb&subsection=dataset http://benchmark.ini.rub.de/?secti ...

  6. new 一个结构体数组_每天一个IDA小技巧(四):结构体识别

    之前提到IDA可以将一长串的数组数据声明变成一行数组声明,简化反汇编代码,对于结构体,IDA也同样支持通过各种设置工具来改善结构体代码的可读性. 这篇文章的目标是将[edx+10h]之类的结构体元素访 ...

  7. 小技巧(5):将TT100K数据集转成VOC格式,并且用Python脚本选出45类超过100张的图片和XML

    上一篇:小技巧(4):将txt中的某两列数据写入csv文件中,制作图像分类标签 文章目录 一.相关准备 1.1 下载数据集 1.2 下载代码文件 1.3 将相关文件移入代码文件 二.创建标准的VOC文 ...

  8. 基于深度学习的大规模交通标志识别(附6GB交通标志数据集)

    01 1.文章信息 <Deep Learning for Large-Scale Traffic-Sign Detection and Recognition>. 国外学者2020年发在I ...

  9. 【服务器数据集】中心服务器上存放(下载)的交通标志和交通信号(红绿灯)灯数据集 整理

    文章目录 交通标志 和 信号灯(红绿灯) 数据集整理(服务器) 0. 数据集整体说明如下: 1. 交通标志 1.1. 整理好的数据集(交通标志) 2. 信号灯(红绿灯) 2.1. 整理好的数据集(交通 ...

最新文章

  1. 第三周项目4(2)-顺序表应用 将所有奇数移到所有偶数前面
  2. MySQL【案例讲解】单行函数
  3. execjs执行js出现window对象未定义时的解决_10个常见的JS语言错误总汇
  4. P3193-[HNOI2008]GT考试【KMP,dp,矩阵乘法】
  5. python搜论文_python论文
  6. 插画素材 | 圣诞节设计离不了!
  7. java线程执行顺序执行_Java多线程系列四——控制线程执行顺序
  8. 用计算机制作标准曲线的方法,如何绘制标准曲线
  9. NSOperation
  10. pci-e串口卡linux 驱动下载,PCI/PCIe串口卡并口卡驱动
  11. c语言函数cot怎么表示,谁知道三角函数sin,cos,tan,cot之间的换算公式?
  12. CCF-CSP真题《202209-3—防疫大数据》思路+python题解
  13. android 4.4 5.1.1,兼容Android 4.4 搜狗输入法5.1版发布
  14. 如何用python爬虫获取百度贴吧内容
  15. DeepSpeech语音转文本合成技术
  16. 经典SQL练习——详细到令人发指(未完待续)
  17. 短视频高流量的秘诀,上热门全靠这些技巧
  18. PHP短信通知+语音播报自动双呼
  19. win10安装wsl 2.0子系统 安装在非C盘
  20. 轻松高效搭建可视化数据网站

热门文章

  1. [附源码]Java计算机毕业设计SSM电影院购票系统
  2. Fluke 17B+ 万用表与PC通讯(数据实时采集)-- 升级篇
  3. 高通MSM平台上的AMSS
  4. 人工智能如何改变半导体的分层技术
  5. vmpalyer.exe - 应用程序错误: 应用程序无法正常启动(0xc000007b)。请单击“确定
  6. Pytest框架 —— setUp()和tearDown()函数
  7. Dynamo For Revit:宜家小方桌
  8. Verilog编程艺术(3)——第四部分 高级设计
  9. 2022年长沙市湘雅三医院三基考试模拟试题及答案
  10. ncl如何添加线shp文件_教程合集 | NCL与GrADS地图绘制合集