1. 学习目的

使用Logistics Regression进行数据分类。

2. 学习要求

  • 学习LR学习算法的核心代码
  • 写出详细的注释说明

3. 代码实践

(1)创建数据

def generate_data(seed):np.random.seed(seed)# class1data_size_1 = 300# feature1x1_1 = np.random.normal(loc=5.0, scale=1.0, size=data_size_1)# feature2x2_1 = np.random.normal(loc=4.0, scale=1.0, size=data_size_1)y_1 = [0 for _ in range(data_size_1)]# class2data_size_2 = 400# feature1x1_2 = np.random.normal(loc=5.0, scale=2.0, size=data_size_2)# feature2x2_2 = np.random.normal(loc=4.0, scale=2.0, size=data_size_2)y_2 = [1 for _ in range(data_size_2)]# concatenatex1 = np.concatenate((x1_1, x1_2), axis=0)x2 = np.concatenate((x2_1, x2_2), axis=0)# 合成为一个整的数据集,变为二维矩阵x = np.hstack((x1.reshape(-1,1), x2.reshape(-1,1)))y = np.concatenate((y_1, y_2), axis=0)# 总的数据大小data_size_all = data_size_1 + data_size_2# 打乱数据shuffled_index = np.random.permutation(data_size_all)x = x[shuffled_index]y = y[shuffled_index]return x, y

(2)分割数据

其中80%数据用于训练,20%数据用于测试,由于数据量小,不设置验证数据集

# 数据分割,由于数据量小,不设置验证数据集
def data_split(x_data, y_data):# 80%数据用于训练train_split = int(len(y_data) * 0.8)x_train = x[:train_split]y_train = y[:train_split]# 20%数据用于测试x_test = x[train_split:]y_test = y[train_split:]return x_train, y_train, x_test, y_test

(3)构建模型

根据上文:李宏毅机器学习(四),以及大佬:王佳旭同学代码。

# Logistic Regression模型
class LogisticRegression():''':param lr: 学习率:param num_iters: 更新轮数:param seed: 随机数种子'''def __init__(self, lr=0.1, num_iters=100, seed=None):self.seed = seedself.lr = lrself.num_iters = num_itersdef fit(self, x, y):np.random.seed(self.seed)# 参数初始化w bself.w = np.random.normal(loc=0.0, scale=1.0, size=x.shape[1])self.b = np.random.normal(loc=0.0, scale=1.0)# 数据集self.x = xself.y = y# 迭代更新for i in range(self.num_iters):self._update_step()# sigmod处理def _sigmoid(self, z):return 1.0 / (1.0 + np.exp(-z))# 函数模型 w*x + b,经过SIGMOD处理def _f(self, x, w, b):z = x.dot(w) + breturn self._sigmoid(z)# 初次预测算出概率        def predict_proba(self, x=None):if x is None:x = self.xy_pred = self._f(x, self.w, self.b)return y_pred# 再预测,根据概率分类def predict(self, x=None):if x is None:x = self.xy_pred_proba = self._f(x, self.w, self.b)y_pred = np.array([0 if y_pred_proba[i] < 0.5 else 1 for i in range(len(y_pred_proba))])return y_pred# 为分类进行评分def score(self, y_true=None, y_pred=None):if y_true is None or y_pred is None:y_true = self.yy_pred = self.predict()# 计算准确率            acc = np.mean([1 if y_true[i] == y_pred[i] else 0 for i in range(len(y_true))])return acc# 损失函数def loss(self, y_true=None, y_pred_proba=None):if y_true is None or y_pred_proba is None:y_true = self.yy_pred_proba = self.predict_proba()return np.mean(-1.0 * (y_true * np.log(y_pred_proba) + (1.0 - y_true) * np.log(1.0 - y_pred_proba)))# 梯度下降def gradient_descent(self):y_pred = self.predict()d_w = (y_pred - self.y).dot(self.x) / len(self.y)d_b = np.mean(y_pred - self.y)self.w = self.w - self.lr * d_wself.b = self.b - self.lr * d_breturn self.w, self.b

(4)训练生成结果

import matplotlib.pyplot as pltdef main():# 生成数据x, y = generate_data(seed = 514)x_train, y_train, x_test, y_test = data_split(x, y)# 数据归一化x_train = (x_train - np.min(x_train, axis=0)) / (np.max(x_train, axis=0) - np.min(x_train, axis=0))x_test = (x_test - np.min(x_test, axis=0)) / (np.max(x_test, axis=0) - np.min(x_test, axis=0))# 逻辑斯蒂回归分类器clf = LogisticRegression(lr=0.1, num_iters=500, seed=514)clf.fit(x_train, y_train)# 结果可视化split_boundary_func = lambda x: (-clf.b - clf.w[0] * x) / clf.w[1]xx = np.arange(0.1, 0.6, 0.1)cValue = ['g','b'] plt.scatter(x_train[:,0], x_train[:,1], c=[cValue[i] for i in y_train], marker='o')plt.plot(xx, split_boundary_func(xx), c='red')plt.show()# 测试数据集上的损失y_test_pred = clf.predict(x_test)y_test_pred_proba = clf.predict_proba(x_test)print(clf.score(y_test, y_test_pred))print(clf.loss(y_test, y_test_pred_proba))if __name__ == '__main__':main()

本人在代码方面还是有所欠缺,对numpy、matplotlib的使用不熟悉。感谢王同学提供的代码。

李宏毅机器学习-代码实践相关推荐

  1. 【机器学习基础】(六):通俗易懂无监督学习K-Means聚类算法及代码实践

    K-Means是一种无监督学习方法,用于将无标签的数据集进行聚类.其中K指集群的数量,Means表示寻找集群中心点的手段. 一. 无监督学习 K-Means 贴标签是需要花钱的. 所以人们研究处理无标 ...

  2. 【机器学习基础】通俗易懂无监督学习K-Means聚类算法及代码实践

    K-Means是一种无监督学习方法,用于将无标签的数据集进行聚类.其中K指集群的数量,Means表示寻找集群中心点的手段. 一. 无监督学习 K-Means 贴标签是需要花钱的. 所以人们研究处理无标 ...

  3. 【机器学习基础】(五):通俗易懂决策树与随机森林及代码实践

    与SVM一样,决策树是通用的机器学习算法.随机森林,顾名思义,将决策树分类器集成到一起就形成了更强大的机器学习算法.它们都是很基础但很强大的机器学习工具,虽然我们现在有更先进的算法工具来训练模型,但决 ...

  4. 【机器学习基础】(四):通俗理解支持向量机SVM及代码实践

    上一篇文章我们介绍了使用逻辑回归来处理分类问题,本文我们讲一个更强大的分类模型.本文依旧侧重代码实践,你会发现我们解决问题的手段越来越丰富,问题处理起来越来越简单. 支持向量机(Support Vec ...

  5. 【机器学习基础】(三):理解逻辑回归及二分类、多分类代码实践

    本文是机器学习系列的第三篇,算上前置机器学习系列是第八篇.本文的概念相对简单,主要侧重于代码实践. 上一篇文章说到,我们可以用线性回归做预测,但显然现实生活中不止有预测的问题还有分类的问题.我们可以从 ...

  6. 视频+笔记+能够跑通的代码,《李宏毅机器学习完整笔记》发布!

    点击我爱计算机视觉标星,更快获取CVML新技术 [导读]关于机器学习的学习资料从经典书籍.免费公开课到开源项目应有尽有,可谓是太丰富啦,给学习者提供了极大的便利.但网上比比皆是的学习资料大部分都是英文 ...

  7. 机器学习:理解逻辑回归及二分类、多分类代码实践

    作者 | caiyongji   责编 | 张红月 来源 | 转载自 caiyongji(ID:cai-yong-ji) 本文的概念相对简单,主要侧重于代码实践.现实生活中不止有预测的问题还有分类的问 ...

  8. 机器学习(三):理解逻辑回归及二分类、多分类代码实践

    本文是机器学习系列的第三篇,算上前置机器学习系列是第八篇.本文的概念相对简单,主要侧重于代码实践. 上一篇文章说到,我们可以用线性回归做预测,但显然现实生活中不止有预测的问题还有分类的问题.我们可以从 ...

  9. 视频教程-机器学习之聚类、主成分分析理论与代码实践-机器学习

    机器学习之聚类.主成分分析理论与代码实践 干过开发,做到资深Java软件开发工程师,后做过培训,总共培训近千人.目前在高校工作,博士学位.主要研究领域为机器学习与深度学习. 纪佳琪 ¥68.00 立即 ...

最新文章

  1. UNIX中的文件控制--fcntl()
  2. 「LOJ 2289」「THUWC 2017」在美妙的数学王国中畅游——LCT泰勒展开
  3. Java B2B2C o2o多用户商城 springcloud架构 (六)分布式配置中心(Spring Cloud Config)
  4. js中的window.onload和jquery中的load区别的讲解
  5. 【图像处理】——Python+opencv实现提取图像的几何特征(面积、周长、细长度、区间占空比、重心、不变矩等)
  6. 问答系统设计的一些思考
  7. 华为云发布国内首个 AI 模型市场,加速企业 AI 应用落地
  8. CSS —— 多媒体查询
  9. Unity3D数学工具(Mathf)
  10. 基于R语言的现代贝叶斯统计学(INLA下的贝叶斯回归、多层贝叶斯回归、生存分析、随机游走模型、广义可加模型、极端数据的贝叶斯分析等)
  11. 找不到硬盘分区怎么办
  12. 什么是功率因数补偿/校正
  13. java pgm_用Java读取pgm文件
  14. cents OS7配置 php curl.so方法
  15. 蜜瓜文案:水果店蜜瓜简单文案,蜜瓜水果朋友圈配的文案
  16. 什么叫python解析器_Python IDE和解释器的区别是什么?
  17. api c语言 播放视频,使用OpenCV播放视频文件(C/C++ API比较)
  18. ASP 仿 Monorail MVC 的实现思路
  19. TongWeb上应用部署方式
  20. 将 Linux 上的 VSCode 字体更改为 Consolas

热门文章

  1. window7 海康硬盘录像机+ffmpeg+nginx+ckplayer实现网页实时预览监控视频(无敌详细版)
  2. 新的开始,从头来过!
  3. 做P2C必须了解的二维码知识
  4. Java8环境下使用restTemplate单/多线程下载大文件和小文件
  5. dir /s真是个神奇的存在
  6. 移动端App架构Demo
  7. 电脑打字卡顿,onenote、word使用时打字卡顿一下后才显示的解决方案
  8. 秦心,Recyclerview+okhttp
  9. android原理分析博客,高通Android平台下zoom4X实验原理分析(一)
  10. 使用Java定义数组