多标签分类中存在类别不平衡的问题,想要尝试用focalloss损失函数,但是网上很少有多标签分类的损失函数设计,终于在kaggle上别人做的keras下的focalloss中举例了多标签问题:
Focalloss for Keras
代码和例子如下:

Focal loss主要思想是这样:在数据集中,很自然的有些样本是很容易分类的,而有些是比较难分类的。在训练过程中,这些容易分类的样本的准确率可以达到99%,而那些难分类的样本的准确率则很差。问题就在于,那些容易分类的样本仍然在贡献着loss,那我们为什么要给所有的样本同样的权值?
这正是Focal loss要解决的问题。focal loss减小了正确分类的样本的权值,而不是给所有的样本同样的权值。这和给与训练样本更多的难分类样本时一样的效果。在实际中,当我们有数据不均衡的情况时,我们的多数的类别很快的会训练的很好,分类准确率很高,因为我们有更多的数据。但是,为了确保我们在少数类别上也能有很好的准确率,我们使用focal loss,给与少数类别的样本更高的权值。focal loss使用Keras是很容易实现的:

from keras import backend as K
import tensorflow as tfdef KerasFocalLoss(target, input):gamma = 2.input = tf.cast(input, tf.float32)max_val = K.clip(-input, 0, 1)loss = input - input * target + max_val + K.log(K.exp(-max_val) + K.exp(-input - max_val))invprobs = tf.log_sigmoid(-input * (target * 2.0 - 1.0))loss = K.exp(invprobs * gamma) * lossreturn K.mean(K.sum(loss, axis=1))# 这个例子就是多标签的情况
Y_true = np.array([[0, 1, 0, 1, 0], [1, 0, 1, 0, 1]])
Y_pred = np.array([[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1]], dtype=np.float32)
print(K.eval(KerasFocalLoss(Y_true, Y_pred)))

当然文章里面还拓展了一下pytorch的focalloss设计,这里不搬运了。
这个代码里面没有加入alpha权重,我想实现一下:


使用FL可以改善单个类别中正负样本的不平衡,提高难挖掘样本的分类精度,计算每个类别的权重W。使阳性病例数越多的疾病,权重越小,以平衡各类别之间的数据倾斜程度。因此,使用W和FL的乘积作为最终的损失函数,可以同时解决单类正负样本不平衡和多类数据偏斜的问题

#多标签版本
from tensorflow.keras import backend as K
import tensorflow as tf
import numpy as npdef KerasFocalLoss(target, input, Num):#Num为样本数量gamma = 2.input = tf.cast(input, tf.float32)max_val = K.clip(-input, 0, 1)loss = input - input * target + max_val + K.log(K.exp(-max_val) + K.exp(-input - max_val))invprobs = tf.compat.v1.log_sigmoid(-input * (target * 2.0 - 1.0))loss = K.exp(invprobs * gamma) * lossW = 1/np.log(Num)W = tf.cast(W, tf.float32)we_loss = tf.compat.v1.matmul(loss, W)return K.mean(K.sum(we_loss, axis=1))

多标签分类的Focal loss设计相关推荐

  1. pytorch多标签分类交叉熵loss

    import torch import numpy as np def multilabel_categorical_crossentropy(y_true, y_pred):"" ...

  2. 从loss的硬截断、软化到Focal Loss

    对于二分类模型,我们总希望模型能够给正样本输出1,负样本输出0,但限于模型的拟合能力等问题,一般来说做不到这一点.而事实上在预测中,我们也是认为大于0.5的就是正样本了,小于0.5的就是负样本.这样就 ...

  3. (HEM/OHEM)hard negative(example)mining难例挖掘 与focal loss、GHM损失函数

    目录 分类任务中的样本不均衡及hard negative mining的必要性 hard negative example HEM(hard example/negative mining) 与 OH ...

  4. RetinaNet论文详解Focal Loss for Dense Object Detection

    一.论文相关信息 ​ 1.论文题目:Focal Loss for Dense Object Detection ​ 2.发表时间:2017 ​ 3.文献地址:https://arxiv.org/pdf ...

  5. CE Loss,BCE Loss以及Focal Loss的原理理解

    一.交叉熵损失函数(CE Loss,BCE Loss) 最开始理解交叉熵损失函数被自己搞的晕头转向的,最后发现是对随机变量的理解有偏差,不知道有没有读者和我有着一样的困惑,所以在本文开始之前,先介绍一 ...

  6. RetinaNet+focal loss

    one stage 精度不高,一个主要原因是正负样本的不平衡,以YOLO为例,每个grid cell有5个预测,本来正负样本的数量就有差距,再相当于进行5倍放大后,这种数量上的差异更会被放大. 文中提 ...

  7. 多标签分类之非对称损失-Asymmetric Loss

    论文:Asymmetric Loss For Multi-Label Classification GitHub:https://github.com/Alibaba-MIIL/ASL https:/ ...

  8. 【CV】RetinaNet:使用二分类类别不平衡损失 Focal Loss 实现更好的目标检测

    论文名称:Focal Loss for Dense Object Detection 论文下载:https://arxiv.org/abs/1610.02357 论文年份:ICCV 2017 论文被引 ...

  9. Focal Loss 分类问题 pytorch实现代码(简单实现)

    ps:由于降阳性这步正负样本数量在差距巨大.正样本1500多个,而负样本750000多个.要用 Focal Loss来解决这个问题. 首先感谢Code_Mart的博客把理论汇总了下https://bl ...

最新文章

  1. 【效率】推荐10个堪称神器的网站!
  2. springboot中的mybatis是如果使用pagehelper的
  3. 【Java正则表达式】正则基本语法、使用方式(分组、替换、分割)、简单爬虫基础
  4. wdcp查看mysql日志_查看修改服务器中的WDCP数据库操作记录
  5. 如何设置MySQL的环境变量
  6. 和宝塔可以同时安装吗_服用钙拮抗剂可以同时补钙吗
  7. 以太坊创世区块源码分析
  8. 掌握这个套路,让你的可视化大屏万里挑一
  9. HDOJ 1755 - A Number Puzzle 排列数字凑同余,状态压缩DP
  10. 执行oracle 函数,oracle 函数function语法及简单实例
  11. 数字技术加持 华为云为测绘地理信息产业夯实“云底座”
  12. A+B,氵题一道,84种解法!大佬羡慕
  13. sap服务器之间文件复制,sap跨服务器客户端复制
  14. Easypack之Alpine容器系列:Redmine
  15. 一起来了解一下FIFO!
  16. 只有在细细品读她的作品的时候,我才找到久违的宁静
  17. Google Play 之 deviceId
  18. 知乎:面朝大海,春暖花开!
  19. java毕业设计-大学生实习管理系统 实习申请系统【附源码+文档】
  20. 2023最新短视频去水印解析API接口开发文档

热门文章

  1. 100G的文件如何读取 - 第306篇
  2. 卡特兰数 卡塔兰数 概念 代码实现 模型分析全集
  3. 明确市场定位让软文营销从针对性出发
  4. Linux总线pice错误,PCIe总线错误严重性=已更正
  5. 服务器或者docker容器中安装pip
  6. C语言中的int8_t,uint8_t, int16_t,uint16_t, int32_t,uint32_t, int64_t,uint64_t和int数组,char数组以及sizeof()的理解
  7. 元素的 tabIndex 属性
  8. 详谈高大上的图片加载框架Glide
  9. @synchronized和NSLock产生死锁场景
  10. window program try