K-Means works on MNIST

一、实验环境
编程语言:Python 3.8

二、一点小建议
在小数据集上,或许我们还可以使用嵌套循环解决问题,但在MNIST这样稍微有点大的数据集上,如果我们还是坚持使用嵌套循环,那么仅仅是完成一次迭代就需要花上一段时间,因此,我们需要学会使用向量化编程,甚至是二维数组(矩阵)化编程,如若读者有过手动实现BPNN的经历,或许就更加能够体会到上述两种编程方式所带来的便利之处,这也是我非常喜欢矩阵的原因,因此,下面我们即将实现的K-Means和GMM都是采用了上述两种编程方式。

三、基本原理
有时间我再回来补上

四、代码展示
K-Means

# Filename: kMeans.py
# Usage: python kMeans.pyimport numpy as np
from munkres import Munkres
import matplotlib.pyplot as plt
import random'''
np.load helps you load data from train/test-images/labels.npy
train-images.npy contains 60000 images with each image is a 28*28 matrix.
train-labels.npy contains the corresponding labels ranging from 0 to 9.
10000 test instances are included in test_images and test_labels.
'''
# Load dataset, including train_images, train_labels, test_images, test_labels
train_images = np.load("train-images.npy")
train_labels = np.load("train-labels.npy")
test_images = np.load("test-images.npy")
test_labels = np.load("test-labels.npy")
'''
print(train_images.shape)       # (60000, 784)
print(train_labels.shape)       # (60000,)
print(test_images.shape)        # (10000, 784)
print(test_labels.shape)        # (10000,)
'''def distEclud(vecA, vecB):'''计算两个向量的欧式距离'''return np.sqrt(np.sum(np.power(vecA - vecB, 2)))def randCent_1(dataSet, k):'''为给定数据集构建一个包含k个随机质心的集合随机质心必须要在整个数据集的边界之内,这可以通过找到数据集中每一维的最小和最大值来完成生成0到1.0之间的随机数并通过取值范围和最小值,以便确保随机点在数据的边界之内'''n = np.shape(dataSet)[1]centroids = np.mat(np.zeros((k, n)))print(np.shape(centroids))for j in range(n):minJ = np.min(dataSet[:, j])rangeJ = np.float(np.max(dataSet[:, j]) - minJ)centroids[:,j] = minJ + rangeJ * np.random.rand(k,1)return centroidsdef randCent_2(dataSet, k):'''Distance-based methodStart with one random data instanceChoose the point that is farthest to the existing centersIssue: may choose outliers'''n = np.shape(dataSet)[1]m = np.shape(dataSet)[0]centroids = []index = random.randint(0, m - 1)centroids.append(dataSet[index])mySet = set()mySet.add(index)for i in range(1, k):dist = []for j in range(len(np.array(centroids))):dist.append(np.sum(np.power(dataSet - centroids[j], 2), axis = 1))dist = np.array(dist)max_dist = np.sum(dist, axis = 0)max_index = np.argmax(max_dist)if max_index not in mySet:mySet.add(max_index)centroids.append(dataSet[max_index])centroids = np.mat(centroids)return centroidsdef randCent_3(dataSet, k):'''Random methodChoose the data instance randomlyIssue: may choose nearby instance'''centroids = np.array(random.sample(list(dataSet), k)) centroids = np.mat(centroids)return centroidsdef kMeans(dataSet, labels, k, iter_num, distMeas = distEclud, createCent = randCent_2):'''函数一开始确定数据集中样本点的总数,并创建一个矩阵来存储每个点的簇分配结果簇分类结果矩阵clusterAssment包含两列:一列记录簇索引值,第二列存储误差,误差是指当前点到簇质心的距离反复迭代计算质心-分配-重新计算的过程,直到所有数据点的簇分配结果不再改变为止'''epoch = []acc = []num = 0m = np.shape(dataSet)[0]clusterAssment = np.mat(np.zeros((m, 2)))centroids = createCent(dataSet, k)clusterChanged = Truewhile clusterChanged and num < iter_num:num += 1                        # 迭代次数+1clusterChanged = False          # 样本点的簇分配是否发生变化# 寻找最近的质心minDist =[]for i in range(k):minDist.append(np.sum(np.power(dataSet - np.array(centroids)[i], 2), axis = 1))minDist = np.array(minDist)min_dist = np.min(minDist, axis = 0)            # 最小距离min_index = np.argmin(minDist, axis = 0)        # 最小距离对应的簇心标号if((clusterAssment.T[0] == min_index.T).all()):clusterChanged = False                      # 样本点的簇分配没有发生变化else:clusterChanged = True                       # 样本点的簇分配发生变化clusterAssment.T[0] = min_index.T# 更新质心的位置for cent in range(k):ptsInClust = dataSet[np.nonzero(clusterAssment[:, 0].A == cent)[0]]centroids[cent, :] = np.mean(ptsInClust, axis = 0)min_index = maplabels(labels + 1, min_index + 1)    # 匈牙利算法count = np.sum(labels + 1 == min_index)               # 分类正确的样本数目accuracy = count / len(labels)                        # 准确率print("当迭代次数为", num, "时在训练数据集上的聚类精度为", accuracy)epoch.append(num)acc.append(accuracy)plt.figure(1)plt.xlabel("iteration")plt.ylabel("accuracy")plt.title("accuracy-iteration")plt.plot(epoch, acc)plt.show()return centroids, clusterAssmentdef calcAcc(dataSet, labels, centroids):# 给定数据集和真实标签以及簇心,寻找最近的质心,计算准确率# 寻找最近的质心minDist =[]k = len(centroids)for i in range(k):minDist.append(np.sum(np.power(dataSet - np.array(centroids)[i], 2), axis = 1))minDist = np.array(minDist)min_dist = np.min(minDist, axis = 0)min_index = np.argmin(minDist, axis = 0)# 计算准确率min_index = maplabels(labels + 1, min_index + 1)count = np.sum(labels + 1 == min_index)   # 分类正确的样本数目accuracy = count/len(labels)            # 准确率return accuracydef maplabels(L1, L2):L2 = L2Label1 = np.unique(L1)Label2 = np.unique(L2)nClass1 = len(Label1)nClass2 = len(Label2)nClass = np.maximum(nClass1, nClass2)G = np.zeros((nClass, nClass))for i in range(nClass1):ind_cla1 = L1 == Label1[i]ind_cla1 = ind_cla1.astype(float)for j in range(nClass2):ind_cla2 = L2 == Label2[j]ind_cla2 = ind_cla2.astype(float)G[i, j] = np.sum(ind_cla2*ind_cla1)m = Munkres()index = m.compute(-G.T)index = np.array(index)index = index+1# print(-G.T)# print(index)newL2 = np.zeros(L2.shape, dtype=int)for i in range(nClass2):for j in range(len(L2)):if L2[j] == index[i, 0]:newL2[j] = index[i, 1]return newL2def k_foldCrossValidation(dataSet, k, i):'''交叉验证法或k折交叉验证先将数据集D划分为k个大小相似的互斥子集,每次用k-1个子集的并集作为训练集,余下的那个子集作为验证集这样就可以获得k组训练集/验证集,从而可以进行k次训练和验证,最终返回的是这k个验证结果的均值:param dataSet:数据集:param k:将数据集划分为k个大小相似的互斥子集:param i:选取k个互斥子集中的一个作为验证集:return: train_set:训练集validation_set:验证集'''quotient = len(dataSet) // kremainder = len(dataSet) - quotient * kif(i < remainder):count = quotient + 1first = i * countelse:count = quotientfirst = i * count + remainderlast = first + counttrain_set = dataSet[:first] + dataSet[last:]validation_set = dataSet[first:last]return train_set, validation_setdef main():print("Running...")centroids, clusterAssment = kMeans(train_images, train_labels, 10, 40)accuracy = calcAcc(test_images, test_labels, centroids)     # 在测试数据集上的聚类精度print("在测试数据集上的聚类精度为", accuracy)if __name__ == '__main__':main()

五、实验结果

Random method Distance-based method

六、注意
值得注意的是,由于Random method是从训练数据集中随机选取k个样本点作为簇的质心,因此,有可能会出现这k个样本中的某几个距离较近的情况。在运行上面这份代码时,如果选择质心的函数是randCent_1或者randCent_3,有可能会在第一次迭代时出现Running warning的情况,在这种情况下,请按Ctrl+C中止程序运行,并在终端中输入python kMeans.py重新运行即可。

kmeans works on mnist相关推荐

  1. 信号处理深度学习机器学习_机器学习和信号处理如何融合?

    信号处理深度学习机器学习 As a design engineer, I am increasingly exposed to more complex engineering challenges ...

  2. 层次聚类算法 算法_聚类算法简介

    层次聚类算法 算法 Take a look at the image below. It's a collection of bugs and creepy-crawlies of different ...

  3. 数据挖掘之聚类分析(Cluster Analysis)

    1.Motivations(目的) Identify grouping structure of data so that objects within the same group are clos ...

  4. k-均值聚类算法_聚类算法-K-均值算法

    k-均值聚类算法 聚类算法-K-均值算法 (Clustering Algorithms - K-means Algorithm) K-Means算法简介 (Introduction to K-Mean ...

  5. MNIST | 基于k-means和KNN的0-9数字手写体识别

    MNIST | 基于k-means和KNN的0-9数字手写体识别 1 背景说明 2 算法原理 3 代码实现 3.1 文件目录 3.2 核心代码 4 实验与结果分析 5 后记 概要: 本实验是在实验&q ...

  6. 客户细分_客户细分:K-Means聚类和A / B测试

    客户细分 语境 (Context) I have been working in Advertising, specifically Digital Media and Performance, fo ...

  7. ecw2c理解元数据:使用BigQuery k-means将4,000个堆栈溢出标签聚类

    您如何将超过4,000个活动的Stack Overflow标签分组为有意义的组? 对于无监督学习和k均值聚类来说,这是一项完美的任务-现在您可以在BigQuery中完成所有这些工作. 让我们找出方法. ...

  8. K-均值聚类(K-Means) C++代码实现

    K-均值聚类(K-Means)简介可以参考: http://blog.csdn.net/fengbingchun/article/details/79276668 以下是K-Means的C++实现,c ...

  9. OpenCV3.3中K-Means聚类接口简介及使用

    OpenCV3.3中给出了K-均值聚类(K-Means)的实现,即接口cv::kmeans,接口的声明在include/opencv2/core.hpp文件中,实现在modules/core/src/ ...

最新文章

  1. “勤奋”,是能让你走出低谷最有效的方法
  2. pandas画时间序列图
  3. WEB安全:XSS漏洞与SQL注入漏洞介绍及解决方案
  4. ios底部栏设计规范_超全面的UI设计规范整理,你值得收藏!
  5. poj1066--Treasure Hunt(规范相交)
  6. lvm硬盘管理及LVM扩容
  7. bzoj 3316: JC loves Mkk(二分+单调队列)
  8. 「日常训练」Queue(Codeforces Round 303 Div.2 D)
  9. ghub无法安装_好用了还是更别扭了,简析罗技G HUB驱动程序
  10. 数据库建模逆向工程工具
  11. 删除卸载不干净的任务 vmware卸载不干净->服务清理 / 注册表清理
  12. Vue绘制业务流程图(附源码)
  13. SONY笔记本电脑SVS131100C系统重装后Fn键功能问题
  14. Flutter获取网络图片:The following SocketException was thrown resolving an image codec:
  15. vs2017下配置Xamarin
  16. Xtend官方文档——第二部分(一)
  17. S7-200SMART PLC基础知识汇总
  18. html显示hdf5文件,图片转换成HDF5文件(加载,保存)
  19. Android运行ListView的代码,Android ListView组件详解及示例代码
  20. 计算机开机f8键,开机F8键“高级启动选项”的秘密

热门文章

  1. 第十七届全国大学生智能汽车华南赛区竞赛 - 流程册(文档)
  2. 邱博士 复旦计算机,邱世藩博士简介
  3. 我的世界工业服务器发展最快,《我的世界》服务器生存及快速发展心得
  4. “性格测试”刷爆朋友圈 小心隐私泄露
  5. 探访“视障人士”智能化住所 阿里云IoT如何用物联网改变他的生活起居
  6. Pango-ERROR **: 15:31:09.214: Harfbuzz version too old (1.4.2)解决
  7. 复杂链表的复制(C语言)
  8. HDOJ(HDU) 2500 做一个正气的杭电人(水~)
  9. 基于ASP.NET的电影搜索网站设计与实现
  10. VS2019实现简易的射击坦克小游戏(easyx)