一、背景

在廖星宇《深度学习入门》的github项目中,留了一道思考题:

小练习:上面的例子是一个三次的多项式,尝试使用二次的多项式去拟合它,看看最后能做到多好
提示:参数 w = torch.randn(2, 1),同时重新构建 x 数据集

二、代码部分

在项目中没有给出代码,作者自己改写了一个,作了大概如下改动:

  • 改动了criterion,选用了自带的MSE
  • 使用了nn.Sequential并强其放到了cuda上,试图用GPU增加运算效率
  • 训练100000000次
  • 多画几个窗口,好对比一下
  • 代码部分没有精校,可能存在一些错误
  • 也许criterion是否有更好的选择
  • 代码在我本地没有问题啊!(手动狗头
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 27 17:36:24 2020@author: ftimes
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nnSEED=2020
torch.manual_seed(SEED)
DEVICE=torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义一个多变量函数
w_target = np.array([0.5, 3, 2.4]) # 定义参数
b_target = np.array([0.9]) # 定义参数f_des = 'y = {:.2f} + {:.2f} * x + {:.2f} * x^2 + {:.2f} * x^3'.format(b_target[0], w_target[0], w_target[1], w_target[2]) # 打印出函数的式子
print(f_des)x_sample = np.arange(-3, 3.1, 0.1)
y_sample = b_target[0] + w_target[0] * x_sample + w_target[1] * x_sample ** 2 + w_target[2] * x_sample ** 3plt.plot(x_sample, y_sample ,label='real curve')
plt.legend(loc='best')# 构建数据 x 和 y
# x 是一个如下矩阵 [x, x^2, x^3]
# y 是函数的结果 [y]x_train = np.stack([x_sample ** i for i in range(1, 3)], axis=1)
x_train = torch.tensor(x_train).float().to(DEVICE) # 转换成 float tensor
y_train = torch.tensor(y_sample).float().unsqueeze(1).to(DEVICE) # 转化成 float tensorseq_net = nn.Sequential(nn.Linear(2, 1), # PyTorch 中的线性层,wx + b
).to(DEVICE)'''
w = nn.Parameter(torch.randn(2, 1)*0.01).to(DEVICE)
b = nn.Parameter(torch.zeros(1)).to(DEVICE)
'''optimizer=torch.optim.SGD(seq_net.parameters(),0.000001)
criterion=nn.MSELoss().to(DEVICE)y_pred = seq_net(x_train)plt.figure(2)
#plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')
plt.plot(x_train.data.cpu().numpy()[:, 0], y_pred.data.cpu().numpy(), label='fitting curve', color='r')
plt.plot(x_train.data.cpu().numpy()[:, 0], y_sample, label='real curve', color='b')
plt.legend()EPOCH=1000000
for e in range(EPOCH):y_pred = seq_net(x_train)loss=criterion(y_pred,y_train)optimizer.zero_grad()loss.backward()optimizer.step()if (e + 1) % 100000 == 0:print('epoch: {}, loss: {}'.format(e+1, loss.data))y_pred =seq_net(x_train)plt.figure(3)
#plt.xlim(1, 2)
#plt.ylim(0, 20)
plt.plot(x_train.data.cpu().numpy()[:, 0], y_pred.data.cpu().numpy(), label='fitting curve', color='r')
plt.plot(x_train.data.cpu().numpy()[:, 0], y_sample, label='real curve', color='b')
plt.legend()print(seq_net[0].weight,seq_net[0].bias)
print(f_des)

三、测试结果

epoch: 100000, loss: 279.2529296875
epoch: 200000, loss: 155.79791259765625
epoch: 300000, loss: 120.1278305053711
epoch: 400000, loss: 109.79125213623047
epoch: 500000, loss: 106.81745910644531
epoch: 600000, loss: 105.95108795166016
epoch: 700000, loss: 105.70142364501953
epoch: 800000, loss: 105.61643981933594
epoch: 900000, loss: 105.60639953613281
epoch: 1000000, loss: 105.60639953613281
Parameter containing:
tensor([[13.8103,  3.0138]], device='cuda:0', requires_grad=True) Parameter containing:
tensor([0.8424], device='cuda:0', requires_grad=True)
y = 0.90 + 0.50 * x + 3.00 * x^2 + 2.40 * x^3我们可以看到,在学习率非常小的情况下,这个loss始终维持在了105.58638763427734。
如果是设置成0.03,可以迅速收敛到这个数。
可能是我哪里弄错了?恳请您指正我这个初学者。
但我们可以肯定的是,用二次多项式无法很好的拟合的三次多项式。
下面上图。
  1. 原始图像
    2. 初始图像
    3. 1000000次模拟后的图像

如果您有更好的代码,欢迎在评论区留言,感激不尽~

「Pytorch」用二次多项式拟合三次多项式一百万次,看看最后能做到多好?相关推荐

  1. c++ 曲线拟合的最小二乘法 公式 二次多项式和三次多项式

    struct Hisnum//直方图结构体 进行多项式拟合 {     int gray;     int num; }; struct Hisnum//直方图结构体 进行多项式拟合 {int gra ...

  2. 玩转「Wi-Fi」系列之测试工具(三)

    以前网络有故障,都会打开电脑看看是什么地方出现故障,现在进入移动时代,可能你整个网络里没有一台电脑,那如何用手机发现网络的问题呢? 实际开发过程中,经常会使用一些第三方工具来获网络的相关信息, 介绍如 ...

  3. 程序员用「美貌」突破二维图像的人脸识别

    GitChat 作者:于航 原文: 如何利用"女装术"突破基于二维图像的人脸识别 关注微信公众号:「GitChat 技术杂谈」 一本正经的讲技术 [不要错过文末彩蛋] 首先声明,这 ...

  4. 「FastAdmin」fastadmin二次开发中如何自定义查询数据

    fastadmin二次开发中如何自定义查询数据 问题背景:最近做一个网站的过程中遇到了一个需求:对于不同用户组的用户,显示的数据要根据权限来筛选.问题看起来不是很难,文档和社区中已经给了足够的提示,我 ...

  5. 窗口管理器 实现_「42」Python布局管理器(三):place实现组件的精确与灵活布局...

    已经学习了两类布局管理器: Pack布局管理器:按照垂直或者水平的方向自然排布: Grid布局管理器:采用表格结构组织组件,组件位置受限表格形式. 两类管理器都属于那种很古板的布局方式,不能适应需要相 ...

  6. 使用三次多项式拟合天猫双十一交易额

    前言 据说天猫双十一交易额造假,交易额数据可以用二次或三次多项式完美拟合,看到这个后我觉得可以试一试.那么说干就干.我们用sklearn多项式回归来拟合,只做三次多项式,二次多项式也是一样,只要去掉三 ...

  7. c++ 三次多项式拟合_最小二乘法多项式曲线拟合数学原理及其C++实现

    本文使用 Zhihu On VSCode 创作并发布 0 前言 自动驾驶开发中经常涉及到多项式曲线拟合,本文详细描述了使用最小二乘法进行多项式曲线拟合的数学原理,通过样本集构造范德蒙德矩阵,将一元 N ...

  8. 「机器学习」机器学习算法优缺点对比(汇总篇)

    作者 | 杜博亚 来源 | 阿泽的学习笔记 「本文的目的,是务实.简洁地盘点一番当前机器学习算法」.文中内容结合了个人在查阅资料过程中收集到的前人总结,同时添加了部分自身总结,在这里,依据实际使用中的 ...

  9. 又酸啦!华为「天才少年」校招薪资曝光....

    点击上方蓝色小字,关注"涛哥聊Python" 重磅干货,第一时间送达 转子自[量子位] 西交大本科毕业年薪100万,华科博士毕业年薪201万. 开出如此天价年薪的正是华为. 果然, ...

最新文章

  1. php 删除一周前,linux下删除7天前日志的代码(php+shell)
  2. [转]asp.net文件下载方法...
  3. 10个你必须知道的Python内置函数
  4. elastic 修改map_Amazon Elastic Map Reduce使用Apache Mahout计算建议
  5. LeetCode Largest Number
  6. 基本卷-动态卷性能测评(未完待续)
  7. 使用 Redis 如何实现延迟队列?
  8. Spring-tx-PlatformTransactionManager(DataSourceTransactionManager)
  9. vss2005 配置与使用
  10. js读取txt文件中的内容
  11. 15个开发者最亲睐的Andr​​oid代码编辑器
  12. 游戏音效只是简单的改原素材吗?
  13. redis过期策略有哪些?内存淘汰机制有哪些?
  14. Zookeeper(七)开源客户端
  15. ARM 代码烧录方案与原理详解 --- SWD/JTAG + Bootloader + OTA (ICP + ISP + IAP)
  16. SQL Server 下载和安装详细教程(点赞收藏)
  17. Office 365入门教程(一):开始使用Office 365
  18. Win10卸载OneDrive
  19. PC端页面在手机端完整显示
  20. echarts 桑基图sankey

热门文章

  1. Android 停车地图及停车导航,优先停车导航app
  2. 消防应急疏散指示系统在广场住宅区项目的应用
  3. 大厂 vs 小厂,我的亲身体验
  4. IOS开发者账号的相关配置-子账号(二)
  5. 聊天软件中的窗口上滑和下滑提示上下线
  6. cubeMX 选择管脚引脚警告
  7. android墓碑机制推送,工信部放大招,又一手机品牌加入推送联盟,安卓流畅度有救了...
  8. 人工智能面试题分享(含答案)
  9. mysql 父子维,将有父子关系的一维数组转换成树形结构(多维)数据
  10. ubuntu 下 pycharm 搜狗输入法候选词在左下角问题