Multi-Label Classification

首先分清一下multiclass和multilabel:

  • 多类分类(Multiclass classification): 表示分类任务中有多个类别, 且假设每个样本都被设置了一个且仅有一个标签。比如从100个分类中击中一个。
  • 多标签分类(Multilabel classification): 给每个样本一系列的目标标签,即表示的是样本各属性而不是相互排斥的。比如图片中有很多的概念如天空海洋人等等,需要预测出一个概念集合。

Challenge

多标签任务的难度主要集中在以下问题:

  • 标签数量较大且基本会呈现长尾形态。
  • 往往类标之间相互依赖并不独立。
  • absence标签占比较高,即标注的标签并不能完美覆盖所有概念面。
  • 标签往往较短语义少,理解困难。

Solution

现有的方法应对multi的预测主要有2大路线:

  • 改造数据适应算法:将多个类别合并成单个类别。
  • 改造算法适应数据:控制激活函数阈值得到结果。

而一般研究最多的应对relation会有3种策略:
一阶策略:忽略和其它标签的相关性,比如把多标签分解成多个独立的二分类问题。
二阶策略:考虑标签之间的成对关联,比如为相关标签和不相关标签排序。
高阶策略:考虑多个标签之间的关联,比如对每个标签考虑所有其它标签的影响。

Densenet

它的基本思路与ResNet一致,但是它建立的是前面所有层与后面层的密集连接(dense connection),它的名称也是由此而来。DenseNet的另一大特色是通过特征在channel上的连接来实现特征重用(feature reuse)。这些特点让DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能,DenseNet也因此斩获CVPR 2017的最佳论文奖。

DenseBlock


相比ResNet,DenseNet提出了一个更激进的密集连接机制:即互相连接所有的层,具体来说就是每个层都会接受其前面所有层作为其额外的输入。图1为ResNet网络的连接机制,作为对比,图2为DenseNet的密集连接机制。可以看到,ResNet是每个层与前面的某层(一般是2~3层)短路连接在一起,连接方式是通过元素级相加。而在DenseNet中,每个层都会与前面所有层在channel维度上连接(concat)在一起(这里各个层的特征图大小是相同的,后面会有说明),并作为下一层的输入。对于一个 L 层的网络,包含个连接,相比ResNet,这是一种密集连接。而且DenseNet是直接concat来自不同层的特征图,这可以实现特征重用,提升效率,这一特点是DenseNet与ResNet最主要的区别。



整体网络结构


CNN网络一般要经过Pooling或者stride>1的Conv来降低特征图的大小,而DenseNet的密集连接方式需要特征图大小保持一致。为了解决这个问题,DenseNet网络中使用DenseBlock+Transition的结构,其中DenseBlock是包含很多层的模块,每个层的特征图大小相同,层与层之间采用密集连接方式。而Transition模块是连接两个相邻的DenseBlock,并且通过Pooling使特征图大小降低。上图给出了DenseNet的网络结构,它共包含3个DenseBlock,各个DenseBlock之间通过Transition连接在一起。Transition层包括一个1x1的卷积和2x2的AvgPooling,结构为BN+ReLU+1x1 Conv+2x2 AvgPooling。另外,Transition层可以起到压缩模型的作用。

原论文实验结果


综合来看,DenseNet的优势主要体现在以下几个方面:

  • 由于密集连接方式,DenseNet提升了梯度的反向传播,使得网络更容易训练。由于每层可以直达最后的误差信号,实现了隐式的“deep supervision”;
  • 参数更小且计算更高效,这有点违反直觉,由于DenseNet是通过concat特征来实现短路连接,实现了特征重用,并且采用较小的growth rate,每个层所独有的特征图是比较小的;
  • 由于特征复用,最后的分类器使用了低级特征。

服装多标签分类小实验

数据划分

总数据量:5547
训练(4993):测试(554) = 9 :1


def read_split_data(root: str, test_rate: float = 0.1):random.seed(0)  # 保证随机结果可复现assert os.path.exists(root), "dataset root: {} does not exist.".format(root)# 拿到所有类别class_ = set()for cla in os.listdir(root):class_.add(cla.split('_')[0])class_.add(cla.split('_')[1])class_ = list(class_)class_.sort()# 建立类别索引并存储class_indices = dict((k, v) for v, k in enumerate(class_))json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)# 读取所有图像路径和对应类别索引train_images_path = []  # 存储训练集的所有图片路径train_images_label = []  # 存储训练集图片对应索引信息val_images_path = []  # 存储验证集的所有图片路径val_images_label = []  # 存储验证集图片对应索引信息supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型# onehot编码形式表示出每张图像的labelimages_path_and_onehot = {}for dir_ in os.listdir(root):for img_name in os.listdir(os.path.join(root, dir_)):image_path = os.path.join(root, dir_, img_name)onehot_class = [0] * 9# print(str(image_path), str(image_path).split('\\'))class0, class1 = str(image_path).split('\\')[-2].split('_')[0], image_path.split('\\')[-2].split('_')[1]idx0, idx1 = class_indices[class0], class_indices[class1]onehot_class[idx0], onehot_class[idx1] = 1, 1images_path_and_onehot[image_path] = onehot_class# 随机抽取相应比例的数据作为测试集test_path = random.sample(list(images_path_and_onehot), k=int(len(list(images_path_and_onehot)) * test_rate))# 分别存储训练和测试的图像路径及其对应onehot标签for image_path in images_path_and_onehot.keys():if image_path in test_path:  # 如果该路径在采样的验证集样本中则存入验证集val_images_path.append(image_path)val_images_label.append(images_path_and_onehot[image_path])else:  # 否则存入训练集train_images_path.append(image_path)train_images_label.append(images_path_and_onehot[image_path])print("{} images were found in the dataset.".format(len(images_path_and_onehot.keys())))print("{} images for training.".format(len(train_images_path)))print("{} images for validation.".format(len(val_images_path)))return train_images_path, train_images_label, val_images_path, val_images_label

模型

  • 使用densenet121网络,
  • loss函数:二值交叉熵
  • pretrain:imagenet 1000k
  • lr: 0.0001
  • epoches: 50(实际跑42epoch就收敛了)
  • scheduler:余弦衰减

loss

结果评估


部分测试图像预测可视化:

【参考】
https://zhuanlan.zhihu.com/p/37189203
https://nakaizura.blog.csdn.net/article/details/114753747?spm=1001.2014.3001.5506

多标签分类任务-服装分类相关推荐

  1. 【小白学PyTorch】15.TF2实现一个简单的服装分类任务

    <<小白学PyTorch>> 小白学PyTorch | 14 tensorboardX可视化教程 小白学PyTorch | 13 EfficientNet详解及PyTorch实 ...

  2. KNN分类器、最近邻分类、KD树、KNN分类的最佳K值、基于半径的最近邻分类器、KNN多分类、KNN多标签分类、KNN多输出分类、KNN分类的优缺点

    KNN分类器.最近邻分类.KD树.KNN分类的最佳K值.基于半径的最近邻分类器.KNN多分类.KNN多标签分类.KNN多输出分类.KNN分类的优缺点 目录

  3. 分类家族:二分类、多分类、多标签分类、多输出分类

    分类家族:二分类.多分类.多标签分类.多输出分类 目录 分类家族:二分类.多分类.多标签分类.多输出分类 二分类

  4. 机器学习之深度学习 二分类、多分类、多标签分类、多任务分类

    多任务学习可以运用到许多的场景. 首先,多任务学习可以学到多个任务的共享表示,这个共享表示具有较强的抽象能力,能够适应多个不同但相关的目标,通常可以使主任务获取更好的泛化能力. 此外,由于使用了共享表 ...

  5. 深度学习模型处理多标签(multi_label)分类任务——keras实战

    深度学习模型处理多标签(multi_label)分类任务--keras实战 https://zhuanlan.zhihu.com/p/107737824

  6. 二分类、多分类、多标签分类等

    分类一般分为三种情况:二分类.多分类和多标签分类.多标签分类比较直观的理解是,一个样本可以同时拥有几个类别标签,比如一张图片含有车子.房子等,那么它的标签可以是房子.车子等,一部电影的标签可以是动作. ...

  7. 深度学习:LeNet-5实现服装分类(PyTorch)

    深度学习:LeNet-5代码实践(PyTorch) 前置知识 LeNet-5模型详解 代码实战 服装分类数据集 定义模型 测试数据 训练模型 结果展示 前置知识 卷积神经网络详细指南 SGD+动量法 ...

  8. 二分类、多分类和多标签分类

    1.基本概念 二分类:表示分类任务中有两个类别,比如我们想识别一幅图片是不是猫.也就是说,训练一个分类器,输入一幅图片,用特征向量x表示,输出是不是猫,用y=0或1表示.二类分类是假设每个样本都被设置 ...

  9. 关于二分类,多分类,及多标签分类的损失函数详解及Pytorch实现

    相信很多小伙伴最开始都是从分类任务入手深度学习这个领域的吧,这个就类似学习代码的第一课,"Hello world"一样.深度学习中,除了模型设计之外,最重要的想必就是选取合适的损失 ...

最新文章

  1. 性能超过人类炼丹师,AutoGluon 低调开源
  2. 为在innodb中什么主键用auto_increment效率会提高
  3. matlab生成wav文件并用python验证
  4. JS向对象中添加和删除属性
  5. 小程序二级页面tabbar_小程序页面推广踩坑记
  6. java等待_Java学习:等待唤醒机制
  7. JavaScript鼠标经过图片晃动效果
  8. python 中的 type(), dtype(), astype()的区别
  9. 计算机内存外存共同点,存储器:内存和外存
  10. 4用计算机显示内存不足,电脑提示内存不足的解决方法总汇
  11. 机械专业与计算机专业哪个专业更好,机械类哪个专业好
  12. 解决谷歌disabled状态下操作问题
  13. java 复制excel_利用Java实现复制Excel工作表功能
  14. Kjava林林工具箱源代码(jbuilder工程)
  15. 高中信息技术教资科目三总结
  16. 浏览器显示海康摄像头实时预览画面纯前端解决方案
  17. video视频相关问题:火狐浏览器报错“没有找到支持的视频格式和MIME类型”
  18. C++ 学习(基础语法篇)
  19. thinkphp5常用函数汇总_THINKSNS常用函数
  20. (java)使用createNewFile提示系统找不到指定路径

热门文章

  1. 最小生成树算法普利姆算法和克鲁斯卡尔算法实现
  2. java之yield(),sleep(),wait()区别详解-备忘笔记[转]
  3. 深度强化学习落地指南总结(二)-动作空间设计
  4. python的networkx 算法_Python NetworkX 学习笔记
  5. html中伪类的兼容性,css,ie8兼容性_CSS 伪类在IE8中样式无法生效,css,ie8兼容性,伪类 - phpStudy...
  6. 【CG】透视变换(Perspective Transformation)
  7. Oracle如何将GMT时间(即格林尼标准时间)转换成标准时间格式
  8. L1-008 求整数段和(Python3)
  9. 香奈儿旗下标志性酒庄将开启酒窖珍藏参与苏富比葡萄酒拍卖盛典
  10. linux内网编译源码包,Netkeeper For Linux(含源码)