多标签分类器(附pytorch代码)
多标签分类器
多标签分类任务与多分类任务有所不同,多分类任务是将一个实例分到某个类别中,多标签分类任务是将某个实例分到多个类别中。多标签分类任务有有两大特点:
- 类标数量不确定,有些样本可能只有一个类标,有些样本的类标可能高达几十甚至上百个
- 类标之间相互依赖,例如包含蓝天类标的样本很大概率上包含白云
如下图所示,即为一个多标签分类学习的一个例子,一张图片里有多个类别,房子,树,云等,深度学习模型需要将其一一分类识别出来。
多标签分类器损失函数
假设X=Rd\mathcal{X}=\mathbb{R}^dX=Rd表示ddd维样本空间,Y={y=(y1,y2,⋯,yn)∣yi∈{0,1},i=1,⋯,n}\mathcal{Y}=\{y=(y_1,y_2,\cdots,y_n)|y_i\in \{0,1\},i=1,\cdots,n\}Y={y=(y1,y2,⋯,yn)∣yi∈{0,1},i=1,⋯,n}表示nnn维标签空间。训练该多标签分类器的损失函数可以用二元交叉熵函数,该多标签分类器的最后一层为sigmoid\mathrm{sigmoid}sigmoid,多标签分类模型预测的概率向量为p=(p1,p2,⋯,pn)p=(p_1,p_2,\cdots,p_n)p=(p1,p2,⋯,pn),其中pi∈[0,1](i=1,⋯,n)p_i \in [0,1](i=1,\cdots,n)pi∈[0,1](i=1,⋯,n),此时真实标签分布yyy和预测概率分布ppp的二元损失函数为:loss1=−1n∑i=1n[yilogpi+(1−yi)log(1−pi)]\mathrm{loss1}=-\frac{1}{n}\sum\limits_{i=1}^n [y_i \log p_i+(1-y_i)\log(1-p_i)]loss1=−n1i=1∑n[yilogpi+(1−yi)log(1−pi)]
代码实现
针对图像的多标签分类器pytorch的简化代码实现如下所示。因为图像的多标签分类器的数据集比较难获取,所以可以通过对mnist数据集中的每个图片打上特定的多标签,例如类别111的多标签可以为[1,1,0,1,0,1,0,0,1][1,1,0,1,0,1,0,0,1][1,1,0,1,0,1,0,0,1],然后再利用重新打标后的数据集训练出一个mnist的多标签分类器。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import osclass CNN(nn.Module):def __init__(self):super().__init__()self.Sq1 = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2), # (16, 28, 28) # output: (16, 28, 28)nn.ReLU(), nn.MaxPool2d(kernel_size=2), # (16, 14, 14))self.Sq2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), # (32, 14, 14)nn.ReLU(), nn.MaxPool2d(2), # (32, 7, 7))self.out = nn.Linear(32 * 7 * 7, 100) def forward(self, x):x = self.Sq1(x)x = self.Sq2(x)x = x.view(x.size(0), -1) x = self.out(x)## Sigmoid activation output = F.sigmoid(x) # 1/(1+e**(-x))return outputdef loss_fn(pred, target):return -(target * torch.log(pred) + (1 - target) * torch.log(1 - pred)).sum()def multilabel_generate(label):Y1 = F.one_hot(label, num_classes = 100)Y2 = F.one_hot(label+10, num_classes = 100)Y3 = F.one_hot(label+50, num_classes = 100) multilabel = Y1+Y2+Y3return multilabel# def multilabel_generate(label):
# multilabel_dict = {}
# multi_list = []
# for i in range(label.shape[0]):
# multi_list.append(multilabel_dict[label[i].item()])
# multilabel_tensor = torch.tensor(multi_list)
# return multilabeldef train():epoches = 10mnist_net = CNN()mnist_net.train()opitimizer = optim.SGD(mnist_net.parameters(), lr=0.002)mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= 128, shuffle=True)for epoch in range(epoches):loss = 0 for batch_X, batch_Y in train_loader:opitimizer.zero_grad()outputs = mnist_net(batch_X)loss = loss_fn(outputs, multilabel_generate(batch_Y)) / batch_X.shape[0]loss.backward()opitimizer.step()print(loss)if __name__ == '__main__':train()
多标签分类器(附pytorch代码)相关推荐
- mapbox 修改初始位置_一行代码教你如何随心所欲初始化Bert参数(附Pytorch代码详细解读)...
微信公众号:NLP从入门到放弃 微信文章在这里(排版更漂亮,但是内置链接不太行,看大家喜欢哪个点哪个看吧): 一行代码带你随心所欲重新初始化bert的参数(附Pytorch代码详细解读)mp.wei ...
- 聊一聊计算机视觉中常用的注意力机制 附Pytorch代码实现
聊一聊计算机视觉中常用的注意力机制以及Pytorch代码实现 注意力机制(Attention)是深度学习中常用的tricks,可以在模型原有的基础上直接插入,进一步增强你模型的性能.注意力机制起初是作 ...
- python图片自动上色_老旧黑白片修复机——使用卷积神经网络图像自动着色实战(附PyTorch代码)...
摘要: 照片承载了很多人在某个时刻的记忆,尤其是一些老旧的黑白照片,尘封于脑海之中,随着时间的流逝,记忆中对当时颜色的印象也会慢慢消散,这确实有些可惜.技术的发展会解决一些现有的难题,深度学习恰好能够 ...
- ViT结构详解(附pytorch代码)
参考这篇文章,本文会加一些注解. 源自paper: AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE ...
- 给GAN一句描述,它就能按要求画画,微软CVPR新研究 | 附PyTorch代码
晓查 发自 凹非寺 量子位 报道 | 公众号 QbitAI 让AI认得图像,根据自己的理解给出一段叙述,已经不是什么新鲜事了.从图像到文字容易,把这个过程反过来却很难. 让AI画图有了成熟的解决方案 ...
- Transformer 详解(上) — 编码器【附pytorch代码实现】
Transformer 详解(上)编码器 Transformer结构 文本嵌入层 位置编码 注意力机制 编码器之多头注意力机制层 编码器之前馈全连接层 规范化层和残差连接 代码实现Transforme ...
- SegNet学习笔记(附Pytorch 代码)
SegNet 的应用 SegNet常用于图像的语义分割.什么是语义分割了?,我们知道图像分割大致可以划分为三类,一类是语义分割.一类是实例分割,一类是全景分割,另外还有一些可以归为超像素分割.打个比方 ...
- 文本生成客观评价指标总结(附Pytorch代码实现)
前言:最近在做文本生成的工作,调研发现针对不同的文本生成场景(机器翻译.对话生成.图像描述.data-to-text 等),客观评价指标也不尽相同.虽然网络上已经有很多关于文本生成评价指标的文章,本博 ...
- 【2DWT:2维离散小波变换(附Pytorch代码)】
二维离散小波变换 一.相关基础 1.小波变换基础函数 2.小波变换 二.原理 三.基本小波基:哈尔小波 四.代码实现 参考: 图像信号具有非平稳特性,无法使用一种确定的数学模型来描述,而小波变换的多分 ...
最新文章
- 屏幕为什么要正负压供电_负压变换器的设计
- Android开发举步维艰,上弘法寺七七四十九天取得“真经”!
- 大话PHP设计模式:类自动载入、PSR-0规范、链式操作、11种面向对象设计模式实现和使用、OOP的基本原则和自动加载配置...
- 拥抱云原生,Fluid 结合 JindoFS:阿里云 OSS 加速利器
- ansa打开catia文件_关于CATIA文件格式的那些事儿
- [CF/AT/Luogu]各大网站网赛 爆肝部部长工作报告文件Ⅱ
- Java Iterator到Java 8 Stream
- 高颜值在线绘图平台ImageGP系列教程 - 参数介绍
- 一个HTTP请求,把网站打裂开了!
- 小目标神器!TPH-YOLOv5:将Transformer预测加载Yolov5!
- FireWire笔记
- Ubuntu 13.10 用sogou拼音替换ibus-转
- 面向对象设计模式与原则
- Android QFIL 烧录
- 虽迟但到,手眼标定代码实现篇
- Linux源码安装pgadmin4,pgAdmin4 - 搞定安装部署
- 计算机如何连接网络扫描仪,windows系统下怎么共享扫描仪?
- 软硬件全开源,航芯方案分享 | 热敏打印机方案
- 电话号码的字母组合---2022/01/23
- 「链节点活动年度总结」2019年区块链行业会议回顾