Optimizer优化器
这节内容主要是对比在 Torch 实践中所会用到的几种优化器
编写伪数据
为了对比各种优化器的效果, 需要有一些数据, 可以自己编一些伪数据, 这批数据是这样的:
具体的数据生成代码如下:
import torch
import torch.utils.data as Data
import torch.nn.functional as F
import matplotlib.pyplot as pltLR = 0.01
BATCH_SIZE = 32
EPOCH = 12# fake dataset
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))# plot dataset
plt.scatter(x.numpy(), y.numpy())
plt.show()# put dateset into torch dataset
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,)
每个优化器优化一个神经网络
为了对比每一种优化器, 我们给他们各自创建一个神经网络, 但这个神经网络都来自同一个 Net 形式.。具体实现如下:
# 默认的 network 形式
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.hidden = torch.nn.Linear(1, 20) # hidden layerself.predict = torch.nn.Linear(20, 1) # output layerdef forward(self, x):x = F.relu(self.hidden(x)) # activation function for hidden layerx = self.predict(x) # linear outputreturn x# 为每个优化器创建一个 net
net_SGD = Net()
net_Momentum = Net()
net_RMSprop = Net()
net_Adam = Net()
nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]
优化器Optimizer
接下来创建不同的优化器,用来训练不同的网络,并创建一个loss_func
用来计算误差. 我们用几种常见的优化器, SGD, Momentum, RMSprop, Adam.
# different optimizers
opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]loss_func = torch.nn.MSELoss()
losses_his = [[], [], [], []] # 记录 training 时不同神经网络的 loss
训练/出图
for epoch in range(EPOCH):print('Epoch: ', epoch)for step, (b_x, b_y) in enumerate(loader):# 对每个优化器, 优化属于他的神经网络for net, opt, l_his in zip(nets, optimizers, losses_his):output = net(b_x) # get output for every netloss = loss_func(output, b_y) # compute loss for every netopt.zero_grad() # clear gradients for next trainloss.backward() # backpropagation, compute gradientsopt.step() # apply gradientsl_his.append(loss.data.numpy()) # loss recoder
SGD
是最普通的优化器, 也可以说没有加速效果, 而 Momentum
是 SGD
的改良版, 它加入了动量原则. 后面的 RMSprop
又是 Momentum
的升级版. 而 Adam
又是 RMSprop
的升级版. 不过从这个结果中我们看到, Adam
的效果似乎比 RMSprop
要差一点. 所以说并不是越先进的优化器, 结果越佳. 我们在自己的试验中可以尝试不同的优化器, 找到那个最适合你数据/网络的优化器.
本文转载自莫烦python的pytorch学习板块,源代码可去Optimizer 优化器查看
Optimizer优化器相关推荐
- 加速神经网络训练方法及不同Optimizer优化器性能比较
本篇博客主要介绍几种加速神经网络训练的方法. 我们知道,在训练样本非常多的情况下,如果一次性把所有的样本送入神经网络,每迭代一次更新网络参数,这样的效率是很低的.为什么?因为梯度下降法参数更新的公式一 ...
- pytorch 7 optimizer 优化器 加速训练
pytorch 7 optimizer 优化器 加速训练 import torch import torch.utils.data as Data import torch.nn.functional ...
- PyTorch 实现批训练和 Optimizer 优化器
批训练 import torch import torch.utils.data as DataBATCH_SIZE = 5x = torch.linspace(1, 10, 10) # this i ...
- [Python人工智能] 四.TensorFlow创建回归神经网络及Optimizer优化器
从本篇文章开始,作者正式开始研究Python深度学习.神经网络及人工智能相关知识.前一篇文章讲解了TensorFlow基础和一元直线预测的案例,以及Session.变量.传入值和激励函数:这篇文章将详 ...
- Optimizer 优化器
要点 这节内容主要是用 Torch 实践 这个 优化器 动画简介 中起到的几种优化器, 这几种优化器具体的优势不会在这个节内容中说了, 所以想快速了解的话, 上面的那个动画链接是很好的去处. 下图就是 ...
- PLSQL_性能优化系列04_Oracle Optimizer优化器
2014-09-25 Created By BaoXinjian 一.摘要 1. Oracle优化器介绍 本文讲述了Oracle优化器的概念.工作原理和使用方法,兼顾了Oracle8i.9i以及最新的 ...
- 深度学习训练之optimizer优化器(BGD、SGD、MBGD、SGDM、NAG、AdaGrad、AdaDelta、Adam)的最全系统详解
文章目录 1.BGD(批量梯度下降) 2.SGD(随机梯度下降) 2.1.SGD导致的Zigzag现象 3.MBGD(小批量梯度下降) 3.1 BGD.SGD.MBGD的比较 4.SGDM 5.NAG ...
- TensorFlow(四)优化器函数Optimizer
因为大多数机器学习任务就是最小化损失,在损失定义的情况下,后面的工作就交给了优化器.因为深度学习常见的是对于梯度的优化,也就是说,优化器最后其实就是各种对于梯度下降算法的优化. 常用的optimize ...
- 优化器 optimizer
优化器 optimizer optimizer 优化器,用来根据参数的梯度进行沿梯度下降方向进行调整模型参数,使得模型loss不断降低,达到全局最低,通过不断微调模型参数,使得模型从训练数据中学习进行 ...
最新文章
- Android库so文件及skia函数的调用
- python datetime datetime_Python datetime.tzinfo方法代碼示例
- C语言-宏定义#define的用法
- python3根据地址批量获取百度地图经纬度
- Nginx Mac笔记
- VTK:Utilities之FunctionParser
- php网址变量怎么输出,【PHP网站】如何使用dedecms v5.7前台模版里输出变量
- java中string范围_java中long parseLong(String s)方法中string(十进制数字)的范围?
- Vue-组件之间的数据共享
- 网页滚动截屏怎么截长图
- C语言必背代码大全(2021整理)
- 【LED灯屏控制器】国产FPGA之 AG10KSDE176 初探(1)
- 80后的我们为什么不结婚
- csgo红锁号能解锁吗_CSGO红锁黑刀号!重磅!大规模红锁!
- 常用网络结构:Alex,VGG,Resnet对比
- C++:实现量化Overnight-indexed swap 隔夜指数掉期测试实例
- MFC下ODBC方式连接数据库
- 大数据架构师拿年薪50W的方法诀窍
- nat模式下更改网络环境 虚拟机中Linux无法上网的问题
- php不支持gd库,如何解决php不支持gd库的问题
热门文章
- 如何将ChatGPT培养成「私人助理」
- 【酷炫雪花飞舞特效】(HTML+JS+CSS+效果+代码)
- php-fpm比php成为apache模块好在哪
- 我看盛大 [以下内容仅为个人观点]
- 2023哈工大软件工程考研 | 395+251 | 个人经验分享
- Python 分析youku sohu tudou视频各种清晰度的下载地址
- EXCEL 把几列排列组合列出所有排列组合情况的绿色工具
- 如何精通python语言_精通Python自然语言处理
- 怎样才能画好动漫人物的腿?画好动漫人物的腿有哪些技巧?
- python zipfile_详解python3中zipfile模块用法