文章目录

  • 数据集下载链接
  • 数据集简介
  • 制作新的train以及val文件
  • 训练自己的网络

数据集下载链接

百度网盘下载:
链接: https://pan.baidu.com/s/1Uro6RuEbRGGCQ8iXvF2SAQ 密码: hl31

数据集简介

提到Imagenet大家都知道,是一个非常大型、有名的开源数据集。一般设计一个新的分类网络就会在Imagenet 1000类的数据上进行训练以及验证。包括常见的目标检测网络等,所使用的backbone一般都会先基于Imagenet进行预训练。但对于普通研究员或者开发者而言,这个数据集太大了(全部下载大概有100GB左右),而且训练对硬件要求也非常高,通常都是很多块高端显卡并行训练,即使是这样的配置通常还要训练好几天的时间。所以让很多人望而却步(我就是其中之一,关键太大,而且国内下载很慢)。

2016年google DeepMind团队从Imagnet数据集中抽取的一小部分(大小约3GB)制作了Mini-Imagenet数据集,共有100个类别,每个类别都有600张图片,共60000张(都是.jpg结尾的文件),而且图像的大小并不是固定的。

数据集的结构为:

├── mini-imagenet: 数据集根目录├── images: 所有的图片都存在这个文件夹中├── train.csv: 对应训练集的标签文件├── val.csv: 对应验证集的标签文件└── test.csv: 对应测试集的标签文件

Mini-Imagenet数据集中还包含了train.csvval.csv以及test.csv三个文件。需要注意的是,当时作者制作这个数据集时主要是针对小样本学习领域的,而且提供的标签文件并不是从每个类别中进行采样的。我自己用pandas包分析了下每个标签文件。

  • train.csv包含38400张图片,共64个类别。
  • val.csv包含9600张图片,共16个类别。
  • test.csv包含12000张图片,共20个类别。

每个csv文件之间的图像以及类别都是相互独立的,即共60000张图片,100个类。

pandas读取的csv文件数据格式如下,每一行对应一张图片的名称和所属类别:

                filename      label
0  n0153282900000005.jpg  n01532829
1  n0153282900000006.jpg  n01532829
2  n0153282900000007.jpg  n01532829
3  n0153282900000010.jpg  n01532829
4  n0153282900000014.jpg  n01532829

至于每个类别对应的实际物体名称,可查看这个json文件,这个文件是Imagenet1000类数据中对应的标签文件。

{"0": ["n01440764", "tench"], "1": ["n01443537", "goldfish"], "2": ["n01484850", "great_white_shark"],...
}

制作新的train以及val文件

根据上面分析的,如果想用Mini-Imgenet数据集直接去训练自己的分类网络是不可行的,因为train.csvval.csv并不是从每个类别中进行采样的,所以我们需要自己去构建一个新的train.csvval.csv文件。下面是我自己写的一个构建train.csvval.csv标签文件的脚本,该脚本会从这100个类别中按给定的比例去划分训练集和验证集。

import os
import jsonimport pandas as pd
from PIL import Image
import matplotlib.pyplot as pltdef read_csv_classes(csv_dir: str, csv_name: str):data = pd.read_csv(os.path.join(csv_dir, csv_name))# print(data.head(1))  # filename, labellabel_set = set(data["label"].drop_duplicates().values)print("{} have {} images and {} classes.".format(csv_name,data.shape[0],len(label_set)))return data, label_setdef calculate_split_info(path: str, label_dict: dict, rate: float = 0.2):# read all imagesimage_dir = os.path.join(path, "images")images_list = [i for i in os.listdir(image_dir) if i.endswith(".jpg")]print("find {} images in dataset.".format(len(images_list)))train_data, train_label = read_csv_classes(path, "train.csv")val_data, val_label = read_csv_classes(path, "val.csv")test_data, test_label = read_csv_classes(path, "test.csv")# Union operationlabels = (train_label | val_label | test_label)labels = list(labels)labels.sort()print("all classes: {}".format(len(labels)))# create classes_name.jsonclasses_label = dict([(label, [index, label_dict[label]]) for index, label in enumerate(labels)])json_str = json.dumps(classes_label, indent=4)with open('classes_name.json', 'w') as json_file:json_file.write(json_str)# concat csv datadata = pd.concat([train_data, val_data, test_data], axis=0)print("total data shape: {}".format(data.shape))# split data on every classesnum_every_classes = []split_train_data = []split_val_data = []for label in labels:class_data = data[data["label"] == label]num_every_classes.append(class_data.shape[0])# shuffleshuffle_data = class_data.sample(frac=1, random_state=1)num_train_sample = int(class_data.shape[0] * (1 - rate))split_train_data.append(shuffle_data[:num_train_sample])split_val_data.append(shuffle_data[num_train_sample:])# imshowimshow_flag = Falseif imshow_flag:img_name, img_label = shuffle_data.iloc[0].valuesimg = Image.open(os.path.join(image_dir, img_name))plt.imshow(img)plt.title("class: " + classes_label[img_label][1])plt.show()# plot classes distributionplot_flag = Falseif plot_flag:plt.bar(range(1, 101), num_every_classes, align='center')plt.show()# concatenate datanew_train_data = pd.concat(split_train_data, axis=0)new_val_data = pd.concat(split_val_data, axis=0)# save new csv datanew_train_data.to_csv(os.path.join(path, "new_train.csv"))new_val_data.to_csv(os.path.join(path, "new_val.csv"))def main():data_dir = "/home/wz/mini-imagenet/"  # 指向数据集的根目录json_path = "./imagenet_class_index.json"  # 指向imagenet的索引标签文件# load imagenet labelslabel_dict = json.load(open(json_path, "r"))label_dict = dict([(v[0], v[1]) for k, v in label_dict.items()])calculate_split_info(data_dir, label_dict)if __name__ == '__main__':main()

训练自己的网络

项目地址:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing
pytorch_classification->mini-imagenet文件夹中,里面提供了两个训练脚本,一个是针对单GPU的,一个是针对多GPU的。在这个项目中是以训练ShuffleNetv2为例进行讲解的。训练了100个epoch,达到了78%的准确率。

接着,我拿这个预训练权重去做迁移学习,训练其他的小数据集,确实也有一定帮助。在我测试过程中,如果不使用预训练权重,训练自己的数据集能达到80%的准确率,如果使用预训练权重能达到90%的准确率。当然基于Mini-Imagenet的预训练权重和基于Imagenet的预训练权重还有一些差距,毕竟数据量摆在这。之前使用基于Imagenet的预训练权重准确率可以达到94%。
当然,对于自己新搭的网络,如果想快速验证一下,Mini-Imagenet也是一个不错的选择。

使用Mini-ImageNet训练分类网络相关推荐

  1. YOLO如何训练分类网络???

    一般YOLO的工程应用直接上检测,源自作者提供了分类的预训练模型,但是如果自己改网络训练怎么办?预训练网络没有怎么办? 不怕,两种方法可以解决! 第一种:比较笨的方法,就是下载imagenet数据训练 ...

  2. 【代码实验】CNN实验——利用Imagenet子集训练分类网络(AlexNet/ResNet)

    文章目录 前言 一.数据准备 二.训练 三.结果 前言 Imagenet是计算机视觉的经典分类比赛,但是Imagenet数据集本身太大了,我们穷学生没有这么大的算力,2016年google DeepM ...

  3. tictoc正方形网络模型_Trick | 分类网络Trick大汇总

    本文介绍了训练分类网络的各个阶段可以用来提升性能的Trick,也就是俗称的调参术.结果顶级调参术的调教,ResNet- 50的top-1验证精度在ImageNet上从75.3%提高到79.29%.这个 ...

  4. Caffe实践】如何利用Caffe训练ImageNet分类网络

    Caffe实践]如何利用Caffe训练ImageNet分类网络 源文章:https://github.com/BVLC/caffe/tree/master/examples/imagenet 由于要使 ...

  5. 卷积神经网络:VGG16 是基于大量真实图像的 ImageNet 图像库预训练的网络

    卷积神经网络:VGG16 是基于大量真实图像的 ImageNet 图像库预训练的网络 图片输入->卷积->填充(Padding)->池化(pooling)->平坦(Flatte ...

  6. 迁移学习实战 | 快速训练残差网络 ResNet-101,完成图像分类与预测,精度高达 98%!...

    作者 | AI 菌 出品 | CSDN博客 头图 | CSDN付费下载自视觉中国 前言 笔者在实现ResNet的过程中,由于电脑性能原因,不得不选择层数较少的ResNet-18进行训练.但是很快发现, ...

  7. 这就是神经网络 1:早期分类网络之LeNet-5、AlexNet、ZFNet、OverFeat、VGG

    概述 本系列文章计划介绍总结经典的神经网络结构,先介绍分类网络,后续会包括通用物体检测.语义分割,然后扩展到一些相对较细的领域如人脸检测.行人检测.行人重识别.姿态估计.文本检测等. 一些经典网络的年 ...

  8. 神经网络和深度学习(二)——一个简单的手写数字分类网络

    本文转自:https://blog.csdn.net/qq_31192383/article/details/77198870 一个简单的手写数字分类网络 接上一篇文章,我们定义了神经网络,现在我们开 ...

  9. 细粒度分类网络 RACNN 论文翻译

    racnn论文翻译 论文原地址http://openaccess.thecvf.com/content_cvpr_2017/papers/Fu_Look_Closer_to_CVPR_2017_pap ...

最新文章

  1. Android8.0后版本的分区变化
  2. Sql Server 调用DLL
  3. hdu1255 扫描线,矩形重叠面积(两次以上)
  4. PHP易混淆函数的区分
  5. vi的插入模式下退格和方向键不能使用的解决方法
  6. Win7启动修复MBR(Win7+Linux删除Linux后进入grub rescue的情况)
  7. 联想一体机电源键不亮_联想电脑一体机B505拆机经验
  8. IMAX影厅专候天神下凡 巨幕电影2010年观影指南
  9. java nio 客户端_Java网络编程:Netty框架学习(二)---Java NIO,实现简单的服务端客户端消息传输...
  10. java 模块化 soa_OSGI与SOA的千丝万缕
  11. 帆软报表如何传递主表原有参数给子表呢_报表工具--钻取功能--超链接下钻
  12. 用视频作为Mac动态壁纸Dynamic Wallpaper
  13. ll1语法分析程序c语言,c语言语法分析器,实现ll1分析
  14. IT桌面运维常识系列 -(Windows脚本)
  15. 米家插件平台的技术实践之路
  16. python做三维图片挑战眼力_查找「儿童大家来找茬图片」安卓应用 - 豌豆荚
  17. prometheus+grafana搭建监控平台监控压测服务器mysql性能
  18. FCN分割Pascal VOC 2007
  19. 【已恢复】苹果再堵开发者账号注册漏洞,黑市账号价格有价无市!
  20. not marked as ignorable

热门文章

  1. 结合Free to Earn和Play to Earn,Monsterra在GameFi领域的尝试
  2. 使用CGlib实现Bean拷贝(BeanCopier)
  3. idea注释模版配置(吐血推荐!!!)
  4. Reward Machines for Cooperative Multi-Agent Reinforcement Learning论文阅读
  5. linux文本剪切命令---cut
  6. css 图片使用过滤器
  7. 你可以在虚拟世界里过上美好生活吗?
  8. DAZ设置dForce模拟权重
  9. 学完计算机图形学可以做什么,计算机图形学心得体会.doc
  10. 百度 搜索原理 如何 应对百度 的封杀 和 降权