多任务学习pytorch使用不同学习率同时训练多个网络的方法
多任务学习时需要多个网络一起训练,并设置不同的学习率,pytorch中有以下几种方法:
首先网络设置如下:
import torch# Encoder参数共享 Decoder分别训练
Encoder = SharedEncoder().cuda()
Dose_decoder = Dose_prediction().cuda()
Gra_decoder = Gradient_regression().cuda()criterion = torch.nn.MSELoss()
lr1 = 0.0002
lr2 = 0.0001
方法一:
# 使用三个优化器
opt1 = torch.optim.Adam(Encoder.parameters(), lr=lr1)
opt2 = torch.optim.Adam(Dose_decoder.parameters(), lr=lr1)
opt3 = torch.optim.Adam(Gra_decoder.parameters(), lr=lr2)Encoder.train()
Dose_decoder.train()
Gra_decoder.train()for input, label, gra_label, _ in tqdm.tqdm(train_dataloaders):inputs = input.cuda()labels = label.cuda()gra_labels = gra_label.cuda()opt1.zero_grad()opt2.zero_grad()opt3.zero_grad()conv1, conv2, conv3, conv4, center = Encoder(inputs)outputs = Dose_decoder(conv1, conv2, conv3, conv4, center)output_gra = Gra_decoder(conv1, conv2, conv3, conv4, center)loss_main = criterion(outputs, labels)loss_gra = criterion(output_gra, gra_labels)# 损失要加起来反向传播loss = 10 * loss_main + 5 * loss_graloss.backward()opt1.step()opt2.step()opt3.step()
方法二:
# 如果对某个网络不设置学习率,则使用最外层的lr
optimizer = torch.optim.Adam([{'params': Encoder.parameters()},{'params': Dose_decoder.parameters()},{'params': Gra_decoder.parameters(),'lr': lr2}], lr=lr1)
方法三:
使用python内置库itertools的chain方法
from itertools import chainoptimizer = torch.optim.Adam(params=chain(Encoder.parameters(), Dose_decoder.parameters(), Gra_decoder.parameters()), lr=0.0001)
多任务学习pytorch使用不同学习率同时训练多个网络的方法相关推荐
- pytorch:如何从头开始训练一个CNN网络?
文章目录 前言 一.CNN? 二.用单批量测试模型 1.引入库 2.读入数据集 3. 建造Module实例 4. 训练 总结 前言 在刚开始学习Deep Learning时,一件几乎不可能的事情就是知 ...
- 多任务学习 Pytorch实现
多任务学习MTL的简单实现,主要是为了理解MTL 代码写得挺烂的,有时间回来整理一下 import torch import torch.nn as nn import torchvision imp ...
- 多任务学习原理与优化
文章目录 一.什么是多任务学习 二.为什么我们需要多任务学习 三.多任务学习模型演进 Hard shared bottom 硬共享 Soft shared bottom 软共享 软共享: MOE &a ...
- 多任务学习(Multi-Task Learning, MTL)
目录 [显示] 1 背景 2 什么是多任务学习? 3 多任务学习如何发挥作用? 3.1 提高泛化能力的潜在原因 3.2 多任务学习机制 3.3 后向传播多任务学习如何发现任务是相关的 4 多任务学习可 ...
- 从零学习pytorch 第1课 搭建一个超简单的网络
课程目录(在更新,喜欢加个关注点个赞呗): 从零学习pytorch 第1课 搭建一个超简单的网络 从零学习pytorch 第1.5课 训练集.验证集和测试集的作用 从零学习pytorch 第2课 Da ...
- 从零学习pytorch 第2课 Dataset类
课程目录(在更新,喜欢加个关注点个赞呗): 从零学习pytorch 第1课 搭建一个超简单的网络 从零学习pytorch 第1.5课 训练集.验证集和测试集的作用 从零学习pytorch 第2课 Da ...
- 综述翻译:多任务学习-An Overview of Multi-Task Learning in Deep Neural Networks
An Overview of Multi-Task Learning in Deep Neural Networks 文章目录 An Overview of Multi-Task Learning i ...
- 【阅读笔记】多任务学习之PLE(含代码实现)
本文作为自己阅读论文后的总结和思考,不涉及论文翻译和模型解读,适合大家阅读完论文后交流想法. PLE 一. 全文总结 二. 研究方法 三. 结论 四. 创新点 五. 思考 六. 参考文献 七. Pyt ...
- 8_用opencv调用深度学习框架tenorflow、Pytorch、Torch、caffe训练好的模型(20190212)
用opencv调用深度学习框架tenorflow.Pytorch.Torch.caffe训练好模型(20190212) 文章目录: https://blog.csdn.net/hust_bochu_x ...
最新文章
- linux 卸载模块命令,Linux中module模块的编译、加载、卸载
- 2017,SAP向云看齐
- Java设计链表(不带头结点的单链表)
- WordPress后台添加侧边栏菜单
- 分页查询插件PageHelper 5.x版本
- LeetCode 632. 最小区间(排序+滑动窗口)
- Nutanix,在转型的道路上越走越远 | 人物志
- Android自定义View构造函数详解
- (cljs/run-at (JSVM. :browser) 简单类型可不简单啊~)
- @程序员:别人身边的小姐姐是这样来的,你能学学吗
- 常见花材的固定的方法有哪些_波峰焊喷嘴的常见故障及处理方法有哪些
- 链表的自顶向下归并排序
- java淘淘商城_淘淘商城-张志君分布式电商视频教程 下载
- MessageQueue的使用方法(一)
- 03 在CentOS7中安装oracle11g
- 阿里测开岗定级P7全流程加面试真题
- Spring关于AOP配置举例(注解方式)
- MATLAB中wcp什么意思,WCP是什么意思
- Python灰帽子黑客与逆向工程师的Python编程之道
- R语言28-Prosper 贷款数据分析4
热门文章
- html调用app store,调用App Store Connect Api
- IT领域标准化浅谈(二):中国IT领域标准制定工作程序
- SAP查看SPRO配置对应的事务码
- 解决导出EXCEL自动将长的数字的字符串变成E+的科学计数法
- 【Y忍冬草】Qt使用中一些小知识
- (附源码)计算机毕业设计SSM基于web的健康饮食信息管理系统
- 打开方式对话框 及 RUNDLL32(RUNDLL)的使用
- 发光字招牌色彩的意义
- 只有突破你的思维,你才能真正做大,只有少数人明白的顶尖思维!
- 一文看懂“声纹识别VPR” | AI产品经理需要了解的AI技术概念_团员分享_@cony