多任务学习(MTL) --- 知识小结+实现
What
- 通常一个模型训练时有多个目标函数loss同时训练就可以叫多任务学习,预测时输出多个结果的模型就是多任务模型
Why
- 工业界实际应用时维护单个模型比同时维护k个模型更方便,成本更低
- 提高泛化性能
How
思路1:手工加权平均
- 基本思想:对于多任务的loss,最简单的方式是直接将loss函数对于每个任务的loss进行加权
- 这种方式手工设置权重,模型性能对权重的选择非常敏感,而且loss的权重作为超参数进行调参很不方便,更好的加权方式应该是自适应动态调整的
思路2:动态加权平均
基本思想:不同任务难易程度不同,学习速度不同,对于这点可以针对不同任务设置不同学习率,但是更好的思路是动态调整让各个任务以相近的速度学习,这就是DWA(Dynamic Weight Averaging — 动态加权平均)算法的核心思想
loss下降快的任务,则权重会变小;反之权重会变大
思路3:动态任务优先级
基本思想:难学的任务给予更高的权重
KPI高的任务,学习起来比较简单,则权重会变小;反之,难学的任务权重会变大
思路4:不确定性加权方法
基本思想:难学的任务给予更小的权重使得整体的多任务模型的训练更加顺畅和有效(和思路3相反。。)
前提概念:认知不确定性和偶然不确定性
- 认知不确定性(epistemic):指的是由于缺少数据导致的认知偏差。当数据很少的时候,训练数据提供的样本分布很难代表数据全局的分布,导致模型训练学偏。这种不确定性可以通过增加数据来改善。
- 偶然不确定性(aleatoric):指的是由于数据本身,或者任务本身带来的认知偏差。偶然不确定性有个特点,其不会随着数据量增加而改善结果,数据即使增加,偏差仍然存在。
- 偶然不确定性可以分为两种情况:
- 数据依赖型或异方差。在进行数据标注的时候的误标记、错标记等,这些错误的数据也会造成模型预测偏差;
- 任务依赖型或同方差。这个指的是,同一份数据,对于不同的任务可能会导致不同的偏差
这种思路希望基于偶然不确定性(aleatoric)中的同方差不确定性来进行建模,以两个任务为例,最终推导后的loss函数:
- 其中,sigma1和sigma2是两个任务中,各自存在的不确定性
- sigma越大,任务的不确定性越大,则任务的权重越小,即噪声大且难学的任务权重会变小,简单的任务权重变大
原论文《Multi-task learning using uncertainty to weigh losses for scene geometry and semantics》
- github上用pytorch的实现:https://github.com/Mikoto10032/AutomaticWeightedLoss
import torch import torch.nn as nnclass AutomaticWeightedLoss(nn.Module):"""automatically weighted multi-task lossParams:num: int,the number of lossx: multi-task lossExamples:loss1=1loss2=2awl = AutomaticWeightedLoss(2)loss_sum = awl(loss1, loss2)"""def __init__(self, num=2):super(AutomaticWeightedLoss, self).__init__()params = torch.ones(num, requires_grad=True)self.params = torch.nn.Parameter(params)def forward(self, *x):loss_sum = 0for i, loss in enumerate(x):loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2)return loss_sumif __name__ == '__main__':awl = AutomaticWeightedLoss(2)print(awl.parameters())
- 应用示例:
from torch import optim from AutomaticWeightedLoss import AutomaticWeightedLossmodel = Model()awl = AutomaticWeightedLoss(2) # we have 2 losses loss_1 = ... loss_2 = ...# learnable parameters optimizer = optim.Adam([{'params': model.parameters()},{'params': awl.parameters(), 'weight_decay': 0}])for i in range(epoch):for data, label1, label2 in data_loader:# forwardpred1, pred2 = Model(data) # calculate lossesloss1 = loss_1(pred1, label1)loss2 = loss_2(pred2, label2)# weigh lossesloss_sum = awl(loss1, loss2)# backwardoptimizer.zero_grad()loss_sum.backward()optimizer.step()
多任务学习(MTL) --- 知识小结+实现相关推荐
- 【推荐系统多任务学习 MTL】PLE论文精读笔记(含代码实现)
论文地址: Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized ...
- 多任务学习(MTL)在转化率预估上的应用
今天主要和大家聊聊多任务学习在转化率预估上的应用. 多任务学习(Multi-task learning,MTL)是机器学习中的一个重要领域,其目标是利用多个学习任务中所包含的有用信息来帮助每个任务学习 ...
- 【推荐系统多任务学习MTL】ESMM 论文精读笔记(含代码实现)
论文地址:Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate ...
- 【推荐系统多任务学习MTL】MMoE论文精读笔记(含代码实现)
论文地址: Google KDD 2018 MMOE (内含论文官方讲解视频) PDF Modeling Task Relationships in Multi-task Learning with ...
- 深度学习面试知识小结
参考了一部分别人的,自己也整理了一点,持续更新中- 计算机视觉算法岗面试题:在这里. CNN结构特点 1.局部连接使网络可以提取数据的局部特征 2.首先权值共享就是滤波器共享,即是用相同的滤波器去扫一 ...
- DNN中多任务学习概述
为什么大公司搜索推荐都用CTR/CVR Cotrain的框架? 一元@炼丹笔记 我们平时做项目/竞赛的时候,一般都是单指标优化的问题,很多时候我们模型的评估指标也是单个指标,例如AUC, GAUC, ...
- 我们如何在Pinterest Ads中使用AutoML,多任务学习和多塔模型
Ernest Wang | Software Engineer, Ads Ranking 欧内斯特·王| 软件工程师,广告排名 People come to Pinterest in an explo ...
- 【阅读笔记】多任务学习之PLE(含代码实现)
本文作为自己阅读论文后的总结和思考,不涉及论文翻译和模型解读,适合大家阅读完论文后交流想法. PLE 一. 全文总结 二. 研究方法 三. 结论 四. 创新点 五. 思考 六. 参考文献 七. Pyt ...
- 脉络梳理:推荐系统中的多任务学习
© 作者|杨晨 机构|中国人民大学 研究方向|推荐系统 本文聚焦推荐系统中的一个研究方向 -- Multi-Task Recommendation,整理近五年内的研究工作,进行分类总结,并针对22年最 ...
- 【机器学习基础】一文看透多任务学习
作者:十方 大家在做模型的时候,往往关注一个特定指标的优化,如做点击率模型,就优化AUC,做二分类模型,就优化f-score.然而,这样忽视了模型通过学习其他任务所能带来的信息增益和效果上的提升.通过 ...
最新文章
- 每天坚持一个CSS——社会人
- 2020年最新前端学习路线
- 辽宁省2021年高考成绩位次查询,辽宁2021八省联考分数、位次表(非官方),附志愿填报样表...
- 像Excel一样使用python进行数据分析(1)
- file 选择的图片作为背景图片_酷炫!用Python把桌面变成实时更新的地球图片
- 十二、Python第十二课——函数
- Java的反射(二)
- 入门嵌入式HTML/CSS/脚本引擎 sciter(问题篇)
- 模式识别算法中英文对照
- ORBSLAM3整体框架
- 苹果 ios mdm服务器搭建
- 白苹果了怎么办_iOS更新白苹果处理及第三方售后吐槽
- 洛谷 P2342 叠积木 题解
- 安装虚拟机(VMware)保姆级教程(附安装包)
- 【java.lang.ref】当WeakReference的referent重写了finalize方法时会发生什么
- 【分享】SBO初始化的过程及内容
- 计算机网络常见的协议之ARP协议
- python数据分析-面试题
- 读书笔记---Head First 设计模式--- 装饰者模式
- 活灵活现用Git-基础篇