使用Mini-ImageNet训练分类网络
文章目录
- 数据集下载链接
- 数据集简介
- 制作新的train以及val文件
- 训练自己的网络
数据集下载链接
百度网盘下载:
链接: https://pan.baidu.com/s/1Uro6RuEbRGGCQ8iXvF2SAQ 密码: hl31
数据集简介
提到Imagenet大家都知道,是一个非常大型、有名的开源数据集。一般设计一个新的分类网络就会在Imagenet 1000类的数据上进行训练以及验证。包括常见的目标检测网络等,所使用的backbone一般都会先基于Imagenet进行预训练。但对于普通研究员或者开发者而言,这个数据集太大了(全部下载大概有100GB左右),而且训练对硬件要求也非常高,通常都是很多块高端显卡并行训练,即使是这样的配置通常还要训练好几天的时间。所以让很多人望而却步(我就是其中之一,关键太大,而且国内下载很慢)。
2016年google DeepMind团队从Imagnet数据集中抽取的一小部分(大小约3GB)制作了Mini-Imagenet数据集,共有100个类别,每个类别都有600张图片,共60000张(都是.jpg
结尾的文件),而且图像的大小并不是固定的。
![]() |
![]() |
![]() |
![]() |
数据集的结构为:
├── mini-imagenet: 数据集根目录├── images: 所有的图片都存在这个文件夹中├── train.csv: 对应训练集的标签文件├── val.csv: 对应验证集的标签文件└── test.csv: 对应测试集的标签文件
Mini-Imagenet数据集中还包含了train.csv
、val.csv
以及test.csv
三个文件。需要注意的是,当时作者制作这个数据集时主要是针对小样本学习领域的,而且提供的标签文件并不是从每个类别中进行采样的。我自己用pandas
包分析了下每个标签文件。
train.csv
包含38400张图片,共64个类别。val.csv
包含9600张图片,共16个类别。test.csv
包含12000张图片,共20个类别。
每个csv
文件之间的图像以及类别都是相互独立的,即共60000张图片,100个类。
用pandas
读取的csv
文件数据格式如下,每一行对应一张图片的名称和所属类别:
filename label
0 n0153282900000005.jpg n01532829
1 n0153282900000006.jpg n01532829
2 n0153282900000007.jpg n01532829
3 n0153282900000010.jpg n01532829
4 n0153282900000014.jpg n01532829
至于每个类别对应的实际物体名称,可查看这个json文件,这个文件是Imagenet1000类数据中对应的标签文件。
{"0": ["n01440764", "tench"], "1": ["n01443537", "goldfish"], "2": ["n01484850", "great_white_shark"],...
}
制作新的train以及val文件
根据上面分析的,如果想用Mini-Imgenet数据集直接去训练自己的分类网络是不可行的,因为train.csv
和val.csv
并不是从每个类别中进行采样的,所以我们需要自己去构建一个新的train.csv
和val.csv
文件。下面是我自己写的一个构建train.csv
和val.csv
标签文件的脚本,该脚本会从这100个类别中按给定的比例去划分训练集和验证集。
import os
import jsonimport pandas as pd
from PIL import Image
import matplotlib.pyplot as pltdef read_csv_classes(csv_dir: str, csv_name: str):data = pd.read_csv(os.path.join(csv_dir, csv_name))# print(data.head(1)) # filename, labellabel_set = set(data["label"].drop_duplicates().values)print("{} have {} images and {} classes.".format(csv_name,data.shape[0],len(label_set)))return data, label_setdef calculate_split_info(path: str, label_dict: dict, rate: float = 0.2):# read all imagesimage_dir = os.path.join(path, "images")images_list = [i for i in os.listdir(image_dir) if i.endswith(".jpg")]print("find {} images in dataset.".format(len(images_list)))train_data, train_label = read_csv_classes(path, "train.csv")val_data, val_label = read_csv_classes(path, "val.csv")test_data, test_label = read_csv_classes(path, "test.csv")# Union operationlabels = (train_label | val_label | test_label)labels = list(labels)labels.sort()print("all classes: {}".format(len(labels)))# create classes_name.jsonclasses_label = dict([(label, [index, label_dict[label]]) for index, label in enumerate(labels)])json_str = json.dumps(classes_label, indent=4)with open('classes_name.json', 'w') as json_file:json_file.write(json_str)# concat csv datadata = pd.concat([train_data, val_data, test_data], axis=0)print("total data shape: {}".format(data.shape))# split data on every classesnum_every_classes = []split_train_data = []split_val_data = []for label in labels:class_data = data[data["label"] == label]num_every_classes.append(class_data.shape[0])# shuffleshuffle_data = class_data.sample(frac=1, random_state=1)num_train_sample = int(class_data.shape[0] * (1 - rate))split_train_data.append(shuffle_data[:num_train_sample])split_val_data.append(shuffle_data[num_train_sample:])# imshowimshow_flag = Falseif imshow_flag:img_name, img_label = shuffle_data.iloc[0].valuesimg = Image.open(os.path.join(image_dir, img_name))plt.imshow(img)plt.title("class: " + classes_label[img_label][1])plt.show()# plot classes distributionplot_flag = Falseif plot_flag:plt.bar(range(1, 101), num_every_classes, align='center')plt.show()# concatenate datanew_train_data = pd.concat(split_train_data, axis=0)new_val_data = pd.concat(split_val_data, axis=0)# save new csv datanew_train_data.to_csv(os.path.join(path, "new_train.csv"))new_val_data.to_csv(os.path.join(path, "new_val.csv"))def main():data_dir = "/home/wz/mini-imagenet/" # 指向数据集的根目录json_path = "./imagenet_class_index.json" # 指向imagenet的索引标签文件# load imagenet labelslabel_dict = json.load(open(json_path, "r"))label_dict = dict([(v[0], v[1]) for k, v in label_dict.items()])calculate_split_info(data_dir, label_dict)if __name__ == '__main__':main()
训练自己的网络
项目地址:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing
在pytorch_classification
->mini-imagenet
文件夹中,里面提供了两个训练脚本,一个是针对单GPU的,一个是针对多GPU的。在这个项目中是以训练ShuffleNetv2为例进行讲解的。训练了100个epoch,达到了78%的准确率。
接着,我拿这个预训练权重去做迁移学习,训练其他的小数据集,确实也有一定帮助。在我测试过程中,如果不使用预训练权重,训练自己的数据集能达到80%的准确率,如果使用预训练权重能达到90%的准确率。当然基于Mini-Imagenet的预训练权重和基于Imagenet的预训练权重还有一些差距,毕竟数据量摆在这。之前使用基于Imagenet的预训练权重准确率可以达到94%。
当然,对于自己新搭的网络,如果想快速验证一下,Mini-Imagenet也是一个不错的选择。
使用Mini-ImageNet训练分类网络相关推荐
- YOLO如何训练分类网络???
一般YOLO的工程应用直接上检测,源自作者提供了分类的预训练模型,但是如果自己改网络训练怎么办?预训练网络没有怎么办? 不怕,两种方法可以解决! 第一种:比较笨的方法,就是下载imagenet数据训练 ...
- 【代码实验】CNN实验——利用Imagenet子集训练分类网络(AlexNet/ResNet)
文章目录 前言 一.数据准备 二.训练 三.结果 前言 Imagenet是计算机视觉的经典分类比赛,但是Imagenet数据集本身太大了,我们穷学生没有这么大的算力,2016年google DeepM ...
- tictoc正方形网络模型_Trick | 分类网络Trick大汇总
本文介绍了训练分类网络的各个阶段可以用来提升性能的Trick,也就是俗称的调参术.结果顶级调参术的调教,ResNet- 50的top-1验证精度在ImageNet上从75.3%提高到79.29%.这个 ...
- Caffe实践】如何利用Caffe训练ImageNet分类网络
Caffe实践]如何利用Caffe训练ImageNet分类网络 源文章:https://github.com/BVLC/caffe/tree/master/examples/imagenet 由于要使 ...
- 卷积神经网络:VGG16 是基于大量真实图像的 ImageNet 图像库预训练的网络
卷积神经网络:VGG16 是基于大量真实图像的 ImageNet 图像库预训练的网络 图片输入->卷积->填充(Padding)->池化(pooling)->平坦(Flatte ...
- 迁移学习实战 | 快速训练残差网络 ResNet-101,完成图像分类与预测,精度高达 98%!...
作者 | AI 菌 出品 | CSDN博客 头图 | CSDN付费下载自视觉中国 前言 笔者在实现ResNet的过程中,由于电脑性能原因,不得不选择层数较少的ResNet-18进行训练.但是很快发现, ...
- 这就是神经网络 1:早期分类网络之LeNet-5、AlexNet、ZFNet、OverFeat、VGG
概述 本系列文章计划介绍总结经典的神经网络结构,先介绍分类网络,后续会包括通用物体检测.语义分割,然后扩展到一些相对较细的领域如人脸检测.行人检测.行人重识别.姿态估计.文本检测等. 一些经典网络的年 ...
- 神经网络和深度学习(二)——一个简单的手写数字分类网络
本文转自:https://blog.csdn.net/qq_31192383/article/details/77198870 一个简单的手写数字分类网络 接上一篇文章,我们定义了神经网络,现在我们开 ...
- 细粒度分类网络 RACNN 论文翻译
racnn论文翻译 论文原地址http://openaccess.thecvf.com/content_cvpr_2017/papers/Fu_Look_Closer_to_CVPR_2017_pap ...
最新文章
- Android8.0后版本的分区变化
- Sql Server 调用DLL
- hdu1255 扫描线,矩形重叠面积(两次以上)
- PHP易混淆函数的区分
- vi的插入模式下退格和方向键不能使用的解决方法
- Win7启动修复MBR(Win7+Linux删除Linux后进入grub rescue的情况)
- 联想一体机电源键不亮_联想电脑一体机B505拆机经验
- IMAX影厅专候天神下凡 巨幕电影2010年观影指南
- java nio 客户端_Java网络编程:Netty框架学习(二)---Java NIO,实现简单的服务端客户端消息传输...
- java 模块化 soa_OSGI与SOA的千丝万缕
- 帆软报表如何传递主表原有参数给子表呢_报表工具--钻取功能--超链接下钻
- 用视频作为Mac动态壁纸Dynamic Wallpaper
- ll1语法分析程序c语言,c语言语法分析器,实现ll1分析
- IT桌面运维常识系列 -(Windows脚本)
- 米家插件平台的技术实践之路
- python做三维图片挑战眼力_查找「儿童大家来找茬图片」安卓应用 - 豌豆荚
- prometheus+grafana搭建监控平台监控压测服务器mysql性能
- FCN分割Pascal VOC 2007
- 【已恢复】苹果再堵开发者账号注册漏洞,黑市账号价格有价无市!
- not marked as ignorable