多标签分类器

 多标签分类任务与多分类任务有所不同,多分类任务是将一个实例分到某个类别中,多标签分类任务是将某个实例分到多个类别中。多标签分类任务有有两大特点:

  • 类标数量不确定,有些样本可能只有一个类标,有些样本的类标可能高达几十甚至上百个
  • 类标之间相互依赖,例如包含蓝天类标的样本很大概率上包含白云

如下图所示,即为一个多标签分类学习的一个例子,一张图片里有多个类别,房子,树,云等,深度学习模型需要将其一一分类识别出来。

多标签分类器损失函数

 假设X=Rd\mathcal{X}=\mathbb{R}^dX=Rd表示ddd维样本空间,Y={y=(y1,y2,⋯,yn)∣yi∈{0,1},i=1,⋯,n}\mathcal{Y}=\{y=(y_1,y_2,\cdots,y_n)|y_i\in \{0,1\},i=1,\cdots,n\}Y={y=(y1​,y2​,⋯,yn​)∣yi​∈{0,1},i=1,⋯,n}表示nnn维标签空间。训练该多标签分类器的损失函数可以用二元交叉熵函数,该多标签分类器的最后一层为sigmoid\mathrm{sigmoid}sigmoid,多标签分类模型预测的概率向量为p=(p1,p2,⋯,pn)p=(p_1,p_2,\cdots,p_n)p=(p1​,p2​,⋯,pn​),其中pi∈[0,1](i=1,⋯,n)p_i \in [0,1](i=1,\cdots,n)pi​∈[0,1](i=1,⋯,n),此时真实标签分布yyy和预测概率分布ppp的二元损失函数为:loss1=−1n∑i=1n[yilog⁡pi+(1−yi)log⁡(1−pi)]\mathrm{loss1}=-\frac{1}{n}\sum\limits_{i=1}^n [y_i \log p_i+(1-y_i)\log(1-p_i)]loss1=−n1​i=1∑n​[yi​logpi​+(1−yi​)log(1−pi​)]

代码实现

  针对图像的多标签分类器pytorch的简化代码实现如下所示。因为图像的多标签分类器的数据集比较难获取,所以可以通过对mnist数据集中的每个图片打上特定的多标签,例如类别111的多标签可以为[1,1,0,1,0,1,0,0,1][1,1,0,1,0,1,0,0,1][1,1,0,1,0,1,0,0,1],然后再利用重新打标后的数据集训练出一个mnist的多标签分类器。

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import osclass CNN(nn.Module):def __init__(self):super().__init__()self.Sq1 = nn.Sequential(         nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),   # (16, 28, 28)                           #  output: (16, 28, 28)nn.ReLU(),                    nn.MaxPool2d(kernel_size=2),    # (16, 14, 14))self.Sq2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),  # (32, 14, 14)nn.ReLU(),                      nn.MaxPool2d(2),                # (32, 7, 7))self.out = nn.Linear(32 * 7 * 7, 100)  def forward(self, x):x = self.Sq1(x)x = self.Sq2(x)x = x.view(x.size(0), -1)    x = self.out(x)## Sigmoid activation   output = F.sigmoid(x)  # 1/(1+e**(-x))return outputdef loss_fn(pred, target):return -(target * torch.log(pred) + (1 - target) * torch.log(1 - pred)).sum()def multilabel_generate(label):Y1 = F.one_hot(label, num_classes = 100)Y2 = F.one_hot(label+10, num_classes = 100)Y3 = F.one_hot(label+50, num_classes = 100)  multilabel = Y1+Y2+Y3return multilabel# def multilabel_generate(label):
#   multilabel_dict = {}
#   multi_list = []
#   for i in range(label.shape[0]):
#       multi_list.append(multilabel_dict[label[i].item()])
#   multilabel_tensor = torch.tensor(multi_list)
#     return multilabeldef train():epoches = 10mnist_net = CNN()mnist_net.train()opitimizer = optim.SGD(mnist_net.parameters(), lr=0.002)mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= 128, shuffle=True)for epoch in range(epoches):loss = 0 for batch_X, batch_Y in train_loader:opitimizer.zero_grad()outputs = mnist_net(batch_X)loss = loss_fn(outputs, multilabel_generate(batch_Y)) / batch_X.shape[0]loss.backward()opitimizer.step()print(loss)if __name__ == '__main__':train()

多标签分类器(附pytorch代码)相关推荐

  1. mapbox 修改初始位置_一行代码教你如何随心所欲初始化Bert参数(附Pytorch代码详细解读)...

    微信公众号:NLP从入门到放弃 微信文章在这里(排版更漂亮,但是内置链接不太行,看大家喜欢哪个点哪个看吧): 一行代码带你随心所欲重新初始化bert的参数(附Pytorch代码详细解读)​mp.wei ...

  2. 聊一聊计算机视觉中常用的注意力机制 附Pytorch代码实现

    聊一聊计算机视觉中常用的注意力机制以及Pytorch代码实现 注意力机制(Attention)是深度学习中常用的tricks,可以在模型原有的基础上直接插入,进一步增强你模型的性能.注意力机制起初是作 ...

  3. python图片自动上色_老旧黑白片修复机——使用卷积神经网络图像自动着色实战(附PyTorch代码)...

    摘要: 照片承载了很多人在某个时刻的记忆,尤其是一些老旧的黑白照片,尘封于脑海之中,随着时间的流逝,记忆中对当时颜色的印象也会慢慢消散,这确实有些可惜.技术的发展会解决一些现有的难题,深度学习恰好能够 ...

  4. ViT结构详解(附pytorch代码)

    参考这篇文章,本文会加一些注解. 源自paper: AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE ...

  5. 给GAN一句描述,它就能按要求画画,微软CVPR新研究 | 附PyTorch代码

    晓查 发自 凹非寺  量子位 报道 | 公众号 QbitAI 让AI认得图像,根据自己的理解给出一段叙述,已经不是什么新鲜事了.从图像到文字容易,把这个过程反过来却很难. 让AI画图有了成熟的解决方案 ...

  6. Transformer 详解(上) — 编码器【附pytorch代码实现】

    Transformer 详解(上)编码器 Transformer结构 文本嵌入层 位置编码 注意力机制 编码器之多头注意力机制层 编码器之前馈全连接层 规范化层和残差连接 代码实现Transforme ...

  7. SegNet学习笔记(附Pytorch 代码)

    SegNet 的应用 SegNet常用于图像的语义分割.什么是语义分割了?,我们知道图像分割大致可以划分为三类,一类是语义分割.一类是实例分割,一类是全景分割,另外还有一些可以归为超像素分割.打个比方 ...

  8. 文本生成客观评价指标总结(附Pytorch代码实现)

    前言:最近在做文本生成的工作,调研发现针对不同的文本生成场景(机器翻译.对话生成.图像描述.data-to-text 等),客观评价指标也不尽相同.虽然网络上已经有很多关于文本生成评价指标的文章,本博 ...

  9. 【2DWT:2维离散小波变换(附Pytorch代码)】

    二维离散小波变换 一.相关基础 1.小波变换基础函数 2.小波变换 二.原理 三.基本小波基:哈尔小波 四.代码实现 参考: 图像信号具有非平稳特性,无法使用一种确定的数学模型来描述,而小波变换的多分 ...

最新文章

  1. 屏幕为什么要正负压供电_负压变换器的设计
  2. Android开发举步维艰,上弘法寺七七四十九天取得“真经”!
  3. 大话PHP设计模式:类自动载入、PSR-0规范、链式操作、11种面向对象设计模式实现和使用、OOP的基本原则和自动加载配置...
  4. 拥抱云原生,Fluid 结合 JindoFS:阿里云 OSS 加速利器
  5. ansa打开catia文件_关于CATIA文件格式的那些事儿
  6. [CF/AT/Luogu]各大网站网赛 爆肝部部长工作报告文件Ⅱ
  7. Java Iterator到Java 8 Stream
  8. 高颜值在线绘图平台ImageGP系列教程 - 参数介绍
  9. 一个HTTP请求,把网站打裂开了!
  10. 小目标神器!TPH-YOLOv5:将Transformer预测加载Yolov5!
  11. FireWire笔记
  12. Ubuntu 13.10 用sogou拼音替换ibus-转
  13. 面向对象设计模式与原则
  14. Android QFIL 烧录
  15. 虽迟但到,手眼标定代码实现篇
  16. Linux源码安装pgadmin4,pgAdmin4 - 搞定安装部署
  17. 计算机如何连接网络扫描仪,windows系统下怎么共享扫描仪?
  18. 软硬件全开源,航芯方案分享 | 热敏打印机方案
  19. 电话号码的字母组合---2022/01/23
  20. 「链节点活动年度总结」2019年区块链行业会议回顾

热门文章

  1. 用最复杂的方式学会数组(Python实现动态数组)
  2. python线程延时函数_延迟队列的python实现
  3. mac如何更新git版本
  4. 数据挖掘中的模式发现(一)频繁项集、频繁闭项集、最大频繁项集
  5. 07.RequestResponse
  6. js计算胎儿体重的代码
  7. 数据挖掘与数据分析的主要区别
  8. 2021数澜科技牛势开局,吹响新年冲锋号!
  9. c语言求子串 位置 长度,Manchester-求子串位置的定位函数
  10. Web开发基础-新闻页面-老九门