上一篇文章:「14」支持向量机——我话说完,谁支持?谁反对?,我们通过SVM的数学原理讲解了这个最常见的机器学习算法。这一篇我们用一个非常简单的python实战项目来练习一下SVM并加深理解。

复习

SVM是一种二分类模型,处理的数据可以分为三类:

  1. 线性可分,通过硬间隔最大化,学习线性分类器
  2. 近似线性可分,通过软间隔最大化,学习线性分类器
  3. 线性不可分,通过核函数以及软间隔最大化,学习非线性分类器

线性分类器,在平面上对应直线;非线性分类器,在平面上对应曲线。

硬间隔对应于线性可分数据集,可以将所有样本正确分类,也正因为如此,受噪声样本影响很大,不推荐。

软间隔对应于通常情况下的数据集(近似线性可分或线性不可分),允许一些超平面附近的样本被错误分类,从而提升了泛化性能。

如下图:

实线是由硬间隔最大化得到的,预测能力显然不及由软间隔最大化得到的虚线。

对于线性不可分的数据集,如下图:

我们直观上觉得这时线性分类器,也就是直线,不能很好的分开红点和蓝点。

但是可以用一个介于红点与蓝点之间的类似圆的曲线将二者分开,如下图:

我们假设这个黄色的曲线就是圆,不妨设其方程为x^2+y^2=1,那么核函数是干什么的呢?

我们将x^2映射为X,y^2映射为Y,那么超平面变成了X+Y=1。

那么原空间的线性不可分问题,就变成了新空间的(近似)线性可分问题。

此时就可以运用处理(近似)线性可分问题的方法去解决线性不可分数据集的分类问题。

python实战

实例:SVM建模——蓝瘦香菇到底有没有毒

数据集:8000多条蘑菇的数据,有23个特征。其中是否有毒是因变量,其余特征为自变量。Mushroom Data Set

特征举例

{伞盖形状:钟形/平面/圆锥/凸出/圆球/凹进};

{气味:杏仁味/鱼腥味/臭味/泥土味/无味/刺鼻/辣味};

{地点:草地/树叶/牧场/道路/城市/垃圾场/树木}

……

数据预处理

import pandas as pd
import numpy as np
%matplotlib inline
import numpy as np
from scipy import stats
mush_df = pd.read_csv('mushrooms.csv')
mush_df_encoded = pd.get_dummies(mush_df)

把分类特征转化为one-hot表示,这样虽然特征数量增加了几倍,但是保证每个特征对应的值为0或1。用pandas包里的get_dummies(data)可以很轻松地实现。

这样特征维度就从23变成了119,其中因变量(有毒/无毒)也被拆分为两个特征,第一列是'无毒',1表示无毒,0表示有毒;第二列只是反过来表示。因此在后续训练的时候自变量只有119-2=117个,因变量依然是一个。

# 将特征和类别标签分别赋值给 X 和 y,因变量y选择第二列(1有毒,0无毒)X_mush = mush_df_encoded.iloc[:,2:]#所有行、第三列及往后
y_mush = mush_df_encoded.iloc[:,1]#所有行、第二列(第二列是‘有毒’,1表示有毒,0表示无毒)

训练SVM

from sklearn.svm import SVC
from sklearn.decomposition import PCA
from sklearn.pipeline import make_pipeline
pca = PCA(n_components=117, whiten=True, random_state=42)

n_components可将数据压缩为n维向量,发现在这个案例中n最多取117(自变量的维度数量)。这个降维功能多用于上万维度的向量,可以压缩为千维简便计算。但是如果在这个案例中设置n=2,压缩为2维,准确率下降较大。

考虑到这个数据维度也不是很大,完全可以不降维。因此n_components=117。

svc = SVC(kernel='linear', class_weight='balanced')

kernel设置为线性可分(而非核函数,核函数是一个类别的数据被另一个类别包围,而蘑菇数据一般不会这样)

建立模型:

model = make_pipeline(pca, svc)

划分训练数据和测试数据。随机数种子用来确保每一次建立的训练集/测试集数据固定:

from sklearn.model_selection import train_test_split
Xtrain, Xtest, ytrain, ytest = train_test_split(X_mush, y_mush, random_state=41)

调参:通过交叉验证寻找最佳的松弛参数C

from sklearn.model_selection import GridSearchCV
param_grid = {'svc__C':[1,2,5,10]}#设置的C可能的值是1,2,5,10,可以自由设置
grid = GridSearchCV(model,param_grid)
%time grid.fit(Xtrain, ytrain)
print(grid.best_params_)

print结果是1,也就是说C=1的时候训练效果最好,分割最高效

model=grid.best_estimator_

因此把model参数设置为C=1

训练好的svm储存在model里,那么这个训练好的模型到底好不好用呢?

现在把它调用来做预测:

yfit=model.predict(Xtest)

预测过程就相当于遮住测试集蘑菇的真实情况(有毒/无毒),只看117个自变量特征,预测这个蘑菇是不是有毒,然后和真实情况做对比。

预测结果怎样?生成报告:

from sklearn.metrics import classification_report
print(classification_report(ytest, yfit))

结果是100%正确!0代表无毒蘑菇,1047个全都预测出来了;1代表有毒蘑菇,984个也全都预测出来了。而且机器没有把一个有毒蘑菇的当成无毒的。可以放心地食用啦!

上面是直接调用Scikit库进行计算的。对支持向量机底层优化、算法计算原理感兴趣的同学可以参考以下代码:

'''
支持向量机,代码参考了《机器学习实战》
即寻找超平面:w*x+b=0,使得二分类有最大间隔
利用SMO算法,快速求解对偶问题,得到最优解w,b
''''''
在ε精度范围内违反KKT条件的点:
1、alpha[i]=0   ⇔ yiEi<-ε
2、0<alpha[i]<c ⇔ |yiEi|>ε
3、alpha[i]=c   ⇔ yiEi>ε
alpha[i]=0或c时,称为边界点;0<alpha[i]<c,称为非边界点(nonbound)
'''import numpy as npimport scipy.io as sciofrom matplotlib import pyplot as pltdef loadDataSet():#载入数据集dataFile='nonlinearData.mat'#读取mat数据集data=scio.loadmat(dataFile)dataset=data['nonlinear'].T#N*3的格式,N为样本个数,平面坐标X1,X2,标签Ydataset=dataset[2300:2500,:]#选200个样本进行训练positive=np.array([[0,0,0]])#+1的样本集negative=np.array([[0,0,0]])#-1的样本集for i in range(dataset.shape[0]):if(dataset[i][2]==1):positive=np.row_stack((positive,np.array([dataset[i]])))else:negative=np.row_stack((negative,np.array([dataset[i]])))return positive[1:,:],negative[1:,:],dataset'''
def kernel(xi,xj):#核函数(xi,xj为向量)return xi.dot(xj.T)#内积(数据集线性可分或近似线性可分)
'''sigma=10.0def kernel(xi,xj):#高斯核函数(数据集线性不可分)M=xi.shape[0]K=np.zeros((M,1))for l in range(M):A=np.array([xi[l]])-xjK[l]=[np.exp(-0.5*float(A.dot(A.T))/(sigma**2))]return Kdef findNonBound(alpha,C):#寻找非边界点nonbound=[]for i in range(len(alpha)):if(0<alpha[i] and alpha[i]<C):nonbound.append(i)return nonbounddef selectJrand(i,N):#随机选择jj=iwhile(j==i):j=int(np.random.uniform(0,N))#左闭右开,类型转化后0<=j<=N-1return jclass SVM(object):def __init__(self,X,Y,C,epsilon):self.X=X#数据集N*D(D维特征)self.Y=Y#标签(1,-1)self.N=X.shape[0]#数据集大小self.C=C#惩罚系数self.epsilon=epsilon#精度self.alpha=np.zeros((self.N,1))#拉格朗日乘子N*1self.b=0#位移self.E=np.zeros((self.N,2))#误差缓存表N*2,第一列为更新状态(0-未更新,1-已更新),第二列为缓存值def computeEk(self,k):#计算缓存项Ekxk=np.array([self.X[k]])y=np.array([self.Y]).Tgxk=float(self.alpha.T.dot(y*kernel(self.X,xk)))+self.bEk=gxk-self.Y[k]return Ekdef updateEk(self,k):#更新缓存项Ek包括计算Ek和设置对应的更新状态为1Ek=self.computeEk(k)self.E[k]=[1,Ek]def selectJ(self,i,Ei):#内循环,根据i选择jself.E[i]=[1,Ei]#更新EivalidE=np.nonzero(self.E[:,0])[0]#validE保存更新状态为1的缓存项的行指标if(len(validE)>1):j=0maxDelta=0Ej=0for k in validE:#寻找最大的|Ei-Ej|if(k==i):   continueEk=self.computeEk(k)if(abs(Ei-Ek)>maxDelta):j=kmaxDelta=abs(Ei-Ek)Ej=Ekelse:#随机选择j=selectJrand(i,self.N)Ej=self.computeEk(j)return j,Ejdef inner(self,i):Ei=self.computeEk(i)if((self.Y[i]*Ei>self.epsilon and float(self.alpha[i])>0) or\(self.Y[i]*Ei<-self.epsilon and float(self.alpha[i])<self.C)):#alpha[i]违反了KKT条件j,Ej=self.selectJ(i,Ei)#选择对应的alpha[j]alphaiold=float(self.alpha[i])alphajold=float(self.alpha[j])if(self.Y[i]!=self.Y[j]):L=max(0,alphajold-alphaiold)H=min(self.C,self.C+alphajold-alphaiold)else:L=max(0,alphajold+alphaiold-self.C)H=min(self.C,alphajold+alphaiold)if(L==H): return 0xi=np.array([self.X[i]])xj=np.array([self.X[j]])eta=float(kernel(xi,xi)+kernel(xj,xj)-2*kernel(xi,xj))if(eta<=0): return 0alphajnewunc=alphajold+self.Y[j]*(Ei-Ej)/eta#未剪辑的alphajnew#更新alphajif(alphajnewunc>H): self.alpha[j]=[H]elif(alphajnewunc<L): self.alpha[j]=[L]else: self.alpha[j]=[alphajnewunc]#更新Ejself.updateEk(j)if(abs(float(self.alpha[j])-alphajold)<0.00001): return 0#更新alphaiself.alpha[i]=[alphaiold+Y[i]*Y[j]*(alphajold-float(self.alpha[j]))]#更新Eiself.updateEk(i)#更新bbi=-Ei-self.Y[i]*float(kernel(xi,xi))*(float(self.alpha[i])-alphaiold)-\self.Y[j]*float(kernel(xj,xi))*(float(self.alpha[j])-alphajold)+self.bbj=-Ej-self.Y[i]*float(kernel(xi,xj))*(float(self.alpha[i])-alphaiold)-\self.Y[j]*float(kernel(xj,xj))*(float(self.alpha[j])-alphajold)+self.bif(0<float(self.alpha[i]) and float(self.alpha[i])<self.C): self.b=bielif(0<float(self.alpha[j]) and float(self.alpha[j])<self.C): self.b=bjelse: self.b=0.5*(bi+bj)return 1else: return 0def visualize(self,positive,negative):plt.xlabel('X1')#横坐标plt.ylabel('X2')#纵坐标plt.scatter(positive[:,0],positive[:,1],c = 'r',marker = 'o')#+1样本红色标出plt.scatter(negative[:,0],negative[:,1],c = 'g',marker = 'o')#-1样本绿色标出nonZeroAlpha=self.alpha[np.nonzero(self.alpha)]#非零的alphasupportVector=X[np.nonzero(self.alpha)[0]]#支持向量y=np.array([self.Y]).T[np.nonzero(self.alpha)]#支持向量对应的标签plt.scatter(supportVector[:,0],supportVector[:,1],s=100,c='y',alpha=0.5,marker='o')#标出支持向量print("支持向量个数:",len(nonZeroAlpha))X1=np.arange(-50.0,50.0,0.1)X2=np.arange(-50.0,50.0,0.1)x1,x2=np.meshgrid(X1,X2)g=self.bfor i in range(len(nonZeroAlpha)):#g+=nonZeroAlpha[i]*y[i]*(x1*supportVector[i][0]+x2*supportVector[i][1])g+=nonZeroAlpha[i]*y[i]*np.exp(-0.5*((x1-supportVector[i][0])**2+(x2-supportVector[i][1])**2)/(sigma**2))plt.contour(x1,x2,g,0,colors='b')#画出超平面plt.title("sigma: %f" %sigma)plt.show()def SMO(X,Y,C,epsilon,maxIters):#SMO的主程序SVMClassifier=SVM(X,Y,C,epsilon)iters=0iterEntire=True#由于alpha被初始化为零向量,所以先遍历整个样本集while(iters<maxIters):#循环在整个样本集与非边界点集上切换,达到最大循环次数时退出iters+=1if(iterEntire):#循环遍历整个样本集alphaPairChanges=0for i in range(SVMClassifier.N):#外层循环alphaPairChanges+=SVMClassifier.inner(i)if(alphaPairChanges==0):    break#整个样本集上无alpha对变化时退出循环else:   iterEntire=False#有alpha对变化时遍历非边界点集else:#循环遍历非边界点集alphaPairChanges=0nonbound=findNonBound(SVMClassifier.alpha,SVMClassifier.C)#非边界点集for i in nonbound:#外层循环alphaPairChanges+=SVMClassifier.inner(i)if(alphaPairChanges==0):iterEntire=True#非边界点全满足KKT条件,则循环遍历整个样本集return SVMClassifierif __name__ == "__main__":positive,negative,dataset=loadDataSet()#返回+1与-1的样本集,总训练集X=dataset[:,0:2]#X1,X2Y=dataset[:,2]#YSVMClassifier=SMO(X,Y,1,0.001,40)SVMClassifier.visualize(positive,negative)

「15」支持向量机Python实战篇——蓝瘦香菇到底有没有毒?相关推荐

  1. python决策树可视化_「决策树」| Part3—Python实现之可视化

    文章首发于微信公众号:AlgorithmDeveloper,专注机器学习与Python,编程与算法,还有生活. 1.前言 「决策树」| Part2-Python实现之构建决策树中我们已经可以基于给定数 ...

  2. python数学函数_「分享」关于Python整理的常用数学函数整理

    原标题:「分享」关于Python整理的常用数学函数整理 1.函数说明 abs(number)返回数字的绝对值,如abs(-10)返回10 pow(x,y[,z]) 返回x的y次幂(所得结果对z取模), ...

  3. python实战篇(五)---百度api实现车型识别

    十二年来,有笑泪,有阴晴,相伴一场,人来人往,只是日常.--蔡康永 前言 api全称为应用程序接口,说白了就是别人写好了一个可实现功能的函数接口,我们可以直接调用来实现功能,今天,我们一起来学习,用百 ...

  4. python实战篇(六)---打造自己的签名软件

    为什么上帝看到思考的人会笑?那是因为人在思考,却又抓不住真理.因为人越思考,一个人的思想就越跟另一个人的思想相隔万里. --米兰·昆德拉 Python实战篇重在实战,今天,我们就来设计一款自己的签名软 ...

  5. 基于「ClamAv」通过python进行病毒检测(2)-- pyClamd控制clamd详解

    介绍pyClamd模块一般用法和常用方法等. 我们可以使用python来控制clamd,从而操控ClamAv,需要引入第三方模块:pyClamd. 使用pyClamd控制clamd之前,必须先正确安装 ...

  6. 数据中心何时能摆脱夜夜割,蓝瘦香菇

    在数据中心建成投产之后,数据中心将经历一段漫长的运维周期,也许几年,也许十几年.众所周知,电子设备的使用寿命一般为三年,高精尖的设备寿命可超过五年,远远低于数据中心的生命周期.这样,在数据中心的生命周 ...

  7. 找工作找到蓝瘦香菇?看牵引力学员如何轻松过万

    不少刚毕业应届生踌躇满志的走出校园,打算在职场上大干一番,没想到在找工作的过程中碰了一鼻子灰. 蓝瘦香菇,跌跌撞撞,感觉面试真的让自己的缺点暴露无遗.无限放大.觉得自己一无是处,然而有这样一位应届毕业 ...

  8. 对于技术合伙人来说,这才叫蓝瘦香菇!

    创业者:我是做互联网+餐饮的,你们做一个简单的移动端收银系统大概需要多久? 程序猿:······ 创业者:就显示一个菜单,顾客点击购买就出现一个支付二维码就可以了. 程序猿:-- 创业者:有那么难吗? ...

  9. 「11」Python实战篇:利用KNN进行电影分类

    上一期文章:「10」民主投票法--KNN的秘密 中,我们剖析了KNN算法的本质和特点.局限.这里我们用python代码进行KNN的实现.第1部分是KNN的基础算法步骤,第2部分是一个电影分类的实战项目 ...

最新文章

  1. 有关包络Spectral Envelope的疑问
  2. Design Pattern - Iterator(C#)
  3. 删除单链上数据域值最小的节点_深入浅出数据结构
  4. 2、MySQL创建存储过程(CREATE PROCEDURE)(函数)
  5. 全球及中国服务器电源行业市场深度策略分析及投资规划咨询建议报告2022-2028年版
  6. java 默认参数 实例化_如何使用Kotlin中的默认构造函数参数值实例化对象?
  7. LeetCode77:Combinations
  8. 云计算学习笔记005---Hadoop HDFS和MapReduce 架构浅析
  9. 计算机系统-电路设计07-上升沿D触发器的内部电路实现/移位寄存器/串行接口/并行接口
  10. JAVA的日期与毫秒的相互转换
  11. Java 动态代理机制详解
  12. asp.net获取服务器信息
  13. SQL正则表达式、列表运算、涉及null的查询
  14. SSIM PSNR db
  15. dcp1608 linux驱动下载,兄弟激光 DCP-1608驱动
  16. 史上最全的开源库整理
  17. 211大学计算机找工作,华为最青睐的5所大学,每年招聘大量毕业生,第一所只是211院校...
  18. h5在线制作平台h5案例分享
  19. asp毕业设计——基于vb+VB.NET+SQL Server的图书馆管理信息系统设计与实现(毕业论文+程序源码)——图书馆管理信息系统
  20. 基于图像变换的最小二乘法及其应用(新生研讨课)

热门文章

  1. MySQL-如何定位慢查询SQL以及优化
  2. 计算机中丢失gdca,GDCA邮件证书使用指南
  3. java中指数函数的使用方法图解,基本初等函数 指数函数 代码篇
  4. 怎么发展到了Word2vec?
  5. 便携式局部放电检测仪-介绍-厂家
  6. 简述物联网应用中的短距离无线通信
  7. win10的.sql文件怎么取消默认sql server打开方式
  8. 前端国际化之react中英文切换
  9. [省选联考 2020 B 卷] 卡牌游戏 题解c++
  10. vue路由跳转刷新页面