目录

  • 一、线性判别分析介绍
  • 二、线性判别分析原理
    • 1. 类内散度矩阵(within-class scatter matrix)
    • 2. 类间散度矩阵(between-class scatter matrix)
    • 3. 广义瑞利商(generalized Rayleigh quotiet)
  • 三、sklearn库实现线性判别分析LDA
  • 四、总结
  • 五、参考

一、线性判别分析介绍

  线性判别分析(Linear Discriminant Analysis,简称 LDALDALDA)是一种经典的线性学习方法,亦称"Fisher 判别分析"。

  线性判别分析思想:给定训练样本集,设法将样例投影到一条直线上。使得同类样例的投影点尽可能接近、异类样例的投影点尽可能远;在对新样本进行分类时,将其投影到该直线上,再根据投影点的位置来确定新样本的类别。

二、线性判别分析原理

  给定数据集 D={(xi,yi)}i=1m,yi∈{0,1}D= \{(\pmb{x_i} , y_i) \}^m_{i=1},y_i\in\{ 0,1\}D={(xi​​xi​​​xi​,yi​)}i=1m​,yi​∈{0,1} ,令XiX_iXi​、μi\pmb{\mu_i}μi​​μi​​​μi​、∑i\pmb{\sum_i}∑i​​∑i​​​∑i​ 分别表示 i∈{0,1}i\in\{0,1\}i∈{0,1} 类示例的集合、均值向量、协方差矩阵。若将数据投影到直线 ω\pmb{\omega}ωωω 上,则两类样本的中心在直线上的投影分别为 ωTμ0\pmb{\omega^T\mu_0}ωTμ0​​ωTμ0​​​ωTμ0​ 和 ωTμ1\pmb{\omega^T\mu_1}ωTμ1​​ωTμ1​​​ωTμ1​ ;若将所有样本点都投影到直线上,则两类样本的协方差分别为 ωT∑0ω\pmb{\omega^T\sum_0\omega}ωT∑0​ω​ωT∑0​ω​​ωT∑0​ω 和 ωT∑1ω\pmb{\omega^T\sum_1\omega}ωT∑1​ω​ωT∑1​ω​​ωT∑1​ω 。由于直线处于一维空间,因此 ωTμ0\pmb{\omega^T\mu_0}ωTμ0​​ωTμ0​​​ωTμ0​、ωTμ1\pmb{\omega^T\mu_1}ωTμ1​​ωTμ1​​​ωTμ1​ 、ωT∑0ω\pmb{\omega^T\sum_0\omega}ωT∑0​ω​ωT∑0​ω​​ωT∑0​ω 和 ωT∑1ω\pmb{\omega^T\sum_1\omega}ωT∑1​ω​ωT∑1​ω​​ωT∑1​ω 均为实数。

  要使得同类样例的投影点尽可能接近,所以应让同类样例投影点的协方差尽可能小,即ωT∑0ω+ωT∑1ω\pmb{{\omega^T\sum_0\omega}+{\omega^T\sum_1\omega}}ωT∑0​ω+ωT∑1​ω​ωT∑0​ω+ωT∑1​ω​​ωT∑0​ω+ωT∑1​ω 尽可能小。
  要使得异类样例的投影点尽可能地远,则让类中心之间的距离尽可能大,即 ∣∣ωTμ0−ωTμ1∣∣22||\pmb{\omega^T\mu_0}-\pmb{\omega^T\mu_1||_2^2}∣∣ωTμ0​​ωTμ0​​​ωTμ0​−ωTμ1​∣∣22​​ωTμ1​∣∣22​​​ωTμ1​∣∣22​ 尽可能大。同时考虑二者,则需要得到的最大化目标为:
J=∣∣ωTμ0−ωTμ1∣∣22ωT∑0ω+ωT∑1ω=ωT(μ0−μ1)(μ0−μ1)TωωT(∑0+∑1)ω\begin{aligned} \pmb{J} &= \frac{||\omega^T\mu_0 - \omega^T\mu_1||_2^2}{{\omega^T\sum_0\omega}+\omega^T\sum_1\omega}\\ &= \frac{\omega^T(\mu_0 - \mu_1)(\mu_0-\mu_1)^T\omega}{\omega^T(\sum_0+\sum_1)\omega}\\ \end{aligned}JJJ​=ωT∑0​ω+ωT∑1​ω∣∣ωTμ0​−ωTμ1​∣∣22​​=ωT(∑0​+∑1​)ωωT(μ0​−μ1​)(μ0​−μ1​)Tω​​

1. 类内散度矩阵(within-class scatter matrix)

  类内散度矩阵用来判断同类样例的投影点之间的距离。
Sw=∑0+∑1=∑x∈X0(x−μ0)(x−μ0)T+∑x∈X1(x−μ1)(x−μ1)T\begin{aligned} S_w &= \sum_0 + \sum_1 \\ &= \sum_{x\in X_0} (x-\mu_0)(x-\mu_0)^T +\sum_{x\in X_1} (x-\mu_1)(x-\mu_1)^T \end{aligned}Sw​​=0∑​+1∑​=x∈X0​∑​(x−μ0​)(x−μ0​)T+x∈X1​∑​(x−μ1​)(x−μ1​)T​

2. 类间散度矩阵(between-class scatter matrix)

  类间散度矩阵用来判断异类样例的投影点之间的距离。
Sb=(μ0−μ1)(μ0−μ1)TS_b = (\mu_0-\mu_1)(\mu_0-\mu_1)^TSb​=(μ0​−μ1​)(μ0​−μ1​)T

3. 广义瑞利商(generalized Rayleigh quotiet)

  广义瑞利商(generalized Rayleigh quotiet)就是 LDALDALDA欲最大化的目标。
  使用类内散度矩阵类间散度矩阵将最大化目标改写为:
J=ωTSbωωTSwω\pmb{J} = \frac{\omega^TS_b\omega}{\omega^TS_w\omega}\\ JJJ=ωTSw​ωωTSb​ω​

  LDALDALDA 可从贝叶斯决策理论的角度来阐释,并可证明,当两类数据同先验、满足高斯分布且协方差相等时,LDALDALDA可达到最优分类。

三、sklearn库实现线性判别分析LDA

  1. 数据生成
#生成200个三个维度样本
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.datasets import make_classification
x, y = make_classification(n_samples=200, n_features=2, n_redundant=0, n_classes=2, n_informative=2,n_clusters_per_class=2,class_sep =1, random_state =0)
fig = plt.figure()
plt.scatter(x[:, 0], x[:, 1], c=y)

  1. 数据处理
#设置分类平滑度
h = .01
#设置X和Y的边界值
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1#使用meshgrid函数返回X和Y两个坐标向量矩阵
xx, yy = np.meshgrid(np.arange(x_min, x_max,h), np.arange(y_min, y_max,h))
Z = lda.predict(np.c_[xx.ravel(), yy.ravel()])
  1. 数据集划分
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(x, y, random_state=33, test_size=0.25)
  1. LDA分类
#使用LDA进行降维
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LogisticRegression
lda = LinearDiscriminantAnalysis(n_components=1)x_train_lda = lda.fit_transform(x_train, y_train)  # LDA是有监督方法,需要用到标签
x_test_lda = lda.fit_transform(x_test, y_test)   # 预测时候特征向量正负问题,乘-1反转镜像
  1. 绘制训练集分类图像
#设置colormap颜色
cm_bright = ListedColormap(['#D9E021', '#0D8ECF'])
#绘制数据点
plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, cmap=cm_bright)
plt.title('Linear Discriminant Analysis Classifiers')
plt.axis('tight')
plt.show()

  1. 绘制测试集分类图像
plt.title('Linear Discriminant Analysis Classifiers')
plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test, cmap=cm_bright)
plt.show()

四、总结

   LDA算法既可以用来降维,也可以用来分类,但是目前来说,主要还是用于降维,和PCA类似,LDA降维基本也不用调参,只需要指定降维到的维数即可。

五、参考

Python机器学习笔记:线性判别分析(LDA)算法

【机器学习】机器学习之线性判别分析(LDA)相关推荐

  1. 机器学习 周志华 课后习题3.5 线性判别分析LDA

    机器学习 周志华 课后习题3.5 线性判别分析LDA 照着书上敲了敲啥都不会,雀食折磨 python代码 # coding=UTF-8 from numpy import * # 我安装numpy的时 ...

  2. lda 吗 样本中心化 需要_机器学习 —— 基础整理(四):特征提取之线性方法——主成分分析PCA、独立成分分析ICA、线性判别分析LDA...

    本文简单整理了以下内容: (一)维数灾难 (二)特征提取--线性方法 1. 主成分分析PCA 2. 独立成分分析ICA 3. 线性判别分析LDA (一)维数灾难(Curse of dimensiona ...

  3. 线性分类(二)-- 线性判别分析 LDA

    在机器学习领域,LDA是两个常用模型的简称:线性判别分析(Linear Discriminant Analysis) 和隐含狄利克雷分布(Latent Dirichlet Allocation).在自 ...

  4. 线性判别分析LDA—西瓜书课后题3.5—MATLAB代码

    题目:编程实现线性判别分析LDA,给出西瓜数据集 3.0a上的结果 简单说就是找一个分离度最大的投影方向,把数据投射上去. clc clear all [num,txt]=xlsread('D:\机器 ...

  5. 数据分享|R语言逻辑回归、线性判别分析LDA、GAM、MARS、KNN、QDA、决策树、随机森林、SVM分类葡萄酒交叉验证ROC...

    全文链接:http://tecdat.cn/?p=27384 在本文中,数据包含有关葡萄牙"Vinho Verde"葡萄酒的信息(点击文末"阅读原文"获取完整代 ...

  6. ML之NB:基于news新闻文本数据集利用纯统计法、kNN、朴素贝叶斯(高斯/多元伯努利/多项式)、线性判别分析LDA、感知器等算法实现文本分类预测

    ML之NB:基于news新闻文本数据集利用纯统计法.kNN.朴素贝叶斯(高斯/多元伯努利/多项式).线性判别分析LDA.感知器等算法实现文本分类预测 目录 基于news新闻文本数据集利用纯统计法.kN ...

  7. 07_数据降维,降维算法,主成分分析PCA,NMF,线性判别分析LDA

    1.降维介绍 保证数据所具有的代表性特性或分布的情况下,将高维数据转化为低维数据. 聚类和分类都是无监督学习的典型任务,任务之间存在关联,比如某些高维数据的分类可以通过降维处理更好的获得. 降维过程可 ...

  8. 『矩阵论笔记』线性判别分析(LDA)最全解读+python实战二分类代码+补充:矩阵求导可以参考

    线性判别分析(LDA)最全解读+python实战二分类代码! 文章目录 一.主要思想! 二.具体处理流程! 三.补充二中的公式的证明! 四.目标函数的求解过程! 4.1.优化问题的转化 4.2.拉格朗 ...

  9. sklearn实现lda模型_运用sklearn进行线性判别分析(LDA)代码实现

    基于sklearn的线性判别分析(LDA)代码实现 一.前言及回顾 本文记录使用sklearn库实现有监督的数据降维技术--线性判别分析(LDA).在上一篇LDA线性判别分析原理及python应用(葡 ...

最新文章

  1. mysql nosql引擎_nosql与mysql的区别是什么
  2. 大话Django之一:安装与启动
  3. 搭建nfs共享存储服务之一nfs服务端搭建
  4. 利用 FFmpeg palettegen paletteuse 生成接近全色的 gif 动画
  5. 运动基元_发现大量Java基元集合处理
  6. LeetCode 616. 给字符串添加加粗标签(Trie树)
  7. 使用SourceTree
  8. java中的集合_Java 集合介绍,常用集合类
  9. 调整DOS窗口大小的方法 2021-03-06
  10. 李彦宏:离破产永远只有30天
  11. python什么字体好看_python docx 中文字体设置的操作方法
  12. 大学生如何合理利用计算机,大学生如何安排自己的课余时间?6招,学霸教会你正确使用手机...
  13. Mac安装软件时各种异常情况的解决方法
  14. CSS核心概念一把梭-基础部分
  15. jqGrid subGrid配置 如何首次加载动态展开所有的子表格
  16. 如何用计算机玩扫雷,扫雷怎么玩_玩好扫雷游戏的技巧是什么【图文】-太平洋电脑网PConline-太平洋电脑网...
  17. linux查看进程和端口信息的命令
  18. 一份golang令牌桶攻略(juju/ratelimit)
  19. ps2018首选项出现要求96和8之间的整数怎么办?
  20. deepin 输入法频繁重启,无法正常输入汉字解决方法

热门文章

  1. 使用Altium Designer 20绘制双层板以及四层板
  2. 如何用数据管理去挖掘大数据的商业价值
  3. rabbitmq消费者“无故消失”
  4. Python WordCloud 文本分析 生成词云图
  5. 网络套接字编程之IO模型详解
  6. PHP开发环境搭建与工具
  7. 根文件系统(四)——U盘文件系统制作
  8. MES系统,即制造执行系统Manufacturing Execution System)
  9. 计算机科学与技术的职业需求,职业规划:计算机科学与技术专业就业前景
  10. 真相了,读完研就能找到好工作吗?