SMO序列最小优化算法

import numpy as np
import math
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_splitclass SVM:def __init__(self,kernal='GKF',C=1):self.keranl=kernalself.b = 0self.X = Noneself.Y = Noneself.a = Noneself.N = Noneself.C = Cself.feature_num = Noneself.Elist = Noneself.K = Nonedef comput_kernal(self,x,y,sita=0.1):x = np.expand_dims(np.array(x),axis=0)y = np.expand_dims(np.array(y),axis=0)if self.keranl == 'GKF':return math.exp(-(x-y).dot((x-y).T)/2*sita**2)elif self.keranl == 'liner':return x.dot(y.T)[0][0]def compute_K(self):self.K = np.zeros((self.N,self.N))for i in range(self.N):for j in range(self.N):self.K[i][j] = self.comput_kernal(self.X[i],self.X[j])def conput_gx(self,index_x):gx = 0for j in range(self.N):gx += self.K[index_x][j] * self.Y[j] *self.a[j]gx += self.breturn gxdef compute_E_list(self):self.Elist = [self.conput_gx(i) - self.Y[i] for i in range(self.N)]def get_a1_index(self):for i in range(self.N):if self.a[i] < self.C and self.a[i] > 0:gi = self.conput_gx(i)if self.Y[i]*gi != 1:return ifor i in range(self.N):if self.a[i] == 0:if self.conput_gx(i)*self.Y[i] < 1:return ielif self.a[i] == self.C:if self.conput_gx(i)*self.Y[i] > 1:return ielse:return -1def get_a2_index(self,a1_index):E1_E2 = [abs(self.Elist[a1_index] - self.Elist[i]) for i in range(self.N)]a2_index = E1_E2.index(max(E1_E2))return a2_indexdef fit(self,X,Y,max_iter=20):self.N = len(X)self.X = Xself.Y = Yself.feature_num = len(X[0])self.a = np.zeros(self.N) + 0.5self.compute_K()for iter in range(max_iter):print('epoch = ' + str(iter))#更新Eself.compute_E_list()#更新a1 a2a1_index = self.get_a1_index()if a1_index == -1:print('all_is_fit_KTT')breaka2_index = self.get_a2_index(a1_index)a1_old = self.a[a1_index]a2_old = self.a[a2_index]L = max(0, a2_old + a1_old-self.C)H = min(self.C, a2_old + a1_old)n = self.K[a1_index][a1_index] + self.K[a2_index][a2_index] - 2*self.K[a1_index][a2_index]a2_new_unc = a2_old + self.Y[a2_index]*(self.Elist[a1_index]-self.Elist[a2_index])/nif a2_new_unc > H:a2_new = Helif a2_new_unc >= L and a2_new_unc <= H:a2_new = a2_new_uncelif a2_new_unc < L:a2_new = La1_new = a1_old + self.Y[a1_index]*self.Y[a2_index]*(a2_old - a2_new)self.a[a1_index] = a1_newself.a[a2_index] = a2_new#更新bb1_new = - self.Elist[a1_index] - self.Y[a1_index]*self.K[a1_index][a1_index]*(a1_new-a1_old) \- self.Y[a2_index]*self.K[a2_index][a1_index]*(a2_new-a2_old) + self.bb2_new = - self.Elist[a2_index] - self.Y[a1_index]*self.K[a1_index][a2_index]*(a1_new-a1_old) \- self.Y[a2_index]*self.K[a2_index][a2_index]*(a2_new-a2_old) + self.bif 0 < a1_new < self.C and 0 < a2_new < self.C:self.b = (b1_new + b2_new) / 2elif 0 < a1_new < self.C:self.b = b1_newelif 0 < a2_new < self.C:self.b = b2_newdef predict_single(self,x):result_1 = [self.a[i]*self.Y[i]*self.comput_kernal(x,self.X[i]) for i in range(self.N)]return np.sign(sum(result_1) + self.b)def predict(self,X):return [self.predict_single(x) for x in X]def main():# X = [[1,2],#      [2,3],#      [3,3],#      [2,1],#      [3,2]]# Y = [1,1,1,-1,-1]X = []Y = []with open('./iris.data', 'r') as f:for i in f:data = i.split(',')X.append([float(j) for j in data[:4]])Y.append(data[4])Y = [1 if i == 'Iris-setosa\n' else -1 for i in Y]train_X, test_X, train_y, test_y = train_test_split(X,Y,test_size=0.2,random_state=9999)svm_trainer = SVM(C=30,kernal='GKF')svm_trainer.fit(train_X,train_y,max_iter=10)result = svm_trainer.predict(test_X)print(result)print(accuracy_score(test_y,result))if __name__ == '__main__':main()#####result##############
/usr/bin/python3 /Users/zhengyanzhao/PycharmProjects/tongjixuexi/SVM_SMO.py
epoch = 0
epoch = 1
epoch = 2
epoch = 3
epoch = 4
epoch = 5
epoch = 6
epoch = 7
epoch = 8
epoch = 9
[1.0, 1.0, 1.0, -1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
0.8
Process finished with exit code 0

统计学习方法第七章作业:SVM非线性支持向量机之SMO序列最小优化算法代码实现相关推荐

  1. 支持向量机SVM序列最小优化算法SMO

    支持向量机(Support Vector Machine)由V.N. Vapnik,A.Y. Chervonenkis,C. Cortes 等在1964年提出.序列最小优化算法(Sequential ...

  2. 统计学习方法第十一章作业:随机条件场—概率计算问题、IIS/GD学习算法、维特比预测算法 代码实现

    随机条件场-概率计算问题.IIS/GD学习算法.维特比预测算法 这一章的算法不是很好写,整整研究了好几天,代码还是有点小问题,仅供参考. 用的是书上定义的特征函数. import numpy as n ...

  3. 统计学习方法第二十一章作业:PageRank迭代算法、幂法、代数算法 代码实现

    PageRank迭代算法.幂法.代数算法 import numpy as npclass PageRank:def __init__(self,M,D=0.85):self.M = np.array( ...

  4. 统计学习方法第十七章作业:LSA潜在语义分析算法 代码实现

    LSA潜在语义分析算法 import numpy as np import jieba import collectionsclass LSA:def __init__(self,text_list) ...

  5. 统计学习方法第六章作业:逻辑斯谛梯度下降法、最大熵模型 IIS / DFP 算法代码实现

    逻辑斯谛梯度下降法 import numpy as np import matplotlib.pyplot as pltclass logist:def __init__(self,a=1,c=Non ...

  6. 统计学习方法第五章作业:ID3/C4.5算法分类决策树、平方误差二叉回归树代码实现

    ID3/C4.5算法分类决策树 import numpy as np import math class Node:def __init__(self,feature_index=None,value ...

  7. 统计学习方法第三章作业:一般k邻近、平衡kd树构造、kd树邻近搜索算法代码实现

    一般k邻近 import numpy as np import matplotlib.pyplot as pltclass K_near:def __init__(self,X,Y,K=5,p=2): ...

  8. 速学堂第七章作业编程题答案(自写)

    速学堂第七章作业编程题答案 1. 数组查找操作:定义一个长度为10 的一维字符串数组,在每一个元素存放一个单词;然后运行时从命令行输入一个单词,程序判断数组是否包含有这个单词,包含这个单词就打印出&q ...

  9. 李航《统计学习方法》第二章课后答案链接

    李航<统计学习方法>第二章课后答案链接 李航 统计学习方法 第二章 课后 习题 答案 http://blog.csdn.net/cracker180/article/details/787 ...

最新文章

  1. 通信专业学python有用吗-通信算法工程师需要学python吗
  2. NIO详解(十二):AsynchronousFileChannel详解
  3. ui设计中的版式设计_设计中的版式-第3部分
  4. devops java使用_谁会在使用DevOps时最大程度地退缩?
  5. javascript 求解图表曲线波峰与波谷,类似股票曲线
  6. kafka配置项host.name advertised.host.name
  7. Magento网店自定义模板初探(1)——文件夹结构
  8. VC6编译64位程序
  9. 宽带猫、路由器、交换机的作用与区别是什么?
  10. The Fifty-eighth Of Word-Day
  11. 真肝,整理了一周的Spring面试大全【含答案】,吊打Java面试官
  12. Matlab函数、子函数的定义方法
  13. W5500的以太网电路,正常线序连接的话可能必须做过孔交叉线序,能否在线路上做交叉处理?
  14. mysql初始数据库出错_安装MySQL提示initialize database(初始化数据库)错误解决方法...
  15. 蚂蚁金服彭翼捷:金融科技不止用来改良 更要用来改变
  16. 【Android技巧】通过am完成发送开机广播等操作
  17. “阿里外传”之一:阿里巴巴有只宠物,叫雅虎
  18. 用于穿戴脑机接口的脑电EEG传感芯片KS1092
  19. 手机号,身份证,银行卡号数据脱敏
  20. FS_I10X接收机通道说明

热门文章

  1. 马云:我不懂技术但欣赏技术 达摩院必须超越微软 - 20171011
  2. Linux-Rsync命令参数详解
  3. Asp.net之MsChart控件动态绑定温度曲线图
  4. php isset缺陷 用array_key_exists
  5. PHP起点 - PHP常量
  6. listview winfrom 表头_winform ListView点击行表头,排序
  7. Java实现二分查找及其优化
  8. 北斗导航 | Matlab实现电离层延迟计算:Klobuchar(源代码)
  9. linux一直用户身份验证失败,linux – chsh:PAM身份验证失败
  10. oracle 10g进入ascmd,oracle 10g 默许用户名密码及解锁