文章目录

  • 简介
  • 原理
  • 代码
  • 过拟合

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。

简介


多项式回归(Polynomial Regression)顾名思义是包含多个自变量的回归算法,也叫多元线性回归,多数时候利用一元线性回归(一条直线)不能很好拟合数据时,就需要用曲线,而多项式回归就是求解这条曲线。

也就是说一元回归方程是y=wx+by=wx+by=wx+b
而多元回归方程是y=wnxn+wn−1xn−1+⋅⋅⋅+w1x+w0y=w_nx^n+w_{n-1}x^{n-1}+···+w_1x+w_0y=wn​xn+wn−1​xn−1+⋅⋅⋅+w1​x+w0​
比如二元就是y=ax2+bx+cy=ax^2+bx+cy=ax2+bx+c,三元就是y=ax3+bx2+cx+dy=ax^3+bx^2+cx+dy=ax3+bx2+cx+d
但是并不是元数越多越好,可能存在过拟合问题,在最后一节介绍。

一元线性回归可参考另一篇博客:回归-线性回归算法(房价预测项目)

原理


多元线性回归很复杂,特别是当特征数多元数多的时候,可视化难以想象。
用向量矩阵的来表达:y=xw\bold y=\bold x\bold wy=xw
x=(1x1x12⋯x1n1x2x22⋯x2n⋮⋮⋮⋱⋮1xkxk2⋯xkn)x=\begin{pmatrix}\begin{array}{ccccc}1& x_1 & x_1^2 &\cdots& x_1^n\\1& x_2 & x_2^2 &\cdots& x_2^n\\ \vdots & \vdots & \vdots&\ddots & \vdots\\ 1&x_k&x_k^2&\cdots&x_k^n \end{array}\end{pmatrix}x=⎝⎜⎜⎜⎛​11⋮1​x1​x2​⋮xk​​x12​x22​⋮xk2​​⋯⋯⋱⋯​x1n​x2n​⋮xkn​​​⎠⎟⎟⎟⎞​,w=(w01w02⋯w0kw11w12⋯w1k⋮⋮⋱⋮wn1wn2⋯wnk)\bold w=\begin{pmatrix}\begin{array}{cccc}w_{01} & w_{02} &\cdots& w_{0k}\\w_{11} & w_{12} &\cdots& w_{1k}\\ \vdots & \vdots&\ddots & \vdots\\ w_{n1}&w_{n2}&\cdots&w_{nk} \end{array}\end{pmatrix}w=⎝⎜⎜⎜⎛​w01​w11​⋮wn1​​w02​w12​⋮wn2​​⋯⋯⋱⋯​w0k​w1k​⋮wnk​​​⎠⎟⎟⎟⎞​

比如一个特征量二元回归方程:k=1k=1k=1,n=2n=2n=2:
y=(1x1x12)(w01w11w21)=w01+w11x1+w21x12y=(\begin{array}{ccc}1&x_1&x_1^2 \end{array})\begin{pmatrix} w_{01}\\ w_{11}\\w_{21}\end{pmatrix}=w_{01}+w_{11}x_1+w_{21}x_1^2y=(1​x1​​x12​​)⎝⎛​w01​w11​w21​​⎠⎞​=w01​+w11​x1​+w21​x12​

再如两个特征量二元回归方程:k=2k=2k=2,n=2n=2n=2:
y=x=(1x1x121x2x22)(w01w02w11w12w21w22)=(w01+w11x1+w21x12w02+w12x1+w22x12w01+w11x2+w21x22w02+w12x2+w22x22)y= x=\begin{pmatrix}\begin{array}{ccc}1& x_1 & x_1^2\\1& x_2 & x_2^2 \end{array}\end{pmatrix}\begin{pmatrix} w_{01}&w_{02}\\ w_{11}&w_{12}\\w_{21}&w_{22}\end{pmatrix}=\begin{pmatrix}\begin{array}{cc}w_{01}+w_{11}x_1+w_{21}x_1^2 & w_{02}+w_{12}x_1+w_{22}x_1^2\\w_{01}+w_{11}x_2+w_{21}x_2^2& w_{02}+w_{12}x_2+w_{22}x_2^2 \end{array}\end{pmatrix}y=x=(11​x1​x2​​x12​x22​​​)⎝⎛​w01​w11​w21​​w02​w12​w22​​⎠⎞​=(w01​+w11​x1​+w21​x12​w01​+w11​x2​+w21​x22​​w02​+w12​x1​+w22​x12​w02​+w12​x2​+w22​x22​​​)

∣y∣=w01+w11x1+w21x12+w02+w12x2+w22x22|y|=w_{01}+w_{11}x_1+w_{21}x_1^2+w_{02}+w_{12}x_2+w_{22}x_2^2∣y∣=w01​+w11​x1​+w21​x12​+w02​+w12​x2​+w22​x22​

可以看出计算量其实是很大的。

使用最小二乘法作为损失函数,并选择优化算法:正规方程或梯度下降。
正规方程:w=(xTx)−1xTy\bold{w}=(\bold{x}^T\bold{x})^{-1}\bold{x}^T\bold{y}w=(xTx)−1xTy
如果(xTx)(\bold{x}^T\bold{x})(xTx)不可逆,则使用梯度下降求解即可。
可参考:浅谈梯度下降与模拟退火算法

代码


多元线性回归与一元线性回归其实只是x\bold xx的维度不同,也就是说通过设置x\bold xx的维度,调用线性模型LinearRegression即可进行求解,即对数据进行预处理,需要几次方即升到几维,如下2种方法。

  1. 使用hstack()hstack()叠加
    如果维数低,我们可以手动添加即可。
import numpy as npa = np.array([[1, 2], [3, 4]])
b = np.array([[5, 6], [7, 8]])
print("水平方向叠加:", np.hstack((a, b)))
print("垂直方向叠加:", np.vstack((a, b)))

  1. 使用PolynomialFeatures()对特征预处理
    如果维度多,可以用该函数计算生成x\bold xx。
    包括参数:
    degree:默认2,多项式特征的次数;
    interaction_only:默认default=False,若为True,则不含自己和自己相结合的特征项;
    include_bias:默认True,若为True,则包含一列为1的偏差项;
    order:默认‘C’,若为’F’则计算更快,但是后续的拟合慢。
    包括属性:
    powers_:n维幂运算数组,根据degree的值而确定行,根据属性个数而确定列。
    n_input_features_:输入特征的总数,即幂运算矩阵的列;
    n_output_features_:输出特征的总数,即幂运算矩阵的行。
import numpy as np
from sklearn.preprocessing import PolynomialFeaturesx = np.arange(6).reshape(3, 2)
print(x)
poly = PolynomialFeatures(degree=2)
poly.fit(x)
print(poly.powers_)
print("输入特征:", poly.n_input_features_)
print("输出特征:", poly.n_output_features_)
x = poly.transform(x)
print(x)

(插播反爬信息 )博主CSDN地址:https://wzlodq.blog.csdn.net/

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures# 生成数据
X = np.linspace(-5, 5, 100)
Y = 2 * X ** 2 + 3 * X + 5 + np.random.randn(100) * 5
x = X.reshape(-1, 1)# 数据预处理
# 法一、使用hstack直接添加x方
x1 = np.hstack([x, x ** 2])# 法二、使用PolynomialFeatures计算二次方
poly = PolynomialFeatures()
poly.fit(x)
x2 = poly.transform(x)model1 = LinearRegression()  # 创建模型1
model1.fit(x1, Y)  # 训练模型1
y_pred1 = model1.predict(x1)  # 测试1
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)  # 可视化1
plt.scatter(X, Y)
plt.plot(x, y_pred1, color='red')
plt.title("使用hstack()")model2 = LinearRegression()  # 创建模型2
model2.fit(x2, Y)  # 训练模型2
y_pred2 = model2.predict(x2)  # 测试2
plt.subplot(1, 2, 2)  # 可视化2
plt.scatter(X, Y)
plt.plot(x, y_pred2, color='gold')
plt.title("使用PolynomialFeatures()")plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.show()

过拟合


正如前面所说的一样,多项式的幂次并不是越高越好,过高可能出现过拟合情况,导致泛化能力低,过低可能出现欠拟合情况,导致预测结果差,如下图所示。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures# 生成数据
np.random.seed(20221005)
X = np.linspace(-np.pi, np.pi, 100)
Y = np.sin(X) + np.random.randn(100) * 0.4
x = X.reshape(-1, 1)
plt.scatter(x, Y, color="lightblue")for idx, degree in enumerate([1, 3, 30, 100]):print(degree)poly = PolynomialFeatures(degree=degree)poly.fit(x)x1 = poly.transform(x)model = LinearRegression()model.fit(x1, Y)y_pred = model.predict(x1)plt.plot(x, y_pred, label=("%d次方" % degree))plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.legend()
plt.show()

可见并不是幂次越高越好,一般遵循“奥卡姆剃刀”定律,也就是简单平滑的曲线即可。当然了,也有很多方法度量和避免欠拟合与过拟合。

原创不易,请勿转载(本不富裕的访问量雪上加霜 )
博主首页:https://wzlodq.blog.csdn.net/
来都来了,不评论两句吗

机器学习-多项式回归算法相关推荐

  1. 机器学习-多项式回归、正规方程(标准方程)

    机器学习-多项式回归 多项式回归 线性回归并不适用于所有数据,因此有时我们会选择使用曲线来回归数据. 也就是我们在选择特征的时候,可以选择多个特征来进行回归. 如图所示,二次方模型最后会一直向下,很显 ...

  2. 机器学习与算法面试太难?

    机器学习与算法面试太难? 来源: https://mp.weixin.qq.com/s/GrkCvU2Ia_mEaQmiffLotQ 作者:石晓文 八月参加了一些提前批的面试,包括阿里.百度.头条.贝 ...

  3. 免费技术直播:唐宇迪带你一节课了解机器学习经典算法

    常常有小伙伴在后台反馈:机器学习经典算法有哪些? 自学难度大又没有效果,该怎么办? CSDN为了解决这个难题,联合唐宇迪老师为大家带来了一场精彩的直播[一节课掌握机器学习经典算法-线性回归模型].本次 ...

  4. 15分钟带你入门sklearn与机器学习——分类算法篇

    作者 | 何从庆 本文转载自AI算法之心(ID:AIHeartForYou) [导读]众所周知,Scikit-learn(以前称为scikits.learn)是一个用于Python编程语言的免费软件机 ...

  5. 阿里资深AI工程师教你逐个击破机器学习核心算法

    01 近年来,随着 Google 的 AlphaGo 打败韩国围棋棋手李世乭之后,机器学习尤其是深度学习的热潮席卷了整个 IT 界. 所有的互联网公司,尤其是 Google 微软,百度,腾讯等巨头,无 ...

  6. 调包侠福音!机器学习经典算法开源教程(附参数详解及代码实现)

    Datawhale 作者:赵楠.杨开漠.谢文昕.张雨 寄语:本文针对5大机器学习经典算法,梳理了其模型.策略和求解等方面的内容,同时给出了其对应sklearn的参数详解和代码实现,帮助学习者入门和巩固 ...

  7. 机器学习Top10算法,教你选择最合适的那一个!

    本文经AI新媒体量子位(公众号ID:qbitai )授权转载,转载请联系出处 本文共3800字,建议阅读6分钟. 选什么算法?本文为你梳理TOP10机器学习算法特点. 在机器学习领域里,不存在一种万能 ...

  8. 来!一起捋一捋机器学习分类算法

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 来自:算法与数学之美 可是,你能够如数家珍地说出所有常用的分类算法,以及他们的特征.优 ...

  9. 【机器学习】机器学习Top10算法,教你选择最合适的那一个!一文读懂ML中的解析解与数值解...

    在机器学习领域里,不存在一种万能的算法可以完美解决所有问题,尤其是像预测建模的监督学习里. 比方说,神经网络不见得比决策树好,同样反过来也不成立. 最后的结果是有很多因素在起作用的,比方说数据集的大小 ...

最新文章

  1. 盘点程序员写过的惊天Bug:亏损30亿、致6人死亡,甚至差点毁灭世界
  2. hm55主板支持最大内存_内存频率取决于CPU还是主板?内存频率看主板支持还是看CPU支持?...
  3. 关系数据库范式粗略理解
  4. 【控制】《多智能体系统的动力学分析与设计》徐光辉老师-第8章-有输入时滞的二阶多智能体系统的多一致
  5. 2.3.6 生产者消费者问题
  6. 工具坐标6点法_轻松学机器人系列之各坐标系关系
  7. MATLAB中的wavedec、wrcoef函数简析
  8. html 头尾代码自动,HTML Head Generator - 纯 CSS 实现的头部元标签代码生成器 - 钉子の次元...
  9. LeetCode第617题:合并二叉树
  10. 使用http请求发送文件,文件标题乱码
  11. Oracle从入门到精通
  12. 写给非网工的CCNA教程(6)VLAN和802.1q协议
  13. 计算机二级用的ms什么版本,计算机二级ms office用的哪个版本
  14. 3Dmax哪个版本最好用?3dmax哪个版本稳定一点?
  15. html5shiv 无效,解决低版本IE关于html5新特性的兼容性问题html5shiv.js和Respond.js
  16. 如何查看当前域名的注册信息?
  17. 收发一体超声波测距离传感器模块_咸阳KUS3000 超声波额液位物位计
  18. 上班聊天,摸鱼神器,手写一款即时通讯工具(附源码!!!)
  19. 我和王争学设计模式|原型模式
  20. 开关电源PCB走线的时候需要注意什么?

热门文章

  1. bottle 文件服务器,Python库glances和Bottle完成服务器交互式动态监控
  2. Windows ISO镜像资源专用下载工具(Windows ISO Downloader) v4.0 绿色免费版
  3. 采用sql存储的方法保存所爬取的豆瓣电影
  4. Min-Max Max-Min problem algorithm and analysis
  5. QT QMap QMultiMap使用说明
  6. 过敏性鼻炎给宝宝带来哪些危害?
  7. Jetson Nano调试记录:机电设备控制
  8. Netron神经网络结构可视化只显示权重没有箭头,已解决
  9. 危机先知:TOOM舆情监控助力风险预警
  10. Execl根据一列分组,找出另外一列的最大值