引言

TensorFlow是目前流行的机器学习框架,用户可以基于TensorFlow方便地构建机器学习模型,并将模型部署到线上提供服务。

最近看Estimator框架比较流行,公司也想看Wide & Deep模型的效果,所以特来踩坑。

本文数据使用adult income二分类数据集,训练+预测数据总条数48000左右。

本文代码大量来源于Github上Wide&Deep模型的官方实现,由于其对代码进行了大量的封装导致易读性比较差,本人在其基础之上进行了一些简化。

一、开篇

先介绍一下Estimator——一种简化机器学习编程的高阶 TensorFlow API。

其封装了以下操作:

  • 训练

  • 评估

  • 预测

  • 导出

再介绍其优势:

  • 由单机向分布式过渡时代码变动少

  • 代码简单直观

  • 有预创建模型可以直接使用

  • 可以配合feature_column进行特征工程,简化线上操作,也不用顾忌线上线下模型不一致的问题。使用feature_column可以直接接受原始特征,虽然可以带来性能问题,但对于快速试验模型来说是非常友好的。从这里了解更多feature_column的信息。

  • 模型保存,导出,部署相对简洁

  • 与tensorboard配合良好

  • TensorFlow团队推荐使用,也是其开发重点

具体内容可以访问TensorFlow官方指南Introduction to Estimators。

另外推荐阅读美团的Estimator工程实践基于TensorFlow Serving的深度学习在线预估。

二、Estimator构建模型

使用Estimator通常包括四个步骤:

  1. 定义特征列,作为estimator评估器的输入。

  2. 创建数据导入函数input_fn,在训练、预测、评估函数中作为参数。

  3. 实例化Estimator。

  4. 模型训练、预测、评估以及导出。

具体代码如下:

# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
import tensorflow as tftf.logging.set_verbosity(tf.logging.INFO)ROOT_PATH = '/Users/Zhao/Data/adult/'
TRAIN_PATH = ROOT_PATH + 'train.csv'
EVAL_PATH = ROOT_PATH + 'test.csv'
PREDICT_PATH = ROOT_PATH + 'predict.csv'
MODEL_PATH = '/tmp/adult_model'
EXPORT_PATH = '/tmp/adult_export_model'
_CSV_COLUMNS = ['age', 'workclass', 'fnlwgt', 'education', 'education_num','marital_status', 'occupation', 'relationship', 'race', 'gender','capital_gain', 'capital_loss', 'hours_per_week', 'native_country','income_bracket'
]_CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''],[0], [0], [0], [''], [0]]_HASH_BUCKET_SIZE = 1000_NUM_EXAMPLES = {'train': 32561,'validation': 16281,
}def build_model_columns():"""Builds a set of wide and deep feature columns."""# Continuous variable columnsage = tf.feature_column.numeric_column('age')education_num = tf.feature_column.numeric_column('education_num')capital_gain = tf.feature_column.numeric_column('capital_gain')capital_loss = tf.feature_column.numeric_column('capital_loss')hours_per_week = tf.feature_column.numeric_column('hours_per_week')education = tf.feature_column.categorical_column_with_vocabulary_list('education', ['Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college','Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school','5th-6th', '10th', '1st-4th', 'Preschool', '12th'])marital_status = tf.feature_column.categorical_column_with_vocabulary_list('marital_status', ['Married-civ-spouse', 'Divorced', 'Married-spouse-absent','Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])relationship = tf.feature_column.categorical_column_with_vocabulary_list('relationship', ['Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried','Other-relative'])workclass = tf.feature_column.categorical_column_with_vocabulary_list('workclass', ['Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov','Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])# To show an example of hashing:occupation = tf.feature_column.categorical_column_with_hash_bucket('occupation', hash_bucket_size=_HASH_BUCKET_SIZE)# Transformations.age_buckets = tf.feature_column.bucketized_column(age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])# Wide columns and deep columns.base_columns = [education, marital_status, relationship, workclass, occupation,age_buckets,]crossed_columns = [tf.feature_column.crossed_column(['education', 'occupation'], hash_bucket_size=_HASH_BUCKET_SIZE),tf.feature_column.crossed_column([age_buckets, 'education', 'occupation'],hash_bucket_size=_HASH_BUCKET_SIZE),]wide_columns = base_columns + crossed_columnsdeep_columns = [age,education_num,capital_gain,capital_loss,hours_per_week,tf.feature_column.indicator_column(workclass),tf.feature_column.indicator_column(education),tf.feature_column.indicator_column(marital_status),tf.feature_column.indicator_column(relationship),# To show an example of embeddingtf.feature_column.embedding_column(occupation, dimension=8),]return wide_columns, deep_columnsdef input_fn(data_path, shuffle, num_epochs, batch_size):"""Generate an input function for the Estimator."""def parse_csv(value):columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)features = dict(zip(_CSV_COLUMNS, columns))labels = features.pop('income_bracket')# classes = tf.equal(labels, '>50K')  # binary classificationreturn features, labels# Extract lines from input files using the Dataset API.dataset = tf.data.TextLineDataset(data_path)if shuffle:dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])dataset = dataset.map(parse_csv, num_parallel_calls=5)# We call repeat after shuffling, rather than before, to prevent separate# epochs from blending together.dataset = dataset.repeat(num_epochs)dataset = dataset.batch(batch_size)return dataset# estimator.train()可以循环运行,模型的状态将持久保存在model_dir
def run():wide_columns, deep_columns = build_model_columns()# os.system('rm -rf {}'.format(MODEL_PATH))config = tf.estimator.RunConfig(save_checkpoints_steps=100)estimator = tf.estimator.DNNLinearCombinedClassifier(model_dir=MODEL_PATH,linear_feature_columns=wide_columns,linear_optimizer=tf.train.FtrlOptimizer(learning_rate=0.01),dnn_feature_columns=deep_columns,dnn_hidden_units=[256, 64, 32, 16],dnn_optimizer=tf.train.AdamOptimizer(learning_rate=0.001),config=config)# Linear model.# estimator = tf.estimator.LinearClassifier(feature_columns=wide_columns, n_classes=2,#                                           optimizer=tf.train.FtrlOptimizer(learning_rate=0.03))# Train the model.estimator.train(input_fn=lambda: input_fn(data_path=TRAIN_PATH, shuffle=True, num_epochs=40, batch_size=100), steps=2000)"""steps: 最大训练次数,模型训练次数由训练样本数量、num_epochs、batch_size共同决定,通过steps可以提前停止训练"""# Evaluate the model.eval_result = estimator.evaluate(input_fn=lambda: input_fn(data_path=EVAL_PATH, shuffle=False, num_epochs=1, batch_size=40))print('Test set accuracy:', eval_result)# Predict.pred_dict = estimator.predict(input_fn=lambda: input_fn(data_path=PREDICT_PATH, shuffle=False, num_epochs=1, batch_size=40))for pred_res in pred_dict:print(pred_res['probabilities'][1])columns = wide_columns + deep_columnsfeature_spec = tf.feature_column.make_parse_example_spec(feature_columns=columns)serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)estimator.export_savedmodel(EXPORT_PATH, serving_input_fn)if __name__ == '__main__':run()

输出信息:

# 评估
Test set accuracy: {'accuracy': 0.8199128, 'accuracy_baseline': 0.76377374, 'auc': 0.8519885,'auc_precision_recall': 0.6738672, 'average_loss': 0.39400044, 'label/mean': 0.23622628, 'loss': 15.722356,'precision': 0.66900885, 'prediction/mean': 0.26024768,'recall': 0.47035882, 'global_step': 3257}# 预测
0.34541896
0.6025886
0.42424703
0.6234316
0.23452707

可以在通过TensorBoard查看训练过程:

tensorboard --logdir=/tmp/adult_model

查看导出模型结构:

/tmp/adult_export_model
└── 1562756581├── saved_model.pb└── variables├── variables.data-00000-of-00002├── variables.data-00001-of-00002└── variables.index

使用CLI检查保存的模型:

/tmp/adult_export_model/1562756581  saved_model_cli show --dir ./ --allMetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:signature_def['classification']:The given SavedModel SignatureDef contains the following input(s):inputs['inputs'] tensor_info:dtype: DT_STRINGshape: (-1)name: input_example_tensor:0The given SavedModel SignatureDef contains the following output(s):outputs['classes'] tensor_info:dtype: DT_STRINGshape: (-1, 2)name: head/Tile:0outputs['scores'] tensor_info:dtype: DT_FLOATshape: (-1, 2)name: head/predictions/probabilities:0Method name is: tensorflow/serving/classify... ...

使用CLI执行保存的模型:

/tmp/adult_export_model/1562756581  saved_model_cli run --dir /tmp/adult_export_model/1562756581
--tag_set serve --signature_def="predict" --input_examples='examples=[{"age":[46.], "education_num":[10.],
"capital_gain":[7688.], "capital_loss":[0.], "hours_per_week":[38.]}, {"age":[24.], "education_num":[13.],
"capital_gain":[0.], "capital_loss":[0.], "hours_per_week":[50.]}]'2019-07-10 20:03:35.853582: I tensorflow/core/platform/cpu_feature_guard.cc:141]
Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
Result for output key class_ids:
[[1][0]]
Result for output key classes:
[[b'1'][b'0']]
Result for output key logistic:
[[0.6388541 ][0.23918226]]
Result for output key logits:
[[ 0.57039404][-1.157168  ]]
Result for output key probabilities:
[[0.3611459  0.6388541 ][0.76081777 0.23918225]]

模型保存和导出的更多信息可以看https://www.tensorflow.org/guide/saved_model。

三、模型部署

模型部署使用docker+TensorFlow Serving :

docker pull tensorflow/servingdocker run -p 8500:8500 -p 8501:8501 --name adult_export_model --mount type=bind,\
source=/tmp/adult_export_model,target=/models/adult_export_model \
-e MODEL_NAME=adult_export_model -t tensorflow/serving

之后可以通过grpc的方式(port=8500)或者REST API的方式(port=8501)请求服务

使用docker进行部署模型的详情可以看使用docker来运行tensorflow-serving。

这里需要注意的是请求服务时输入要转换为tf.Example类型,从https://www.tensorflow.org/tutorials/load_data/tf_records了解更多关于tf.Example的信息。

下面是使用grpc方式请求模型服务的代码:

from __future__ import print_function
import grpc
import requests
import tensorflow as tffrom tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpcpath = '/Users/Zhao/Desktop/tmp/adult_tf.csv'# 生成tf.Example 数据
def _bytes_feature(value):"""Returns a bytes_list from a string / byte."""return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def _float_feature(value):"""Returns a float_list from a float / double."""return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))def _int64_feature(value):"""Returns an int64_list from a bool / enum / int / uint."""return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))feature_dict = {}
serialized_strings = []
with open(path, encoding='utf-8') as f:lines = f.readlines()names = [key for key in lines[0].strip('\n').split(',')]types = [key for key in lines[1].strip('\n').split(',')]for i in range(2, len(lines)):items = [key for key in lines[i].strip('\n').split(',')]for j in range(len(items)):item = items[j]if types[j] == 'int':item = int(item)feature_dict[names[j]] = _float_feature(item)elif types[j] == 'string':feature_dict[names[j]] = _bytes_feature(bytes(item, encoding='utf-8'))example_proto = tf.train.Example(features=tf.train.Features(feature=feature_dict))serialized = example_proto.SerializeToString()serialized_strings.append(serialized)# print(names)# print(types)# print(serialized_strings[0])## example_proto = tf.train.Example.FromString(serialized_strings[0])# print(example_proto)channel = grpc.insecure_channel(target='localhost:8500')
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)request = predict_pb2.PredictRequest()
request.model_spec.name = 'adult_export_model'
request.model_spec.signature_name = 'predict'data = serialized_strings
size = len(data)
request.inputs['examples'].CopyFrom(tf.contrib.util.make_tensor_proto(data, shape=[size]))result = stub.Predict(request, 10.0)  # 10 secs timeout
print(result)

对比直接预测和通过请求服务预测两种方式的结果:

0.34541896
0.6025886
0.42424703
0.6234316
0.23452707float_val: 0.34541893005371094
float_val: 0.6025885939598083
float_val: 0.42424705624580383
float_val: 0.6234316229820251
float_val: 0.2345270812511444

结果一致。

结语

本文对TensorFlow的Estimator API进行了小小的尝试,试验了wide&Deep模型的效果,由于使用预创建的estimator,自己也没有写多少代码,emmmm,有时间还是写一下自创建的estimator。Estimator和feature_column结合来用确实很丝滑,feature_column的ID特征hash分桶、特征embedding和交叉特征看起来很唬人的亚子,就是不知道实不实用,性能先不说,效果好也行。

TensorFlow Estimator 模型从训练到部署相关推荐

  1. BERT模型从训练到部署全流程

    BERT模型从训练到部署全流程 Tag: BERT 训练 部署 缘起 在群里看到许多朋友在使用BERT模型,网上多数文章只提到了模型的训练方法,后面的生产部署及调用并没有说明. 这段时间使用BERT模 ...

  2. BERT模型从训练到部署

    BERT模型从训练到部署全流程 Tag: BERT 训练 部署 缘起 在群里看到许多朋友在使用BERT模型,网上多数文章只提到了模型的训练方法,后面的生产部署及调用并没有说明. 这段时间使用BERT模 ...

  3. Nebula 在 Akulaku 智能风控的实践:图模型的训练与部署

    本文整理自 Akulaku 反欺诈团队在 nMeetup·深圳场的演讲,B站视频见:https://www.bilibili.com/video/BV1nQ4y1B7Qd 这次主要来介绍下 Nebul ...

  4. 目标检测模型从训练到部署!

    Datawhale干货 作者:张强,Datawhale成员 训练目标检测模型并部署到你的嵌入式设备,让边缘设备长"眼睛". 目标检测的任务是找出图像中所有感兴趣的目标(物体),确定 ...

  5. 目标检测模型从训练到部署,其实如此简单

    目标检测的任务是找出图像中所有感兴趣的目标(物体),确定它们的类别和位置,是计算机视觉领域的核心问题之一.目标检测已应用到诸多领域,比如如安防.无人销售.自动驾驶和军事等. 在许多情况下,运行目标检测 ...

  6. cloud 部署_使用Google Cloud AI平台开发,训练和部署TensorFlow模型

    cloud 部署 实用指南 (A Practical Guide) The TensorFlow ecosystem has become very popular for developing ap ...

  7. tensorflow estimator详细介绍,实现模型的高效训练

    estimator是tensorflow高度封装的一个类,里面有一些可以直接使用的分类和回归模型,例如tf.estimator.DNNClassifier,但这不是这篇博客的主题,而是怎么使用esti ...

  8. bert中文分类模型训练+推理+部署

    文章预览: 0. bert简介 1. bert结构 1. bert中文分类模型训练 1 下载bert项目代码 代码结构 2 下载中文预训练模型 3 制作中文训练数据集 2. bert模型推理 1.te ...

  9. 把一个dataset的表放在另一个dataset里面_现在开始:用你的Mac训练和部署一个图像分类模型...

    可能有些同学学习机器学习的时候比较迷茫,不知道该怎么上手,看了很多经典书籍介绍的各种算法,但还是不知道怎么用它来解决问题,就算知道了,又发现需要准备环境.准备训练和部署的机器,啊,好麻烦. 今天,我来 ...

最新文章

  1. 对于“网站快照”的认识你停留在哪个阶段?
  2. 聊聊大麦网UWP版的首页顶部图片联动效果的实现方法
  3. AODV中实施watchdog
  4. 帝国cms模板仿礼品销售网站
  5. 贪心法——最优装载问题
  6. nodejs 读取excel文件,并去重
  7. catia钣金根据线段折弯_SolidWorks钣金折弯边角余料处理技巧,钣金工艺设计师都在用...
  8. [原创]互联网网站测试经验
  9. dell服务器修改sata,Dell poweredge r210进BIOS改动磁盘控制器(SATA Controller)接口模式...
  10. gtk_init参数传递过程(草稿)
  11. 片假名翻译软件测试,怎么写软件测试用例
  12. 动画(FLASH)下载任我行-----保存自己喜欢的动画的方法总结
  13. List集合排序总结
  14. 文件不小心删除了怎么恢复呢,怎么恢复误删除的文件
  15. L13-页眉页脚设计加水印
  16. WinMerge的使用(代码相同却提示有差异)。
  17. 【有利可图网】PS保存图片提示“无法完成请求”,教你4种解决方法!
  18. 天耀18期 - 11.封装类及常用类【作业】
  19. linux的原理和运用,Linux操作系统原理与应用_内存寻址
  20. 树莓派空气质量检测仪-攀藤G5003ST的连接与使用

热门文章

  1. MIME type (‘text/html‘) is not a supported stylesheet MIME type, and strict MIME checking is enabled
  2. pl/sql 美化规则
  3. 金属粉末应用领域覆盖广,市场前景可观
  4. 源码安装的Apache启动时报错:“Could not reliably determine the server‘s fully qualified domain name”
  5. 手机触摸 事件, 当触摸屏幕时候触发
  6. 上海集成电路设计产业园开园,硬核科技,尽显实力!
  7. Spring mvc 未登录 拦截跳转登陆页面
  8. 相机标定详解(转载不易)
  9. 在IP地址后面加个 /8(/16/24/32)代表什么意思
  10. php中strrpos函数的返回值类型是型_strrpos()和strripos()函数【PHP】