为了对比各种优化器的效果,需要模拟一些数据:

import torch
import torch.utils.data as Data
import torch.nn.functional as F
import matplotlib.pyplot as pltLR = 0.01
BATCH_SIZE = 32
EPOCH = 12# fake dataset
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
y = x.pow(2) + 0.1 * torch.normal(torch.zeros(*x.size()))# put dateset into torch dataset
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

为了对比每一种优化器,我们给它们各自创建一个神经网络,但这个神经网络都来自同一个Net形式:

class Net(torch.nn.Module):  # default networkdef __init__(self):super(Net, self).__init__()self.hidden = torch.nn.Linear(1, 20)  # hidden layerself.predict = torch.nn.Linear(20, 1)  # output layerdef forward(self, x):x = F.relu(self.hidden(x))  # activation function for hidden layerx = self.predict(x)  # linear outputreturn xif __name__ == '__main__':# 为每个优化器创建一个netnet_SGD = Net()net_Momentum = Net()net_RMSprop = Net()net_Adam = Net()nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]# different optimizersopt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]loss_func = torch.nn.MSELoss()losses_his = [[], [], [], []]  # 记录training时不同神经网络的lossfor epoch in range(EPOCH):  # trainingprint('Epoch: ', epoch)for step, (b_x, b_y) in enumerate(loader):  # for each training stepfor net, opt, l_his in zip(nets, optimizers, losses_his):output = net(b_x)  # get output for every netloss = loss_func(output, b_y)  # compute loss for every netopt.zero_grad()  # clear gradients for next trainloss.backward()  # backpropagation, compute gradientsopt.step()  # apply gradientsl_his.append(loss.data.numpy())  # loss recoderlabels = ['SGD', 'Momentum', 'RMSprop', 'Adam']for i, l_his in enumerate(losses_his):plt.plot(l_his, label=labels[i])plt.legend(loc='best')plt.xlabel('Steps')plt.ylabel('Loss')plt.ylim((0, 0.2))plt.show()

SGD是最普通的优化器,也可以说没有加速效果;而MomentumSGD的改良版,它加入了动量原则;后面的RMSprop又是Momentum的升级版;而Adam又是RMSprop的升级版。

莫烦Python_优化器相关推荐

  1. 莫烦大大TensorFlow学习笔记(8)----优化器

    一.TensorFlow中的优化器 tf.train.GradientDescentOptimizer:梯度下降算法 tf.train.AdadeltaOptimizer tf.train.Adagr ...

  2. PyTorch: torch.optim 的6种优化器及优化算法介绍

    import torch import torch.nn.functional as F import torch.utils.data as Data import matplotlib.pyplo ...

  3. 莫烦Python--Tensorflow Day2

    以下是自己学习莫烦Python中的笔记 构建自己的神经网络 import tensorflow as tf import numpy as npdef add_layer(inputs, in_siz ...

  4. tensorflow学习笔记-bili莫烦

    bilibili莫烦tensorflow视频教程学习笔记 1.初次使用Tensorflow实现一元线性回归 # 屏蔽警告 import os os.environ['TF_CPP_MIN_LOG_LE ...

  5. 莫烦python--搭建CNN

    和其他教程一样,莫烦大神也用MNIST作为CNN的入门 一. 调用库 import torch import torch.nn as nn import torch.utils.data as Dat ...

  6. 莫烦---Pytorch学习

    今天翻翻资料,发现有些地方的说明不太到位,修改过来了. Will Yip 2020.7.29 莫烦大神Pytorch -->> 学习视频地址 2020年开年就遇上疫情,还不能上学,有够难受 ...

  7. 莫烦python---pytorch学习(上)

    一.推荐学习网站: 莫烦python 二.pytorch学习 1.介绍 PyTorch是一个非常有可能改变深度学习领域前景的Python库. PyTorch是一个基于Python的库,用来提供一个具有 ...

  8. 莫烦Python视频笔记

    背景:打算学习CNN,上一周看了莫烦的Python课程,目前看到了P28 18.3 CNN卷积神经网络 视频链接:https://www.bilibili.com/video/av16001891/? ...

  9. Optimizer优化器

    这节内容主要是对比在 Torch 实践中所会用到的几种优化器 编写伪数据 为了对比各种优化器的效果, 需要有一些数据, 可以自己编一些伪数据, 这批数据是这样的: 具体的数据生成代码如下: impor ...

最新文章

  1. 2022-2028年中国文化创意产业园区域发展模式与产业整体规划研究报告
  2. 重大里程碑!VOLO屠榜CV任务,无需额外数据,首次在ImageNet 上达到87.1%
  3. Atlas, AJAX
  4. [转载]如何用关键字优化网站?
  5. python line strip_Python进阶---python strip() split()函数实战(转)
  6. 卡尔曼滤波的理解、推导和应用
  7. java blob转为图片_导出的图片为什么会糊?!
  8. Android开发工具之DDMS
  9. 第一次学游泳技巧_第一次学游泳小学生作文(精选5篇)
  10. 老路MBA商学课|第003课:沉没成本|因为来都来了,所以将错就错
  11. -1073740791 (0xC0000409)错误,附加内容:qt布局、页面跳转
  12. fafa什么意思_fafafafafa 什么意思
  13. rpc系列-动态代理
  14. TF、keras两种padding方式:vaild和same
  15. RDLC报表 报表数据(参数栏)不显示怎么办?
  16. 你给文字描述,AI艺术作画,精美无比!附源码,快来试试!
  17. Android开发聊天功能
  18. the type xxx cannot be resoved,It is indirectly referenced from required .class files错误.....
  19. 南方的才子北方的将,陕西的黄土埋皇上
  20. Future与CompletableFuture

热门文章

  1. C语言中getch()与getchar()
  2. 笔记:复杂网络的关键技术及应用
  3. Java网络编程(二) 连续发送数据
  4. voip的会议服务器Conference Servers
  5. 读卡器等设备由于串口冲突显示查找不到设备的问题
  6. 如何用 Excel 做出专业的甘特图?详细来了!
  7. 【python爬虫】爬取网站数据,整理三句半语料数据集
  8. 基于 Markdown 与 Git 的知识管理系统
  9. java-UDP协议实现数据的发送和接收
  10. 数据库-SQL索引相关