1、冻结训练

为什么要冻结,a、因为加载的预训练模型参数一般是效果较好的,如果我们不冻结还是从头开始训练浪费资源(本来可以较快收敛的,结果从头开始训练浪费训练资源)甚至降低原模型精度;b、冻结训练需要的显存较小,显卡非常差的情况下,可以冻结backbone等部分,剩余的进行微调

1.1初始化模型(可以略过)

初始化模型参数

torch.init.normal_

给tensor初始化,一般是给网络中参数weight初始化,初始化参数值符合正态分布。

torch.init.normal_(tensor,mean=,std=) 

mean:均值,std:正态分布的标准差

 torch.init.constant_

初始化参数使其为常值,即每个参数值都相同。一般是给网络中bias进行初始化。

torch.nn.init.constant_(tensor,val)

给不同层使用不同的初始化策略

pytorch的初始化方式总结 - 知乎

 for m in self.children():if isinstance(m, nn.Linear):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, -100)# 也可以判断是否为conv2d,使用相应的初始化方式 elif isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight.item(), 1)nn.init.constant_(m.bias.item(), 0)    

1.2 加载预训练模型参数

参考https://blog.csdn.net/weixin_44791964?type=blog 博主的博客及视频及代码仓库。

模型的 预训练权重 比较重要的部分是 主干特征提取网络的权值部分,用于进行特征提取。预训练权重对于99%的情况都必须要用,不用的话主干部分的权值太过随机,特征提取效果不明显,网络训练的结果也不会好。训练自己的数据集时提示维度不匹配正常,预测的东西都不一样了自然维度不匹配

torch.load()加载模型及其map_location参数_eecspan的博客-CSDN博客_torch加载模型

torch.load()

函数格式为:torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args),一般我们使用的时候,基本只使用前两个参数。

模型的保存

模型保存有两种形式,一种是保存模型的state_dict(),只是保存模型的参数。那么加载时需要先创建一个模型的实例model,之后通过torch.load()将保存的模型参数加载进来,得到dict,再通过model.load_state_dict(dict)将模型的参数更新。
        另一种是将整个模型保存下来,之后加载的时候只需要通过torch.load()将模型加载,即可返回一个加载好的模型。
具体可参考:PyTorch模型的保存与加载。

模型加载中的map_location参数

具体来说,map_location参数是用于重定向,比如此前模型的参数是在cpu中的,我们希望将其加载到cuda:0中。或者我们有多张卡,那么我们就可以将卡1中训练好的模型加载到卡2中,这在数据并行的分布式深度学习中可能会用到。

如果预训练模型与定义模型层数不匹配,使用如下代码,剥出你需要的层的参数

pretrained_dict = torch.load(pretrained_model)
model_dict = model.state_dict()
pretrained_dict = {
k: v for k, v in pretrained_dict.items()
if (k in model_dict and 'Prediction' not in k)
}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

1.3 冻结参数

深度学习基本功2:网络训练小技巧之使用预训练权重、冻结训练和断点恢复 - 知乎

# 冻结阶段训练参数,learning_rate和batch_size可以设置大一点
Init_Epoch          = 0
Freeze_Epoch        = 50
Freeze_batch_size   = 8
Freeze_lr           = 1e-3
# 解冻阶段训练参数,learning_rate和batch_size设置小一点
UnFreeze_Epoch      = 100
Unfreeze_batch_size = 4
Unfreeze_lr         = 1e-4
# 可以加一个变量控制是否进行冻结训练
Freeze_Train        = True
# 冻结一部分进行训练
batch_size  = Freeze_batch_size
lr          = Freeze_lr
start_epoch = Init_Epoch
end_epoch   = Freeze_Epoch
if Freeze_Train:for param in model.backbone.parameters():param.requires_grad = False
# 解冻后训练
batch_size  = Unfreeze_batch_size
lr          = Unfreeze_lr
start_epoch = Freeze_Epoch
end_epoch   = UnFreeze_Epoch
if Freeze_Train:for param in model.backbone.parameters():param.requires_grad = True

如果不进行冻结训练,一定要注意参数设置,注意上述代码中冻结阶段和解冻阶段的learning_rate和batch_size是不一样的,另外起始epoch和结束epoch也要重新调整一下。如果是从0开始训练模型(不使用预训练权重),那么一定不能进行冻结训练

在冻结部分,重新设置了学习率,batchsize。

冻结阶段:为了快速收敛,可以加大学习率及batchsize(此时由于部分参数冻结不需要计算梯度所以可以加大这些)

解冻阶段:所有层都需要计算梯度等加大开销,所以学习率及batchsize的设置不能再像冻结阶段设置的那么大

因为冻结阶段涉及到学习率的设计,所以也写一下学习率的介绍

2、优化器

PyTorch 源码解读之 torch.optim:优化算法接口详解 - 知乎

这个博主写的非常详细,看完就悟了

1.0 基本用法

  • 优化器主要是在模型训练阶段对模型可学习参数进行更新, 常用优化器有 SGD,RMSprop,Adam等
  • 优化器初始化时传入传入模型的可学习参数,以及其他超参数如 lrmomentum
  • 在训练过程中先调用 optimizer.zero_grad() 清空梯度,再调用 loss.backward() 反向传播,最后调用 optimizer.step()更新模型参数
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)

1.1 PyTorch 中的优化器

所有优化器都是继承父类 Optimizer,如下列表是 PyTorch 提供的优化器:

  • SGD
  • ASGD
  • Adadelta
  • Adagrad
  • Adam
  • AdamW
  • Adamax
  • SparseAdam
  • RMSprop
  • Rprop
  • LBFGS

1.2 父类Optimizer 基本原理

Optimizer 是所有优化器的父类,它主要有如下公共方法:

  • add_param_group(param_group): 添加模型可学习参数组
  • step(closure): 进行一次参数更新
  • zero_grad(): 清空上次迭代记录的梯度信息
  • state_dict(): 返回 dict 结构的参数状态
  • load_state_dict(state_dict): 加载 dict 结构的参数状态

3、学习率

参考文章同优化器的那篇

有了优化器,还需要根据 epoch 来调整学习率,lr_schedluer提供了在训练模型时学习率的调整策略。

目前 PyTorch 提供了如下学习率调整策略:

  • StepLR: 等间隔调整策略
  • MultiStepLR: 多阶段调整策略
  • ExponentialLR: 指数衰减调整策略
  • ReduceLROnPlateau: 自适应调整策略
  • CyclicLR: 循环调整策略
  • OneCycleLR: 单循环调整策略
  • CosineAnnealingLR: 余弦退火调整策略
  • CosineAnnealingWarmRestarts: 带预热启动的余弦退火调整策略
  • LambdaLR: 自定义函数调整策略
  • MultiplicativeLR: 乘法调整策略

基类: _LRScheduler

学习率调整类主要的逻辑功能就是每个 epoch 计算参数组的学习率,更新 optimizer对应参数组中的lr值,从而应用在optimizer里可学习参数的梯度更新。所有的学习率调整策略类的父类是torch.optim.lr_scheduler._LRScheduler,基类 _LRScheduler 定义了如下方法:

  • step(epoch=None): 子类公用
  • get_lr(): 子类需要实现
  • get_last_lr(): 子类公用
  • print_lr(is_verbose, group, lr, epoch=None): 显示 lr 调整信息
  • state_dict(): 子类可能会重写
  • load_state_dict(state_dict): 子类可能会重写

剩余部分见原博客

冻结训练优化器学习率相关推荐

  1. 中国博士生提出最先进AI训练优化器,收敛快精度高,网友亲测:Adam可以退休了...

    栗子 鱼羊 晓查 发自 凹非寺  量子位 报道 | 公众号 QbitAI 找到一种快速稳定的优化算法,是所有AI研究人员的目标. 但是鱼和熊掌不可兼得.Adam.RMSProp这些算法虽然收敛速度很快 ...

  2. pytorch优化器学习率调整策略以及正确用法

    优化器 optimzier优化器的作用:优化器就是需要根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用. 从优化器的作用出发,要使得优化器能够起作用,需要主要两个东西: ...

  3. 图像语义分割实践(五)优化器与学习率

    概述 在数据制作环节中,提到minibatch思想用于数据批次量获取,是一种优化器思想,而该文则是对各种优化器进行介绍. 优化器:最小化损失函数算法,把深度学习当炼丹的话,优化器就是炉子,决定火候大小 ...

  4. Pytorch:优化器、损失函数与深度神经网络框架

    Pytorch: 优化器.损失函数与深度神经网络框架 Copyright: Jingmin Wei, Pattern Recognition and Intelligent System, Schoo ...

  5. pytorch学习笔记十二:优化器

    前言 机器学习中的五个步骤:数据 --> 模型 --> 损失函数 --> 优化器 --> 迭代训练,通过前向传播,得到模型的输出和真实标签之间的差异,也就是损失函数,有了损失函 ...

  6. 优化器 optimizer

    优化器 optimizer optimizer 优化器,用来根据参数的梯度进行沿梯度下降方向进行调整模型参数,使得模型loss不断降低,达到全局最低,通过不断微调模型参数,使得模型从训练数据中学习进行 ...

  7. Pytorch模型训练实用教程学习笔记:四、优化器与学习率调整

    前言 最近在重温Pytorch基础,然而Pytorch官方文档的各种API是根据字母排列的,并不适合学习阅读. 于是在gayhub上找到了这样一份教程<Pytorch模型训练实用教程>,写 ...

  8. pytorch优化器与学习率设置详解

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | 小新 来源 | https://a.3durl.cn/Yr ...

  9. 让大规模深度学习训练线性加速、性能无损,基于BMUF的Adam优化器并行化实践...

    导语:深度学习领域经典的 Adam 算法在大规模并行训练的情况下会导致模型性能损失.为了解决这一问题,微软亚洲研究院采用 BMUF 框架对 Adam 算法进行了并行化,并在微软大规模 OCR 和语音产 ...

最新文章

  1. 10个经典又容易被人疏忽的JVM面试题
  2. J2EE(一)——开发简单WEB服务器
  3. rfid4-写成platform驱动
  4. JDK7新特性简单翻译介绍
  5. [转]图片格式WEBP全面解析
  6. YBTOJ:圈套问题(分治法、鸽笼原理)
  7. 移动端怎么让底部固定_移动端排名应该怎么做?两种匹配移动端实战排名干货分享!...
  8. Visual Studio 类视图和“对象浏览器”图标含义
  9. 如何实现自动化前端开发?
  10. 返回顶部的几种方法总结
  11. RS422--ARINC429通讯转换模块 RS422支持全双工通讯接口,通讯速率可设置,ARINC429支持发送和接收
  12. box-sizing属性的理解
  13. 一位计算机专业硕士毕业生的求职经历和感想
  14. 周集中团队Nature子刊中网络图布局的R语言可视化复现
  15. Learn Go with tests 学习笔记(9)——Mocking
  16. contour 函数详解
  17. 【CSS特效扫盲】精选40种纯CSS特效应用实例,肝了10个晚上整理纯CSS特效(上)(附源码下载)
  18. Java通过二维码下载Apk====安卓手机
  19. 像素鸟html代码,flappy-bird方块版(用小方块替代像素鸟)
  20. 机器学习与知识发现电子书_2019年,5本关于机器学习的免费电子书你应该知道(有资源)...

热门文章

  1. linux蓝牙不识别微软鼠标,windows10系统下无法找到蓝牙鼠标的解决方案
  2. 开源架构Fabric、FISCO BCOS(以下简称“BCOS”)、CITA 技术对比
  3. 2021年广东省安全员A证第三批(主要负责人)报名考试及最全题库免费模拟考试
  4. “/wechat”应用程序中的服务器错误。
  5. 6种自媒体赚钱方法!
  6. Android ConstraintLayout约束布局的使用
  7. android 仿新浪微博发现效果
  8. Blender 安装
  9. MTK6572关于相机默认像素问题
  10. anyHouse-iOS 高仿ClubHouse