tensorflow estimator详细介绍,实现模型的高效训练
estimator
是tensorflow高度封装的一个类,里面有一些可以直接使用的分类和回归模型,例如tf.estimator.DNNClassifier
,但这不是这篇博客的主题,而是怎么使用estimator
来实现我们自定义模型的训练。它的步骤主要分为以下几个部分:
- 构建
model_fn
,在这个方法里面定义自己的模型以及训练和测试过程要做的事情; - 构建
input_fn
,在这个方法数据的来源和喂给模型的方式; - 最后,创建
estimator
对象,然后开始训练模型了。可以添加一些config,比如:loss的输出频率等。
构建model_fn方法
import tensorflow as tfdef model_fn(features, labels, mode, params): # 必须要有前面三个参数# feature和labels其实就是`input_fn`方法传输过来的# mode是用来判断你现在是训练或测试阶段# params是在创建`estimator`对象的输入参数lr = params['lr']try:init_checkpoint = params['init_checkpoint']except KeyError:init_checkpoint = Nonex = features['inputs']y = features['labels']#####################在这里定义你自己的网络模型###################pre = tf.layers.dense(x, 1)loss = tf.reduce_mean(tf.pow(pre-y, 2), name='loss')######################在这里定义你自己的网络模型#################### 这里可以加载你的预训练模型assignment_map = dict()if init_checkpoint:for var in tf.train.list_variables(init_checkpoint): # 存放checkpoint的变量名称和shapeassignment_map[var[0]] = var[0]tf.train.init_from_checkpoint(init_checkpoint, assignment_map)# 定义你训练过程要做的事情if mode == tf.estimator.ModeKeys.TRAIN:optimizer = tf.train.AdamOptimizer(lr)train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())output_spec = tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)# 定义你测试(验证)过程elif mode == tf.estimator.ModeKeys.EVAL:metrics = {'eval_loss': tf.metrics.mean_tensor(loss), "accuracy": tf.metrics.accuracy(labels, pre)}output_spec = tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)# 定义你的预测过程elif mode == tf.estimator.ModeKeys.PREDICT:predictions = {'predictions': pre}output_spec = tf.estimator.EstimatorSpec(mode, predictions=predictions)else:raise TypeErrorreturn output_spec
提几点需要注意的地方:
model_fn
方法返回的是tf.estimator.EstimatorSpec
;- TRAIN、EVAL和PREDICT模式不可缺少的参数是不一样的。
构建input_fn方法
def input_fn_bulider(inputs_file, batch_size, is_training):name_to_features = {'inputs': tf.FixedLenFeature([3], tf.float32),'labels': tf.FixedLenFeature([], tf.float32)}def input_fn(params): d = tf.data.TFRecordDataset(inputs_file)if is_training:d = d.repeat()d = d.shuffle()# map_and_batch其实就是将map和batch结合起来而已d = d.apply(tf.contrib.data.map_and_batch(lambda x: tf.parse_single_example(x, name_to_features), batch_size=batch_size))return dreturn input_fn
执行eatimator
if __name__ == '__main':# 定义日志消息的输出级别,为了获取模型的反馈信息,选择INFOtf.logging.set_verbosity(tf.logging.INFO)# 我在这里是指定模型的保存和loss输出频率runConfig = tf.estimator.RunConfig(save_checkpoints_steps=1,log_step_count_steps=1)estimator = tf.estimator.Estimator(model_fn, model_dir='your_save_path',config=runConfig, params={'lr': 0.01})# log_step_count_steps控制的只是loss的global_step的输出# 我们还可以通过tf.train.LoggingTensorHook自定义更多的输出# tensor是我们要输出的内容,输入一个字典,key为打印出来的名称,value为你要输出的tensor的namelogging_hook = tf.train.LoggingTensorHook(every_n_iter=1,tensors={'loss': 'loss'})# 其实给到estimator.train是一个dataset对象input_fn = input_fn_bulider('test.tfrecord', batch_size=1, is_training=True)estimator.train(input_fn, max_steps=1000)# 下面你还可以对模型进行验证和测试,做法是差不多的,我就不列举了
欢迎关注同名公众号:“我就算饿死也不做程序员”。
交个朋友,一起交流,一起学习,一起进步。
tensorflow estimator详细介绍,实现模型的高效训练相关推荐
- 详细介绍BERT模型
文章目录 BERT简介 BERT, OpenAI GPT, 和ELMo之间的区别 相关工作 BERT的改进 BERT 的详细实现 输入/输出表示 预训练BERT 微调BERT BERT用在下游任务 G ...
- 机器学习算法——线性回归的详细介绍 及 利用sklearn包实现线性回归模型
目录 1.线性回归简介 1.1 线性回归应用场景 1.2 什么是线性回归 1.2.1 定义与公式 1.2.2 线性回归的特征与目标的关系分析 2.线性回归api初步使用 2.1 线性回归API 2.2 ...
- TensorFlow Estimator 模型从训练到部署
引言 TensorFlow是目前流行的机器学习框架,用户可以基于TensorFlow方便地构建机器学习模型,并将模型部署到线上提供服务. 最近看Estimator框架比较流行,公司也想看Wide &a ...
- Attension Mechanism模型的详细介绍,原理、分类及应用
模型汇总24 - 深度学习中Attention Mechanism详细介绍:原理.分类及应用 Attention是一种用于提升基于RNN(LSTM或GRU)的Encoder + Decoder模型的效 ...
- 软件过程各类模型详细介绍(重要)
软件过程各类模型详细介绍(重要) 瀑布模型 瀑布模型的优点 瀑布模型的缺点 瀑布模型适合的项目类型 V过程模型 V过程模型的特点 V过程模型适合的项目 原型模型 原型模型的特点 原型模型的优点 增量模 ...
- 华为战略规划落地的核心:VDBD(价值驱动业务设计模型) 盈利模式(利润模型)详细介绍
华为战略规划落地的核心:VDBD(价值驱动业务设计模型) & 盈利模式(利润模型)详细介绍 本文作者 | 谢宁,<华为战略管理法:DSTE实战体系>.<智慧研发管理>作 ...
- 利用谷歌的联邦学习框架Tensorflow Federated实现FedAvg(详细介绍)
目录 I. 前言 II. 数据介绍 III. 联邦学习 1. 整体框架 2. 服务器端 3. 客户端 IV. Tensorflow Federated 1. 数据处理 2. 构造TFF的Keras模型 ...
- BERT模型的详细介绍
1.BERT 的基本原理是什么? BERT 来自 Google 的论文Pre-training of Deep Bidirectional Transformers for Language Unde ...
- [深度学习] 分布式Tensorflow 2.0 介绍(二)
[深度学习] 分布式模式介绍(一) [深度学习] 分布式Tensorflow 2.0介绍(二) [深度学习] 分布式Pytorch 1.0介绍(三) [深度学习] 分布式Horovod介绍(四) 一 ...
最新文章
- 用乐观的心态去面对生活,能让你的生活过得更加快乐
- 写给新手:2021版调参上分手册!
- 如何解决VHDL中参数化赋值:赋全0、全1、全z
- nginx php post限制,叫你如何修改Nginx与PHP的文件上传大小限制
- CSS五种水平居中:text-align margin incline-block flex relative
- Weex学习资料整合
- java 1.6 ubuntu_ubuntu配置 Java SE 1.6
- 微信模版消息 touser 能否多个 群发
- centos7 python3安装numpy_centos下pip3安装numpy
- 计算机硬盘 u盘和光盘属于,磁盘U盘光盘的区别
- 封装和使用Docker流程
- 人工智能第一讲:人工智能概论
- WIN10 LTSC 转 WIN10 专业版(纯净)
- 笔记本计算机内部部件,笔记本内部硬件构造有哪些
- php多层if函数,if函数嵌套计算公式用法
- 招聘中的热门技术技能:SQL、Java、Python 和 Linux
- 基于Qt的智能管家客户端设计
- Ubuntu18.04安装分析
- 解决idea控制台中文乱码问题
- Windows环境下用cloc统计代码量
热门文章
- NTC热敏电阻,错误检测显示
- 若变量已正确定义并赋值,下面符合C语言语法的表达式是
- 实力凸显 | 思迈特软件入选“2022中国软件150强“等三大重磅榜单
- 编程界“网络工程师”都用过的Python学习教程+PDF电子版曝光了
- No Shedule lines due for delivery up to the selected date
- 工业和信息化部关于发布5150-5350兆赫兹频段无线接入系统频率使用相关事宜的通知工信部无函〔2012〕620号
- 《Servlet、JSP和Spring MVC初学指南》——第2章 会话管理 2.1URL重写
- 本田雅阁 2022款混动 2.0L E-CVT 锐酷版这款车怎么样?
- 织梦html引入html代码,织梦标签引入共用html.doc
- Android实现系统相册选择APP全局背景图片