import matplotlib.pyplot as plt
import numpy as np
import pandas as pdfor_mat = 'Advertising.csv'
advertising = pd.read_csv(for_mat)
advertising.head()
x, y = 0.5, 0.8"""绘制RSS关于w的函数图像"""def plo_t():w = 3pred = x * wrss = ((pred - y) ** 2) / 2grad = (pred - y) * xprint("当w=3时,预测值为" + str(pred))print("当w=3时,残差平方和为" + str(rss))print("当w=3时,RSS(w)的梯度为" + str(grad))w_vec = np.linspace(-1, 4, 100)rss_vec = []for w_tmp in w_vec:rss_tmp = (y - x * w_tmp) ** 2 / 2rss_vec.append(rss_tmp)"""画出残差平方和随着权重变化的曲线当w=3时,画出RSS(w)的斜率"""plt.plot(w_vec, rss_vec)# 画出w=3时对应RSS的散点图plt.scatter(w, rss, s=100, c="y", marker="o")# 通过当w=3时的切线plt.plot(np.linspace(2.5, 3.5, 50), np.linspace(2.5, 3.5, 50) * 0.35 - 0.805, "--", linewidth=2.0)plt.xlabel("w", fontsize=16)plt.ylabel("RSS", fontsize=16)plt.show()# plo_t() #调用绘制函数"""算法3.1  梯度下降算法"""
w = 0
lr = 0.5
pred = x * w
loss = ((pred - y) ** 2) / 2
grad = (pred - y) * x
print("自变量的值:" + str(x))
print("真实因变量:" + str(y))
print("初始权重:" + str(w))
print("初始预测值:" + str(pred))
print("初始误差:" + str(loss))
print("初始梯度:" + str(grad))
"""
定义迭代函数
更新迭代后的预测值,预测误差,梯度
"""def ite_re(w_, lr_, grad_, count):count = int(count)for i in range(1, count):w_ = w_ - lr_ * grad_pred_ = x * w_loss_ = ((pred_ - y) ** 2) / 2grad_ = (pred_ - y) * xprint(f"第{i}次更新后的权重:" + str(w_))print(f"第{i}次更新后的预测值:" + str(pred_))print(f"第{i}次更新后的误差:" + str(loss_))print(f"第{i}次更新后的梯度:" + str(grad_))print("\n\n")# ite_re(w, lr, grad, 20)  # 迭代更新20次
"""算法3.2       使用随机梯度下降法迭代更新w"""def a():"""对自变量矩阵x,因变量向量y对数据进行标准化和中心化得到scaled_x和centered_y"""x_ = advertising.iloc[:, 0:2].valuesy_ = advertising.iloc[:, 3].valuesscaled_x = (x_ - np.mean(x_, axis=0, keepdims=True)) / np.std(x_, axis=0, keepdims=True)centered_y = y_ - np.mean(y_)lr_ = 0.1w_ = np.zeros(2)w_record = [w_.copy()]for item in range(5):total_loss = 0for i in range(len(scaled_x)):pred_ = np.sum(scaled_x[i] * w_)total_loss += ((pred_ - centered_y[i]) ** 2) / 2delta = (pred_ - centered_y[i])w_ -= lr_ * (delta * scaled_x[i])w_record.append(w_.copy())c = total_loss / (i + 1)print(c)print(w_)a()

3.1.2随机梯度下降法相关推荐

  1. 1. 批量梯度下降法BGD 2. 随机梯度下降法SGD 3. 小批量梯度下降法MBGD

    排版也是醉了见原文:http://www.cnblogs.com/maybe2030/p/5089753.html 在应用机器学习算法时,我们通常采用梯度下降法来对采用的算法进行训练.其实,常用的梯度 ...

  2. 【数据挖掘】神经网络 后向传播算法 ( 梯度下降过程 | 梯度方向说明 | 梯度下降原理 | 损失函数 | 损失函数求导 | 批量梯度下降法 | 随机梯度下降法 | 小批量梯度下降法 )

    文章目录 I . 梯度下降 Gradient Descent 简介 ( 梯度下降过程 | 梯度下降方向 ) II . 梯度下降 示例说明 ( 单个参数 ) III . 梯度下降 示例说明 ( 多个参数 ...

  3. 机器学习-算法背后的理论与优化(part7)--随机梯度下降法概述

    学习笔记,仅供参考,有错必究 随机梯度下降法概述 机器学习场景 算法模型和损失函数 一个有监督学习算法或模型实质上是在拟合一个预测函数侧或者称为假设函数,其形式固定但参数 w ∈ R d w \in ...

  4. 梯度下降法和随机梯度下降法

    1. 梯度 在微积分里面,对多元函数的参数求∂偏导数,把求得的各个参数的偏导数以向量的形式写出来,就是梯度.比如函数f(x,y), 分别对x,y求偏导数,求得的梯度向量就是(∂f/∂x, ∂f/∂y) ...

  5. 梯度下降法、随机梯度下降法、批量梯度下降法及牛顿法、拟牛顿法、共轭梯度法

    http://ihoge.cn/2018/GradientDescent.html http://ihoge.cn/2018/newton1.html 引言 李航老师在<统计学习方法>中将 ...

  6. 基于随机梯度下降法的手写数字识别、epoch是什么、python实现

    基于随机梯度下降法的手写数字识别.epoch是什么.python实现 一.普通的随机梯度下降法的手写数字识别 1.1 学习流程 1.2 二层神经网络类 1.3 使用MNIST数据集进行学习 注:关于什 ...

  7. DistBelief 框架下的并行随机梯度下降法 - Downpour SGD

    本文是读完 Jeffrey Dean, Greg S. Corrado 等人的文章 Large Scale Distributed Deep Networks (2012) 后的一则读书笔记,重点介绍 ...

  8. 【统计学习】随机梯度下降法求解感知机模型

    1. 感知机学习模型 感知机是一个二分类的线性分类问题,求解是使误分类点到超平面距离总和的损失函数最小化问题.采用的是随机梯度下降法,首先任意选取一个超平面w0和b0,然后用梯度下降法不断地极小化目标 ...

  9. Python随机梯度下降法(四)【完结篇】

    有了前面知识的铺垫,现在来做一个总结,利用随机梯度下降法来实现MNIST数据集的手写识别,关于MNIST的详细介绍,可以参考我的前面两篇文章 MNIST数据集手写数字识别(一),MNIST数据集手写数 ...

  10. 随机梯度下降法(stochastic gradient descent,SGD)

    梯度下降法 大多数机器学习或者深度学习算法都涉及某种形式的优化. 优化指的是改变 特征x以最小化或最大化某个函数 f(x)  的任务. 我们通常以最小化 f(x) 指代大多数最优化问题. 最大化可经由 ...

最新文章

  1. 5.1 计算机网络之传输层(传输层提供的服务及功能概述、端口、套接字--Socket、无连接UDP和面向连接TCP服务)
  2. 零基础学Python(第二十二章 常用内置函数)
  3. DevExpress WPF v18.2新版亮点(四)
  4. 电脑反应慢卡怎么解决_电脑键盘失灵怎么解决
  5. 设计模式学习笔记——原型(Prototype)框架
  6. iPhone 14 Pro系列传出好消息,有望实现8GB内存自由
  7. PHP-获取文件后缀名,并判断是否合法
  8. 朋友圈如何测试(思维导图)
  9. Java集合11 (Queue)
  10. java 二进制右移位_(九)二进制、位运算、位移运算符
  11. win2008php一键,WIN2008 一键安装PHP环境PHP5.3+FastCGI
  12. 手机App测试的相关测试点-简单总结
  13. 云游戏的架构设计和技术实现
  14. Word 技术篇-文档中不同级别标题自动重新编号设置方法,论文多级编号演示
  15. Noip2011 Day1 T1 铺地毯(模拟)
  16. dpdk:vfio-pci模式下iommu(N+Y)-Huge配置-numa配置
  17. 各国时区夏令时切换信息
  18. win10搜索框突然不能使用了
  19. 如何构建自我的认知系统
  20. 最短路算法 :Bellman-ford算法 Dijkstra算法 floyd算法 SPFA算法 详解

热门文章

  1. jquery仿直播app按钮点赞特效
  2. 设计模式--依赖倒转原则
  3. 身份证号的每位数字的意义
  4. 二叉树-求叶节点个数
  5. Redis【10】-Redis发布订阅
  6. 一文学会CentOS 文件常用命令
  7. 【转】DotNetNuke常用扩展模块
  8. 余光中《写给未来的你》
  9. 怎样在微信公众平台发文件?
  10. 【罗开传奇】传奇服务端调整人物属性脚本命令ChangeHumAbility