tf.estimator.Estimator

简单介绍

是一个class 所以需要初始化,作用是用来 训练和评价 tensorflow 模型的
Estimator对象包装由一个名为model_fn函数指定的模型,model_fn在给定输入和许多其他参数的情况下,返回执行训练、评估或预测所需的操作。所有输出(checkpoints, event files, etc.等)都写入model_dir或其子目录。如果没有设置model_dir,则使用临时目录。

初始化

__init__(model_fn,model_dir=None,config=None,params=None,warm_start_from=None
)'''Args:model_fn: Model function. Follows the signature:Args:features:  是从 input_fn中返回的词典tensor 或者 单个tensor ;其实质就是模型的输入(以前我们都是用tf.placeholder输入的,这里使用input_fn 函数返回)  This is the first item returned from the input_fnlabels:  是从 input_fn中返回的词典tensor 或者 单个tensor,注意,如果mode=tf.estimator.ModeKeys.PREDICT(就是在预测的时候), labels将会被设置为None  This is the second item returned from the input_fnmode: Optional. Specifies if this training, evaluation or prediction. See tf.estimator.ModeKeys.params: Optional dict of hyperparameters.接受初始化Estimator实例时的参数params config: Optional estimator.RunConfig object.接受初始化Estimator实例时的参数config  或者一个默认的值. Allows setting up things in your model_fn based on configuration such as num_ps_replicas, or model_dir.Returns: tf.estimator.EstimatorSpec  这里一定要注意 返回的是EstimatorSpec实例model_dir: 输出路径,有关模型的输出的一切东西,全部输出在这里config: 这个是一个类,是官方固定的配置参数,如果用户觉得,不能满足使用,需要添加自己的参数,可以使用下面的这个参数paramsparams: dict of hyper parameters that will be passed into model_fn. Keys are names of parameters, values are basic python types.warm_start_from: Optional string filepath to a checkpoint or SavedModel to warm-start from, or a tf.estimator.WarmStartSettings object to fully configure warm-starting. If the string filepath is provided instead of a tf.estimator.WarmStartSettings, then all variables are warm-started, and it is assumed that vocabularies and tf.Tensor names are unchanged.
'''

重点圈出

The config argument can be passed tf.estimator.RunConfig object containing information about the execution environment. It is passed on to the model_fn, if the model_fn has a parameter named “config” (and input functions in the same manner). If the config parameter is not passed, it is instantiated by the Estimator. Not passing config means that defaults useful for local execution are used. Estimator makes config available to the model (for instance, to allow specialization based on the number of workers available), and also uses some of its fields to control internals, especially regarding checkpointing.

The params argument contains hyperparameters. It is passed to the model_fn, if the model_fn has a parameter named “params”, and to the input functions in the same manner. Estimator only passes params along, it does not inspect it. The structure of params is therefore entirely up to the developer.

方法

train 方法

从input_fn 获取数据,用来训练模型

train(input_fn,hooks=None,steps=None,max_steps=None,saving_listeners=None
)'''
Args:input_fn: A function that provides input data for training as minibatches. See Premade Estimators for more information. The function should construct and return one of the following: * A tf.data.Dataset object: Outputs of Dataset object must be a tuple (features, labels) with same constraints as below. * A tuple (features, labels): Where features is a tf.Tensor or a dictionary of string feature name to Tensor and labels is a Tensor or a dictionary of string label name to Tensor. Both features and labels are consumed by model_fn. They should satisfy the expectation of model_fn from inputs.hooks: List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the training loop.steps: Number of steps for which to train the model. If None, train forever or train until input_fn generates the tf.errors.OutOfRange error or StopIteration exception. steps works incrementally. If you call two times train(steps=10) then training occurs in total 20 steps. If OutOfRange or StopIteration occurs in the middle, training stops before 20 steps. If you don't want to have incremental behavior please set max_steps instead. If set, max_steps must be None.max_steps: Number of total steps for which to train model. If None, train forever or train until input_fn generates the tf.errors.OutOfRange error or StopIteration exception. If set, steps must be None. If OutOfRange or StopIteration occurs in the middle, training stops before max_steps steps. Two calls to train(steps=100) means 200 training iterations. On the other hand, two calls to train(max_steps=100) means that the second call will not do any iteration since first call did all 100 steps.saving_listeners: list of CheckpointSaverListener objects. Used for callbacks that run immediately before or after checkpoint savings.
Returns:self, for chaining.'''

主要参数说明

input_fn:是一个为训练提供输入数据的函数(每次提供一个batch_size的数据),其返回的是的格式是(features,labels),正好作为mode_fn的输入,其返回的格式应该是下列之一:

  1. tf.data.Dataset object: Outputs of Dataset object must be a tuple (features, labels)
  2. A tuple (features, labels): Where features is a tf.Tensor or a dictionary of string feature name to Tensor and labels is a Tensor or a dictionary of string label name to Tensor

max_steps:最大训练多少step(也就是训练多少个batch_size),当我们暂停后,继续训练程序会检测目前已经训练的步数是否大于max_steps若大于等于,那么就不会继续训练(On the other hand, two calls to train(max_steps=100) means that the second call will not do any iteration since first call did all 100 steps.

step:会在原来的基础上,继续“增长式”训练,例如你调用了两次train(input_fn,step=10),那么模型就相当于训练了20个迭代

evaluate 方法

Evaluates the model given evaluation data input_fn.
For each step, calls input_fn, which returns one batch of data. Evaluates until: - steps batches are processed, or - input_fn raises an end-of-input exception获取input_fn返回的数据并输入到模型中,用来评价模型每一步都调用一次input_fn,其返回one batch of data,知道等于steps 或者input_fn raises an end-of-input exception

evaluate(input_fn,steps=None,hooks=None,checkpoint_path=None,name=None
)'''
Args:input_fn: A function that constructs the input data for evaluation. See Premade Estimators for more information. The function should construct and return one of the following: * A tf.data.Dataset object: Outputs of Dataset object must be a tuple (features, labels) with same constraints as below. * A tuple (features, labels): Where features is a tf.Tensor or a dictionary of string feature name to Tensor and labels is a Tensor or a dictionary of string label name to Tensor. Both features and labels are consumed by model_fn. They should satisfy the expectation of model_fn from inputs.steps: Number of steps for which to evaluate model. If None, evaluates until input_fn raises an end-of-input exception.hooks: List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the evaluation call.checkpoint_path: Path of a specific checkpoint to evaluate. If None, the latest checkpoint in model_dir is used. If there are no checkpoints in model_dir, evaluation is run with newly initialized Variables instead of ones restored from checkpoint.name: Name of the evaluation if user needs to run multiple evaluations on different data sets, such as on training data vs test data. Metrics for different evaluations are saved in separate folders, and appear separately in tensorboard.Returns:A dict containing the evaluation metrics specified in model_fn keyed by name, as well as an entry global_step which contains the value of the global step for which this evaluation was performed. For canned estimators, the dict contains the loss (mean loss per mini-batch) and the average_loss (mean loss per sample). Canned classifiers also return the accuracy. Canned regressors also return the label/mean and the prediction/mean.
'''

参数说明

具体的参数和train方法类似,就不说了,这里主要说一下 这个方法的返回(return)
返回的是一个词典,是在mode_fn中提前指定好的,同时还会返回执行了多少step
例如在model_fn函数中一般有如下类似定义:

    estim_specs=tf.estimator.EstimatorSpec(mode=mode,predictions=pred_classes,loss=loss_op,train_op=train_op,eval_metric_ops={"accuracy":acc_op})

中的 eval_metric_ops={“accuracy”:acc_op}),最后会输出类似这种

{'accuracy': 0.9192, 'loss': 0.28470048, 'global_step': 1000}

predict方法

predict(input_fn,predict_keys=None,hooks=None,checkpoint_path=None,yield_single_examples=True
)'''
Args:input_fn: A function that constructs the features. Prediction continues until input_fn raises an end-of-input exception (tf.errors.OutOfRangeError or StopIteration). See Premade Estimators for more information. The function should construct and return one of the following:A tf.data.Dataset object: Outputs of Dataset object must have same constraints as below.features: A tf.Tensor or a dictionary of string feature name to Tensor. features are consumed by model_fn. They should satisfy the expectation of model_fn from inputs.A tuple, in which case the first item is extracted as features.predict_keys: list of str, name of the keys to predict. It is used if the tf.estimator.EstimatorSpec.predictions is a dict. If predict_keys is used then rest of the predictions will be filtered from the dictionary. If None, returns all.hooks: List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the prediction call.checkpoint_path: Path of a specific checkpoint to predict. If None, the latest checkpoint in model_dir is used. If there are no checkpoints in model_dir, prediction is run with newly initialized Variables instead of ones restored from checkpoint.yield_single_examples: If False, yields the whole batch as returned by the model_fn instead of decomposing the batch into individual elements. This is useful if model_fn returns some tensors whose first dimension is not equal to the batch size.
'''

说明

给定输入,返回在model_fn中指定要输出的内容tf.estimator.EstimatorSpec(mode,predictions=pred_classes)

    ........pred_classes=tf.argmax(logits,axis=1)pred_probas=tf.nn.softmax(logits)#PREDICTSif mode==tf.estimator.ModeKeys.PREDICT:return tf.estimator.EstimatorSpec(mode,predictions=pred_classes)...........

具体参数和trian 方法的参数基本相同,就不多说,这里重点讲一下下面几个:
predict_keys: 是一个str类型的list,如果使用这个predict_keys,那么模型只会返回predictions 中和predict_keys相同的key的值
**checkpoint_path:**要预测的特定检查点的路径。如果没有,则使用model_dir中的最新检查点。如果在model_dir中没有检查点,则使用新初始化的变量而不是从检查点恢复的变量运行预测
yield_single_examples: 如果为False,则生成model_fn返回的整个批,而不是将批分解为单个元素。如果model_fn返回其第一维不等于批处理大小的一些张量,则这很有用。

tf.estimator.Estimator讲解相关推荐

  1. tf.estimator.Estimator解析

    Estimator类代表了一个模型,以及如何对这个模型进行训练和评估, class Estimator(builtins.object) 可以按照下面方式创建一个E def resnet_v1_10_ ...

  2. tf.estimator.Estimator的使用

    tf.estimator.Estimator是TF比较高级的接口. 最近在使用bert预训练模型的时候用到了tf.estimator.Estimator.使用该接口的时候需要开发者完成的工作比较少,一 ...

  3. [tensorflow]tf.estimator.Estimator构建tensorflow模型

    目录 一.Estimator简介 二.数据集 三.定义特征列 四.estimator创建模型 五.模型训练.评估和预测 六.模型保存和恢复 一.Estimator简介 Estimator是Tensor ...

  4. Tensorflow API 讲解——tf.estimator.Estimator

    class Estimator(builtins.object) #介绍 Estimator 类,用来训练和验证 TensorFlow 模型. Estimator 对象包含了一个模型 model_fn ...

  5. tf.estimator.EstimatorSpec讲解

    作用 是一个class(类),是定义在model_fn中的,并且model_fn返回的也是它的一个实例,这个实例是用来初始化Estimator类的 (Ops and objects returned ...

  6. 阿里云机器学习平台PAI的视频介绍(其中tensorflow高级教程有tf的代码优化讲解)

    https://tianchi.aliyun.com/competition/new_articleDetail.html?spm=5176.9876270.0.0.65d0a126iwqolt&am ...

  7. tf.estimator的用法

    tf.estimator的用法 利用 tf.estimator 训练模型时需要写两个重要的函数,一个用于数据输入的函数(input_fn),另一个用于模型创建的函数(model_fn).下面逐一来说明 ...

  8. 机器学习笔记5-Tensorflow高级API之tf.estimator

    前言 本文接着上一篇继续来聊Tensorflow的接口,上一篇中用较低层的接口实现了线性模型,本篇中将用更高级的API--tf.estimator来改写线性模型. 还记得之前的文章<机器学习笔记 ...

  9. tf.estimator用法

    estimator:估算器 tf.estimator -----一种高级TensorFlow API.估算器封装以下操作: 训练(training) 评价(evaluation) 预测(predict ...

最新文章

  1. python图片识别-Python+Opencv识别两张相似图片
  2. Jupyter中打印所有结果的解决办法
  3. 查看防火墙状态并关闭防火墙
  4. 打造最强加密工具之《绝密信息传递》
  5. 天梯赛-是否完全二叉搜索树
  6. 容器编排技术 -- Kubernetes 应用连接到 Service
  7. 百度微软云服务器地址,win10的ie浏览器默认地址被百度劫持
  8. Memcache入门知识
  9. 【渝粤教育】国家开放大学2018年春季 0092-22T民法 参考试题
  10. 服务器为什么经常掉线?
  11. 将输入的单词按首字母排序
  12. uniapp按照官方《针对plus.runtime.install在安卓9.0+上无法执行的解决方案》处理后报错:安装包解析错误
  13. Elasticsearch——Settings设置
  14. 中文分词之HMM详解
  15. Mysql 1022
  16. FFmpeg:基础命令
  17. (spring-第4回【IoC基础篇】)spring基于注解的配置
  18. 05-SSM版文件上传和下载
  19. 计算机专业项目指导教师评语,指导教师评语
  20. 【ACWing】665. 倍数

热门文章

  1. 未来五年内智能机器人是影响金融行业发展的技术趋势
  2. 岁月凶残,敬请珍惜——得知早已不能过五四节时之随想
  3. 客户自助服务第一步:在线客服、在线帮助中心
  4. MyBatis Generator使用方法(详细)
  5. IDEA中mybatis-generator插件的使用
  6. 有关GMT写入中文——pstext命令的相关事宜
  7. 股票和期货的区别(股指期货1个点赚多少钱)
  8. python绘制网络拓扑_python 画网络拓扑图
  9. 2021-05-27 WMS系统中的二维码技术应用
  10. iOS开发- ios学习资源