y=wi*xi+b,基于最小二乘法的线性回归:寻找参数w和b,使得w和b对x_test_data的预测值y_pred_data与真实的回归目标y_test_data之间的均方误差最小。

公式推导:

基于最小二乘法构造linear_model有5个步骤:
1、导包。

from sklearn import linear_model
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error,r2_score,mean_absolute_error

sklearn中有专门的线性模型包linear_model,numpy用于生成数据,matplotlib用于画图,另外导入MSE,R_Square和MAE三个评价指标。
2、构造数据集。可以自动生成数据,也可以寻找现有数据,以下数据是作业中的数据,样本数据只有一个特征。
3、训练模型。
4、输出系数w和截距b并对测试集进行预测。
5、作图。

完整代码:

import pandas as pd
import matplotlib.pyplot as plt
from sklearn import linear_model
import numpy as np
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_errordef load_data():data = pd.read_csv('Salary_Data.csv', encoding='gbk')data = data.values.tolist()train_x = []train_y = []test_x = []test_y = []# 前一半作为训练集,后一半作为测试集for i in range(len(data)):if i < len(data) / 2:train_x.append(data[i][0])train_y.append(data[i][1])else:test_x.append(data[i][0])test_y.append(data[i][1])return train_x, train_y, test_x, test_ydef model():print('手写:')train_x, train_y, test_x, test_y = load_data()# 最小二乘法得到参数sum = 0.0sum_square = 0.0sum_2 = 0.0sum_b = 0.0for i in range(len(train_x)):sum = sum + train_x[i]sum_square = sum_square + train_x[i] ** 2ave_x = sum / len(train_x)for i in range(len(train_x)):sum_2 = sum_2 + (train_y[i] * (train_x[i] - ave_x))w = sum_2 / (sum_square - sum ** 2 / len(train_x))for i in range(len(train_x)):sum_b = sum_b + (train_y[i] - w * train_x[i])b = sum_b / len(train_x)print('w=', w, 'b=', b)# 测试pred_y = []for i in range(len(test_x)):pred_y.append(w * test_x[i] + b)# 计算MSE,MAE,r2_scoresum_mse = 0.0sum_mae = 0.0sum1 = 0.0sum2 = 0.0for i in range(len(pred_y)):sum_mae = sum_mae + np.abs(pred_y[i] - test_y[i])sum_mse = sum_mse + (pred_y[i] - test_y[i]) ** 2sum_y = 0.0for i in range(len(test_y)):sum_y = sum_y + test_y[i]ave_y = sum_y / len(test_y)for i in range(len(pred_y)):sum1 = sum1 + (pred_y[i] - test_y[i]) ** 2sum2 = sum2 + (ave_y - test_y[i]) ** 2print('MSE:', sum_mse / len(pred_y))print('MAE:', sum_mae / len(pred_y))print('R2_Squared:', 1 - sum1 / sum2)# 显示plt.scatter(test_x, test_y, color='black')plt.plot(test_x, pred_y, color='blue', linewidth=3)plt.show()print('\n')# 调包
def sklearn_linearmodel():print('调包:')train_x, train_y, test_x, test_y = load_data()train_x = np.array(train_x).reshape(-1, 1)train_y = np.array(train_y).reshape(-1, 1)test_x = np.array(test_x).reshape(-1, 1)test_y = np.array(test_y).reshape(-1, 1)# 训练+测试lr = linear_model.LinearRegression()lr.fit(train_x, train_y)y_pred = lr.predict(test_x)# 输出系数和截距print('w:', lr.coef_, 'b:', lr.intercept_)# 输出评价指标print('MSE:', mean_squared_error(test_y, y_pred))print('MAE:', mean_absolute_error(test_y, y_pred))print('R2_Squared:', r2_score(test_y, y_pred))# 显示plt.scatter(test_x, test_y, color='black')plt.plot(test_x, y_pred, color='blue', linewidth=3)plt.show()if __name__ == '__main__':model()sklearn_linearmodel()

机器学习之linear_model(普通最小二乘法手写+sklearn实现+评价指标)相关推荐

  1. 【机器学习与算法】python手写算法:Cart树

    [机器学习与算法]python手写算法:Cart树 背景 代码 输出示例 背景 Cart树算法原理即遍历每个变量的每个分裂节点,找到增益(gini或entropy)最大的分裂节点进行二叉分割. 这里只 ...

  2. 机器学习入门-kNN算法实现手写数字识别

    实验环境 Python:3.7.0 Anconda:3-5.3.1 64位 操作系统:win10 开发工具:sublime text(非必要) 简介 本次实验中的重点为采用kNN算法进行手写数字识别, ...

  3. matlab朴素贝叶斯手写数字识别_机器学习系列四:MNIST 手写数字识别

    4. MNIST 手写数字识别 机器学习中另外一个相当经典的例子就是MNIST的手写数字学习.通过海量标定过的手写数字训练,可以让计算机认得0~9的手写数字.相关的实现方法和论文也很多,我们这一篇教程 ...

  4. 吴恩达机器学习 逻辑回归 作业3(手写数字分类) Python实现 代码详细解释

    整个项目的github:https://github.com/RobinLuoNanjing/MachineLearning_Ng_Python 里面可以下载进行代码实现的数据集 题目介绍: In t ...

  5. Python3入门机器学习经典算法与应用——手写knn模块

    文章目录 手写knn模块 kNN.py metrics.py model_selection.py 手写knn模块 `-- playML|-- __init__.py|-- kNN.py|-- met ...

  6. 机器学习入门-用KNN实现手写数字图片识别(包含自己图片转化)

    Python实现KNN手写数字图片识别 1.数据集格式 2.把自己图片转化为数据集格式(把宽高是32像素x32像素的黑白图像转换为文本格式) 3.用数据集实现 4.运行结果 4.代码下载地址 KNN是 ...

  7. python识别手写数字knn_机器学习-kNN实现简单的手写数字识别系统

    功能 利用k-邻近算法,实现识别数字0到9 开发环境Mac Python3.5(Anaconda) PIL numpy 数据集和项目源代码 数据集 下面是32*32的黑白图像 32* 32像素数据集 ...

  8. 机器学习实战 k-近邻算法 手写识别系统

    转载于:https://www.cnblogs.com/crysa/p/8735556.html

  9. (机器学习实战)2.3手写识别系统(详细注释)

    编译:python3.6 代码和训练集下载:https://pan.baidu.com/s/1m7HdAkuwGgXX8v5-DN118Q import operator from numpy imp ...

最新文章

  1. 怎么做合格的首席信息主管CIO?
  2. 逃计算机课检讨书600字,检讨书600字3篇
  3. oppo n1t android 版本,OPPO N1的手机系统是什么?OPPO N1能升级安卓4.3吗?
  4. 牛顿法 Newton Method
  5. python跟易语言的爬虫_新人Python,第一只爬虫,,我就只会re.findall,你咬我?
  6. 使用ArcGIS Server发布我们的数据
  7. mysql视图中调用函数写法_从视图中调用函数
  8. 从质疑到成为必选项,低代码技术发展及 2022 展望
  9. Centos7 使用Docker 部署Tomca+mysql+调试联通_02
  10. 机器学习 - 随机森林手动10 折交叉验证
  11. 2016年下半年信息安全工程师考试真题含答案(下午题)
  12. 在maven中做ssm整合
  13. sensenet的编译调试
  14. 自媒体短视频采集工具,采集多个平台的视频
  15. LNMP(nginx php-fpm mysql) 环境部署——php
  16. 《统计学基于R》第一章 数据与R
  17. 基于树莓派的手势识别Oled屏幕显示
  18. 1_一些文献中的英文解释和用法整理
  19. 讯飞语音识别之语音转文字------java
  20. SSH The authenticity of host can’t be established Are you sure you want to continue connecting

热门文章

  1. YDOOK:ESP8266: 官方AT固件下载 WiFi 开发固件下载
  2. 如何提高网络推广员自身的素养?
  3. go每日新闻--2020-05-12
  4. js深拷贝和浅拷贝的区别
  5. excel条件求和技巧:应用SUMIF函数计算客户余款
  6. Spring之bean对象
  7. 消费升级还是消费降级?别纠结了。
  8. python程序设计教程杨年华_Python程序设计教程
  9. 消费者组织称iPhone电池寿命被苹果高估达51%
  10. 数据结构-队列(Queue)