python代码完成Fisher判别的推导

  • 一、Fisher算法的主要思想
  • 二、Fisher数学算法步骤
    • ①计算各类样本均值向量 m i m_i mi​, m i m_i mi​是各个类的均值, N i N_i Ni​是 w i w_i wi​类的样本个数。
    • ②计算样本类内离散度矩阵 S i S_i Si​和总类内离散度矩阵 S w S_w Sw​
    • ③计算样本类间离散度矩阵 S b S_b Sb​
    • ④求投影方向向量 W W W (维度和样本的维度相同)。我们希望投影后,在一维 Y Y Y空间里各类样本尽可能分开,就是我们希望的两类样本均值之差 ( m 1 ‾ − m 2 ‾ ) (\overline{m_1}-\overline{m_2}) (m1​​−m2​​)越大越好,同时希望各类样本内部尽量密集,即是:希望类内离散度越小越好。因此,我们可以定义Fisher准则函数为:
    • 2使得 J F ( w ) J_F(w) JF​(w)取得最大值 w w w为:
    • ⑤将训练集内所有样本进行投影。
    • ⑥. 计算在投影空间上的分割阈值 y 0 y_0 y0​,在一维Y空间,各类样本均值 m i ‾ \overline{m_i} mi​​为:
    • ⑦对于给定的测试样本 x x x,计算出它在 w w w上的投影点 y y y
    • ⑧根据决策规则分类!
  • 三、python实现代码

数据集Iris.csv:链接下载:
提取码:eah8

一、Fisher算法的主要思想

  • 线性判别分析(Linear Discriminant Analysis
    简称LDA)是一种经典的线性学习方法,在二分类问题上因为最早由【Fisher,1936年】提出,所以也称为“Fisher 判别分析!”
    Fisher(费歇)判别思想是投影,使多维问题简化为一维问题来处理。选择一个适当的投影轴,使所有的样本点都投影到这个轴上得到一个投影值。对这个投影轴的方向的要求是:使每一类内的投影值所形成的类内离差尽可能小,而不同类间的投影值所形成的类间离差尽可能大。

二、Fisher数学算法步骤

  • 为了找到最佳投影方向,需要计算出 各类样本均值、样本类内离散度矩阵 Si\boldsymbol S_{i}S i和样本总类内离散度矩阵 Sw\boldsymbolS_{w}Sw、样本类间离散度矩阵 Sb\boldsymbol S_{b}Sb ,根据Fisher准则,找到最佳投影向量,将训练集内的所有样本进行投影,投影到一维Y空间,由于Y空间是一维的,则需要求出Y空间的划分边界点,找到边界点后,就可以对待测样本进行一维Y空间投影,判断它的投影点与分界点的关系,将其归类。具体方法如下(以两类问题为例子):

①计算各类样本均值向量 m i m_i mi​, m i m_i mi​是各个类的均值, N i N_i Ni​是 w i w_i wi​类的样本个数。

②计算样本类内离散度矩阵 S i S_i Si​和总类内离散度矩阵 S w S_w Sw​

③计算样本类间离散度矩阵 S b S_b Sb​

④求投影方向向量 W W W (维度和样本的维度相同)。我们希望投影后,在一维 Y Y Y空间里各类样本尽可能分开,就是我们希望的两类样本均值之差 ( m 1 ‾ − m 2 ‾ ) (\overline{m_1}-\overline{m_2}) (m1​​−m2​​)越大越好,同时希望各类样本内部尽量密集,即是:希望类内离散度越小越好。因此,我们可以定义Fisher准则函数为:

2使得 J F ( w ) J_F(w) JF​(w)取得最大值 w w w为:

⑤将训练集内所有样本进行投影。

⑥. 计算在投影空间上的分割阈值 y 0 y_0 y0​,在一维Y空间,各类样本均值 m i ‾ \overline{m_i} mi​​为:


样本类内离散度 S i ‾ 2 \overline{S_i}^2 Si​​2和总类内离散度 S w ‾ \overline{S_w} Sw​​

而此时类间离散度就成为两类均值差的平方。

计算阈值 y 0 y_0 y0​

⑦对于给定的测试样本 x x x,计算出它在 w w w上的投影点 y y y

⑧根据决策规则分类!

三、python实现代码

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
path=r'D:/iris-data/iris.csv'
df = pd.read_csv(path, header=0)
Iris1=df.values[0:50,0:4]
Iris2=df.values[50:100,0:4]
Iris3=df.values[100:150,0:4]#类均值向量
m1=np.mean(Iris1,axis=0)
m2=np.mean(Iris2,axis=0)
m3=np.mean(Iris3,axis=0)#各类内离散度矩阵
s1=np.zeros((4,4))
s2=np.zeros((4,4))
s3=np.zeros((4,4))
for i in range(0,30,1):a=Iris1[i,:]-m1a=np.array([a])b=a.Ts1=s1+np.dot(b,a)
for i in range(0,30,1):c=Iris2[i,:]-m2c=np.array([c])d=c.Ts2=s2+np.dot(d,c)
for i in range(0,30,1):a=Iris3[i,:]-m3a=np.array([a])b=a.Ts3=s3+np.dot(b,a) #总类内离散矩阵
sw12=s1+s2
sw13=s1+s3
sw23=s2+s3
#投影方向
a=np.array([m1-m2])
sw12=np.array(sw12,dtype='float')
sw13=np.array(sw13,dtype='float')
sw23=np.array(sw23,dtype='float')
#判别函数以及T
#需要先将m1-m2转化成矩阵才能进行求其转置矩阵
a=m1-m2
a=np.array([a])
a=a.T
b=m1-m3
b=np.array([b])
b=b.T
c=m2-m3
c=np.array([c])
c=c.T
w12=(np.dot(np.linalg.inv(sw12),a)).T
w13=(np.dot(np.linalg.inv(sw13),b)).T
w23=(np.dot(np.linalg.inv(sw23),c)).T
#print(m1+m2) #1x4维度  invsw12 4x4维度  m1-m2 4x1维度
#判别函数以及阈值T(即w0)
T12=-0.5*(np.dot(np.dot((m1+m2),np.linalg.inv(sw12)),a))
T13=-0.5*(np.dot(np.dot((m1+m3),np.linalg.inv(sw13)),b))
T23=-0.5*(np.dot(np.dot((m2+m3),np.linalg.inv(sw23)),c))
kind1=0
kind2=0
kind3=0
newiris1=[]
newiris2=[]
newiris3=[]
for i in range(30,49):x=Iris1[i,:]x=np.array([x])g12=np.dot(w12,x.T)+T12g13=np.dot(w13,x.T)+T13g23=np.dot(w23,x.T)+T23if g12>0 and g13>0:newiris1.extend(x)kind1=kind1+1elif g12<0 and g23>0:newiris2.extend(x)elif g13<0 and g23<0 :newiris3.extend(x)
#print(newiris1)
for i in range(30,49):x=Iris2[i,:]x=np.array([x])g12=np.dot(w12,x.T)+T12g13=np.dot(w13,x.T)+T13g23=np.dot(w23,x.T)+T23if g12>0 and g13>0:newiris1.extend(x)elif g12<0 and g23>0:newiris2.extend(x)kind2=kind2+1elif g13<0 and g23<0 :newiris3.extend(x)
for i in range(30,50):x=Iris3[i,:]x=np.array([x])g12=np.dot(w12,x.T)+T12g13=np.dot(w13,x.T)+T13g23=np.dot(w23,x.T)+T23if g12>0 and g13>0:newiris1.extend(x)elif g12<0 and g23>0:     newiris2.extend(x)elif g13<0 and g23<0 :newiris3.extend(x)kind3=kind3+1
correct=(kind1+kind2+kind3)/60
print("样本类内离散度矩阵S1:",s1,'\n')
print("样本类内离散度矩阵S2:",s2,'\n')
print("样本类内离散度矩阵S3:",s3,'\n')
print('-----------------------------------------------------------------------------------------------')
print("总体类内离散度矩阵Sw12:",sw12,'\n')
print("总体类内离散度矩阵Sw13:",sw13,'\n')
print("总体类内离散度矩阵Sw23:",sw23,'\n')
print('-----------------------------------------------------------------------------------------------')
print('判断出来的综合正确率:',correct*100,'%')

运行结果:

以上便是此次实验的所有结果。参考博客:http://bob0118.club/?p=266

Fisher判别的推导概念和过程+python代码实现(三分类)相关推荐

  1. 皮尔森相关性系数的计算python代码(三)

    部分代码 import os import pandas as pd import numpy as np from scipy.stats import pearsonrdef Pearson(da ...

  2. knn算法python代码_K-最近邻分类算法(KNN)及python实现

    一.引入 问题:确定绿色圆是属于红色三角形.还是蓝色正方形? KNN的思想: 从上图中我们可以看到,图中的数据集是良好的数据,即都打好了label,一类是蓝色的正方形,一类是红色的三角形,那个绿色的圆 ...

  3. pca算法python代码_三种方法实现PCA算法(Python)

    主成分分析,即Principal Component Analysis(PCA),是多元统计中的重要内容,也广泛应用于机器学习和其它领域.它的主要作用是对高维数据进行降维.PCA把原先的n个特征用数目 ...

  4. Python 代码中三种波浪线和 PEP8

    红色 红色波浪线是代码的错误, 必须处理,代码才能执行 注意: 在后续课程中,某些代码没有写完,也会出现红色波浪线 灰色 灰色波浪线,不会影响代码的正常执行,基本上所有的灰色 浪线都是PEP8 造成的 ...

  5. python代码实现决策树分类

    0. 前言 上一篇博客对决策树算法的思想作了描述,也详细写了如何构造一棵决策树.现在希望用python代码来实现它.此处先调用机器学习中的算法库来实现. 2. python代码实现决策树(决策树分类器 ...

  6. 深度学习笔记——神经网络(ANN)搭建过程+python代码

    目录 1.多维数组的运算 (1)多维数组 (2)矩阵乘法 (3)神经网络的内积 2.3层神经网络的实现 (1)第一层加权和 (2)输入层到第1层的信号传递 (3)第1层到第2层的信号传递 (4)完整代 ...

  7. 吴恩达机器学习python代码练习三(多类别分类)

    import numpy as np import pandas as pd import matplotlib.pyplot as plt import scipy.io as sio from s ...

  8. 判别性的低秩字典学习代码matlab,基于分类的判别性字典学习的稀疏编码算法研究...

    第1章绪论1.1课题研究的背景及意义计算机视觉一直是人类视觉研究中的一项非常热门的领域.计算机视觉研究的目的是为了让计算机能够利用图像和图像序列来识别和感知周围的世界,以帮助人们在复杂的情况下解决未知 ...

  9. 【视频】TFLearn深度学习库,20行Python代码实现情感分类

    向AI转型的程序员都关注了这个号

最新文章

  1. SQL中的case when then else end用法
  2. elasticsearch 查询(match和term)
  3. envi矢量图层外面有蓝色边框_晒晒装完的新房,头次见全屋浅蓝背景墙,加石膏线边框,温馨别致...
  4. android数据回传多个页面_Android页面之间进行数据回传
  5. 51nod---无法表示的数
  6. CRM_DOC_FLOW_READ_DB debug
  7. c#中的奇异递归模式
  8. 关于@SuppressWarnings(unchecked)注解
  9. 心语收集12:我以为要是唱的用心良苦,你就会对我多点在乎
  10. UVA12468 Zapping【水题】
  11. Joseph UVA 1452 Jump
  12. CSS布局:图片在DIV中上下左右居中(水平和垂直都居中)
  13. 计算机视觉SLAM方向顶会
  14. 【SCIENTIFIC AMERICAN】Internet Cables Could Also Measure Quakes 网络光纤也可以用来测量地震(20191204)
  15. 【8.8gzoj综合】师生树【BFS】
  16. ps cc2019 安装教程
  17. shader镜面反射(Reflection)
  18. emg采集精度_EMG
  19. 基于java的springboot多用户商城(类淘宝京东)系统毕业设计springboot开题报告
  20. 行业研究报告-全球与中国PH/ORP变送器市场现状及未来发展趋势

热门文章

  1. 三兄弟GETRO、GETTO、SETTO各显神通
  2. 利用crontab实现SVN的自动化备份
  3. 针对UI设计面试,你应注意的几个细节!
  4. 三维可视化技术都有哪些运用
  5. C#,WebApi接口开发
  6. strptime和strftime的用法
  7. Go 语言设置goproxy.io镜像源
  8. Flink学习笔记(一):No new data sinks have been defined since the last execution.
  9. 数值计算——系数矩阵部分对角线为0时线性方程组求解方法(附程序)
  10. 云图说丨初识可信分布式身份服务