学习资料:
https://www.tensorflow.org/get_started/tflearn

相应的中文翻译:
http://studyai.site/2017/03/05/%E3%80%90Tensorflow%20r1.0%20%E6%96%87%E6%A1%A3%E7%BF%BB%E8%AF%91%E3%80%91%E3%80%90tf.contrib.learn%E5%BF%AB%E9%80%9F%E5%85%A5%E9%97%A8%E3%80%91/


今天学习用 tf.contrib.learn 来建立 DNN 对 Iris 数据集进行分类.

问题:
我们有 Iris 数据集,它包含150个样本数据,分别来自三个品种,每个品种有50个样本,每个样本具有四个特征,以及它属于哪一类,分别由 0,1,2 代表三个品种。
我们将这150个样本分为两份,一份是训练集具有120个样本,另一份是测试集具有30个样本。
我们要做的就是建立一个神经网络分类模型对每个样本进行分类,识别它是哪个品种。

一共有 5 步:

  • 导入 CSV 格式的数据集
  • 建立神经网络分类模型
  • 用训练数据集训练模型
  • 评价模型的准确率
  • 对新样本数据进行分类

代码:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport os
import urllibimport numpy as np
import tensorflow as tf# Data sets
IRIS_TRAINING = "iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"IRIS_TEST = "iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"def main():# If the training and test sets aren't stored locally, download them.if not os.path.exists(IRIS_TRAINING):raw = urllib.urlopen(IRIS_TRAINING_URL).read()with open(IRIS_TRAINING, "w") as f:f.write(raw)if not os.path.exists(IRIS_TEST):raw = urllib.urlopen(IRIS_TEST_URL).read()with open(IRIS_TEST, "w") as f:f.write(raw)# Load datasets.training_set = tf.contrib.learn.datasets.base.load_csv_with_header(filename=IRIS_TRAINING,target_dtype=np.int,features_dtype=np.float32)test_set = tf.contrib.learn.datasets.base.load_csv_with_header(filename=IRIS_TEST,target_dtype=np.int,features_dtype=np.float32)# Specify that all features have real-value datafeature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]# Build 3 layer DNN with 10, 20, 10 units respectively.classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,hidden_units=[10, 20, 10],n_classes=3,model_dir="/tmp/iris_model")# Define the training inputsdef get_train_inputs():x = tf.constant(training_set.data)y = tf.constant(training_set.target)return x, y# Fit model.classifier.fit(input_fn=get_train_inputs, steps=2000)# Define the test inputsdef get_test_inputs():x = tf.constant(test_set.data)y = tf.constant(test_set.target)return x, y# Evaluate accuracy.accuracy_score = classifier.evaluate(input_fn=get_test_inputs,steps=1)["accuracy"]print("\nTest Accuracy: {0:f}\n".format(accuracy_score))# Classify two new flower samples.def new_samples():return np.array([[6.4, 3.2, 4.5, 1.5],[5.8, 3.1, 5.0, 1.7]], dtype=np.float32)predictions = list(classifier.predict(input_fn=new_samples))print("New Samples, Class Predictions:    {}\n".format(predictions))if __name__ == "__main__":main()

从代码可以看出很简短的几行就可以完成之前学过的很长的代码所做的事情,用起来和用 sklearn 相似。

关于 tf.contrib.learn 可以查看:
https://www.tensorflow.org/api_guides/python/contrib.learn

可以看到里面也有 kmeans,logistic,linear 等模型:


在上面的代码中:

  • tf.contrib.learn.datasets.base.load_csv_with_header 可以导入 CSV 数据集。
  • 分类器模型只需要一行代码,就可以设置这个模型具有多少隐藏层,每个隐藏层有多少神经元,以及最后分为几类。
  • 模型的训练也是只需要一行代码,输入指定的数据,包括特征和标签,再指定迭代的次数,就可以进行训练。
  • 获得准确率也同样很简单,只需要输入测试集,调用 evaluate。
  • 预测新的数据集,只需要把新的样本数据传递给 predict。

关于代码里几个新的方法:

1. load_csv_with_header():

用于导入 CSV,需要三个必需的参数:

  • filename,CSV文件的路径
  • target_dtype,数据集的目标值的numpy数据类型。
  • features_dtype,数据集的特征值的numpy数据类型。

在这里,target 是花的品种,它是一个从 0-2 的整数,所以对应的numpy数据类型是np.int

2. tf.contrib.layers.real_valued_column:

所有的特征数据都是连续的,因此用 tf.contrib.layers.real_valued_column,数据集中有四个特征(萼片宽度,萼片高度,花瓣宽度和花瓣高度),因此 dimension=4 。

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

3. DNNClassifier:

  • feature_columns=feature_columns, 上面定义的一组特征
  • hidden_units=[10, 20, 10],三个隐藏层分别包含10,20,10个神经元。
  • n_classes=3,三个目标类,代表三个 Iris 品种。
  • model_dir=/tmp/iris_model,TensorFlow在模型训练期间将保存 checkpoint data。

在后面会学到关于 TensorFlow 的 logging and monitoring 的章节,可以 track 一下训练中的模型: “Logging and Monitoring Basics with tf.contrib.learn”。


推荐阅读
历史技术博文链接汇总
也许可以找到你想要的

TensorFlow-4: tf.contrib.learn 快速入门相关推荐

  1. TensorFlow学习笔记12----Creating Estimators in tf.contrib.learn

    原文教程:tensorflow官方教程 记录关键内容与学习感受.未完待续.. Creating Estimators in tf.contrib.learn --tf.contrib.learn框架, ...

  2. TensorFlow高层次机器学习API (tf.contrib.learn)

    TensorFlow高层次机器学习API (tf.contrib.learn) 1.tf.contrib.learn.datasets.base.load_csv_with_header 加载csv格 ...

  3. TF学习——TF之API:TensorFlow的高级机器学习API—tf.contrib.learn的简介、使用方法、案例应用之详细攻略

    TF学习--TF之API:TensorFlow的高级机器学习API-tf.contrib.learn的简介.使用方法.案例应用之详细攻略 目录 tf.contrib.learn的简介 tf.contr ...

  4. Tensorflow高级API的进阶--利用tf.contrib.learn建立输入函数

    正文共5958个字,预计阅读时间15分钟. 笔记整理者:王小草 笔记整理时间:2017年2月27日 笔记对应的官方文档:https://www.tensorflow.org/get_started/i ...

  5. tf.contrib.learn.preprocessing.VocabularyProcessor

    tf.contrib.learn.preprocessing.VocabularyProcessor (max_document_length, min_frequency=0, vocabulary ...

  6. Tensorflow:tf.contrib.rnn.DropoutWrapper函数(谷歌已经为Dropout申请了专利!)、MultiRNNCell函数的解读与理解

    Tensorflow:tf.contrib.rnn.DropoutWrapper函数(谷歌已经为Dropout申请了专利!).MultiRNNCell函数的解读与理解 目录 1.tf.contrib. ...

  7. TensorFlow:AI工程师的快速入门实战利器

    经过几年的发展,深度学习已经成为人工智能领域最热门的技术.谷歌.亚马逊.百度.Facebook 纷纷开源了自己的深度学习框架. 在众多框架中,TensorFlow 凭借其强劲的运算性能.高效的超大集群 ...

  8. TensorFlow Estimator 教程之----快速入门

    TensorFlow 版本:1.10.0 > Guide > Introduction to Estimators Estimator 概述 本篇将介绍 TensorFlow 中的 Est ...

  9. 快速入门 TensorFlow2 模型部署

    机器学习问题不仅是一个科学问题,更是一个工程问题. 大多数年轻的数据科学家都希望将大部分时间花在构建完美的机器学习模型上,但是企业不仅需要训练一个完美的模型,同时也需要将其部署,向用户提供便捷的服务. ...

最新文章

  1. java拍照搜题软件下载_拍照即可秒出答案,搜题类App:是教辅“神器”还是偷懒“神器”?...
  2. 长波通信、中波通信、短波通信、超短波通信与微波通信介绍
  3. servlet文件上传blob_servlet实现从oracle数据库的blob字段中读出文件并显示 | 学步园...
  4. 胡言乱语集锦-大数据,手机,传统,养生
  5. 给初级拍摄者的十条好建议
  6. StereoPannerNode
  7. linux下反汇编命令,Linux命令学习手册-objdump命令
  8. Flutter进阶第6篇: 获取设备信息 以及 使用高德Api获取地理位置
  9. 省赛第七场(fzu1881 ~fzu1889)
  10. php 地区表设计,php消息表设计
  11. 安卓linux关机命令行,linux定时关机命令?安卓定时关机命令?windows关机命令?Linux系统下定时关机命令shutdown...
  12. 网络拓扑图js插件——jTopo应用
  13. python的设计哲学是什么意思_哲学是什么?
  14. 雾霾、压力、不良习惯对肌肤的4大危害
  15. Android APP在线自动更新安装
  16. SQL注入(基于 tryhackme 的讲解)
  17. scalac: Token not found...
  18. vsto 隐藏前面的列滚动条在后面需要滚动到最前面
  19. Python从小白到新手
  20. 三星I9100/galaxy S3图文ROOT教程

热门文章

  1. 基于Kappa-mu/M分布的联合多用户分集与并行中继继选择RF/FSO系统性能研究
  2. 计算机视觉编程 第六章 图像聚类
  3. 【概率论】随机变量函数的分布
  4. 免费制作证件照,这3个在线网站千万别错过
  5. 读书笔记 -公司改造 和 紧迫感
  6. layui 实现动态 radio 、select下拉框 jQuery赋值方法
  7. BOM物料清单及生产计划的分解
  8. 使用HTML+CSS实现轮播图
  9. Mysql中查询系统时间的方法
  10. 刷题之路:DP思想(动态规划)