多任务学习时需要多个网络一起训练,并设置不同的学习率,pytorch中有以下几种方法:

首先网络设置如下:

import torch# Encoder参数共享  Decoder分别训练
Encoder = SharedEncoder().cuda()
Dose_decoder = Dose_prediction().cuda()
Gra_decoder = Gradient_regression().cuda()criterion = torch.nn.MSELoss()
lr1 = 0.0002
lr2 = 0.0001

方法一:

# 使用三个优化器
opt1 = torch.optim.Adam(Encoder.parameters(), lr=lr1)
opt2 = torch.optim.Adam(Dose_decoder.parameters(), lr=lr1)
opt3 = torch.optim.Adam(Gra_decoder.parameters(), lr=lr2)Encoder.train()
Dose_decoder.train()
Gra_decoder.train()for input, label, gra_label, _ in tqdm.tqdm(train_dataloaders):inputs = input.cuda()labels = label.cuda()gra_labels = gra_label.cuda()opt1.zero_grad()opt2.zero_grad()opt3.zero_grad()conv1, conv2, conv3, conv4, center = Encoder(inputs)outputs = Dose_decoder(conv1, conv2, conv3, conv4, center)output_gra = Gra_decoder(conv1, conv2, conv3, conv4, center)loss_main = criterion(outputs, labels)loss_gra = criterion(output_gra, gra_labels)# 损失要加起来反向传播loss = 10 * loss_main + 5 * loss_graloss.backward()opt1.step()opt2.step()opt3.step()

方法二:

# 如果对某个网络不设置学习率,则使用最外层的lr
optimizer = torch.optim.Adam([{'params': Encoder.parameters()},{'params': Dose_decoder.parameters()},{'params': Gra_decoder.parameters(),'lr': lr2}], lr=lr1)

方法三:
使用python内置库itertools的chain方法

from itertools import chainoptimizer = torch.optim.Adam(params=chain(Encoder.parameters(), Dose_decoder.parameters(), Gra_decoder.parameters()), lr=0.0001)

多任务学习pytorch使用不同学习率同时训练多个网络的方法相关推荐

  1. pytorch:如何从头开始训练一个CNN网络?

    文章目录 前言 一.CNN? 二.用单批量测试模型 1.引入库 2.读入数据集 3. 建造Module实例 4. 训练 总结 前言 在刚开始学习Deep Learning时,一件几乎不可能的事情就是知 ...

  2. 多任务学习 Pytorch实现

    多任务学习MTL的简单实现,主要是为了理解MTL 代码写得挺烂的,有时间回来整理一下 import torch import torch.nn as nn import torchvision imp ...

  3. 多任务学习原理与优化

    文章目录 一.什么是多任务学习 二.为什么我们需要多任务学习 三.多任务学习模型演进 Hard shared bottom 硬共享 Soft shared bottom 软共享 软共享: MOE &a ...

  4. 多任务学习(Multi-Task Learning, MTL)

    目录 [显示] 1 背景 2 什么是多任务学习? 3 多任务学习如何发挥作用? 3.1 提高泛化能力的潜在原因 3.2 多任务学习机制 3.3 后向传播多任务学习如何发现任务是相关的 4 多任务学习可 ...

  5. 从零学习pytorch 第1课 搭建一个超简单的网络

    课程目录(在更新,喜欢加个关注点个赞呗): 从零学习pytorch 第1课 搭建一个超简单的网络 从零学习pytorch 第1.5课 训练集.验证集和测试集的作用 从零学习pytorch 第2课 Da ...

  6. 从零学习pytorch 第2课 Dataset类

    课程目录(在更新,喜欢加个关注点个赞呗): 从零学习pytorch 第1课 搭建一个超简单的网络 从零学习pytorch 第1.5课 训练集.验证集和测试集的作用 从零学习pytorch 第2课 Da ...

  7. 综述翻译:多任务学习-An Overview of Multi-Task Learning in Deep Neural Networks

    An Overview of Multi-Task Learning in Deep Neural Networks 文章目录 An Overview of Multi-Task Learning i ...

  8. 【阅读笔记】多任务学习之PLE(含代码实现)

    本文作为自己阅读论文后的总结和思考,不涉及论文翻译和模型解读,适合大家阅读完论文后交流想法. PLE 一. 全文总结 二. 研究方法 三. 结论 四. 创新点 五. 思考 六. 参考文献 七. Pyt ...

  9. 8_用opencv调用深度学习框架tenorflow、Pytorch、Torch、caffe训练好的模型(20190212)

    用opencv调用深度学习框架tenorflow.Pytorch.Torch.caffe训练好模型(20190212) 文章目录: https://blog.csdn.net/hust_bochu_x ...

最新文章

  1. linux 卸载模块命令,Linux中module模块的编译、加载、卸载
  2. 2017,SAP向云看齐
  3. Java设计链表(不带头结点的单链表)
  4. WordPress后台添加侧边栏菜单
  5. 分页查询插件PageHelper 5.x版本
  6. LeetCode 632. 最小区间(排序+滑动窗口)
  7. Nutanix,在转型的道路上越走越远 | 人物志
  8. Android自定义View构造函数详解
  9. (cljs/run-at (JSVM. :browser) 简单类型可不简单啊~)
  10. @程序员:别人身边的小姐姐是这样来的,你能学学吗
  11. 常见花材的固定的方法有哪些_波峰焊喷嘴的常见故障及处理方法有哪些
  12. 链表的自顶向下归并排序
  13. java淘淘商城_淘淘商城-张志君分布式电商视频教程 下载
  14. MessageQueue的使用方法(一)
  15. 03 在CentOS7中安装oracle11g
  16. 阿里测开岗定级P7全流程加面试真题
  17. Spring关于AOP配置举例(注解方式)
  18. MATLAB中wcp什么意思,WCP是什么意思
  19. Python灰帽子黑客与逆向工程师的Python编程之道
  20. R语言28-Prosper 贷款数据分析4

热门文章

  1. html调用app store,调用App Store Connect Api
  2. IT领域标准化浅谈(二):中国IT领域标准制定工作程序
  3. SAP查看SPRO配置对应的事务码
  4. 解决导出EXCEL自动将长的数字的字符串变成E+的科学计数法
  5. 【Y忍冬草】Qt使用中一些小知识
  6. (附源码)计算机毕业设计SSM基于web的健康饮食信息管理系统
  7. 打开方式对话框 及 RUNDLL32(RUNDLL)的使用
  8. 发光字招牌色彩的意义
  9. 只有突破你的思维,你才能真正做大,只有少数人明白的顶尖思维!
  10. 一文看懂“声纹识别VPR” | AI产品经理需要了解的AI技术概念_团员分享_@cony