复现经典:《统计学习方法》第 3 章 k 近邻法
本文是李航老师的《统计学习方法》[1]一书的代码复现。
作者:黄海广[2]
备注:代码都可以在github[3]中下载。
我将陆续将代码发布在公众号“机器学习初学者”,敬请关注。
代码目录
第 1 章 统计学习方法概论
第 2 章 感知机
第 3 章 k 近邻法
第 4 章 朴素贝叶斯
第 5 章 决策树
第 6 章 逻辑斯谛回归
第 7 章 支持向量机
第 8 章 提升方法
第 9 章 EM 算法及其推广
第 10 章 隐马尔可夫模型
第 11 章 条件随机场
第 12 章 监督学习方法总结
代码参考:wzyonggege[4],WenDesi[5],火烫火烫的[6]
第 3 章 k 近邻法
1.
近邻法是基本且简单的分类与回归方法。近邻法的基本做法是:对给定的训练实例点和输入实例点,首先确定输入实例点的个最近邻训练实例点,然后利用这个训练实例点的类的多数来预测输入实例点的类。
2.
近邻模型对应于基于训练数据集对特征空间的一个划分。近邻法中,当训练集、距离度量、值及分类决策规则确定后,其结果唯一确定。
3.
近邻法三要素:距离度量、值的选择和分类决策规则。常用的距离度量是欧氏距离及更一般的pL距离。值小时,近邻模型更复杂;值大时,近邻模型更简单。值的选择反映了对近似误差与估计误差之间的权衡,通常由交叉验证选择最优的。
常用的分类决策规则是多数表决,对应于经验风险最小化。
4.
近邻法的实现需要考虑如何快速搜索 k 个最近邻点。kd树是一种便于对 k 维空间中的数据进行快速检索的数据结构。kd 树是二叉树,表示对维空间的一个划分,其每个结点对应于维空间划分中的一个超矩形区域。利用kd树可以省去对大部分数据点的搜索, 从而减少搜索的计算量。
距离度量
设特征空间
是维实数向量空间 ,,, ,则:,的距离定义为:
-
曼哈顿距离 -
欧氏距离 -
切比雪夫距离
import math
from itertools import combinations
def L(x, y, p=2):# x1 = [1, 1], x2 = [5,1]if len(x) == len(y) and len(x) > 1:sum = 0for i in range(len(x)):sum += math.pow(abs(x[i] - y[i]), p)return math.pow(sum, 1 / p)else:return 0
课本例 3.1
x1 = [1, 1]
x2 = [5, 1]
x3 = [4, 4]
# x1, x2
for i in range(1, 5):r = {'1-{}'.format(c): L(x1, c, p=i) for c in [x2, x3]}print(min(zip(r.values(), r.keys())))
(4.0, '1-[5, 1]')
(4.0, '1-[5, 1]')
(3.7797631496846193, '1-[4, 4]')
(3.5676213450081633, '1-[4, 4]')
python 实现,遍历所有数据点,找出
个距离最近的点的分类情况,少数服从多数
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inlinefrom sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from collections import Counter
# data
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
# data = np.array(df.iloc[:100, [0, 1, -1]])
df.head(10)
sepal length | sepal width | petal length | petal width | label | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
5 | 5.4 | 3.9 | 1.7 | 0.4 | 0 |
6 | 4.6 | 3.4 | 1.4 | 0.3 | 0 |
7 | 5.0 | 3.4 | 1.5 | 0.2 | 0 |
8 | 4.4 | 2.9 | 1.4 | 0.2 | 0 |
9 | 4.9 | 3.1 | 1.5 | 0.1 | 0 |
plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()
data = np.array(df.iloc[:100, [0, 1, -1]])
X, y = data[:,:-1], data[:,-1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
class KNN:def __init__(self, X_train, y_train, n_neighbors=3, p=2):"""parameter: n_neighbors 临近点个数parameter: p 距离度量"""self.n = n_neighborsself.p = pself.X_train = X_trainself.y_train = y_traindef predict(self, X):# 取出n个点knn_list = []for i in range(self.n):dist = np.linalg.norm(X - self.X_train[i], ord=self.p)knn_list.append((dist, self.y_train[i]))for i in range(self.n, len(self.X_train)):max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))dist = np.linalg.norm(X - self.X_train[i], ord=self.p)if knn_list[max_index][0] > dist:knn_list[max_index] = (dist, self.y_train[i])# 统计knn = [k[-1] for k in knn_list]count_pairs = Counter(knn)
# max_count = sorted(count_pairs, key=lambda x: x)[-1]max_count = sorted(count_pairs.items(), key=lambda x: x[1])[-1][0]return max_countdef score(self, X_test, y_test):right_count = 0n = 10for X, y in zip(X_test, y_test):label = self.predict(X)if label == y:right_count += 1return right_count / len(X_test)
clf = KNN(X_train, y_train)
clf.score(X_test, y_test)
1.0
test_point = [6.0, 3.0]
print('Test Point: {}'.format(clf.predict(test_point)))
Test Point: 1.0
plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')
plt.plot(test_point[0], test_point[1], 'bo', label='test_point')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()
scikit-learn 实例
from sklearn.neighbors import KNeighborsClassifier
clf_sk = KNeighborsClassifier()
clf_sk.fit(X_train, y_train)
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',metric_params=None, n_jobs=None, n_neighbors=5, p=2,weights='uniform')
clf_sk.score(X_test, y_test)
1.0
sklearn.neighbors.KNeighborsClassifier
n_neighbors: 临近点个数
p: 距离度量
algorithm: 近邻算法,可选{'auto', 'ball_tree', 'kd_tree', 'brute'}
weights: 确定近邻的权重
kd 树
kd树是一种对 k 维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。
kd树是二叉树,表示对
维空间的一个划分(partition)。构造kd树相当于不断地用垂直于坐标轴的超平面将维空间切分,构成一系列的 k 维超矩形区域。kd 树的每个结点对应于一个维超矩形区域。
构造kd树的方法如下:
构造根结点,使根结点对应于
维空间中包含所有实例点的超矩形区域;通过下面的递归方法,不断地对维空间进行切分,生成子结点。在超矩形区域(结点)上选择一个坐标轴和在此坐标轴上的一个切分点,确定一个超平面,这个超平面通过选定的切分点并垂直于选定的坐标轴,将当前超矩形区域切分为左右两个子区域 (子结点);这时,实例被分到两个子区域。这个过程直到子区域内没有实例时终止(终止时的结点为叶结点)。在此过程中,将实例保存在相应的结点上。
通常,依次选择坐标轴对空间切分,选择训练实例点在选定坐标轴上的中位数 (median)为切分点,这样得到的kd树是平衡的。注意,平衡的kd树搜索时的效率未必是最优的。
构造平衡 kd 树算法
输入:
维空间数据集,
其中 ,
=
;
输出:kd树。
(1)开始:构造根结点,根结点对应于包含
的维空间的超矩形区域。
选择
为坐标轴,以 T 中所有实例的坐标的中位数为切分点,将根结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴垂直的超平面实现。
由根结点生成深度为 1 的左、右子结点:左子结点对应坐标
小于切分点的子区域, 右子结点对应于坐标大于切分点的子区域。
将落在切分超平面上的实例点保存在根结点。
(2)重复:对深度为
的结点,选择为切分的坐标轴,,以该结点的区域中所有实例的坐标的中位数为切分点,将该结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴垂直的超平面实现。
由该结点生成深度为
的左、右子结点:左子结点对应坐标小于切分点的子区域,右子结点对应坐标大于切分点的子区域。
将落在切分超平面上的实例点保存在该结点。
(3)直到两个子区域没有实例存在时停止。从而形成kd树的区域划分。
# kd-tree每个结点中主要包含的数据结构如下
class KdNode(object):def __init__(self, dom_elt, split, left, right):self.dom_elt = dom_elt # k维向量节点(k维空间中的一个样本点)self.split = split # 整数(进行分割维度的序号)self.left = left # 该结点分割超平面左子空间构成的kd-treeself.right = right # 该结点分割超平面右子空间构成的kd-treeclass KdTree(object):def __init__(self, data):k = len(data[0]) # 数据维度def CreateNode(split, data_set): # 按第split维划分数据集exset创建KdNodeif not data_set: # 数据集为空return None# key参数的值为一个函数,此函数只有一个参数且返回一个值用来进行比较# operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为需要获取的数据在对象中的序号#data_set.sort(key=itemgetter(split)) # 按要进行分割的那一维数据排序data_set.sort(key=lambda x: x[split])split_pos = len(data_set) // 2 # //为Python中的整数除法median = data_set[split_pos] # 中位数分割点split_next = (split + 1) % k # cycle coordinates# 递归的创建kd树return KdNode(median,split,CreateNode(split_next, data_set[:split_pos]), # 创建左子树CreateNode(split_next, data_set[split_pos + 1:])) # 创建右子树self.root = CreateNode(0, data) # 从第0维分量开始构建kd树,返回根节点# KDTree的前序遍历
def preorder(root):print(root.dom_elt)if root.left: # 节点不为空preorder(root.left)if root.right:preorder(root.right)
# 对构建好的kd树进行搜索,寻找与目标点最近的样本点:
from math import sqrt
from collections import namedtuple# 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数
result = namedtuple("Result_tuple","nearest_point nearest_dist nodes_visited")def find_nearest(tree, point):k = len(point) # 数据维度def travel(kd_node, target, max_dist):if kd_node is None:return result([0] * k, float("inf"),0) # python中用float("inf")和float("-inf")表示正负无穷nodes_visited = 1s = kd_node.split # 进行分割的维度pivot = kd_node.dom_elt # 进行分割的“轴”if target[s] <= pivot[s]: # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)nearer_node = kd_node.left # 下一个访问节点为左子树根节点further_node = kd_node.right # 同时记录下右子树else: # 目标离右子树更近nearer_node = kd_node.right # 下一个访问节点为右子树根节点further_node = kd_node.lefttemp1 = travel(nearer_node, target, max_dist) # 进行遍历找到包含目标点的区域nearest = temp1.nearest_point # 以此叶结点作为“当前最近点”dist = temp1.nearest_dist # 更新最近距离nodes_visited += temp1.nodes_visitedif dist < max_dist:max_dist = dist # 最近点将在以目标点为球心,max_dist为半径的超球体内temp_dist = abs(pivot[s] - target[s]) # 第s维上目标点与分割超平面的距离if max_dist < temp_dist: # 判断超球体是否与超平面相交return result(nearest, dist, nodes_visited) # 不相交则可以直接返回,不用继续判断#----------------------------------------------------------------------# 计算目标点与分割点的欧氏距离temp_dist = sqrt(sum((p1 - p2)**2 for p1, p2 in zip(pivot, target)))if temp_dist < dist: # 如果“更近”nearest = pivot # 更新最近点dist = temp_dist # 更新最近距离max_dist = dist # 更新超球体半径# 检查另一个子结点对应的区域是否有更近的点temp2 = travel(further_node, target, max_dist)nodes_visited += temp2.nodes_visitedif temp2.nearest_dist < dist: # 如果另一个子结点内存在更近距离nearest = temp2.nearest_point # 更新最近点dist = temp2.nearest_dist # 更新最近距离return result(nearest, dist, nodes_visited)return travel(tree.root, point, float("inf")) # 从根节点开始递归
例 3.2
data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
kd = KdTree(data)
preorder(kd.root)
[7, 2]
[5, 4]
[2, 3]
[4, 7]
[9, 6]
[8, 1]
from time import clock
from random import random# 产生一个k维随机向量,每维分量值在0~1之间
def random_point(k):return [random() for _ in range(k)]# 产生n个k维随机向量
def random_points(k, n):return [random_point(k) for _ in range(n)]
ret = find_nearest(kd, [3,4.5])
print (ret)
Result_tuple(nearest_point=[2, 3], nearest_dist=1.8027756377319946, nodes_visited=4)
N = 400000
t0 = clock()
kd2 = KdTree(random_points(3, N)) # 构建包含四十万个3维空间样本点的kd树
ret2 = find_nearest(kd2, [0.1,0.5,0.8]) # 四十万个样本点中寻找离目标最近的点
t1 = clock()
print ("time: ",t1-t0, "s")
print (ret2)
time: 5.204035100000002 s
Result_tuple(nearest_point=[0.09308431086306368, 0.5071110780404813, 0.7998624450062822], nearest_dist=0.009920338124925524, nodes_visited=88)
参考资料
[1] 《统计学习方法》: https://baike.baidu.com/item/统计学习方法/10430179
[2] 黄海广: https://github.com/fengdu78
[3] github: https://github.com/fengdu78/lihang-code
[4] wzyonggege: https://github.com/wzyonggege/statistical-learning-method
[5] WenDesi: https://github.com/WenDesi/lihang_book_algorithm
[6] 火烫火烫的: https://blog.csdn.net/tudaodiaozhale
关于本站
“机器学习初学者”公众号由是黄海广博士创建,黄博个人知乎粉丝23000+,github排名全球前100名(33000+)。本公众号致力于人工智能方向的科普性文章,为初学者提供学习路线和基础资料。原创作品有:吴恩达机器学习个人笔记、吴恩达深度学习笔记等。
往期精彩回顾
那些年做的学术公益-你不是一个人在战斗
适合初学者入门人工智能的路线及资料下载
吴恩达机器学习课程笔记及资源(github标星12000+,提供百度云镜像)
吴恩达深度学习笔记及视频等资源(github标星8500+,提供百度云镜像)
《统计学习方法》的python代码实现(github标星7200+)
机器学习的数学精华(在线阅读版)
备注:加入本站微信群或者qq群,请回复“加群”
复现经典:《统计学习方法》第 3 章 k 近邻法相关推荐
- 统计学习方法第三章 k近邻法
文章目录 第三章 k近邻法 k近邻算法 k近邻模型的距离划分 k值的选择 k近邻分类决策规则 第三章 k近邻法 只讨论分类问题的k近邻法 k近邻三个基本要素: k值选择 距离度量 分类决策规则 k近邻 ...
- 《统计学习方法》读书笔记——K近邻法(原理+代码实现)
传送门 <统计学习方法>读书笔记--机器学习常用评价指标 <统计学习方法>读书笔记--感知机(原理+代码实现) <统计学习方法>读书笔记--K近邻法(原理+代码实现 ...
- 机器学习理论《统计学习方法》学习笔记:第三章 k近邻法
机器学习理论<统计学习方法>学习笔记:第三章 k近邻法 3 k近邻法 3.1 K近邻算法 3.2 K近邻模型 3.2.1 模型 3.2.2 距离度量 3.2.3 K值的选择 3.2.4 分 ...
- 【机器学习】《统计学习方法》学习笔记 第三章 k近邻法
第三章 k k k 近邻法(KNN) 多分类模型,思路是将最近的 N N N 个邻居的分类值中的多数作为自己的分类值.没有显式的学习过程. 三个基本要素:距离度量. k k k 值选择和分类决策规则. ...
- 统计机器学习【3】- K近邻法(三)Ball Tree
在计算机科学中,球树(ball tree)是一种空间划分数据结构,用于组织在多维空间中的点.球数之所有得到此名,是因为它将数据点划分为一组嵌套的超球体.这种类型的数据结构特征使其在很多方面都有用,特别 ...
- 统计学习方法 - 第1章 - 概论
全书章节 第1章 统计学习方法概论 第2章 感知机 第3章 k近邻法 第4章 朴素贝叶斯法 第5章 决策树 第6章 逻辑斯谛回归与最大熵模型 第7章 支持向量机 第8章 提升方法 第9章 EM算法及其 ...
- 统计学习方法笔记(李航)———第三章(k近邻法)
k 近邻法 (k-NN) 是一种基于实例的学习方法,无法转化为对参数空间的搜索问题(参数最优化 问题).它的特点是对特征空间进行搜索.除了k近邻法,本章还对以下几个问题进行较深入的讨 论: 切比雪夫距 ...
- 统计学习方法——第1章(个人笔记)
统计学习方法--第1章 统计学习及监督学习概论 <统计学习方法>(第二版)李航,学习笔记 1.1 统计学习 1.特点 (1)以计算机及网络为平台,是建立在计算机及网络上的: (2)以数据为 ...
- 一篇详解带你再次重现《统计学习方法》——第二章、感知机模型
个性签名:整个建筑最重要的是地基,地基不稳,地动山摇. 而学技术更要扎稳基础,关注我,带你稳扎每一板块邻域的基础. 博客主页:七归的博客 专栏:<统计学习方法>第二版--个人笔记 创作不易 ...
最新文章
- 独家 | 如何在BigQueryML中使用K-均值聚类来更好地理解和描述数据(附代码)
- mysql删除数据后id自增不连续的解决方法
- javaweb开发3.基于Servlet+JSP+JavaBean开发模式的用户登录注册
- The difference between sleep(), wait(), and yield() in human terms.
- java的(PO,VO,TO,BO,DAO,POJO)解释
- 公钥(Public Key)与私钥(Private Key)
- kong 使用jwt RSA256证书
- 如何提高你的工作效率?
- 大数据、物联网、AI 等技术正当时!
- .html与.htm为网页后缀的区别
- 命令查询每个文件文件数
- 五款不错的Web前端开发工具,对小白来说完全够用了!
- 哪些软件可以做国外问卷调查
- PageHelper关闭count语句优化
- 利用代码快速批量取消微博的关注
- 数据结构——树和二叉树章节思维导图
- MatLab 画图方法
- python处理excel数据并对数据进行打分
- windows下hadoop的部署和使用
- Linux终端模式下查看电脑的硬件配置信息小技巧
热门文章
- 你不知道的Node.js性能优化,读了之后水平直线上升
- 160. Intersection of Two Linked Lists
- 使用条件注释完成浏览器兼容
- 小项目--bank1
- asp.net网站中CrystalReport的简单应用
- 伊利诺伊香槟分校计算机科学,伊利诺伊大学香槟分校计算机科学与工程世界排名2019年最新排名第24(ARWU世界排名)...
- Cochrane系统综述注册的具体流程
- 大型单细胞数据分析解决方案
- 医学科研如何快速掌握R语言?
- gson解析天气json_几种常用JSON解析库性能比较