介绍

K_Means其实用sklearn即可,TensorFlow1.0早期版本支持K_Means,在2.0之后,由于很多api废弃,导致实现K_Means有很多坑。以下为踩坑记录。

完整代码路径:https://github.com/lilihongjava/leeblog_python/tree/master/tensorflow_kmeans

数据集

采用sklearn iris.csv数据集,位于data目录下

训练方法

入口代码

tf_k_means_model(feature_column="sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)",

center_count=3, input1="./data/iris.csv", output1="./data/")

采用tf.compat.v1.estimator.experimental.KMeans api,此API是从1.X版本迁移来的,目前处于experimental阶段,用于生产环境要小心!

train方法需要接受输入函数(input function),input_fn用于将feature和target data传递给Estimator的train/evaluate/predict方法。这里,将numpy数据转换为Tensors。

def input_fn():

return tf.data.Dataset.from_tensors(tf.convert_to_tensor(points, dtype=tf.float32)).repeat(2)

model.train(input_fn)

模型导出

用的是tf.Estimator.export_saved_model方法,需要指定特征列的类型,这里用的是numeric_column

if output1:

my_feature_columns = []

for key in feature_column:

my_feature_columns.append(tf.feature_column.numeric_column(key=key))

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(

tf.feature_column.make_parse_example_spec(my_feature_columns))

estimator_path = model.export_saved_model(output1, serving_input_fn)

导入模型

使用tf.saved_model.load导入目录下的模型,这里k_means导出模型signatures没有predict,这里采用cluster_index替代预测

imported = tf.saved_model.load(model_path)

imported.signatures["cluster_index"]

多维预测

这里要注意的是,一个tf.train.Example代表一个样本数据,这里需要用个list存放多个样本数据。

# 将输入数据转换成序列化后的 Example 字符串。

examples = []

for index, row in feature_dict.iterrows():

feature = {}

for col, value in row.iteritems():

feature[col] = tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

example = tf.train.Example(

features=tf.train.Features(

feature=feature

)

)

examples.append(example.SerializeToString())

整体代码

# encoding: utf-8

"""

@author: lee

@time: 2020/6/29 10:41

@file: main.py

@desc:

"""

import pandas as pd

import tensorflow as tf

import numpy as np

from tensorflow_kmeans.util.common_util import create_df

from tensorflow_kmeans.util.fileUtil import get_last_dir

from util.common_util import arg_check_transformation

def tf_k_means_model(feature_column=None, center_count=None, input1=None, output1=None):

print("输入参数:", locals())

feature_column = arg_check_transformation("list_name_str", "feature_column", feature_column)

if center_count:

center_count = arg_check_transformation("int", "center_count", center_count)

else:

raise Exception("聚类数不能为空")

df = pd.read_csv(input1)

model = tf.compat.v1.estimator.experimental.KMeans(

num_clusters=center_count, use_mini_batch=False)

points = np.array(df[feature_column])

def input_fn():

return tf.data.Dataset.from_tensors(tf.convert_to_tensor(points, dtype=tf.float32)).repeat(2)

# train

num_iterations = 10

previous_centers = None

for _ in range(num_iterations):

model.train(input_fn)

cluster_centers = model.cluster_centers()

if previous_centers is not None:

print('delta:', cluster_centers - previous_centers)

previous_centers = cluster_centers

print('score:', model.score(input_fn))

print('cluster centers:', cluster_centers)

if output1:

my_feature_columns = []

for key in feature_column:

my_feature_columns.append(tf.feature_column.numeric_column(key=key))

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(

tf.feature_column.make_parse_example_spec(my_feature_columns))

estimator_path = model.export_saved_model(output1, serving_input_fn)

def model_predict(input_data, input_model_path, feature_column):

feature_column = arg_check_transformation("list_name_str", "feature_column", feature_column)

model_path = get_last_dir(input_model_path)

imported = tf.saved_model.load(model_path)

feature_dict = input_data[feature_column]

# 将输入数据转换成序列化后的 Example 字符串。

examples = []

for index, row in feature_dict.iterrows():

feature = {}

for col, value in row.iteritems():

feature[col] = tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

example = tf.train.Example(

features=tf.train.Features(

feature=feature

)

)

examples.append(example.SerializeToString())

re = imported.signatures["cluster_index"](

examples=tf.constant(examples))

return re["output"].numpy()

if __name__ == '__main__':

tf_k_means_model(feature_column="sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)",

center_count=3, input1="./data/iris.csv", output1="./data/")

data_frame = pd.DataFrame(np.array([[5.0, 3.3, 1.4, 0.2, 0], [7.0, 3.2, 4.7, 1.4, 1]]),

columns=['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)',

'petal width (cm)', 'target'])

predict = model_predict(

feature_column="sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)",

input_model_path="./data/", input_data=data_frame)

print(predict)

参考:https://www.tensorflow.org/api_docs/python/tf/compat/v1/estimator/experimental/KMeans

https://www.tensorflow.org/guide/saved_model#savedmodels_from_estimators

本文地址:https://blog.csdn.net/qq_33873431/article/details/107160609

如您对本文有疑问或者有任何想说的,请点击进行留言回复,万千网友为您解惑!

java 重写 tensorflow_TensorFlow2.0怎么实现K_Means相关推荐

  1. 三分钟了解“Java重写”

    要了解"Java重写",首先要知道"继承",继承是一种基于已有类(父类)创建新类(子类)的一种方式 下面的Son类继承了Father类 public class ...

  2. [转载] Java 重写paint绘图

    参考链接: 用Java重写Override 感谢原文:https://zhidao.baidu.com/question/260060153.html 这个方法需要注意的地方不多,也就是重写时,先调用 ...

  3. Java 重写(Override)与重载(Overload)

    TestDog.java /*  * 重写(Override)  * 重写是子类对父类的允许访问的方法的实现过程进行重新编写, 返回值和形参都不能改变.即外壳不变,核心重写!  * 重写的好处在于子类 ...

  4. Scala error: type mismatch; found : java.util.List[?0] required: java.util.List[B]

    Scala error: type mismatch; found : java.util.List[?0] required: java.util.List[B] 目录 Scala error: t ...

  5. uestWrapper.getSession(HttpServletRequestWrapper.java:241) ~[tomcat-embed-core-9.0.13.jar:9.0.13]

    报错信息如下: uestWrapper.getSession(HttpServletRequestWrapper.java:241) ~[tomcat-embed-core-9.0.13.jar:9. ...

  6. AS5 安装 JAVA 1.6.0 用于 TOTOplayer的启动

    系统:RedHat as-5 TOTOPLAYER (LINUX下的千千静听)启动需要JAVA-1.6的版本,系统默认安装是1.4,所以需要升级 查看本机JAVA版本命令:rpm -qa|grep j ...

  7. java 7 发布,【UC浏览器】Java平台7.0正式版发布啦

    [UC浏览器]Java平台7.0正式版发布啦 UC浏览器是UC 优视科技开发的一款手机浏览器,支持WEB.WAP页面浏览,速度快而稳定,页面排版美观:具有网站导航.搜索.下载.个人数据管理等功能,您能 ...

  8. java oauth2.0_教程:如何实现Java OAuth 2.0以使用GitHub和Google登录

    java oauth2.0 将Google和GitHub OAuth登录添加到Java应用程序的指南 我们添加到Takipi的最新功能之一是3rd party登录. 如果您像我一样懒惰,那么我想您也希 ...

  9. 教程:如何实现Java OAuth 2.0以使用GitHub和Google登录

    将Google和GitHub OAuth登录添加到Java应用程序的指南 我们添加到Takipi的最新功能之一是3rd party登录. 如果您像我一样懒惰,那么我想您也希望跳过填写表单和输入新密码的 ...

最新文章

  1. R语言层次聚类(hierarchical clustering):使用scale函数进行特征缩放、hclust包层次聚类(创建距离矩阵、聚类、绘制树状图dendrogram,在树状图上绘制红色矩形框)
  2. Asp.net控件开发学习笔记(四)---Asp.net服务端状态管理
  3. 电子科技学院计算机调剂,2020年电子科技大学电子科学技术研究院考研调剂信息...
  4. MapReduce入门2-流量监控
  5. Intel Hyperscan简介
  6. Wireshark-002导入导出
  7. 拿别人源码去申请软著_别拿自己的尺子,去丈量别人的生活!
  8. MFC ------- AfxGetMainWnd( )
  9. 语义噪声 | 语义网:重新发明轮子,创新者的窘境
  10. 斐讯k2路由器刷PandoraBox一宽带多人用
  11. 网页制作html新手代码,网页制作HTML基础标签代码大全
  12. 双线性插值GPU加速
  13. imagej得到灰度图数据_老司机带你解锁ImageJ实用技巧(下)
  14. mac jupyter notebook 服务似乎挂掉了,但是会立刻重启的
  15. java多个文件加密压缩_java中文件如何加密压缩?
  16. mysql的sql语句没错但是报错_sql语句没错·但是却报错,怎么回事?
  17. win8、server 2012 清除winsxs文件夹
  18. Android 音视频开发-FFmpeg 命令
  19. dot全称_DOT是什么
  20. 如何把旧电脑游戏数据迁移到新电脑?

热门文章

  1. 计算机信息管理招聘笔试题,计算机信息管理专业卫生事业单位招聘考试笔试模拟题(七)...
  2. 核密度聚类(二)核密度估计、自适应核密度的数学原理
  3. Salesforce搬砖之功能概况
  4. 操作系统之处理器管理的概念
  5. Jetty9 NO JSP Support for /, did not find org.apache.jasper.servlet.JspServlet
  6. 群晖安装GitServer
  7. 换热站远程监控系统方案
  8. 可折叠列表ExpandableList
  9. 工业相机 linux驱动软件,机器视觉软件及工业相机软件下载 - pylon, ToF 等 | Basler...
  10. enum c++语言_第三章 C语言关键字