最近在做目标检测,为了合理的打标签想到了用聚类算法来对自己的数据进行分类,这样可以避免同样的标签打的太多,而有的标签又打的太少,浪费时间和精力。网上查了一下,都是注重讲解算法本身,不才来说一下我的使用流程,见笑。。。

import numpy as np
import tensorflow as tf
from tensorflow.contrib.factorization import KMeans
import os
import cv2
# 导入MNIST数据集
from tensorflow.examples.tutorials.mnist import input_data
import time
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
num_steps = 50  # 训练次数
batch_size = 1024  # 每一批的样本数
k = 25  # clusters的数量
num_classes = 10  # 10分类
num_features = 784  # 每张图片是28*28
def mytrain():full_data_x = mnist.train.imagesX = tf.placeholder(tf.float32, shape=[None, num_features])# K-Means 的参数kmeans = KMeans(inputs=X, num_clusters=k, distance_metric='cosine',use_mini_batch=True)# 创建 KMeans 模型training_graph = kmeans.training_graph()if len(training_graph) > 6:(all_scores, cluster_idx, scores, cluster_centers_initialized,cluster_centers_var, init_op, train_op) = training_graphelse:(all_scores, cluster_idx, scores, cluster_centers_initialized,init_op, train_op) = training_graphcluster_idx = cluster_idx[0]avg_distance = tf.reduce_mean(scores)# 初始化变量 (用默认值)init_vars = tf.global_variables_initializer()sess = tf.Session()sess.run(init_vars, feed_dict={X: full_data_x})sess.run(init_op, feed_dict={X: full_data_x})saver=tf.train.Saver(tf.global_variables(),max_to_keep=5)ckpt_file=r'./model'# 训练for i in range(1, num_steps + 1):_, d, idx = sess.run([train_op, avg_distance, cluster_idx],feed_dict={X: full_data_x})if i % 10 == 0 or i == 1:print("Step %i, Avg Distance: %f" % (i, d))saver.save(sess, ckpt_file)def mytest():# 输入图片X = tf.placeholder(tf.float32, shape=[None, num_features])kmeans = KMeans(inputs=X, num_clusters=k, distance_metric='cosine',use_mini_batch=True)training_graph = kmeans.training_graph()if len(training_graph) > 6:(all_scores, cluster_idx, scores, cluster_centers_initialized,cluster_centers_var, init_op, train_op) = training_graphelse:(all_scores, cluster_idx, scores, cluster_centers_initialized,init_op, train_op) = training_graphwith tf.Session() as sess:ckpt_state = tf.train.get_checkpoint_state('./')print(ckpt_state)saver = tf.train.Saver()if ckpt_state:saver.restore(sess, tf.train.latest_checkpoint(r'./'))test_x, test_y = mnist.test.images, mnist.test.labelsall_scores_,scores_ = sess.run([all_scores,scores],feed_dict={X: test_x})print(all_scores_[0].shape)print(scores_[0].shape)all_scores_=all_scores_[0]scores_=scores_[0]for i in range(scores_.shape[0]):L=abs(all_scores_[i]-scores_[i]) # 计算scores与all_scores中的每个元素的距离min_val=min(L) # 取最小的距离val=list(L).index(min_val) # 得到scores在all_scores的下标 ,既是数据的分类path1 = os.path.join("./images", str(val))if os.path.isdir(path1) == False:os.mkdir(path1)img=test_x[i].reshape(28,28) * 255img=img.astype(np.uint8)cv2.imwrite(os.path.join(path1,str(time.time()) + ".png"),img)mytrain()
mytest()

最后得到25个文件夹,对应我们设置的k值。

这是其中一个文件夹的图片
这里只是使用了mnist数据集做的测试,各位可以自行改成自己要打标签的数据集。代码只是对tensorflow官网的例子进行的改写。

这里遇到了一个问题,如果有大佬知道,忘不吝赐教。

full_data_x = mnist.train.images
shape_=full_data_x.shape
full_data_x=full_data_x.reshape(shape_[0],28,28)
X = tf.placeholder(tf.float32, shape=[None, 28,28])

把输入改成这样,会报错:

Traceback (most recent call last):File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1356, in _do_callreturn fn(*args)File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1341, in _run_fnoptions, feed_dict, fetch_list, target_list, run_metadata)File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1429, in _call_tf_sessionrunrun_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input points should be a matrix.[[{{node NearestNeighbors}}]]During handling of the above exception, another exception occurred:Traceback (most recent call last):File "I:/myPython/Kmeans/tensorflow_Kmeans.py", line 87, in <module>mytrain()File "I:/myPython/Kmeans/tensorflow_Kmeans.py", line 44, in mytrainfeed_dict={X: full_data_x})File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 950, in runrun_metadata_ptr)File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1173, in _runfeed_dict_tensor, options, run_metadata)File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1350, in _do_runrun_metadata)File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1370, in _do_callraise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input points should be a matrix.[[node NearestNeighbors (defined at /myPython/Kmeans/tensorflow_Kmeans.py:25) ]]Original stack trace for 'NearestNeighbors':File "/myPython/Kmeans/tensorflow_Kmeans.py", line 87, in <module>mytrain()File "/myPython/Kmeans/tensorflow_Kmeans.py", line 25, in mytraintraining_graph = kmeans.training_graph()File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\contrib\factorization\python\ops\clustering_ops.py", line 377, in training_graphall_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\contrib\factorization\python\ops\clustering_ops.py", line 257, in _infer_graphinp, clusters, 1)File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\ops\gen_clustering_ops.py", line 258, in nearest_neighbors"NearestNeighbors", points=points, centers=centers, k=k, name=name)File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 788, in _apply_op_helperop_def=op_def)File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\util\deprecation.py", line 507, in new_funcreturn func(*args, **kwargs)File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\framework\ops.py", line 3616, in create_opop_def=op_def)File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\framework\ops.py", line 2005, in __init__self._traceback = tf_stack.extract_stack()

根据提示,我输入的不是一个矩阵。我。。。。

KMeans 的使用相关推荐

  1. java iris_利用K-Means聚类算法实现对iris.data.ulab

    该楼层疑似违规已被系统折叠 隐藏此楼查看此楼 利用K-Means聚类算法实现对iris.data.ulabel数据的聚类,这是在网上找到如果要换成我的iris.date iris.date.ulabl ...

  2. Python,OpenCV中的K均值聚类——K-Means Cluster

    Python,OpenCV中的K均值聚类 1. 效果图 2. 原理 2.1 什么是K均值聚类? 2.2 K均值聚类过程 2.3 cv2.kmeans(z, 2, None, criteria, 10, ...

  3. 使用Python,OpenCV,K-Means聚类查找图像中最主要的颜色

    Python,OpenCV,K-Means聚类查找图像中最主要的颜色 1. K-Means是什么? 2. 步骤 3. 效果图 4. 源代码 参考 对于肉眼来说,从一幅图中识别出主要颜色很容易.那怎么用 ...

  4. 机器学习中的聚类算法(1):k-means算法

    一文详解激光点云的物体聚类:https://mp.weixin.qq.com/s/FmMJn2qjtylUMRGrD5telw 引言: Q:什么是聚类算法? 现在我们在做的深度学习当中,比如图像的识别 ...

  5. 机器学习(17)无监督学习 -- K-means算法与性能评估

    目录 一.K-means 1.概念 2.过程 3.API(K-means) 二.K-means性能评估 1.轮廓系数 2.API(轮廓系数) 一.K-means 1.概念 无监督学习:没有目标值(没有 ...

  6. python实现K-means算法

    K-means算法流程: 随机选k个样本作为初始聚类中心 计算数据集中每个样本到k个聚类中心距离,并将其分配到距离最小的聚类中心 对于每个聚类,重新计算中心 回到2,至得到局部最优解 python代码 ...

  7. Udacity机器人软件工程师课程笔记(二十一) - 对点云进行集群可视化 - 聚类的分割 - K-means|K均值聚类, DBSCAN算法

    聚类的分割 1.K-均值聚类 (1)K-均值聚类介绍 k均值聚类算法(k-means clustering algorithm)是一种迭代求解的聚类分析算法,其步骤是随机选取K个对象作为初始的聚类中心 ...

  8. K-均值聚类(K-Means) C++代码实现

    K-均值聚类(K-Means)简介可以参考: http://blog.csdn.net/fengbingchun/article/details/79276668 以下是K-Means的C++实现,c ...

  9. OpenCV3.3中K-Means聚类接口简介及使用

    OpenCV3.3中给出了K-均值聚类(K-Means)的实现,即接口cv::kmeans,接口的声明在include/opencv2/core.hpp文件中,实现在modules/core/src/ ...

  10. K-Means聚类算法原理

    来自:https://www.cnblogs.com/pinard/p/6164214.html K-Means算法是无监督聚类算法,它有很多变体.包括初始化优化K-Means++,距离计算优化elk ...

最新文章

  1. 虚拟化技术中,为什么说容器技术暂时将不会取代虚拟机模式
  2. Altiris 7.1 远程
  3. wxPython控件学习之StaticText静态文本框
  4. 【Luogu】P1607庙会班车Fair Shuttle(线段树+贪心)
  5. sqlyog怎么查找表_VBA代码解决方案第58讲:在VBA中查找指定工作表的实用方法
  6. 如何从零开始搭建自己的博客
  7. 反向传播网络(BP 网络)
  8. 【转载】SNMPv3 配置及snmpwalk命令信息获取
  9. android打电话录音软件,Android uni-app实现音视频通话
  10. MD5文件如何解密zip文件
  11. word论文排版操作
  12. 【毕业设计-教程】红外控制原理详解 - 单片机嵌入式 物联网 stm32 c51
  13. 自适应辛普森学习笔记
  14. python做交易软件_我用Python做了个量化交易工具!
  15. WPF 控件【U】UserControl(一) UserControl、ContentControl、Page的区别,及它们的使用方法
  16. Vue2切换生产环境、测试环境和开发环境
  17. 视频消重伪原创版 怎么修改视频的MD5值
  18. java保存火车票信息_java抓取12306火车余票信息
  19. C++模板——事半功倍的神器
  20. php 控制304,php静态文件返回304技巧分享,_PHP教程

热门文章

  1. 联璧爆雷?斐讯要甩锅?别听信谣言,一起捋捋这个事
  2. matlab图片差异度的比较,Matlab进行图像相似度比较,使用欧式距离
  3. 【Image Restoration】Restormer: Efficient Transformer for High-Resolution Image Restoration
  4. [读书笔记]java基础与案例详解 主编徐明华
  5. Teamviewer验证邮件收不到的问题
  6. 出租车轨迹地图匹配实例
  7. web包下载_Plex Home Theater免费版下载-多功能媒体播放器 v1.4.1.469 免费版
  8. Oracle 10G出现了tabel name类似 BIN$DEh5mRKIRGKvC6E+bQCawQ==$0 的表
  9. windows版 redis5 下载
  10. [Windows] 微信超级管家,自动好友回复、计数、自动同意、群发、好友导出、消息日志、无限多开