第五章 TensorFlow工具库(下)
01.TensorFlow与自定义预估器
1.1 预估器
预估器也是一种高级API,其优点为:
- 不必编写大量样板文件代码
- 灵活,模型允许替换默认行为
可以通过两种可能的方式构建模型: - 预制预估器:预先定义的估算器,旨在生成特定类型的模型
- 自定义预估器:允许使用model_fn函数,可完全掌握模型的创建方式
1.2 自定义预估器
在之前的案例中我们使用了TensorFlow里面自带的深度神经网络分类器tf.estimator.DNNClassifier()
.这些TensorFlow自带的Estimator称为预制估算器Pre-made Estimator
(预创制的Estimator)。
classifier = tf.estimator.DNNClassifier(feature_columns = feature_columns,hidden_units = [10,10],n_classes = 3,madel_dir = models_path,config = ckpt_config)
和自定义输入函数input_fn
一样,TensorFlow允许我们自己创建更加灵活的Estimator,自定义Estimator是tf.estimator.Estimator()
方法生成,和预制估算器一样使用。
1.2.1 结构概览
自定义Estimator
应该具有DNNClassifier
一样的功能
- 创建的时候接收一些参数,如
feature_columns
、hidden_units
、n_classes
等 - 具有
train()
、evaluate()
和predict()
三个方法来训练、评价和预测
则其语法格式为:
tf.estimator.Estimator(model_fn, # 模型函数mode_dir = None, # 存储目录config = None, # 设置参数对象params = None, # 超参数,将传递给model_fn使用warm_start_from = None # 热启动目录路径)
模型函数model_fn
是唯一没有设置默认值得参数,它也是自定义Estimator
最关键的部分,包含了最核心的算法。model_fn
是一个能进行运算的函数,伪代码为:
my_model(feature, # 输入的特征数据lables, # 输入的标签数据mode, # train、evaluate、或predictparams # 超参数,对应上面Estimator传来的参数)
1.2.2 神经网络layers
下图model_fn
运作流程,以iris为例。
从上图可以看到结构:
- 输入层
Input Layer
,数据从这里输入 - 隐藏层
Hidden Layer
,2层,每层包含多个节点,数据流经这里,被推测规律 - 输出层
Output Layer
,将推测结果整理显示
我们并不需要手工实现隐层的算法和工作原理,TensorFlow已经设计好了相关算法,我们只需要创建好网络层,并按顺序连接起来即可。
1.2.3 编写model_fn
如上例预制估算器DNNClassifier
中的参数对应自定义Estimator
的参数,这些参数都会被Estimator
打包放在params
超参数中,传递给model_fn
,下面伪代码是在model_fn
内创建网络层。
improt tensorflow as tf# 自定义模型函数
def my_model_fn(features,labels,model,params):# 输入层,feature_columns对应Classifier(feature_columns = ...)net = tf.feature_column.input_layer(features,params['feature_columns'])# 隐藏层,hidden_units对应Classifier(unit = [10,10]),2层各含10个节点for units in params['hidden_units']:net = tf.layers.dense(net,units = units,activation = tf.nn.relu)# 输出层,n_classes对应3中鸢尾花logits = tf.layers.dense(net,params['n_classes'],activation = None)
1.2.4 训练(train)、评价(evaluate)和预测(predict)
训练:深度学习模型 = 模型表示 + 优化
评估:指标
- tf.layer:op封装好的层次板块
- tf.losses:损失函数板块
- tf.train.AdamOptimizer:优化器板块
- tf.metrics:评估指标板块
- tf.summary:信息总结,给tensorflow作可视化积累日志文件
- tf.estimator.EstimatorSpec:不同阶段返回的对象
- For
mode == ModeKeys.TRAIN
:required fields areloss
andtrain_op
. - For
mode == ModeKeys.EVAL
:required fields areloss
- For
mode == ModeKeys.PREDICT
:required fields arepredictions
前面我们知道,自定义的估算分类器必须能够用来执行my_classifier.train()
、my_classifier.evaluate()
和my_classifier.predict()
三个方法。但实际上,它们都是model_fn
这个函数的分身,上面my_model
中的mode
包含train
、evaluate
和predict
。示例代码
- For
my_model(...,...,"TRAIN",...) # 如果是"EVAL"就执行评价,"PREDICT"就执行预测
修改my_model
代码来实现三个功能:
improt tensorflow as tf# 自定义模型函数
def my_model_fn(features,labels,model,params):# 输入层,feature_columns对应Classifier(feature_columns = ...)net = tf.feature_column.input_layer(features,params['feature_columns'])# 隐藏层,hidden_units对应Classifier(unit = [10,10]),2层各含10个节点for units in params['hidden_units']:net = tf.layers.dense(net,units = units,activation = tf.nn.relu)# 输出层,n_classes对应3中鸢尾花logits = tf.layers.dense(net,params['n_classes'],activation = None)# 预测predicted_classes = tf.argmax(logits,1) # 预测的结果中最大值即种类if mode = tf.estimator.ModeKeys.PREDICT:predictions = {'class_ids': predicted_classes[:,tf.newaxis], # 拼成[[3],[2]]格式'probabilities':tf.nn.softmax(logits), # 把[-1.3,2.6,-0.9]规则化到0-1范围,表示可能性'logits':logits # [-1.3,2.6,-0.9]}return tf.estimator.EstimatorSpec(mode,predictions = predictions)# 训练if mode = tf.estimator.ModeKeys.TRAIN:# 优化函数,用来优化损失函数optimizer = tf.train.AdagradOptimizer(learning_rate = 0.1)# 执行优化train_op = optimizer。minimize(loss,global_step = tf.train.get_global_step())return tf.estimator.EstimatorSpec(mode,loss = loss,train_op = train_op)# 评价accuracy = tf.metrics.accuracy(labels = labels,predictions = predicted_classes,name = 'acc_op' # 计算精度)metrics = {'accuracy':accuracy}tf.summary.scalar('accuracy':accuracy[1]) # 可视化使用if mode = tf.estimator.ModeKeys.EVAL:return tf.estimator.EstimatorSpec(mode,loss = loss,eval_metric_ops = metrics)
注意,请将预测Predict
放在最先编写,否则会引发后续错误。
则创建自定义分类器:
classifier = tf.eatimator.Estimator(model_fn = my_model, # 注意这里,调用my_modelparams = {'feature_columns':feature_columns,'hidden_units':[10,10],'n_classes':3,})
02.基于TensorFlow自定义CNN预估器
2.1 Estimator的优势
- 学习流程
- Estimator封装了对训练、评估和预测的控制,用户无需不断的为新的任务重复编写代码,可以专注于对网络结构的控制。
- 网络结构
- Estimator的网络结构在model_fn中独立定义的
- 用户创建的任何网络结构都可以在Estimator的控制下使用
- 可允许用户使用别人定义好的model_fn
- 数据导入
- Estimator的数据导入也是有input_fn独立定义的
- 可仅通过改变input_fn的定义,来使用相同的网络结构学习不同的数据
2.2代码详解
核心主要是Estimator的构建,其它的则是简单的CNN搭建,在此就不列出代码了(因为没有现成的,不想从视频中搬过来……)。
03.基于TensorFlow自定义RNN预估器
3.1 代码详解
核心主要是Estimator的构建,其它的则是简单的CNN搭建,在此就不列出代码了(因为没有现成的,主要不想从视频中搬过来……)。
第五章 TensorFlow工具库(下)相关推荐
- 《RabbitMQ 实战指南》第五章 RabbitMQ 进阶(下)
<RabbitMQ 实战指南>第五章 RabbitMQ 进阶(下) 文章目录 <RabbitMQ 实战指南>第五章 RabbitMQ 进阶(下) 一.持久化 二.生产者确认 1 ...
- 操作系统课堂同步练习选择题(第五章)题库信阳师范学院柳春华老师
(备注:有红色,红色为正确答案 :若无,蓝(绿)色为正确答案) 第五章 虚拟存储器 一.章节练习 1.系统抖动是指( ). A. 使用机器时,千万屏幕闪烁的现象 B. 刚被调出的页面又 ...
- 【第三章:标准单元库 下】静态时序分析圣经翻译计划
本文由知乎赵俊军授权转载,知乎主页为https://www.zhihu.com/people/zhao-jun-jun-19 3.6 黑盒的接口时序模型 本节将介绍黑盒(任意模块或块)的IO接口时序弧 ...
- 《UNIX环境高级编程》笔记 第五章-标准IO库
1. 流和FILE对象 在第三章的系统调用都是围绕文件描述符fd的.但是标准I/O库函数操作则是围绕流进行的.当使用标准I/O库打开或创建一个文件时,使用一个流与一个文件关联. 当打开一个流时,标准I ...
- 数论概论 第五章 习题解答(下) (宋二娃的BLOG)
5.5 (a)n 算法长度 终止值 21 8 1 13 10 1 31 ...
- 2021.12.19【读书笔记】丨生物信息学与功能基因组学(第五章 高级数据库搜索 下)
5.5 用类似于BLAST的比对工具快速搜索基因组DNA 需求:随着基因组DNA数据库数量增长,对比对工具要求越来越高 能在基因组DNA中找到外显子 比对时考虑基因组DNA包含的测序错误 有相应的算法 ...
- 第五章 Scrapy爬虫框架(5.1 Scrapy框架基础)
Scrapy是一个高级Web爬虫框架,用于爬取网站并从页面中提取结构化数据.它可以用于数据挖掘.数据监控和自动化测试等多个方面.与之前讲过的Requests库和Selenium库不同,Scrapy更适 ...
- Qt-UI 界面工具库简介
一.关于Qt-UI界面工具库 Qt-UI界面工具库是武汉维仕杰科技有限公司基于Qt上进行扩展开发的控件包和界面工具,并且拥有完全自主的知识产权.得益于丰富的界面开发经验和强大的支持团队,使得Qt-UI ...
- 天龙日梅兰竹菊_第三百一十五章 梅兰竹菊
第三百一十五章梅兰竹菊 自打应下无崖子的承诺以来,楚柏便是一直马不停蹄的赶路! 赶到西夏,在见了李秋水之后,又被李秋水拉着前往[缥缈峰],这一路,风尘仆仆的楚柏,总算是难得空闲下来了: 不得不说! [ ...
最新文章
- bootstrap跟vue冲突吗_知道微服务,但你知道微前端吗?
- python 的文件读写方法:read readline readlines wirte writelines
- 【强烈推荐】Github star 10K+,周志华机器学习详细公式推导!
- C#版本与.NET版本对应关系以及各版本的特性
- python有内存处理模块吗_使用Python多处理的高内存使用
- 机器学习之数据预处理
- 文献综述写作之“结构内容”
- 【华为云技术分享】云图说 | 华为云AnyStack on BMS解决方案:助力线下虚拟化业务迁移上云
- android动画能超过父容器吗,Android中你不得不知道的动画知识 (一)
- 《大话》之 三大工厂
- java面向对象(封装-继承-多态)
- Alert提示框插件
- HDR到底是干什么的?建模的时候有什么用处?
- java 通过身份证计算年龄性别
- 扒开系统调用的三层皮(上)
- SSL基础:23:生成Kubernetes集群证书(OpenSSL方式)
- linux中gimp命令截图,Linux利用GIMP截图
- jar包打包成exe安装包
- linux ps1 主机名 ip,Bash Shell PS1: 自定义你的linux提示符十例
- java poi 填充单元格_POI操作excel表格(建立工作薄、创建工作表、将数据填充到单元格中)...