常见的损失函数,如交叉熵损失、平方误差损失、Hinge损失等并不是本文的重点,关于这些损失函数的介绍网上很多,可以参考如下几篇文章

  • 机器学习中的 7 大损失函数实战总结
  • 常见的损失函数(loss function)总结
  • 机器学习算法及其损失函数
  • 损失函数loss大大总结

本文的重点在于总结损失函数在应用上的一些trick和技巧

单任务

相对于多任务学习而言,我们常见的模型大部分都是单任务学习,只有一个学习目标(loss)

一、标签平滑(Label Smoothing)

1. 什么叫标签平滑

正常模型的正样本标签为1,负样本标签为0,这是一种hard的学习,这样的模型可以叫做过于自信的模型。标签平滑就是将正样本的标签为0.9,负样本的标签为0.1,这是一种soft的学习,也就是让告诉模型,不要这么自信。当然标签平滑的程度可以根据情况修改。

2. 标签平滑的作用

在论文When Does Label Smoothing Help?中,作者说明标签平滑可以提高神经网络的鲁棒性和泛化能力。

3. 标签平滑的禁忌

在知识蒸馏中的教师网络中,采用标签平滑的话会影响效果,导致教师网络无法有效传递知识。

二、样本不平衡

  • 样本不平衡在真实业务场景普遍存在,在数据角度可以通过过采样和抽样等方法来缓解。
  • 在模型角度,可以通过损失函数的设计进一步缓解样本不平衡问题,以cross entropy 为例

普通的ce损失:

通过一个参数alpha来调节ce损失:

通过alpha来增大正样本的权重,缓解正样本少的问题。

三、难易不平衡

在样本中,必然存在一些样本容易学习(特征鲜明),一些样本较难学习。如果易分样本占据多数,那么损失函数就会被容易样本主导,对难分样本的学习能力较弱,从而影响模型的学习效果。

1. Focal Loss

假设:易分样本(置信度高的样本)对模型的提升效果非常小,模型应该主要关注与那些难分样本
focal loss的思想很简单:降低易分样本的损失,提高难分样本的损失:

假设: lambda取2时,p=0.968,那么(1-p)2=0.001,也就是易分样本的损失衰减了1000倍。

2. GHM(gradient harmonizing mechanism)

首先Focal Loss存在很多问题 - 关注难分样本:如果样本中存在离群点,Focal Loss过分关注离群点,那么模型就跑偏了。 - labmda是超参数,全靠经验

GHM不是根据样本的难易程度来进行衰减,而是根据(一定梯度内)的样本数量进行衰减,也就是谁的样本数量多,就衰减谁。 - 首先定义梯度模长g:g=|p−p*|,那么g是怎么来的呢,和梯度有什么关系呢?

  • g正比于检测的难易程度,g越大则检测难度越大,可以看到,其实g就是梯度的绝对值(模长)。
  • 下图是不同梯度模长g的样本数量分布
  • 最左侧样本是易分样本,最右侧是难分样本,两者占比都较大,因此都要进行相应的衰减。
  • 具体衰减方式:将g分桶,统计数量,之后定义梯度密度GD(g),根据GD的倒数来进行衰减。具体可以参考论文

多任务

一、总结

多任务损失函数如何确定,以及如何进行融合,一直都是一个难点。
此处主要参考Deep Multi-Task Learning – 3 Lessons Learned by Zohar Komarovsky of taboola,总结3点如下:

  • 多个损失函数直接融合:简单求和与加权求和比较常用

    • 简单求和(简单)
    • 加权求和(权重难确定)
    • Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics通过不确定性(uncertainty)来调整损失函数对应的权重。
  • 多个任务通过神经网络连接:任务A的输出作为任务B的输入特征 - 前向传播:比较简单,和普通网络一样 - 后向传播:B的梯度不向A传递。
  • 不同loss对应不同learning rate
    • 不同的任务往往需要的learning rate的数量级不一样
    • 以relu为例,learning rate较大则会出现dying relu问题,learning rate较小的话就会导致收敛很慢。
    • 多任务中,对不同的损失函数设置不同的learning rate 可以一定程度缓解。

二、举例

  1. Deep Interest Evolution Network for Click-Through Rate Prediction:
    这是阿里的一篇ctr论文,这里并不是在最后的输出层增加loss,而是在网络中间部分的GRU中引入辅助loss。
  • 损失函数融合:L = Ltarget + α ∗ Laux,超参数加权的方式

2. Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate

这是一篇阿里的文章,标准的multi-task,两个网络分别对应不同的loss

  • 文章根据cvr和ctr的关系,定义了post-view click- through&conversion rate (CTCVR)
  • 不直接优化cvr,而是针对ctr和ctcvr的multi loss 进行优化
  • ctr和ctcvr简单求和的方式

其它
一、负采样 negative sample

解决神经网络输出层计算量太大问题,修改损失函数同时采样少部分负样本

二、sampled softmax

将样本分为几份,每份样本内部进行softmax损失,这样可以一定程度减少复杂度,但是缺点是serving阶段还是要*全量的softmax*

未完待续。。。

yolo-mask的损失函数l包含三部分_损失函数总结-应用和trick相关推荐

  1. yolo-mask的损失函数l包含三部分_【AI初识境】深度学习中常用的损失函数有哪些?...

    这是专栏<AI初识境>的第11篇文章.所谓初识,就是对相关技术有基本了解,掌握了基本的使用方法. 今天来说说深度学习中常见的损失函数(loss),覆盖分类,回归任务以及生成对抗网络,有了目 ...

  2. 损失函数的意义和作用_损失函数的可视化:浅论模型的参数空间与正则

    点击蓝字  关注我们 作者丨土豆@知乎来源丨https://zhuanlan.zhihu.com/p/158857128本文已获授权,不得二次转载 前言 在深度学习中,我们总是不可避免会碰到各种各样的 ...

  3. 使用数组操作解码YOLO Core ML对象检测(三)

    目录 介绍 解码YOLO输出的正确方式 下一步 总目录 将ONNX对象检测模型转换为iOS Core ML(一) 解码Core ML YOLO对象检测器(二) 使用数组操作解码YOLO Core ML ...

  4. python使用matplotlib可视化3D柱状图(3D histogram、三维柱状图、包含三个坐标轴x、y、z)、设置zdir参数为z、改变3d图观察的角度

    python使用matplotlib可视化3D柱状图(3D histogram.三维柱状图.包含三个坐标轴x.y.z).设置zdir参数为z.改变3d图观察的角度 目录

  5. python使用matplotlib可视化3D柱状图(3D bar plot、三维柱状图、包含三个坐标轴x、y、z)、设置zdir参数为y、改变3d图观察的角度

    python使用matplotlib可视化3D柱状图(3D bar plot.三维柱状图.包含三个坐标轴x.y.z).设置zdir参数为y.改变3d图观察的角度 目录

  6. python使用matplotlib可视化3D直方图(3D histogram、三维直方图、包含三个坐标轴x、y、z)、3D直方图可视化多个维度数据的区别和联系

    python使用matplotlib可视化3D直方图(3D histogram.三维直方图.包含三个坐标轴x.y.z).3D直方图可视化多个维度数据的区别和联系 目录

  7. 实现一个行内三个div等分_一个div,包含三个小的div,平均分布的样式

    从11月份开始,自学前端开发,写静态页面中,经常用到一个大的div下包含三个小的div,平均分布div大小样式,写过多次,也多次忘记,每次都要现找资料,不想之后,在这么麻烦,索性今天自己记录一下,方便 ...

  8. 编写一个制造各种车辆的程序。包含三个类,具体要求如下: (1)基类Vehicle,包含轮子数和汽车自身重量两个属性,一个两参数的构造方法,一个显示汽车信息的方法; (2)小轿车类Car,增加载客数属性

    一.题目描述 编写一个制造各种车辆的程序.包含三个类,具体要求如下: (1)基类Vehicle,包含轮子数和汽车自身重量两个属性,一个两参数的构造方法,一个显示汽车信息的方法: (2)小轿车类Car, ...

  9. 百度地图导航的接入(包含三种选择方式驾车、公交、步行)

    百度地图导航的接入(包含三种选择方式驾车.公交.步行) 步骤 1.下载百度的sdk(下载地址:http://lbsyun.baidu.com/sdk/download) 勾选下载"检索功能& ...

最新文章

  1. C#获取邮件客户端保存的邮箱密码
  2. 重置样式表--HTML
  3. 学习笔记(3.23)
  4. 高考考入北大与普通大学考研进北大,有区别吗?
  5. pdo mysql 和 mysqli_PHP中MySQL、MySQLi和PDO的用法和区别
  6. OJ1020: 两整数排序
  7. Anaconda中出现No module named cv2
  8. sql azure 语法_Azure Data Studio中SQL代码段
  9. 计算机在生活中应用视频,计算机在腐蚀防护中的应用教学视频
  10. 读书笔记 之《Thinking in Java》(对象、集合、异常)
  11. 如何求matlab的in(2.0375),东南大学Matlab作业1.doc
  12. opencv-python版本问题
  13. 订单可视化2实战-生产交付流程(流程再造核心区)
  14. 抑郁焦虑测试软件可信度,做题自测抑郁症可靠吗
  15. BI 及其相关技术概览
  16. python控制多个屏幕_多设备控制 + 屏幕操作录制重放 实现完整多设备测试流程...
  17. 基于部标1078视频协议和苏标Adas协议构建主动平台
  18. 智能定位系统实验报告
  19. 数字三角形 (15 分)
  20. 浏览器F12(开发者调试工具) 功能介绍

热门文章

  1. 两张图看懂GC 日志
  2. msql 数据去重,仅保留一条
  3. 谈谈 Swift 中的 map 和 flatMap
  4. 自动化测试测试工具 AirTest 的使用方法与简介
  5. 【Python3-OpenCV】实现实时摄像头人脸检测
  6. 兴趣部落老是显示无法连接服务器失败,qq兴趣部落为什么停运
  7. .net 获取字符串中的第一个逗号的位置_用EXCEL合并同列字符串
  8. html word-wrap,CSS3 Word-wrap
  9. mcs 4微型计算机,MCS-II高性能自主品牌的微机测速仪
  10. 树莓派上传文件到服务器,05_树莓派图片定时上传到服务器