概述

torch.optim.lr_scheduler 模块提供了一些根据 epoch 迭代次数来调整学习率 lr 的方法。为了能够让损失函数最终达到收敛的效果,通常 lr 随着迭代次数的增加而减小时能够得到较好的效果。torch.optim.lr_scheduler.ReduceLROnPlateau 则提供了基于训练中某些测量值使学习率动态下降的方法。

学习率的调整应该放在optimizer更新之后,下面是Demo示例:

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = ExponentialLR(optimizer, gamma=0.9)for epoch in range(20):for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()# 1.进行参数的更新optimizer.step()# 2.对学习率进行更新    scheduler.step()
# 注意:现在,在很多祖传代码中,scheduler.step()的位置可能是在参数更新optimizer.step()之前
# 检查您的pytorch版本如果是V1.1.0+,那么需要将scheduler.step()在optimizer.step()之后调用

PyTorch 1.1.0 之前, 学习率更新操作 scheduler.step() 会在 optimizer.step() 操作之前调用; v1.1.0 修改了这种调用机制。如果在 optimizer.step() 之前调用 scheduler.step() , 会自动跳过第一次 lr 的更新。如果更新了 v1.1.0 后您的结果不一样了,请确认是不是在这里的调用顺序有误。

优化器optimizer

为了进一步说明 lr_scheduler 的机制,我们首先需要了解一下 optimizer 的结构,以 Adam() 为例(所有 optimizers 都继承自 torch.optim.Optimizer 类)。
对于 class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

  • params (iterable):需要优化的网络参数,传进来的网络参数必须是Iterable。
  • 优化一个网络,网络的每一层看做一个parameter group,一整个网络就是parameter groups(一般给赋值为net.parameters()——generator的字典);
  • 优化多个网络,有两种方法:
  • 多个网络的参数合并到一起,形如[*net_1.parameters(), *net_2.parameters()]或itertools.chain(net_1.parameters(), net_2.parameters());
  • 当成多个网络优化,让多个网络的学习率各不相同,形如[{‘params’: net_1.parameters()}, {‘params’: net_2.parameters()}]
  • lr (float, optional):学习率;
  • betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999));
  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8);
  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0);
  • amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: False)。

optimizer的属性

  • optimizer.defaults: 字典,存放这个优化器的一些初始参数,有:'lr', 'betas', 'eps', 'weight_decay', 'amsgrad'
  • optimizer.param_groups:列表,每个元素都是一个字典,每个元素包含的关键字有:'params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad'params类是各个网络的参数放在了一起。
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
import itertoolsinitial_lr = 0.1class model(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)def forward(self, x):passnet_1 = model()
net_2 = model()optimizer_1 = torch.optim.Adam(net_1.parameters(), lr = initial_lr)
print("******************optimizer_1*********************")
print("optimizer_1.defaults:", optimizer_1.defaults)
print("optimizer_1.param_groups长度:", len(optimizer_1.param_groups))
print("optimizer_1.param_groups一个元素包含的键:", optimizer_1.param_groups[0].keys())
print()
####################################################################################
******************optimizer_1*********************
optimizer_1.defaults: {'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}
optimizer_1.param_groups长度: 1
optimizer_1.param_groups一个元素包含的键: dict_keys(['params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad'])
####################################################################################optimizer_2 = torch.optim.Adam([*net_1.parameters(), *net_2.parameters()], lr = initial_lr)
# optimizer_2 = torch.opotim.Adam(itertools.chain(net_1.parameters(), net_2.parameters())) # 和上一行作用相同
print("******************optimizer_2*********************")
print("optimizer_2.defaults:", optimizer_2.defaults)
print("optimizer_2.param_groups长度:", len(optimizer_2.param_groups))
print("optimizer_2.param_groups一个元素包含的键:", optimizer_2.param_groups[0].keys())
print()
####################################################################################
******************optimizer_2*********************
optimizer_2.defaults: {'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}
optimizer_2.param_groups长度: 1
optimizer_2.param_groups一个元素包含的键: dict_keys(['params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad'])
####################################################################################optimizer_3 = torch.optim.Adam([{"params": net_1.parameters()}, {"params": net_2.parameters()}], lr = initial_lr)
print("******************optimizer_3*********************")
print("optimizer_3.defaults:", optimizer_3.defaults)
print("optimizer_3.param_groups长度:", len(optimizer_3.param_groups))
print("optimizer_3.param_groups一个元素包含的键:", optimizer_3.param_groups[0].keys())
####################################################################################
******************optimizer_3*********************
optimizer_3.defaults: {'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}
optimizer_3.param_groups长度: 2
optimizer_3.param_groups一个元素包含的键: dict_keys(['params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad'])
####################################################################################

lr_scheduler更新optimizerlr,是更新的optimizer.param_groups[n][‘lr’],而不是optimizer.defaults[‘lr’]

torch.optim.lr_scheduler.LambdaLR

以lambdaLR为例:

CLASS torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=- 1, verbose=False)

  • optimizer (Optimizer) – Wrapped optimizer.
  • lr_lambda (function or list) – A function which computes a multiplicative factor given an integer parameter epoch, or a list of such functions, one for each group in optimizer.param_groups.
  • last_epoch (int) – The index of last epoch. Default: -1.
  • verbose (bool) – If True, prints a message to stdout for each update. Default: False.

Demo示例:

# Assuming optimizer has two groups.
lambda1 = lambda epoch: epoch // 30
lambda2 = lambda epoch: 0.95 ** epoch
scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
for epoch in range(100):train(...)validate(...)scheduler.step()

CLASS torch.optim.lr_scheduler.LambdaLR实例函数

  • get_last_lr()

返回上次计算的lr

  • print_lr(is_verbose, group, lr, epoch=None)

打印当前lr

  • state_dict()

Returns the state of the scheduler as a dict.
It contains an entry for every variable in self.dict which is not the optimizer. The learning rate lambda functions will only be saved if they are callable objects and not if they are functions or lambdas.
When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.

  • load_state_dict(state_dict)

Loads the schedulers state.
When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.

后期会写一个完整的训练代码出来,给后面的项目进行参考

参考文献

Pytorch官方文档
csdn中关于lr_scheduler用法的详细介绍

【Pytorch教程】使用lr_scheduler调整学习率相关推荐

  1. pytorch 动态调整学习率,学习率自动下降,根据loss下降

    0 为什么引入学习率衰减? 我们都知道几乎所有的神经网络采取的是梯度下降法来对模型进行最优化,其中标准的权重更新公式: W+=α∗gradient W+=\alpha * \text { gradie ...

  2. Pytorch不同层设置不同学习率

    1 主要目标 不同的参数可能需要不同的学习率,本文主要实现的是不同层中参数的不同学习率设置. 尤其是当我们在使用预训练的模型时,需要对一些除了主干网络以外的分支进行单独修改并进行初始化,其他主干网络层 ...

  3. pytorch中调整学习率的lr_scheduler机制

    pytorch中调整学习率的lr_scheduler机制 </h1><div class="clear"></div><div class ...

  4. 【PyTorch】lr_scheduler.StepLR==>调整学习率的方法

    lr_scheduler.StepLR class torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoc ...

  5. pytorch中调整学习率: torch.optim.lr_scheduler

    文章翻译自:https://pytorch.org/docs/stable/optim.html torch.optim.lr_scheduler 中提供了基于多种epoch数目调整学习率的方法. t ...

  6. PyTorch框架学习十四——学习率调整策略

    PyTorch框架学习十四--学习率调整策略 一._LRScheduler类 二.六种常见的学习率调整策略 1.StepLR 2.MultiStepLR 3.ExponentialLR 4.Cosin ...

  7. asp文本框输入控制是5的倍数_DL知识拾贝(Pytorch)(五):如何调整学习率

    知识导图 学习率对于深度学习是一个重要的超参数,它控制着基于损失梯度调整神经网络权值的速度,大多数优化算法(SGD.RMSprop.Adam)对其都有所涉及.学习率过下,收敛的太慢,网络学习的也太慢: ...

  8. pytorch 学习率代码_DL知识拾贝(Pytorch)(五):如何调整学习率

    知识导图 学习率对于深度学习是一个重要的超参数,它控制着基于损失梯度调整神经网络权值的速度,大多数优化算法(SGD.RMSprop.Adam)对其都有所涉及.学习率过下,收敛的太慢,网络学习的也太慢: ...

  9. [翻译Pytorch教程]NLP部分:使用TorchText进行文本分类

    本教程展示如何在torchtext中调用文本分类数据集,包括: AG_NEWS, SogouNews, DBpedia, YelpReviewPolarity, YelpReviewFull, Yah ...

最新文章

  1. git 拉取远程分支及修改远程仓库地址
  2. [我的1024开源程序]100元写的单词本说明书
  3. 机器学习中的算法(2)-支持向量机(SVM)基础
  4. 【题解】lugu P4095 Eden的新背包问题
  5. 有效的数独Python解法
  6. php mysql 作业计划,关于php:我需要使用cron作业每30分钟恢复一次数据库(mysql)
  7. Android ListView使用
  8. C语言实现hello world代码
  9. 21个php常用方法汇总
  10. linux下MySQL安装及设置
  11. cjson读取json文件_JSON格式介绍和使用cJSON解析 | 学步园
  12. TwinCAT 3 基础——安装
  13. To Kill a Mockingbird(杀死一只反舌鸟)简记
  14. 我所热爱的多触摸系统 bill buxton
  15. CAS单点登录:CAS客户端搭建(整合Shiro和Spring Boot)
  16. Scala----特质trait的使用
  17. matlab 双y轴对数坐标 误差线,matlab双y轴添加误差棒(转载)
  18. 手机开机后Spreadtrum Factory Test phone test factory used full phone test item test BT EUT exit
  19. 内是不是半包围结构_半包围结构是什么意思 半包围结构字的书写规则
  20. Vue脚手架创建项目流程

热门文章

  1. HTML+CSS仿京东购物车页面静态页面
  2. Android逆向(一)Android逆向工具(一)
  3. python求列表的峰度系数
  4. 20212高考成绩查询,高考查分手机客户端
  5. 新建pycharm项目只输出Hi pycharm 怎么办
  6. 几种不同的推荐引擎比较
  7. 对Softmax函数的理解
  8. 【ffmpeg】-fflags nobuffer 会导致 av_find_stream_info失败
  9. Linux 输出的重定向
  10. 锋利的SQL2014:联接算法