estimator是tensorflow高度封装的一个类,里面有一些可以直接使用的分类和回归模型,例如tf.estimator.DNNClassifier,但这不是这篇博客的主题,而是怎么使用estimator来实现我们自定义模型的训练。它的步骤主要分为以下几个部分:

  1. 构建model_fn,在这个方法里面定义自己的模型以及训练和测试过程要做的事情;
  2. 构建input_fn,在这个方法数据的来源和喂给模型的方式;
  3. 最后,创建estimator对象,然后开始训练模型了。可以添加一些config,比如:loss的输出频率等。

构建model_fn方法

import tensorflow as tfdef model_fn(features, labels, mode, params):  # 必须要有前面三个参数# feature和labels其实就是`input_fn`方法传输过来的# mode是用来判断你现在是训练或测试阶段# params是在创建`estimator`对象的输入参数lr = params['lr']try:init_checkpoint = params['init_checkpoint']except KeyError:init_checkpoint = Nonex = features['inputs']y = features['labels']#####################在这里定义你自己的网络模型###################pre = tf.layers.dense(x, 1)loss = tf.reduce_mean(tf.pow(pre-y, 2), name='loss')######################在这里定义你自己的网络模型#################### 这里可以加载你的预训练模型assignment_map = dict()if init_checkpoint:for var in tf.train.list_variables(init_checkpoint):  # 存放checkpoint的变量名称和shapeassignment_map[var[0]] = var[0]tf.train.init_from_checkpoint(init_checkpoint, assignment_map)# 定义你训练过程要做的事情if mode == tf.estimator.ModeKeys.TRAIN:optimizer = tf.train.AdamOptimizer(lr)train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())output_spec = tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)# 定义你测试(验证)过程elif mode == tf.estimator.ModeKeys.EVAL:metrics = {'eval_loss': tf.metrics.mean_tensor(loss), "accuracy": tf.metrics.accuracy(labels, pre)}output_spec = tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)# 定义你的预测过程elif mode == tf.estimator.ModeKeys.PREDICT:predictions = {'predictions': pre}output_spec = tf.estimator.EstimatorSpec(mode, predictions=predictions)else:raise TypeErrorreturn output_spec

提几点需要注意的地方:

  1. model_fn方法返回的是tf.estimator.EstimatorSpec;
  2. TRAIN、EVAL和PREDICT模式不可缺少的参数是不一样的。

构建input_fn方法

def input_fn_bulider(inputs_file, batch_size, is_training):name_to_features = {'inputs': tf.FixedLenFeature([3], tf.float32),'labels': tf.FixedLenFeature([], tf.float32)}def input_fn(params): d = tf.data.TFRecordDataset(inputs_file)if is_training:d = d.repeat()d = d.shuffle()# map_and_batch其实就是将map和batch结合起来而已d = d.apply(tf.contrib.data.map_and_batch(lambda x: tf.parse_single_example(x, name_to_features), batch_size=batch_size))return dreturn input_fn

执行eatimator

if __name__ == '__main':# 定义日志消息的输出级别,为了获取模型的反馈信息,选择INFOtf.logging.set_verbosity(tf.logging.INFO)# 我在这里是指定模型的保存和loss输出频率runConfig = tf.estimator.RunConfig(save_checkpoints_steps=1,log_step_count_steps=1)estimator = tf.estimator.Estimator(model_fn, model_dir='your_save_path',config=runConfig, params={'lr': 0.01})# log_step_count_steps控制的只是loss的global_step的输出# 我们还可以通过tf.train.LoggingTensorHook自定义更多的输出# tensor是我们要输出的内容,输入一个字典,key为打印出来的名称,value为你要输出的tensor的namelogging_hook = tf.train.LoggingTensorHook(every_n_iter=1,tensors={'loss': 'loss'})# 其实给到estimator.train是一个dataset对象input_fn = input_fn_bulider('test.tfrecord', batch_size=1, is_training=True)estimator.train(input_fn, max_steps=1000)# 下面你还可以对模型进行验证和测试,做法是差不多的,我就不列举了

欢迎关注同名公众号:“我就算饿死也不做程序员”。
交个朋友,一起交流,一起学习,一起进步。

tensorflow estimator详细介绍,实现模型的高效训练相关推荐

  1. 详细介绍BERT模型

    文章目录 BERT简介 BERT, OpenAI GPT, 和ELMo之间的区别 相关工作 BERT的改进 BERT 的详细实现 输入/输出表示 预训练BERT 微调BERT BERT用在下游任务 G ...

  2. 机器学习算法——线性回归的详细介绍 及 利用sklearn包实现线性回归模型

    目录 1.线性回归简介 1.1 线性回归应用场景 1.2 什么是线性回归 1.2.1 定义与公式 1.2.2 线性回归的特征与目标的关系分析 2.线性回归api初步使用 2.1 线性回归API 2.2 ...

  3. TensorFlow Estimator 模型从训练到部署

    引言 TensorFlow是目前流行的机器学习框架,用户可以基于TensorFlow方便地构建机器学习模型,并将模型部署到线上提供服务. 最近看Estimator框架比较流行,公司也想看Wide &a ...

  4. Attension Mechanism模型的详细介绍,原理、分类及应用

    模型汇总24 - 深度学习中Attention Mechanism详细介绍:原理.分类及应用 Attention是一种用于提升基于RNN(LSTM或GRU)的Encoder + Decoder模型的效 ...

  5. 软件过程各类模型详细介绍(重要)

    软件过程各类模型详细介绍(重要) 瀑布模型 瀑布模型的优点 瀑布模型的缺点 瀑布模型适合的项目类型 V过程模型 V过程模型的特点 V过程模型适合的项目 原型模型 原型模型的特点 原型模型的优点 增量模 ...

  6. 华为战略规划落地的核心:VDBD(价值驱动业务设计模型) 盈利模式(利润模型)详细介绍

    华为战略规划落地的核心:VDBD(价值驱动业务设计模型) & 盈利模式(利润模型)详细介绍 本文作者 | 谢宁,<华为战略管理法:DSTE实战体系>.<智慧研发管理>作 ...

  7. 利用谷歌的联邦学习框架Tensorflow Federated实现FedAvg(详细介绍)

    目录 I. 前言 II. 数据介绍 III. 联邦学习 1. 整体框架 2. 服务器端 3. 客户端 IV. Tensorflow Federated 1. 数据处理 2. 构造TFF的Keras模型 ...

  8. BERT模型的详细介绍

    1.BERT 的基本原理是什么? BERT 来自 Google 的论文Pre-training of Deep Bidirectional Transformers for Language Unde ...

  9. [深度学习] 分布式Tensorflow 2.0 介绍(二)

    [深度学习] 分布式模式介绍(一) [深度学习] 分布式Tensorflow 2.0介绍(二) [深度学习] 分布式Pytorch 1.0介绍(三) [深度学习] 分布式Horovod介绍(四) 一 ...

最新文章

  1. 用乐观的心态去面对生活,能让你的生活过得更加快乐
  2. 写给新手:2021版调参上分手册!
  3. 如何解决VHDL中参数化赋值:赋全0、全1、全z
  4. nginx php post限制,叫你如何修改Nginx与PHP的文件上传大小限制
  5. CSS五种水平居中:text-align margin incline-block flex relative
  6. Weex学习资料整合
  7. java 1.6 ubuntu_ubuntu配置 Java SE 1.6
  8. 微信模版消息 touser 能否多个 群发
  9. centos7 python3安装numpy_centos下pip3安装numpy
  10. 计算机硬盘 u盘和光盘属于,磁盘U盘光盘的区别
  11. 封装和使用Docker流程
  12. 人工智能第一讲:人工智能概论
  13. WIN10 LTSC 转 WIN10 专业版(纯净)
  14. 笔记本计算机内部部件,笔记本内部硬件构造有哪些
  15. php多层if函数,if函数嵌套计算公式用法
  16. 招聘中的热门技术技能:SQL、Java、Python 和 Linux
  17. 基于Qt的智能管家客户端设计
  18. Ubuntu18.04安装分析
  19. 解决idea控制台中文乱码问题
  20. Windows环境下用cloc统计代码量

热门文章

  1. NTC热敏电阻,错误检测显示
  2. 若变量已正确定义并赋值,下面符合C语言语法的表达式是
  3. 实力凸显 | 思迈特软件入选“2022中国软件150强“等三大重磅榜单
  4. 编程界“网络工程师”都用过的Python学习教程+PDF电子版曝光了
  5. No Shedule lines due for delivery up to the selected date
  6. 工业和信息化部关于发布5150-5350兆赫兹频段无线接入系统频率使用相关事宜的通知工信部无函〔2012〕620号
  7. 《Servlet、JSP和Spring MVC初学指南》——第2章 会话管理 2.1URL重写
  8. 本田雅阁 2022款混动 2.0L E-CVT 锐酷版这款车怎么样?
  9. 织梦html引入html代码,织梦标签引入共用html.doc
  10. Android实现系统相册选择APP全局背景图片