Adam算法

Adam算法在RMSProp算法基础上对小批量随机梯度也做了指数加权移动平均。Adam算法可以看做是RMSProp算法与动量法的结合

算法内容

Adam算法使用了动量变量vt\boldsymbol{v}_tvt​和RMSProp算法中小批量随机梯度按元素平方的指数加权移动平均变量st\boldsymbol{s}_tst​,并在时间步0将它们中每个元素初始化为0。

  • 给定超参数0≤β1<10 \leq \beta_1 < 10≤β1​<1(算法作者建议设为0.9)

时间步ttt的动量变量vt\boldsymbol{v}_tvt​即小批量随机梯度gt\boldsymbol{g}_tgt​的指数加权移动平均:

vt←β1vt−1+(1−β1)gt.\boldsymbol{v}_t \leftarrow \beta_1 \boldsymbol{v}_{t-1} + (1 - \beta_1) \boldsymbol{g}_t. vt​←β1​vt−1​+(1−β1​)gt​.

和RMSProp算法中一样,给定超参数0≤β2<10 \leq \beta_2 < 10≤β2​<1(算法作者建议设为0.999), 将小批量随机梯度按元素平方后的项gt⊙gt\boldsymbol{g}_t \odot \boldsymbol{g}_tgt​⊙gt​做指数加权移动平均得到st\boldsymbol{s}_tst​:

st←β2st−1+(1−β2)gt⊙gt.\boldsymbol{s}_t \leftarrow \beta_2 \boldsymbol{s}_{t-1} + (1 - \beta_2) \boldsymbol{g}_t \odot \boldsymbol{g}_t. st​←β2​st−1​+(1−β2​)gt​⊙gt​.

由于我们将v0\boldsymbol{v}_0v0​和s0\boldsymbol{s}_0s0​中的元素都初始化为0, 在时间步ttt我们得到

vt=(1−β1)∑i=1tβ1t−igi\boldsymbol{v}_t = (1-\beta_1) \sum_{i=1}^t \beta_1^{t-i} \boldsymbol{g}_ivt​=(1−β1​)i=1∑t​β1t−i​gi​

将过去各时间步小批量随机梯度的权值相加,得到
(1−β1)∑i=1tβ1t−i=1−β1t(1-\beta_1) \sum_{i=1}^t \beta_1^{t-i} = 1 - \beta_1^t(1−β1​)i=1∑t​β1t−i​=1−β1t​

需要注意的是,当ttt较小时,过去各时间步小批量随机梯度权值之和会较小。

例如,当β1=0.9\beta_1 = 0.9β1​=0.9时,v1=0.1g1\boldsymbol{v}_1 = 0.1\boldsymbol{g}_1v1​=0.1g1​。为了消除这样的影响,对于任意时间步ttt,我们可以将vt\boldsymbol{v}_tvt​再除以1−β1t1 - \beta_1^t1−β1t​,从而使过去各时间步小批量随机梯度权值之和为1。这也叫作偏差修正。在Adam算法中,我们对变量vt\boldsymbol{v}_tvt​和st\boldsymbol{s}_tst​均作偏差修正:

v^t←vt1−β1t,\hat{\boldsymbol{v}}_t \leftarrow \frac{\boldsymbol{v}_t}{1 - \beta_1^t}, v^t​←1−β1t​vt​​,

s^t←st1−β2t.\hat{\boldsymbol{s}}_t \leftarrow \frac{\boldsymbol{s}_t}{1 - \beta_2^t}. s^t​←1−β2t​st​​.

接下来,Adam算法使用以上偏差修正后的变量v^t\hat{\boldsymbol{v}}_tv^t​和s^t\hat{\boldsymbol{s}}_ts^t​,将模型参数中每个元素的学习率通过按元素运算重新调整:

gt′←ηv^ts^t+ϵ,\boldsymbol{g}_t' \leftarrow \frac{\eta \hat{\boldsymbol{v}}_t}{\sqrt{\hat{\boldsymbol{s}}_t} + \epsilon},gt′​←s^t​​+ϵηv^t​​,

其中η\etaη是学习率,ϵ\epsilonϵ是为了维持数值稳定性而添加的常数,如10−810^{-8}10−8。和AdaGrad算法、RMSProp算法以及AdaDelta算法一样,目标函数自变量中每个元素都分别拥有自己的学习率。最后,使用gt′\boldsymbol{g}_t'gt′​迭代自变量:

xt←xt−1−gt′.\boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \boldsymbol{g}_t'. xt​←xt−1​−gt′​.

实现Adam优化算法

def get_data_ch7():  data = np.genfromtxt('data/airfoil_self_noise.dat', delimiter='\t')data = (data - data.mean(axis=0)) / data.std(axis=0)return torch.tensor(data[:1500, :-1], dtype=torch.float32), \torch.tensor(data[:1500, -1], dtype=torch.float32) # 前1500个样本(每个样本5个特征)
%matplotlib inline
import torch
import sysfeatures, labels = get_data_ch7()def init_adam_states():v_w, v_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)s_w, s_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)return ((v_w, s_w), (v_b, s_b))def adam(params, states, hyperparams):beta1, beta2, eps = 0.9, 0.999, 1e-6for p, (v, s) in zip(params, states):v[:] = beta1 * v + (1 - beta1) * p.grad.datas[:] = beta2 * s + (1 - beta2) * p.grad.data**2v_bias_corr = v / (1 - beta1 ** hyperparams['t'])s_bias_corr = s / (1 - beta2 ** hyperparams['t'])p.data -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr) + eps)hyperparams['t'] += 1

使用学习率为0.01的Adam算法来训练模型。

def train_ch7(optimizer_fn, states, hyperparams, features, labels,batch_size=10, num_epochs=2):# 初始化模型net, loss = linreg, squared_lossw = torch.nn.Parameter(torch.tensor(np.random.normal(0, 0.01, size=(features.shape[1], 1)), dtype=torch.float32),requires_grad=True)b = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=True)def eval_loss():return loss(net(features, w, b), labels).mean().item()ls = [eval_loss()]data_iter = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)for _ in range(num_epochs):start = time.time()for batch_i, (X, y) in enumerate(data_iter):l = loss(net(X, w, b), y).mean()  # 使用平均损失# 梯度清零if w.grad is not None:w.grad.data.zero_()b.grad.data.zero_()l.backward()optimizer_fn([w, b], states, hyperparams)  # 迭代模型参数if (batch_i + 1) * batch_size % 100 == 0:ls.append(eval_loss())  # 每100个样本记录下当前训练误差# 打印结果和作图print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))set_figsize()plt.plot(np.linspace(0, num_epochs, len(ls)), ls)plt.xlabel('epoch')plt.ylabel('loss')
train_ch7(adam, init_adam_states(), {'lr': 0.01, 't': 1}, features, labels)

也可以使用pytorch内置的optim.Adam实现:

train_pytorch_ch7(torch.optim.Adam, {'lr': 0.01}, features, labels)

深度学习优化算法-Adam算法相关推荐

  1. adam算法效果差原因_深度学习优化器-Adam两宗罪

    在上篇文章中,我们用一个框架来回顾了主流的深度学习优化算法.可以看到,一代又一代的研究者们为了我们能炼(xun)好(hao)金(mo)丹(xing)可谓是煞费苦心.从理论上看,一代更比一代完善,Ada ...

  2. Adam 那么棒,为什么还对 SGD 念念不忘?一个框架看懂深度学习优化算法

    作者|Juliuszh 链接 | https://zhuanlan.zhihu.com/juliuszh 本文仅作学术分享,若侵权,请联系后台删文处理 机器学习界有一群炼丹师,他们每天的日常是: 拿来 ...

  3. 2017年深度学习优化算法最新进展:改进SGD和Adam方法

    2017年深度学习优化算法最新进展:如何改进SGD和Adam方法 转载的文章,把个人觉得比较好的摘录了一下 AMSGrad 这个前期比sgd快,不能收敛到最优. sgdr 余弦退火的方案比较好 最近的 ...

  4. 2017年深度学习优化算法最新进展:如何改进SGD和Adam方法?

    2017年深度学习优化算法最新进展:如何改进SGD和Adam方法? 深度学习的基本目标,就是寻找一个泛化能力强的最小值,模型的快速性和可靠性也是一个加分点. 随机梯度下降(SGD)方法是1951年由R ...

  5. 深度学习优化算法实现(Momentum, Adam)

    目录 Momentum 初始化 更新参数 Adam 初始化 更新参数 除了常见的梯度下降法外,还有几种比较通用的优化算法:表现都优于梯度下降法.本文只记录完成吴恩达深度学习作业时遇到的Momentum ...

  6. 深度学习优化算法,Adam优缺点分析

    优化算法 首先我们来回顾一下各类优化算法. 深度学习优化算法经历了 SGD -> SGDM -> NAG ->AdaGrad -> AdaDelta -> Adam -& ...

  7. 深度学习优化算法:Adam算法

    原文链接:动手学深度学习pytorch版:7.8 Adam算法 github:https://github.com/ShusenTang/Dive-into-DL-PyTorch [1] Kingma ...

  8. Adam那么棒,为什么还对SGD念念不忘?一个框架看懂深度学习优化算法

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者|Juliuszh,https://zhuanlan.zhih ...

  9. 重磅 | 2017年深度学习优化算法研究亮点最新综述火热出炉

    翻译 | AI科技大本营(微信ID:rgznai100) 梯度下降算法是机器学习中使用非常广泛的优化算法,也是众多机器学习算法中最常用的优化方法.几乎当前每一个先进的(state-of-the-art ...

最新文章

  1. 八数码 poj 1077 广搜 A* IDA*
  2. 开源 物联网接入_我们刚刚推出了开源产品。 那么接下来会发生什么呢?
  3. 域控制器部署组策略,立即下发强制更新,显示“远程过程调用被取消”,错误代码 8007071a;以及RPC服务器不可用,800706ba【解决方案】
  4. 蜘蛛日志分析工具_如何分析蜘蛛日志?
  5. 揭开伟大架构师的秘密
  6. 2021 年 Windows 成了 Python 开发者的首选
  7. 20191024:单调栈问题的引出
  8. 测试经验谈:测试人怎么从 0—1 进化
  9. Oracle刷建表语句
  10. 代码制作数字流星雨_C语言实现流星雨 | 学步园
  11. 图片放大后模糊的照片怎么处理清晰?
  12. Science Advances:恐惧学习中内侧前额叶和杏仁核theta振荡同步活动
  13. python秒表游戏代码_在pygam游戏中添加秒表
  14. 【Ware】专业的视频剪辑软件推荐
  15. Q - Phalanx
  16. iOS 直播类APP开发流程
  17. 第1章 计算机系统漫游
  18. DRUID 连接池的使用、配置详解
  19. 海底捞市值超大多数上市房企,火锅是怎么做到比卖房还赚钱的?
  20. java解惑之字符之谜(谜题17)

热门文章

  1. python 的scrapy框架
  2. 详解python列表中冒号的用法
  3. 从0-1做产品快速启动,大型干货案例分享
  4. hive修改字段类型
  5. 【微信小程序】使用和风天气接口api(全过程)——获取天气
  6. 打怪小游戏(暂时只支持主线任务、刷怪和作者商店)
  7. 漫步者蓝牙自动断开_500左右蓝牙耳机怎么买?南卡对比漫步者让你看懂
  8. android实现更改密码,重要提醒:手机这个密码一定要改!
  9. REST风格接口学习
  10. ps和matlab哪个,设计用图(主要运行PS,CAD,MATLAB,会声会影等软件)与游戏两不误选什么CPU好?随便给个装机清单。...