torch.optim是一个实现各种优化算法的包。

如何使用优化器

创建优化器

首先需要构造一个优化器对象,该对象将保持当前状态,并根据计算的梯度更新参数。

​要构造优化器,必须给它一个包含要优化的参数的迭代对象(所有参数都应该是变量s)。然后,指定特定于优化器的选项,如学习率、权重衰减等。

如果需要通过.cuda()将模型移动到GPU,需要在为其构建优化器之前执行此操作。.cuda()操作之后的模型参数将与调用之前的对象不同。

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)

每个参数选项

Optimizer 还支持指定每个参数选项。​需要传入dict的iterable,而不是传递变量的iterable​. 它们中的每一个都将定义一个单独的参数组,并且应该包含一个params键,其中包含属于它的参数列表。其他键应与优化器接受的关键字参数匹配,并将用作此组的优化选项。

NOTE

仍然可以将选项作为关键字参数传递。在没有覆盖它们的组中,它们将用作默认值。如果只想更改单个选项,同时在参数组之间保持所有其他选项的一致性时,这非常有用。

例如,当需要指定每层学习速率时,这非常有用:

optim.SGD([{'params': model.base.parameters()},{'params': model.classifier.parameters(), 'lr': 1e-3}], lr=1e-2, momentum=0.9)

这意味着model.base的参数将使用默认学习率1e-2,model.classifier的参数将使用1e-3的学习率,所有参数将使用0.9的动量。

采取优化步骤

​所有优化器都会实现一个更新参数的step()方法。它可以通过两种方式使用:​

方法一:optimizer.step()

是大多数优化器支持的简化版本。使用backward()计算梯度后,即可调用该函数。

for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()

方法二:optimizer.step(closure)

一些优化算法(如共轭梯度法和LBFGS)需要多次重新计算函数,因此必须传入一个闭包,以便它们重新计算模型。闭合应清除梯度,计算损失,然后返回。

for input, target in dataset:def closure():optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()return lossoptimizer.step(closure)

基类

CLASStorch.optim.Optimizer(paramsdefaults)[SOURCE]

这是所有优化器的基类。

WARNING

需要将参数指定为集合,这些集合具有在运行之间一致的确定顺序

Parameters

  • params (iterable) – 一个torch.tensor s或dict s,是指定应优化的张量。

  • defaults – (dict): 包含优化选项默认值的dict(当参数组未指定它们时使用)

Optimizer.add_param_group

​ 将param组添加到优化器的param_group。

Optimizer.load_state_dict

加载优化器状态。

Optimizer.state_dict

​ 以dict的形式返回优化器的状态。

Optimizer.step

执行单个优化步骤(参数更新)。

Optimizer.zero_grad

​ 设置所有优化的torch.tensor的梯度值为零。

优化算法

Adadelta

实现Adadelta算法。

Adagrad

实现Adagrad算法。

Adam

实现Adam算法。

AdamW

实现AdamW算法。

SparseAdam

实现适用于稀疏张量的Adam算法的lazy版本。

Adamax

实现Adamax算法(基于无穷范数的Adam变体)。

ASGD

实现平均随机梯度下降。

LBFGS

实现L-BFGS算法,深受minFunc启发

NAdam

实现NAdam算法。

RAdam

实现RAdam算法。

RMSprop

实现RMSprop算法。

Rprop

实现Rpro算法。

SGD

实现随机梯度下降(可选择使用动量)。

如何调整学习率

torch.optim.lr_scheduler 提供了几种方法来根据时代数调整学习速率。torch.optim.lr_scheduler.ReduceLROnPlateau 允许基于某些验证度量动态降低学习速率。

学习率调度应在优化器更新后应用:

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = ExponentialLR(optimizer, gamma=0.9)for epoch in range(20):for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()scheduler.step()

大多数学习速率调度器可以称为back-to-back(也称为链式调度器)。结果是,每个调度器都会根据前一个调度器所获得的学习速率逐一应用:

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)for epoch in range(20):for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()scheduler1.step()scheduler2.step()

WARNING

在PyTorch 1.1.0之前,在优化器更新之前调用学习率调度器;在版本1.1.0以后,如果在优化器更新(调用optimizer.step())之前使用学习速率调度器(调用scheduler.step()),这将跳过学习率的第一个值,导致升级到PyTorch 1.1.0后无法复制结果,所以需要在优化器更新之后调用学习率调度器。

lr_scheduler.LambdaLR

将每个参数组的学习速率设置为给定函数的初始lr倍。

lr_scheduler.MultiplicativeLR

将每个参数组的学习率乘以指定函数中给定的系数。

lr_scheduler.StepLR

按gamma每个步长时间衰减每个参数组的学习率。

lr_scheduler.MultiStepLR

一旦epoch数达到一个值,则通过gamma衰减每个参数组的学习速率。

lr_scheduler.ConstantLR

将每个参数组的学习速率衰减一个小的常数因子,直到epoch数达到预定义的值:total_iters。

lr_scheduler.LinearLR

通过线性改变小的乘法因子来衰减每个参数组的学习速率,直到epoch数达到预定义的里程碑:total_iters。

lr_scheduler.ExponentialLR

每个epoch用gamma衰减每个参数组的学习率。

lr_scheduler.CosineAnnealingLR

使用余弦退火策略设置每个参数组的学习速率,其中ηmax​设置为初始lr,Tcur​是自SGDR中上次重新启动以来的epoch数:

lr_scheduler.ChainedScheduler

链学习速率调度器列表。

lr_scheduler.SequentialLR

接收在优化过程中预期按顺序调用的调度程序列表和里程碑点,里程碑点提供了准确的间隔,以反映在给定的时间段应该调用哪个调度程序。

lr_scheduler.ReduceLROnPlateau

当指标停止改善时,降低学习率。

lr_scheduler.CyclicLR

根据循环学习速率策略(CLR)设置每个参数组的学习速率。

lr_scheduler.OneCycleLR

根据1周期学习率策略设置每个参数组的学习率。

lr_scheduler.CosineAnnealingWarmRestarts

使用余弦退火计划设置每个参数组的学习率,其中ηmax​设置为初始lr、Tcur​是自上次重新启动以来的epoch数,Ti​是SGDR中两次热重启之间的epoch数:

随机加权平均

torch.optim.swa_utils 实现随机权重平均(SWA),是 torch.optim.swa_utils.AveragedModel类实现SWA模型 torch.optim.swa_utils.SWALR 实现SWA学习速率调度器 torch.optim.swa_utils.update_bn() 是一个用于在培训结束时更新SWA批次规范化统计信息的函数。

构建平均模型

AveragedModel类用于计算SWA模型的权重。可以通过运行以下命令创建平均模型:

>>> swa_model = AveragedModel(model)

这里的模型可以是任意的 torch.nn.Module 对象. swa_model 将跟踪模型参数的运行平均值. 要更新这些平均值,可以使用update_parameters()函数:

>>> swa_model.update_parameters(model)

SWA学习率计划

通常,在SWA中,学习率设置为较高的常数值。SWALR是一个学习速率调度器,它将学习速率退火为固定值,然后保持不变。

例如,以下代码创建了一个调度器,该调度器在每个参数组的5个时间段内将学习速率从其初始值线性退火为0.05:

>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \
>>>         anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)

您还可以通过设置anneal_strategy=“cos”,将余弦退火改为固定值,而不是线性退火。

处理批次规范化

update_bn()是一个实用函数,允许在训练结束时计算给定数据加载程序加载程序上SWA模型的batchnorm统计信息:

>>> torch.optim.swa_utils.update_bn(loader, swa_model)

update_bn()将swa_model应用于dataloader中的每个元素,并计算模型中每个批次规范化层的激活统计信息。

WARNING

update_bn()假设dataloader加载程序中的每个批次都是张量或张量列表,其中第一个元素是应应用网络swa_model的张量。如果数据加载器具有不同的结构,则可以通过在数据集的每个元素上向前传递swa_model来更新swa_model的批处理规范化统计信息。

自定义平均策略

默认情况下,torch.optim.swa_utils.AveragedModel计算参数的运行平均值,但也可以将自定义平均函数与avg_fn参数一起使用。在以下示例中,ema_model计算指数移动平均值。

Example:

>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>>         0.1 * averaged_model_parameter + 0.9 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)

总的应用

在下面的示例中,swa_model是累积权重平均值的swa模型。对模型进行总共300个epoch训练,然后切换到SWA学习率计划策略,并开始在epoch160收集参数的SWA平均值:

>>> loader, optimizer, model, loss_fn = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
>>> swa_start = 160
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
>>>
>>> for epoch in range(300):
>>>       for input, target in loader:
>>>           optimizer.zero_grad()
>>>           loss_fn(model(input), target).backward()
>>>           optimizer.step()
>>>       if epoch > swa_start:
>>>           swa_model.update_parameters(model)
>>>           swa_scheduler.step()
>>>       else:
>>>           scheduler.step()
>>>
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
>>> # Use swa_model to make predictions on test data
>>> preds = swa_model(test_input)

PyTorch-1.10(十三)--torch.optim基本用法相关推荐

  1. pytorch中调整学习率: torch.optim.lr_scheduler

    文章翻译自:https://pytorch.org/docs/stable/optim.html torch.optim.lr_scheduler 中提供了基于多种epoch数目调整学习率的方法. t ...

  2. PyTorch学习笔记:torch.optim---Optimizer以及lr_scheduler

    本文参考 PyTorch optim文档 1 概述 1.1 PyTorch文档:torch.optim解读 下图是optim的文档 TORCH.OPTIM torch.optim is a packa ...

  3. PyTorch 笔记(18)— torch.optim 优化器的使用

    到目前为止,代码中的神经网络权重的参数优化和更新还没有实现自动化,并且目前使用的优化方法都有固定的学习速率,所以优化函数相对简单,如果我们自己实现一些高级的参数优化算法,则优化函数部分的代码会变得较为 ...

  4. torch的拼接函数_从零开始深度学习Pytorch笔记(13)—— torch.optim

    前文传送门: 从零开始深度学习Pytorch笔记(1)--安装Pytorch 从零开始深度学习Pytorch笔记(2)--张量的创建(上) 从零开始深度学习Pytorch笔记(3)--张量的创建(下) ...

  5. PyTorch: torch.optim 的6种优化器及优化算法介绍

    import torch import torch.nn.functional as F import torch.utils.data as Data import matplotlib.pyplo ...

  6. PyTorch 中 torch.optim优化器的使用

    一.优化器基本使用方法 建立优化器实例 循环: 清空梯度 向前传播 计算Loss 反向传播 更新参数 示例: from torch import optim input = ..... optimiz ...

  7. pytorch中torch.optim的介绍

    pytorch中torch.optim的介绍 这是torch自带的一个优化器,里面自带了求导,更新等操作.开门见山直接讲怎么使用: 常用的引入: import torch.optim as optim ...

  8. PyTorch疑难杂症(1)——torch.matmul()函数用法总结

    目录 一.函数介绍 二.常见用法 2.1 两个一维向量的乘积运算 2.2 两个二维矩阵的乘积运算 2.3 一个一维向量和一个二维矩阵的乘积运算 2.4 一个二维矩阵和一个一维向量的乘积运算 2.5 其 ...

  9. PyTorch官方中文文档:torch.optim 优化器参数

    内容预览: step(closure) 进行单次优化 (参数更新). 参数: closure (callable) –...~ 参数: params (iterable) – 待优化参数的iterab ...

  10. pytorch torch.optim.Optimizer

    API 1)Optimizer CLASS torch.optim.Optimizer(params, defaults) 参数 描述 params (iterable) an iterable of ...

最新文章

  1. 使用centos6.6部署Cobbler--自动安装centos系统
  2. Ruby on rails环境和开发工具准备...
  3. 计算机网络ipv4到ipv6怎么实现,论计算机网络协议IPV4到IPV6的过渡策略|房屋搬迁过渡协议...
  4. 持续集成(CI)- 各种工具的资料总结
  5. ICDAR 2019 文本识别冠军方案将开源!
  6. URAL 1404. Easy to Hack! (模拟)
  7. SAP License:值字段更改注意事项
  8. [转]char数组赋值
  9. autotools归纳
  10. android sqlitelog,如何解决Sqlitelog(13)语句中止在PhoneGap中的68错误android
  11. 汉诺塔问题(C语言实现)
  12. Solidworks2022安装
  13. 操作系统的工作流程(流程图表示)
  14. 命令提示符打不开python_Windows-Python在命令提示符下不起作用?
  15. 《动手学深度学习》(PyTorch版)代码注释 - 12 【House_price_prediction】
  16. PCIe 分类、速度
  17. ERROR: Could not build wheels for numpy which use PEP 517 and cannot be installed directly
  18. ssd重装系统的详细教程
  19. Vue3 tailwindui
  20. 实例恢复(Instance Recovery)之前滚(Rolling Forward)和回滚(Rolling Back)

热门文章

  1. 一个标准的k-means(误差平方和版本)
  2. 怎么把Word转PDF格式?分享几种好用的转换方法
  3. gtx 750 linux驱动下载,Ubuntu 12.04安装NVIDIA GTX750显卡驱动
  4. 使用wps-excell画折线图
  5. 如何高效地阅读技术类书籍与博客
  6. 抖音头像有钩什么意思,抖音上有黄勾和蓝勾什么意思
  7. 微信服务通知消息找回_企业微信消息不提醒怎么办?怎么打开企业微信消息通知?...
  8. Redis可视化工具
  9. CSU 2202 EL PSY CONGROO
  10. 图像特征提取(二)——HOG特征