论文链接:https://arxiv.org/pdf/1711.02257.pdf

之前讲过了多任务学习,如简单的shared bottom,都存在一个问题:多个任务的loss如何融合?简单的方式,就是将多个任务的loss直接相加:

但实际情况是,不同任务loss梯度的量级不同,造成有的task在梯度反向传播中占主导地位,模型过分学习该任务而忽视其它任务。此外,不同任务收敛速度不一致的,可能导致有些任务还处于欠拟合,可有些任务已经过拟合了。当然,我们可以人工的设置超参数,如:

由于各任务在训练过程中自己的梯度量级和收敛速度也是动态变化的,所以很显然这样定值的w做并没有很好的解决问题。作者提出了一种可以动态调整loss的w的算法——GradNorm


从上图可知,GradNorm 是以平衡的梯度作为目标,优化Grad Loss,从而动态调整各个任务的w。

那下面我们就来看看Grad Loss是怎么样的:


要注意的是,上式中,减号后面的项,是基于当轮各任务的梯度所计算出来的常量。其中:


G调节着梯度的量级。r调节着任务收敛速度:收敛速度越快,rir_iri​就越小,从而Gw(i)(t)G^{(i)}_w(t)Gw(i)​(t)应该被优化的变小。

算法步骤如下:

实现如下(引自 GitHub):

            # switch for each weighting algorithm:# --> grad normif args.mode == 'grad_norm':# get layer of shared weightsW = model.get_last_shared_layer()# get the gradient norms for each of the tasks# G^{(i)}_w(t) norms = []for i in range(len(task_loss)):# get the gradient of this task loss with respect to the shared parametersgygw = torch.autograd.grad(task_loss[i], W.parameters(), retain_graph=True)# compute the normnorms.append(torch.norm(torch.mul(model.weights[i], gygw[0])))norms = torch.stack(norms)#print('G_w(t): {}'.format(norms))# compute the inverse training rate r_i(t) # \curl{L}_i if torch.cuda.is_available():loss_ratio = task_loss.data.cpu().numpy() / initial_task_losselse:loss_ratio = task_loss.data.numpy() / initial_task_loss# r_i(t)inverse_train_rate = loss_ratio / np.mean(loss_ratio)#print('r_i(t): {}'.format(inverse_train_rate))# compute the mean norm \tilde{G}_w(t) if torch.cuda.is_available():mean_norm = np.mean(norms.data.cpu().numpy())else:mean_norm = np.mean(norms.data.numpy())#print('tilde G_w(t): {}'.format(mean_norm))# compute the GradNorm loss # this term has to remain constantconstant_term = torch.tensor(mean_norm * (inverse_train_rate ** args.alpha), requires_grad=False)if torch.cuda.is_available():constant_term = constant_term.cuda()#print('Constant term: {}'.format(constant_term))# this is the GradNorm loss itselfgrad_norm_loss = torch.tensor(torch.sum(torch.abs(norms - constant_term)))#print('GradNorm loss {}'.format(grad_norm_loss))# compute the gradient for the weightsmodel.weights.grad = torch.autograd.grad(grad_norm_loss, model.weights)[0]

多任务学习——【ICML 2018】GradNorm相关推荐

  1. 【多任务优化】DWA、DTP、Gradnorm(CVPR 2019、ECCV 2018、 ICML 2018)

    多任务学习模型的优化 有多个task就有多个loss,常见的MTL模型loss可以直接简单的对多个任务的loss相加: L = ∑ i L i L=\sum_{i} L_{i} L=i∑​Li​ 显然 ...

  2. ICML 2018 | 从强化学习到生成模型:40篇值得一读的论文

    https://blog.csdn.net/y80gDg1/article/details/81463731 感谢阅读腾讯AI Lab微信号第34篇文章.当地时间 7 月 10-15 日,第 35 届 ...

  3. pareto最优解程序_NIPS 2018 | 作为多目标优化的多任务学习:寻找帕累托最优解

    原标题:NIPS 2018 | 作为多目标优化的多任务学习:寻找帕累托最优解 选自arXiv 作者:Ozan Sener.Vladlen Koltun 参与:李诗萌.王淑婷 多任务学习本质上是一个多目 ...

  4. ICML2018见闻 | 迁移学习、多任务学习领域的进展

    作者 | Isaac Godfried 译者 | 王天宇 编辑 | Jane 出品 | AI科技大本营 [导读]如今 ICML(International Conference on Machine ...

  5. 密集预测任务的多任务学习(Multi-Task Learning)研究综述 - 网络结构篇(上)

    [ TPAMI 2021 ] Multi-Task Learning for Dense Prediction Tasks: A Survey [ The authors ] • Simon Vand ...

  6. 7篇顶会论文带你梳理多任务学习建模方法

    如果觉得我的算法分享对你有帮助,欢迎关注我的微信公众号"圆圆的算法笔记",更多算法笔记和世间万物的学习记录- 公众号后台回复"多任务",即可获取相关论文资料集合 ...

  7. 【datawhale202206】pyTorch推荐系统:多任务学习 ESMMMMOE

    结论速递 多任务学习是排序模型的一种发展方式,诞生于多任务的背景.实践表明,多任务联合建模可以有效提升模型效果,因其可以:任务互助:实现隐式数据增强:学到通用表达,提高泛化能力(特别是对于一些数据不足 ...

  8. 多任务学习优化总结 Multi-task learning(附代码)

    目录 一.多重梯度下降multiple gradient descent algorithm (MGDA) 二.Gradient Normalization (GradNorm) 三.Uncertai ...

  9. 2021年浅谈多任务学习

    作者 | 多多笔记 来源 |AI部落联盟 头图 | 下载于视觉中国 写此文的动机: 最近接触到的几个大厂推荐系统排序模型都无一例外的在使用多任务学习,比如腾讯PCG在推荐系统顶会RecSys 2020 ...

最新文章

  1. 开源多年后,Facebook这个调试工具,再登Github热门榜
  2. day1 作业编写登录窗口
  3. Centos 的inotify和rsync文件实时同步
  4. 【JEECG TBSchedule】详解应对平台高并发的分布式调度框架TBSchedule
  5. python爬虫解析数据_Python爬虫入门知识:解析数据篇
  6. eclipse中修改项目文件夹目录显示结构
  7. 可以用img做参数的成功例子
  8. Spring中IOC和AOP的详细解释(转)
  9. IT故事:软件测试点亮了我人生的烛光
  10. 解决验证码不显示问题
  11. USBPD充电协议,快充协议IC,PD3.0芯片
  12. 如何降低开关电源空载损耗
  13. cd/etc 文件目录浅解
  14. 英语语法形容词的顺序
  15. Python:实现counting sort计数排序算法(附完整源码)
  16. Android 根据逗号分隔String
  17. YDOOK:Pytorch教程:转置矩阵 转置张量 T
  18. Armbian (jammy) 上安装 Docker
  19. 求解非线性方程组的牛顿法c语言,牛顿下山法求解非线性方程(组)(C实现)...
  20. 哪款文件比较软件适合程序员

热门文章

  1. 山东理工大学ACM平台题答案 2561 九九乘法表
  2. MySQL子查询的优缺点_浅谈mysql的子查询
  3. 【LeetCode】重复元素相关题目
  4. Android Studio 快捷键整理
  5. 【统计学笔记】各种假设检验的假设的建立和各统计量公式总结
  6. 从零开始成为优秀的交互设计师
  7. Mandriva 2009 Spring PWP中3D桌面的使用
  8. PHP 图片转base64编码 和 base64编码字符串转换成图片保存
  9. 图解系统(六)——调度算法
  10. Android计算标准BMI值