这节内容主要是对比在 Torch 实践中所会用到的几种优化器

编写伪数据

为了对比各种优化器的效果, 需要有一些数据, 可以自己编一些伪数据, 这批数据是这样的:

具体的数据生成代码如下:

import torch
import torch.utils.data as Data
import torch.nn.functional as F
import matplotlib.pyplot as pltLR = 0.01
BATCH_SIZE = 32
EPOCH = 12# fake dataset
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))# plot dataset
plt.scatter(x.numpy(), y.numpy())
plt.show()# put dateset into torch dataset
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,)

每个优化器优化一个神经网络

为了对比每一种优化器, 我们给他们各自创建一个神经网络, 但这个神经网络都来自同一个 Net 形式.。具体实现如下:

# 默认的 network 形式
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.hidden = torch.nn.Linear(1, 20)   # hidden layerself.predict = torch.nn.Linear(20, 1)   # output layerdef forward(self, x):x = F.relu(self.hidden(x))      # activation function for hidden layerx = self.predict(x)             # linear outputreturn x# 为每个优化器创建一个 net
net_SGD         = Net()
net_Momentum    = Net()
net_RMSprop     = Net()
net_Adam        = Net()
nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]

优化器Optimizer

接下来创建不同的优化器,用来训练不同的网络,并创建一个loss_func 用来计算误差. 我们用几种常见的优化器, SGD, Momentum, RMSprop, Adam.

# different optimizers
opt_SGD         = torch.optim.SGD(net_SGD.parameters(), lr=LR)
opt_Momentum    = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
opt_RMSprop     = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
opt_Adam        = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]loss_func = torch.nn.MSELoss()
losses_his = [[], [], [], []]   # 记录 training 时不同神经网络的 loss

训练/出图

for epoch in range(EPOCH):print('Epoch: ', epoch)for step, (b_x, b_y) in enumerate(loader):# 对每个优化器, 优化属于他的神经网络for net, opt, l_his in zip(nets, optimizers, losses_his):output = net(b_x)              # get output for every netloss = loss_func(output, b_y)  # compute loss for every netopt.zero_grad()                # clear gradients for next trainloss.backward()                # backpropagation, compute gradientsopt.step()                     # apply gradientsl_his.append(loss.data.numpy())     # loss recoder


SGD 是最普通的优化器, 也可以说没有加速效果, 而 MomentumSGD 的改良版, 它加入了动量原则. 后面的 RMSprop 又是 Momentum 的升级版. 而 Adam 又是 RMSprop 的升级版. 不过从这个结果中我们看到, Adam 的效果似乎比 RMSprop 要差一点. 所以说并不是越先进的优化器, 结果越佳. 我们在自己的试验中可以尝试不同的优化器, 找到那个最适合你数据/网络的优化器.

本文转载自莫烦python的pytorch学习板块,源代码可去Optimizer 优化器查看

Optimizer优化器相关推荐

  1. 加速神经网络训练方法及不同Optimizer优化器性能比较

    本篇博客主要介绍几种加速神经网络训练的方法. 我们知道,在训练样本非常多的情况下,如果一次性把所有的样本送入神经网络,每迭代一次更新网络参数,这样的效率是很低的.为什么?因为梯度下降法参数更新的公式一 ...

  2. pytorch 7 optimizer 优化器 加速训练

    pytorch 7 optimizer 优化器 加速训练 import torch import torch.utils.data as Data import torch.nn.functional ...

  3. PyTorch 实现批训练和 Optimizer 优化器

    批训练 import torch import torch.utils.data as DataBATCH_SIZE = 5x = torch.linspace(1, 10, 10) # this i ...

  4. [Python人工智能] 四.TensorFlow创建回归神经网络及Optimizer优化器

    从本篇文章开始,作者正式开始研究Python深度学习.神经网络及人工智能相关知识.前一篇文章讲解了TensorFlow基础和一元直线预测的案例,以及Session.变量.传入值和激励函数:这篇文章将详 ...

  5. Optimizer 优化器

    要点 这节内容主要是用 Torch 实践 这个 优化器 动画简介 中起到的几种优化器, 这几种优化器具体的优势不会在这个节内容中说了, 所以想快速了解的话, 上面的那个动画链接是很好的去处. 下图就是 ...

  6. PLSQL_性能优化系列04_Oracle Optimizer优化器

    2014-09-25 Created By BaoXinjian 一.摘要 1. Oracle优化器介绍 本文讲述了Oracle优化器的概念.工作原理和使用方法,兼顾了Oracle8i.9i以及最新的 ...

  7. 深度学习训练之optimizer优化器(BGD、SGD、MBGD、SGDM、NAG、AdaGrad、AdaDelta、Adam)的最全系统详解

    文章目录 1.BGD(批量梯度下降) 2.SGD(随机梯度下降) 2.1.SGD导致的Zigzag现象 3.MBGD(小批量梯度下降) 3.1 BGD.SGD.MBGD的比较 4.SGDM 5.NAG ...

  8. TensorFlow(四)优化器函数Optimizer

    因为大多数机器学习任务就是最小化损失,在损失定义的情况下,后面的工作就交给了优化器.因为深度学习常见的是对于梯度的优化,也就是说,优化器最后其实就是各种对于梯度下降算法的优化. 常用的optimize ...

  9. 优化器 optimizer

    优化器 optimizer optimizer 优化器,用来根据参数的梯度进行沿梯度下降方向进行调整模型参数,使得模型loss不断降低,达到全局最低,通过不断微调模型参数,使得模型从训练数据中学习进行 ...

最新文章

  1. Android库so文件及skia函数的调用
  2. python datetime datetime_Python datetime.tzinfo方法代碼示例
  3. C语言-宏定义#define的用法
  4. python3根据地址批量获取百度地图经纬度
  5. Nginx Mac笔记
  6. VTK:Utilities之FunctionParser
  7. php网址变量怎么输出,【PHP网站】如何使用dedecms v5.7前台模版里输出变量
  8. java中string范围_java中long parseLong(String s)方法中string(十进制数字)的范围?
  9. Vue-组件之间的数据共享
  10. 网页滚动截屏怎么截长图
  11. C语言必背代码大全(2021整理)
  12. 【LED灯屏控制器】国产FPGA之 AG10KSDE176 初探(1)
  13. 80后的我们为什么不结婚
  14. csgo红锁号能解锁吗_CSGO红锁黑刀号!重磅!大规模红锁!
  15. 常用网络结构:Alex,VGG,Resnet对比
  16. C++:实现量化Overnight-indexed swap 隔夜指数掉期测试实例
  17. MFC下ODBC方式连接数据库
  18. 大数据架构师拿年薪50W的方法诀窍
  19. nat模式下更改网络环境 虚拟机中Linux无法上网的问题
  20. php不支持gd库,如何解决php不支持gd库的问题

热门文章

  1. 如何将ChatGPT培养成「私人助理」
  2. 【酷炫雪花飞舞特效】(HTML+JS+CSS+效果+代码)
  3. php-fpm比php成为apache模块好在哪
  4. 我看盛大 [以下内容仅为个人观点]
  5. 2023哈工大软件工程考研 | 395+251 | 个人经验分享
  6. Python 分析youku sohu tudou视频各种清晰度的下载地址
  7. EXCEL 把几列排列组合列出所有排列组合情况的绿色工具
  8. 如何精通python语言_精通Python自然语言处理
  9. 怎样才能画好动漫人物的腿?画好动漫人物的腿有哪些技巧?
  10. python zipfile_详解python3中zipfile模块用法