小技巧(6):进行BelgiumTSC交通标志数据集识别(定义自己的数据集)
小技巧(5):将TT100K数据集转成VOC格式,并且用Python脚本选出45类超过100张的图片和XML
文章目录
- 1 数据预处理
- 1.1 下载数据集
- 1.2 制作bs_dataset
- 1.3 制作bs_dataloader
- 2 搭建网络模型
- 3 训练和测试模型
- 4 总结
【本文完整代码:https://github.com/cqfdch/BelgiumTSC-pytorch】
图像分类是计算机视觉的基础,pytorch做图像分类都是模块化的操作,主要包括数据预处理,搭建网络模型,训练和测试模型。
1 数据预处理
1.1 下载数据集
能拿来训练的数据集=数据集+标签文件
方法一:可以去官网https://btsd.ethz.ch/shareddata/下载数据集,然后进行自行处理
方法二:(1积分,看着支持就行)通过csdn资源下载(含数据集+标签文件)
https://download.csdn.net/download/Hankerchen/13073778
方法三:我的百度网盘(含数据集+标签文件)链接: https://pan.baidu.com/s/1JYWEFYFJCSRsVPmBfkauPQ 密码: wqtv
1.2 制作bs_dataset
根据pytorch标准的制作数据集的三段式
def init(self):
def len(self):
def getitem(self, idx):
import torch
import os
import pandas as pd
from torch.utils.data import Dataset
import numpy as np
from PIL import Imageclass BelgiumTSC(Dataset):base_folder = 'BelgiumTSC'def __init__(self, root_dir, train=False, transform=None):"""Args:train (bool): Load trainingset or test set.root_dir (string): Directory containing GTSRB folder.transform (callable, optional): Optional transform to be appliedon a sample."""self.root_dir = root_dirself.sub_directory = 'Training' if train else 'Testing'self.csv_file_name = 'train_data.csv' if train else 'test_data.csv'csv_file_path = os.path.join(root_dir, self.base_folder, self.sub_directory, self.csv_file_name)self.csv_data = pd.read_csv(csv_file_path)self.transform = transformdef __len__(self):return len(self.csv_data)def __getitem__(self, idx):img_path = os.path.join(self.root_dir, self.base_folder, self.sub_directory,self.csv_data.iloc[idx, 0])img = Image.open(img_path)classId = self.csv_data.iloc[idx, 1]if self.transform is not None:img = self.transform(img)return img, classId
1.3 制作bs_dataloader
参考链接https://github.com/tomlawrenceuk/GTSRB-Dataloader
通过torch.utils.data.DataLoader()将bs_dataset导入bs_loader。
import bs_dataset as dataset
import torchvision.transforms as transforms
import torchdef get_train_valid_loader(data_dir,batch_size,num_workers=0,):# Create Transformstransform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.3403, 0.3121, 0.3214),(0.2724, 0.2608, 0.2669))])# Create Datasetstrainset = dataset.BelgiumTSC(root_dir=data_dir, train=True, transform=transform)testset = dataset.BelgiumTSC(root_dir=data_dir, train=False, transform=transform)# Load Datasetstrainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)return trainloader, testloader
2 搭建网络模型
网络模型有很多,AlexNet,VGGNet,ResNet等,注意要数据图片尺寸。
具体的代码可以看我的github
https://github.com/cqfdch/BelgiumTSC-pytorch
3 训练和测试模型
参考链接
https://blog.csdn.net/qq_37541097/article/details/104710784
将前面制作好的bs_loader进行训练,每个epoch测试一次。
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import bs_loader
# from model import AlexNet
from model import Modelprint(torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")train_loader, validate_loader = bs_loader.get_train_valid_loader('D:\\Networks\\cnn-ga-master\\data', batch_size=32, num_workers=0)net = Model()
net.to(device)
loss_function = nn.CrossEntropyLoss()
# pata = list(net.parameters())
optimizer = optim.Adam(net.parameters(), lr=0.0002)
epoch = 30
save_path = './model.pth'
best_acc = 0.0
for epoch in range(epoch):# trainnet.train()running_loss = 0.0total = 0correct = 0show_step = 32for step, data in enumerate(tqdm(train_loader),0):images, labels = datadevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = net(images)loss = loss_function(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()_, predicted = torch.max(outputs.detach(), 1)total += labels.size(0)correct += predicted.eq(labels.data).sum().item()# print train process# rate = (step+1)/len(train_loader)# a = "*" * int(rate * 50)# b = "." * int((1 - rate) * 50)# print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")# print('Train-Epoch:%3d, %3d / %3d ,Loss: %.3f, Acc:%.3f'% (epoch+1, step+1, len(train_loader),running_loss/total, (correct/total)))if step % show_step == 0:print("Epoch [{}][{}/{}]:Loss:{:.3f},Acc:{:.3f}".format(epoch+1, step+1, len(train_loader),running_loss/total, (correct/total)))print()# validatenet.eval()val_loss = 0.0total = 0correct = 0acc = 0.0 # accumulate accurate number / epochwith torch.no_grad():for _,val_data in enumerate(validate_loader,0):val_images, val_labels = val_datadevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")val_images, val_labels = val_images.to(device), val_labels.to(device)outputs = net(val_images)loss = loss_function(outputs, val_labels)val_loss += loss.item()_, predicted = torch.max(outputs.detach(), 1)total += val_labels.size(0)correct += predicted.eq(val_labels.data).sum().item()if correct / total > best_acc:best_acc = correct / total# print('*'*100, self.best_acc)torch.save(net.state_dict(), save_path)print('Validate-Loss:%.3f, Acc:%.3f' % (val_loss / total, correct / total))print('Finished Training')
4 总结
完整代码
https://github.com/cqfdch/BelgiumTSC-pytorch
觉得有帮助的可以给个赞,给个星!
小技巧(6):进行BelgiumTSC交通标志数据集识别(定义自己的数据集)相关推荐
- 基于深度学习的道路交通标志数字识别
基于MATLAB深度学习的交通标志识别 课题介绍 交通标志识别技术的研究最早开始于奔驰等14家大型汽车公 司组成的民间组织所资助的Prometheus(Program for European Tra ...
- 车拍条件下交通标志实时识别
一.项目设计与实现 1.1总体设计简介 交通标志有着显著的颜色和形状特征,主要功能以指示.提示和警示为主.标志的意义:警告标志用于警告车辆.行人注意危险地点:禁令标志用于禁止或限制车辆.行人的交通行为 ...
- matlab交通标志神经网络识别,基于神经网络的交通标志识别方法
Municipal & Traffic Construction SCIENCE & TECHNOLOGY FOR DEVELOPMENT 149 基于神经网络的交通标志识别方法 赵丹 ...
- 基于Yolov5的交通标志检测识别设计
项目介绍 上一篇文章介绍了基于卷积神经网络的交通标志分类识别Python交通标志识别基于卷积神经网络的保姆级教程(Tensorflow),并且最后实现了一个pyqt5的GUI界面,并且还制作了一个简单 ...
- 德国交通标志检测识别数据集
http://benchmark.ini.rub.de/?section=gtsdb&subsection=dataset http://benchmark.ini.rub.de/?secti ...
- new 一个结构体数组_每天一个IDA小技巧(四):结构体识别
之前提到IDA可以将一长串的数组数据声明变成一行数组声明,简化反汇编代码,对于结构体,IDA也同样支持通过各种设置工具来改善结构体代码的可读性. 这篇文章的目标是将[edx+10h]之类的结构体元素访 ...
- 小技巧(5):将TT100K数据集转成VOC格式,并且用Python脚本选出45类超过100张的图片和XML
上一篇:小技巧(4):将txt中的某两列数据写入csv文件中,制作图像分类标签 文章目录 一.相关准备 1.1 下载数据集 1.2 下载代码文件 1.3 将相关文件移入代码文件 二.创建标准的VOC文 ...
- 基于深度学习的大规模交通标志识别(附6GB交通标志数据集)
01 1.文章信息 <Deep Learning for Large-Scale Traffic-Sign Detection and Recognition>. 国外学者2020年发在I ...
- 【服务器数据集】中心服务器上存放(下载)的交通标志和交通信号(红绿灯)灯数据集 整理
文章目录 交通标志 和 信号灯(红绿灯) 数据集整理(服务器) 0. 数据集整体说明如下: 1. 交通标志 1.1. 整理好的数据集(交通标志) 2. 信号灯(红绿灯) 2.1. 整理好的数据集(交通 ...
最新文章
- 第三周项目4(2)-顺序表应用 将所有奇数移到所有偶数前面
- MySQL【案例讲解】单行函数
- execjs执行js出现window对象未定义时的解决_10个常见的JS语言错误总汇
- P3193-[HNOI2008]GT考试【KMP,dp,矩阵乘法】
- python搜论文_python论文
- 插画素材 | 圣诞节设计离不了!
- java线程执行顺序执行_Java多线程系列四——控制线程执行顺序
- 用计算机制作标准曲线的方法,如何绘制标准曲线
- NSOperation
- pci-e串口卡linux 驱动下载,PCI/PCIe串口卡并口卡驱动
- c语言函数cot怎么表示,谁知道三角函数sin,cos,tan,cot之间的换算公式?
- CCF-CSP真题《202209-3—防疫大数据》思路+python题解
- android 4.4 5.1.1,兼容Android 4.4 搜狗输入法5.1版发布
- 如何用python爬虫获取百度贴吧内容
- DeepSpeech语音转文本合成技术
- 经典SQL练习——详细到令人发指(未完待续)
- 短视频高流量的秘诀,上热门全靠这些技巧
- PHP短信通知+语音播报自动双呼
- win10安装wsl 2.0子系统 安装在非C盘
- 轻松高效搭建可视化数据网站
热门文章
- [附源码]Java计算机毕业设计SSM电影院购票系统
- Fluke 17B+ 万用表与PC通讯(数据实时采集)-- 升级篇
- 高通MSM平台上的AMSS
- 人工智能如何改变半导体的分层技术
- vmpalyer.exe - 应用程序错误: 应用程序无法正常启动(0xc000007b)。请单击“确定
- Pytest框架 —— setUp()和tearDown()函数
- Dynamo For Revit:宜家小方桌
- Verilog编程艺术(3)——第四部分 高级设计
- 2022年长沙市湘雅三医院三基考试模拟试题及答案
- ncl如何添加线shp文件_教程合集 | NCL与GrADS地图绘制合集