本博客运行环境为Jupyter Notebook、Python3。使用的数据集是鸢尾花数据集。

目录

  • 线性判别分析
  • 代码实现
  • 缺少一组数据的问题已解决!代码已更新!

线性判别分析

线性判别分析(Linear Discriminant Analysis,简称LDA)是一种经典的线性学习方法,在二分类问题.上因为最早由[Fisher, 1936]提出,亦称“Fisher判别分析”。
LDA的基本思想:给定训练样例集,设法将样例投影到一条直线上,使得同类样例的投影点尽可能接近、异类样例的投影点尽可能远离;在对新样本进行分类时,将其投影到同样的这条直线上,再根据投影点的位置来确定新样本的类别。
下图是LDA的二维示意图,“+”、“-”分别代表正倒和反倒,椭圆表示数据簇的外轮廓,虚线表示投影,红色实心圆和实心三角形分别表示两类样本投影后的中心点。

线性判别函数的一般形式可以表示为:
g(X)=WTX+w0g(X)=W^TX+w_{0} g(X)=WTX+w0​
其中,

Fisher选择投影方向W的原则,即使原样本向量在该方向上的投影能兼顾类间分布尽可能分开,类内样本投影尽可能密集的要求。

(1)W的确定
各类样本均值向量mi

样本类内离散度矩阵 Si 和总类内离散度矩阵 Sw

样本类间离散度矩阵 Sb

在投影后的一维空间中,各类样本均值

样本类内离散度和总类内离散度

样本类间离散度

Fisher准则函数为

(2)阈值的确定
W0 是个常数,称为阈值权,对于两类问题的线性分类器可以采用下属决策规则:

如果g(x)>0,则决策x属于W1;如果g(x)<0,则决策x属于W2;如果g(x)=0,则可将x任意分到某一类,或拒绝。

(3)Fisher线性判别的决策规则
Fisher准则函数满足两个性质:
1.投影后,各类样本内部尽可能密集,即总类内离散度越小越好。
2.投影后,各类样本尽可能离得远,即样本类间离散度越大越好。
根据性质确定准则函数,根据使准则函数取得最大值,可求出

这就是Fisher判别准则下的最优投影方向。
得到决策规则

若上述规则成立,则有

对于某一个未知类别的样本向量x,如果y=WT·x>y0,则x∈w1;否则x∈w2。

(4)“群内离散度”与“群间离散度”
“群内离散度”要求的是距离越远越好;而“群间离散度”的距离越近越好。
“群内离散度”(样本类内离散矩阵)的计算公式为
Si=∑x∈Xi(x−mi)(x−mi)TS_i=\sum_{x∈X_i}(x-m_i)(x-m_i)^T Si​=x∈Xi​∑​(x−mi​)(x−mi​)T
因为每一个样本有多维数据,因此需要将每一维数据代入公式计算后最后在求和即可得到样本类内离散矩阵。存在多个样本,重复该计算公式即可算出每一个样本的类内离散矩阵。
“群间离散度”(总体类离散度矩阵)的计算公式为
Swij=Si+SjS_wij=S_i+S_jSw​ij=Si​+Sj​

代码实现

例如鸢尾花数据集,将数据集分为三类样本,然后得到三个总体类离散度矩阵,三个总体类离散度矩阵根据上述公式计算即可。
IRIS数据集以鸢尾花的特征作为数据来源,数据集包含150个数据集,有4维,分为3 类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。
Python代码如下:
df = pd.read_csv(r’Iris.csv’,header = None)这句是数据集存储路径,我已将数据集保存为.csv文件,需要修改为自己的路径。若使用sklearn库引用可以参看后面的代码。

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns#path=r'Iris.csv'
#df = pd.read_csv(path, header=None)
df = pd.read_csv(r'Iris.csv',header = None)
Iris1=df.values[0:50,0:4]
Iris2=df.values[50:100,0:4]
Iris3=df.values[100:150,0:4]
m1=np.mean(Iris1,axis=0)
m2=np.mean(Iris2,axis=0)
m3=np.mean(Iris3,axis=0)
s1=np.zeros((4,4))
s2=np.zeros((4,4))
s3=np.zeros((4,4))
for i in range(0,30,1):a=Iris1[i,:]-m1a=np.array([a])b=a.Ts1=s1+np.dot(b,a)
for i in range(0,30,1):c=Iris2[i,:]-m2c=np.array([c])d=c.Ts2=s2+np.dot(d,c) #s2=s2+np.dot((Iris2[i,:]-m2).T,(Iris2[i,:]-m2))
for i in range(0,30,1):a=Iris3[i,:]-m3a=np.array([a])b=a.Ts3=s3+np.dot(b,a)
sw12=s1+s2
sw13=s1+s3
sw23=s2+s3
#投影方向
a=np.array([m1-m2])
sw12=np.array(sw12,dtype='float')
sw13=np.array(sw13,dtype='float')
sw23=np.array(sw23,dtype='float')
#判别函数以及T
#需要先将m1-m2转化成矩阵才能进行求其转置矩阵
a=m1-m2
a=np.array([a])
a=a.T
b=m1-m3
b=np.array([b])
b=b.T
c=m2-m3
c=np.array([c])
c=c.T
w12=(np.dot(np.linalg.inv(sw12),a)).T
w13=(np.dot(np.linalg.inv(sw13),b)).T
w23=(np.dot(np.linalg.inv(sw23),c)).T
#print(m1+m2) #1x4维度  invsw12 4x4维度  m1-m2 4x1维度
T12=-0.5*(np.dot(np.dot((m1+m2),np.linalg.inv(sw12)),a))
T13=-0.5*(np.dot(np.dot((m1+m3),np.linalg.inv(sw13)),b))
T23=-0.5*(np.dot(np.dot((m2+m3),np.linalg.inv(sw23)),c))
kind1=0
kind2=0
kind3=0
newiris1=[]
newiris2=[]
newiris3=[]
for i in range(30,50):x=Iris1[i,:]x=np.array([x])g12=np.dot(w12,x.T)+T12g13=np.dot(w13,x.T)+T13g23=np.dot(w23,x.T)+T23if g12>0 and g13>0:newiris1.extend(x)kind1=kind1+1elif g12<0 and g23>0:newiris2.extend(x)elif g13<0 and g23<0 :newiris3.extend(x)
#print(newiris1)
for i in range(30,50):x=Iris2[i,:]x=np.array([x])g12=np.dot(w12,x.T)+T12g13=np.dot(w13,x.T)+T13g23=np.dot(w23,x.T)+T23if g12>0 and g13>0:newiris1.extend(x)elif g12<0 and g23>0:newiris2.extend(x)kind2=kind2+1elif g13<0 and g23<0 :newiris3.extend(x)
for i in range(30,50):x=Iris3[i,:]x=np.array([x])g12=np.dot(w12,x.T)+T12g13=np.dot(w13,x.T)+T13g23=np.dot(w23,x.T)+T23if g12>0 and g13>0:newiris1.extend(x)elif g12<0 and g23>0:     newiris2.extend(x)elif g13<0 and g23<0 :newiris3.extend(x)kind3=kind3+1
correct=(kind1+kind2+kind3)/60
print("样本类内离散度矩阵S1:",s1,'\n')
print("样本类内离散度矩阵S2:",s2,'\n')
print("样本类内离散度矩阵S3:",s3,'\n')
print('-----------------------------------------------------------------------------------------------')
print("总体类内离散度矩阵Sw12:",sw12,'\n')
print("总体类内离散度矩阵Sw13:",sw13,'\n')
print("总体类内离散度矩阵Sw23:",sw23,'\n')
print('-----------------------------------------------------------------------------------------------')
print('判断出来的综合正确率:',correct*100,'%')

sklearn库引入数据集:
只需替换引入数据集的部分代码。

from sklearn.datasets import make_multilabel_classification
from sklearn import datasetsiris_datas = datasets.load_iris()
x, y = make_multilabel_classification(n_samples=20, n_features=2,n_labels=1, n_classes=1,random_state=2)  # 设置随机数种子,保证每次产生相同

运行结果如下:

缺少一组数据的问题已解决!代码已更新!

原始代码主要是文件导入那儿的问题。header=0改为header=None;如果还有错,需要把path更改为df = pd.read_csv(r’Iris.csv’,header = None),直接使用数据集。

我这输出的综合准确率只有91.6%,而有些同学比较好的能有96.7%。我一开始怀疑是数据集少了一组数据的原因,后面发现确实少了一组数据,在excel中打开数据集完整,但是运行起来就少一组。但别人使用该数据集时没有出现这种情况。我只好换成sklearn引入数据集,然而输出结果是一致的。不得而解。若有知道该问题的小伙伴,希望可以指导我一下哦。

参考教程:机器学习-西瓜书-周志华

Python-线性判别分析(Fisher判别分析)使用鸢尾花数据集 Iris相关推荐

  1. 【python数据挖掘课程】十九.鸢尾花数据集可视化、线性回归、决策树花样分析

    这是<Python数据挖掘课程>系列文章,也是我这学期上课的部分内容.本文主要讲述鸢尾花数据集的各种分析,包括可视化分析.线性回归分析.决策树分析等,通常一个数据集是可以用于多种分析的,希 ...

  2. 实验二:用python实现SVM支持向量机并对鸢尾花数据集分类

    实验二:SVM支持向量机 1. 实验内容: (1)用你熟知的语言(尽量使用python)实现支持向量机的算法,并在给定的数据集上训练. (2)在测试集上用训练好的支持向量机进行测试,并将预测结果以cs ...

  3. python进行KNN算法分析实战(鸢尾花数据集)

    KNN算法分析实战(鸢尾花数据集) 目录 KNN算法分析实战(鸢尾花数据集) 代码效果图 一.导入需要的包 二. 1.导入数据 ​ 2.建立训练集和测试集 3.设置K值 4. 十重交叉验证K值 5.模 ...

  4. 机器学习与深度学习——通过knn算法分类鸢尾花数据集iris求出错误率并进行可视化

    什么是knn算法? KNN算法是一种基于实例的机器学习算法,其全称为K-最近邻算法(K-Nearest Neighbors Algorithm).它是一种简单但非常有效的分类和回归算法. 该算法的基本 ...

  5. 2.试读取鸢尾花数据集iris.npz,绘制sepal_length和sepal_width两个特征之间的散点图,X轴添加“SepalLength”标签,Y轴添加“SepalWidth”标签,散点设置

    2022-2023学年第1期期末考试 <Python数据分析与应用>试卷A卷 (大数据技术专业2131.2132班适用 120分钟 机试开卷) 班级 学号 姓名 1 题 号 一 总 分 得 ...

  6. 鸢尾花数据集的线性多分类

    目录 一.鸢尾花数据集 二.取萼片的长宽作特征分类 三.取花瓣的长宽作特征分类 实验目的: 在Jupyter下完成一个鸢尾花数据集的线性多分类.可视化显示与测试精度实验. 实验环境: Anaconda ...

  7. Python原生代码实现KNN算法(鸢尾花数据集)

    一.作业题目 Python原生代码实现KNN分类算法,使用鸢尾花数据集. KNN算法介绍: K最近邻(k-Nearest Neighbor,KNN)分类算法,是机器学习算法之一. 该方法的思路是:如果 ...

  8. Python机器学习:KNN算法03训练数据集,测试数据集train test split

    示例代码 首先引入相关包 import numpy as np import matplotlib.pyplot as plt from sklearn import datasets import ...

  9. 如何了解Ski-learn提供的离散型数据集的构造——以鸢尾花数据集为例

    一.利用描述函数 #导入鸢尾花数据集 from sklearn.datasets import load_iris # 描述鸢尾花数据集 iris = load_iris() # 输出对iris数据集 ...

最新文章

  1. 【Java_多线程并发编程】基础篇—线程状态及实现多线程的两种方式
  2. 行为型模式:策略模式
  3. 《阿里巴巴Java开发规约》插件使用详细指南
  4. python 爬取大乐透开奖结果
  5. 联信高效的数据传输机制
  6. Eclipse 皮肤
  7. [渝粤教育] 中国地质大学 现代控制理论 复习题 (2)
  8. TestNG官方文档中文版(1)-介绍
  9. BZOJ3495 : PA2010 Riddle
  10. Modelsim SE-64 10.4版本在WIN10-64位下找不到LICENSE的解决办法
  11. android socket 丢包,socket timeout exception和常见网络丢包情况
  12. Oracle10g 基本命令
  13. 西部数码 php 伪静态,西部数码虚拟主机伪静态如何设置
  14. 你觉得学 Python 还是 Java 更好找工作?
  15. Vue项目关闭格式检查命令
  16. EDK II之USB主控制器(EHCI)驱动的实现框架
  17. python3 利用ffmpeg把音频转换为16khz的wav文件
  18. (八)以交易为生:交易系统
  19. MySQL创建用户,并赋予表权限
  20. python人脸识别考勤系统 dlib+OpenCV和Pyqt5、数据库sqlite 人脸识别系统 计算机 毕业设计 源码

热门文章

  1. windows安装第二个固态硬盘时重装系统的问题
  2. 数据分析方法01对比分析法
  3. HTML5阻击游戏《僵尸之夜》截图
  4. 关于JS跨域访问介绍
  5. 基于毫米波雷达(mmWave Radar)与摄像头(Camera)融合(Fusion)的人行道行人检测(Detection)
  6. Wdf框架之WdfObject状态机(2)
  7. ubuntu Docker无法使用zip unzip指令 解决方案
  8. mysql和oracle课程,Oracle MySQL 管理实战应用培训
  9. reCaptcha去除
  10. python项目简历内容包括哪些方面_一份完整的简历包括什么?