1. BCE公式部分

可以简单浏览下这篇博客的文章:
https://blog.csdn.net/qq_14845119/article/details/114121003

这是多分类 经典 BCELossBCELossBCELoss 公式
L=−yL+−(1−y)L−L = -y L_{+} - (1-y) L_{-} L=−yL+​−(1−y)L−​

其中,L+/−L_{+/-}L+/−​ 是正负例预测概率的log值,即:

L+=log(y^)L−=log(1−y^)y^=sigmoid(logit)\begin{aligned} L_{+} &= log( \hat{y} )\\ L_{-} &= log( 1- \hat{y} )\\ \hat{y} &= sigmoid( logit ) \end{aligned} L+​L−​y^​​=log(y^​)=log(1−y^​)=sigmoid(logit)​

实际上由于 labellabellabel 标签 yyy 值,是一个 0/10/10/1 矩阵,实际上充当了一个掩码 maskmaskmask 的作用,挑选出 L+L_{+}L+​ 中正例部分 和 L−L_{-}L−​ 中负例部分

假设:

y=[0010]y = \begin{bmatrix} 0 & 0 \\ 1 & 0 \end{bmatrix} y=[01​00​]

y^=[0.50.10.30.2]L+=[−0.6931−2.3026−1.2040−1.6094]L−=[−0.6931−0.1054−0.3567−0.2231]\hat{y} = \begin{bmatrix} 0.5 & 0.1 \\ 0.3 & 0.2 \end{bmatrix} \ L_{+} = \begin{bmatrix} -0.6931 & -2.3026 \\ -1.2040 & -1.6094 \end{bmatrix} \ L_{-} = \begin{bmatrix} -0.6931 & -0.1054 \\ -0.3567 & -0.2231 \end{bmatrix} y^​=[0.50.3​0.10.2​] L+​=[−0.6931−1.2040​−2.3026−1.6094​] L−​=[−0.6931−0.3567​−0.1054−0.2231​]

所以,LLL 左下角为L+L_{+}L+​对应的值的相反数,左上角和右上角和右下角为L−L_{-}L−​对应的值的相反数

L=[0.69310.10541.20400.2231]L = \begin{bmatrix} 0.6931 & 0.1054 \\ 1.2040 & 0.2231 \end{bmatrix} L=[0.69311.2040​0.10540.2231​]

代码验证:

x = torch.tensor([0.5, 0.1, 0.3, 0.2]).reshape(2, 2).float()
y = torch.tensor([0, 0, 1, 0]).reshape(2, 2).float()
torch.nn.functional.binary_cross_entropy(x, y, reduction='none')
tensor([[0.6931, 0.1054],[1.2040, 0.2231]])

(不要小看这个 mask 代码的操作,一会儿写 asl 代码会用的上)

2. focal loss 公式部分

基本公式依旧是这个:
L=−yL+−(1−y)L−L = -y L_{+} - (1-y) L_{-} L=−yL+​−(1−y)L−​

L+L_{+}L+​ 和 L−L_{-}L−​ 如下:
L+=(1−p)γ∗log(p)L−=pγ∗log(1−p)p=sigmoid(logit)\begin{aligned} L_{+} &= (1-p)^{\gamma} * log(p) \\ L_{-} &= p^{\gamma} * log(1-p) \\ p &= sigmoid(logit) \end{aligned} L+​L−​p​=(1−p)γ∗log(p)=pγ∗log(1−p)=sigmoid(logit)​

3. asl 公式部分

asl loss 是 focal loss的改进版

L+=(1−p)γ+∗log(p)L−=pmγ−∗log(1−pm)p=sigmoid(logit)pm=max(p−m,0)\begin{aligned} L_{+} &= (1-p)^{\gamma_{+}} &*& log(p) \\ L_{-} &= p_m^{\gamma_{-}} &*& log(1-p_m) \\ p &= sigmoid(logit) \\ p_m &= max(p-m, 0) \end{aligned} L+​L−​ppm​​=(1−p)γ+​=pmγ−​​=sigmoid(logit)=max(p−m,0)​∗∗log(p)log(1−pm​)

由于 pmp_mpm​ 仅在 L−L_{-}L−​ 中存在,而ppp一般出现在L+L_{+}L+​中,(1−p)(1-p)(1−p)一般出现在L−L_{-}L−​中,所以将 pmp_mpm​ 做一些反向操作

先引入一个引理,显然成立,x和y都是函数(或者变量),二者中大的加上负号,就是二者相反数中小的

−max(x,y)==min(−x,−y)-max(x, y) == min(-x, -y) −max(x,y)==min(−x,−y)

所以:
pm=max(p−m,0)=−min(m−p,0)−pm=min(m−p,0)1−pm=min(m−p,0)+11−pm=min(m−p+1,1)1−pm=min(m+1−p,1)1−pm=np.clip(m+1−p,max=1)\begin{aligned} p_m &= max(p-m, 0) \\ &= -min(m-p, 0) \\ -p_m &= min(m-p, 0) \\ 1-p_m &= min(m-p, 0) + 1 \\ 1-p_m &= min(m-p+ 1, 1) \\ 1-p_m &= min(m+ 1-p, 1) \\ 1-p_m &= np.clip(m+ 1-p, max=1) \\ \end{aligned} pm​−pm​1−pm​1−pm​1−pm​1−pm​​=max(p−m,0)=−min(m−p,0)=min(m−p,0)=min(m−p,0)+1=min(m−p+1,1)=min(m+1−p,1)=np.clip(m+1−p,max=1)​

这一行咱等会要用到

4. asl 代码

看看 asl loss 的代码,torch代码来自:
https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py

  • self.gamma_neg 是 γ−\gamma_{-}γ−​
  • self.gamma_pos 是 γ+\gamma_{+}γ+​
  • self.eps 是用作 log 函数内部,防止溢出
class AsymmetricLossOptimized(nn.Module):''' Notice - optimized version, minimizes memory allocation and gpu uploading,favors inplace operations'''def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):super(AsymmetricLossOptimized, self).__init__()self.gamma_neg = gamma_negself.gamma_pos = gamma_posself.clip = clipself.disable_torch_grad_focal_loss = disable_torch_grad_focal_lossself.eps = eps# prevent memory allocation and gpu uploading every iteration, and encourages inplace operationsself.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = Nonedef forward(self, x, y):""""Parameters----------x: input logitsy: targets (multi-label binarized vector)"""self.targets = yself.anti_targets = 1 - y# 分别计算正负例的概率self.xs_pos = torch.sigmoid(x)self.xs_neg = 1.0 - self.xs_pos# 非对称裁剪if self.clip is not None and self.clip > 0:self.xs_neg.add_(self.clip).clamp_(max=1)  # 给 self.xs_neg 加上 clip 值# 先进行基本交叉熵计算self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))# Asymmetric Focusingif self.gamma_neg > 0 or self.gamma_pos > 0:if self.disable_torch_grad_focal_loss:torch.set_grad_enabled(False)# 以下 4 行相当于做了个并行操作self.xs_pos = self.xs_pos * self.targetsself.xs_neg = self.xs_neg * self.anti_targetsself.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)if self.disable_torch_grad_focal_loss:torch.set_grad_enabled(True)self.loss *= self.asymmetric_wreturn -self.loss.sum()

来咱单独看一下代码:

# 非对称裁剪
if self.clip is not None and self.clip > 0:self.xs_neg.add_(self.clip).clamp_(max=1)  # 给 self.xs_neg 加上 clip 值

这两行用于计算:
1−pm=np.clip(m+1−p,max=1)\begin{aligned} 1-p_m &= np.clip(m+ 1-p, max=1) \end{aligned} 1−pm​​=np.clip(m+1−p,max=1)​

# 先进行基本交叉熵计算
self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))

这两行用于计算红框部分:

注意 self.targetsself.anti_targets 都相当于掩码 mask 的作用,此处的 self.loss 矩阵的shape是和 self.targets 一样的 shape,不理解可以回忆一下 BCE公式部分 的计算

而前面的 幂 相当于权重,就是代码中的 self.asymmetric_w,也就是此处的:

self.asymmetric_w 是这样计算的,这部分很妙!

self.xs_pos = self.xs_pos * self.targets
self.xs_neg = self.xs_neg * self.anti_targets
self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)

插一句 torch.pow 该函数会将两个shape相同的张量的对应位置做幂运算,看这个例子

>>> x = torch.tensor([1, 2, 3, 4])
>>> y = torch.tensor([2, 2, 3, 1])
>>> torch.pow(x, y)
tensor([ 1,  4, 27,  4])

计算 self.asymmetric_w 时,只需将pow的 xxx 参数对应位置写成 (1−p)(1-p)(1−p) 或者 pmp_mpm​,将pow的 yyy 参数对应位置写成 γ−\gamma_{-}γ−​ 或者 γ+\gamma_{+}γ+​ 即可,先看简单的,yyy 参数这里计算:

self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets

也是通过 self.targets 的 mask 操作来进行的,而 xxx 参数这样计算:

1 - self.xs_pos - self.xs_neg

当计算 L+L_{+}L+​ 时,self.xs_neg==0,xxx 参数对应位置就是 1 - self.xs_pos(1-p)
当计算 L−L_{-}L−​ 时,self.xs_pos==0,xxx 参数对应位置就是 1 - self.xs_neg 即 (1−(1−pm))=pm(1-(1-p_m))=p_m(1−(1−pm​))=pm​

通过一个 torch.pow 巧妙的计算了 self.asymmetric_w NICE!

之后二者对应位置相乘即可

self.loss *= self.asymmetric_w

5. asl 代码 Paddle 实现

class AsymmetricLossOptimizedWithLogit(nn.Layer):''' Notice - optimized version, minimizes memory allocation and gpu uploading,favors inplace operations'''def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-5, disable_paddle_grad_focal_loss=False):super(AsymmetricLossOptimizedWithLogit, self).__init__()self.gamma_neg = gamma_negself.gamma_pos = gamma_posself.clip = clipself.disable_paddle_grad_focal_loss = disable_paddle_grad_focal_lossself.eps = epsself.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = Nonedef forward(self, x, y, weights=None):""""Parameters----------x: input logitsy: targets (multi-label binarized vector)"""self.targets = yself.anti_targets = 1 - y# Calculating Probabilitiesself.xs_pos = F.sigmoid(x)self.xs_neg = 1.0 - self.xs_pos# Asymmetric Clippingif self.clip is not None and self.clip > 0:# self.xs_neg.add_(self.clip).clip_(max=1)self.xs_neg = (self.xs_neg + self.clip).clip_(max=1)# Basic CE calculationself.loss = self.targets * paddle.log(self.xs_pos.clip(min=self.eps))self.loss.add_(self.anti_targets * paddle.log(self.xs_neg.clip(min=self.eps)))# Asymmetric Focusingif self.gamma_neg > 0 or self.gamma_pos > 0:if self.disable_paddle_grad_focal_loss:paddle.set_grad_enabled(False)self.xs_pos = self.xs_pos * self.targetsself.xs_neg = self.xs_neg * self.anti_targetsself.asymmetric_w = paddle.pow(1 - self.xs_pos - self.xs_neg,(self.gamma_pos * self.targets + \self.gamma_neg * self.anti_targets).astype("float32"))if self.disable_paddle_grad_focal_loss:paddle.set_grad_enabled(True)self.loss *= self.asymmetric_wif weights is not None:self.loss *= weights_loss = -self.loss.sum()return _lossif __name__ == "__main__":np.random.seed(11070109)x = np.random.randn(3, 3)x = paddle.to_tensor(x).cast("float32")y = (x > 0.5).cast("float32")loss = AsymmetricLossOptimizedWithLogit()out = loss(x, y)

loss盘点: asl loss (Asymmetric Loss) 代码解析详细版相关推荐

  1. Focal Loss for Dense Object Detection(RetinaNet)(代码解析)

    转自:https://www.jianshu.com/p/db4ccd194109 转载于:https://www.cnblogs.com/leebxo/p/10485740.html

  2. 巨人通力电梯服务器显示,巨人通力电梯的所有故障代码大全[详细版]

    通力电梯故障代码详解.通力电梯的控制系统可监测到电梯电气系统的常见基本故障,对于监测到的故障,可通过LCECPU板上的显示窗口以故障代码的数字形式显示出来.电梯的控制系统的NVRAM可同时存储99个常 ...

  3. openCV4.0 C++ 快速入门30讲学习笔记(自用 代码+注释)详细版

    课程来源:哔哩哔哩 环境:OpenCV4.5.1 + VS2019 目录 002.图像色彩空间转换 003.图像对象的创建与赋值 004.图像像素的读写操作 005.图像像素的算术操作(加减乘除4种不 ...

  4. 贪吃蛇分析和代码(详细版)

    贪吃蛇分析: 1, 构造蛇移动的地图====>PC端的游戏 浏览器上运行程序(JS) (1)使用div 创建地图元素 (2)添加到body中 (3)设置地图的元素的样式:js实现 (4)调用地图 ...

  5. 多表连接查询详细解析(详细版)

    前言:写SQL语句:(先找出表的数据)找哪2个表,什么字段,过滤条件.(最后思考过滤条件,过滤条件肯定是2张表都有关联的字段,不一定是同名的字段,思考:表1哪个字段和表2哪个字段有关联) 1.为什么要 ...

  6. circle loss代码实现_CenterNet之loss计算代码解析

    [GiantPandaCV导语] 本文主要讲解CenterNet的loss,由偏置部分(reg loss).热图部分(heatmap loss).宽高(wh loss)部分三部分loss组成,附代码实 ...

  7. 【对比学习】CUT模型论文解读与NCE loss代码解析

    标题:Contrastive Learning for Unpaired Image-to-Image Translation(基于对比学习的非配对图像转换) 作者:Taesung Park, Ale ...

  8. pytorch代码解析:loss = y_hat - y.view(y_hat.size())

    pytorch代码解析:pytorch中loss = y_hat - y.view(y_hat.size()) import torchy_hat = torch.tensor([[-0.0044], ...

  9. Generalized Focal Loss 原理与代码解析

    Paper:Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Det ...

最新文章

  1. ubuntu安装QQ
  2. 业界丨2018深度学习十大趋势:元学习成新SGD,多数硬件创企将失败
  3. 再见,工资!程序员工资统计平均14404元,网友:又跌了!
  4. 决战 平安京服务器维护,《决战!平安京》2018年9月7日维护公告
  5. php获取总共内存_php获取页面运行使用内存的两个函数
  6. 如果可能我们还是做好基础的事情吧
  7. LinkedHashMap 的理解以及借助其实现LRU
  8. Servlet使用适配器模式进行增删改查案例(EmpServiceImpl.java)
  9. C++ 11 新特性(十二)函数新特性、内联函数、const详解
  10. 《MySQL——Innodb改进LRU算法》
  11. 计算机英语讲课笔记02
  12. windows 防火墙疑难解答程序_Win8系统设置允许程序通过防火墙的方法
  13. 你赞同企业年薪百万的高管对员工说别羡慕赚的多,人家加班和付出的时候你在玩的说法吗?
  14. jquery中获得table中第几个td元素的值
  15. 【图像去噪】基于matlab鲁棒PCA图像去噪【含Matlab源码 463期】
  16. 使用WebService获取第三方服务数据
  17. 如何长时间高效学习?
  18. 软件测试实战(微软技术专家经验总结)--第九、十章(团队工作、个人管理)读书笔记
  19. 13-新手小白如何选购笔记本电脑?
  20. antd表格分页控件显示英文page

热门文章

  1. java 基础 ppt_Java基础培训课件.ppt
  2. u盘格式FAT32转NTFS
  3. (第七集——第一章)python面向对象
  4. 求职面试、指点迷津各类经验汇总
  5. 《Wir wilden weisen Frauen》翻译——连载
  6. 记一次蓝牙故障:蓝牙不见了或设备管理器里蓝牙设备不停的在刷新
  7. MTBF需要测试多久,MTBF失效率是多少
  8. 想做点培训教师了,和别人一起干吧
  9. android 音频转码慢,适用于Android的最佳音频和视频转码抑制软件
  10. 荧光定量PCR检测法的原理和应用领域