01.TensorFlow与自定义预估器

1.1 预估器

预估器也是一种高级API,其优点为:

  • 不必编写大量样板文件代码
  • 灵活,模型允许替换默认行为
    可以通过两种可能的方式构建模型:
  • 预制预估器:预先定义的估算器,旨在生成特定类型的模型
  • 自定义预估器:允许使用model_fn函数,可完全掌握模型的创建方式

1.2 自定义预估器

在之前的案例中我们使用了TensorFlow里面自带的深度神经网络分类器tf.estimator.DNNClassifier().这些TensorFlow自带的Estimator称为预制估算器Pre-made Estimator(预创制的Estimator)。

classifier = tf.estimator.DNNClassifier(feature_columns = feature_columns,hidden_units = [10,10],n_classes = 3,madel_dir = models_path,config = ckpt_config)

和自定义输入函数input_fn一样,TensorFlow允许我们自己创建更加灵活的Estimator,自定义Estimator是tf.estimator.Estimator()方法生成,和预制估算器一样使用。

1.2.1 结构概览

自定义Estimator应该具有DNNClassifier一样的功能

  • 创建的时候接收一些参数,如feature_columnshidden_unitsn_classes
  • 具有train()evaluate()predict()三个方法来训练、评价和预测
    则其语法格式为:
tf.estimator.Estimator(model_fn, # 模型函数mode_dir = None, # 存储目录config = None, # 设置参数对象params = None, # 超参数,将传递给model_fn使用warm_start_from = None # 热启动目录路径)

模型函数model_fn是唯一没有设置默认值得参数,它也是自定义Estimator最关键的部分,包含了最核心的算法。model_fn是一个能进行运算的函数,伪代码为:

my_model(feature, # 输入的特征数据lables, # 输入的标签数据mode, # train、evaluate、或predictparams # 超参数,对应上面Estimator传来的参数)

1.2.2 神经网络layers

下图model_fn运作流程,以iris为例。

从上图可以看到结构:

  • 输入层Input Layer,数据从这里输入
  • 隐藏层Hidden Layer,2层,每层包含多个节点,数据流经这里,被推测规律
  • 输出层Output Layer,将推测结果整理显示
    我们并不需要手工实现隐层的算法和工作原理,TensorFlow已经设计好了相关算法,我们只需要创建好网络层,并按顺序连接起来即可。

1.2.3 编写model_fn

如上例预制估算器DNNClassifier中的参数对应自定义Estimator的参数,这些参数都会被Estimator打包放在params超参数中,传递给model_fn,下面伪代码是在model_fn内创建网络层。

improt tensorflow as tf# 自定义模型函数
def my_model_fn(features,labels,model,params):# 输入层,feature_columns对应Classifier(feature_columns = ...)net  = tf.feature_column.input_layer(features,params['feature_columns'])# 隐藏层,hidden_units对应Classifier(unit = [10,10]),2层各含10个节点for units in params['hidden_units']:net = tf.layers.dense(net,units = units,activation = tf.nn.relu)# 输出层,n_classes对应3中鸢尾花logits = tf.layers.dense(net,params['n_classes'],activation = None)

1.2.4 训练(train)、评价(evaluate)和预测(predict)

训练:深度学习模型 = 模型表示 + 优化
评估:指标

  • tf.layer:op封装好的层次板块
  • tf.losses:损失函数板块
  • tf.train.AdamOptimizer:优化器板块
  • tf.metrics:评估指标板块
  • tf.summary:信息总结,给tensorflow作可视化积累日志文件
  • tf.estimator.EstimatorSpec:不同阶段返回的对象
    • For mode == ModeKeys.TRAIN:required fields are loss and train_op.
    • For mode == ModeKeys.EVAL:required fields are loss
    • For mode == ModeKeys.PREDICT:required fields are predictions
      前面我们知道,自定义的估算分类器必须能够用来执行my_classifier.train()my_classifier.evaluate()my_classifier.predict()三个方法。但实际上,它们都是model_fn这个函数的分身,上面my_model中的mode包含trainevaluatepredict。示例代码
 my_model(...,...,"TRAIN",...) # 如果是"EVAL"就执行评价,"PREDICT"就执行预测

修改my_model代码来实现三个功能:

improt tensorflow as tf# 自定义模型函数
def my_model_fn(features,labels,model,params):# 输入层,feature_columns对应Classifier(feature_columns = ...)net  = tf.feature_column.input_layer(features,params['feature_columns'])# 隐藏层,hidden_units对应Classifier(unit = [10,10]),2层各含10个节点for units in params['hidden_units']:net = tf.layers.dense(net,units = units,activation = tf.nn.relu)# 输出层,n_classes对应3中鸢尾花logits = tf.layers.dense(net,params['n_classes'],activation = None)# 预测predicted_classes = tf.argmax(logits,1) # 预测的结果中最大值即种类if mode = tf.estimator.ModeKeys.PREDICT:predictions = {'class_ids': predicted_classes[:,tf.newaxis], # 拼成[[3],[2]]格式'probabilities':tf.nn.softmax(logits), # 把[-1.3,2.6,-0.9]规则化到0-1范围,表示可能性'logits':logits # [-1.3,2.6,-0.9]}return tf.estimator.EstimatorSpec(mode,predictions = predictions)# 训练if mode = tf.estimator.ModeKeys.TRAIN:# 优化函数,用来优化损失函数optimizer = tf.train.AdagradOptimizer(learning_rate = 0.1)# 执行优化train_op = optimizer。minimize(loss,global_step = tf.train.get_global_step())return tf.estimator.EstimatorSpec(mode,loss = loss,train_op = train_op)# 评价accuracy = tf.metrics.accuracy(labels = labels,predictions = predicted_classes,name = 'acc_op' # 计算精度)metrics = {'accuracy':accuracy}tf.summary.scalar('accuracy':accuracy[1]) # 可视化使用if mode = tf.estimator.ModeKeys.EVAL:return tf.estimator.EstimatorSpec(mode,loss = loss,eval_metric_ops = metrics)

注意,请将预测Predict放在最先编写,否则会引发后续错误。
则创建自定义分类器:

classifier = tf.eatimator.Estimator(model_fn = my_model, # 注意这里,调用my_modelparams = {'feature_columns':feature_columns,'hidden_units':[10,10],'n_classes':3,})

02.基于TensorFlow自定义CNN预估器

2.1 Estimator的优势

  • 学习流程

    • Estimator封装了对训练、评估和预测的控制,用户无需不断的为新的任务重复编写代码,可以专注于对网络结构的控制。
  • 网络结构
    • Estimator的网络结构在model_fn中独立定义的
    • 用户创建的任何网络结构都可以在Estimator的控制下使用
    • 可允许用户使用别人定义好的model_fn
  • 数据导入
    • Estimator的数据导入也是有input_fn独立定义的
    • 可仅通过改变input_fn的定义,来使用相同的网络结构学习不同的数据

2.2代码详解

核心主要是Estimator的构建,其它的则是简单的CNN搭建,在此就不列出代码了(因为没有现成的,不想从视频中搬过来……)。

03.基于TensorFlow自定义RNN预估器

3.1 代码详解

核心主要是Estimator的构建,其它的则是简单的CNN搭建,在此就不列出代码了(因为没有现成的,主要不想从视频中搬过来……)。

第五章 TensorFlow工具库(下)相关推荐

  1. 《RabbitMQ 实战指南》第五章 RabbitMQ 进阶(下)

    <RabbitMQ 实战指南>第五章 RabbitMQ 进阶(下) 文章目录 <RabbitMQ 实战指南>第五章 RabbitMQ 进阶(下) 一.持久化 二.生产者确认 1 ...

  2. 操作系统课堂同步练习选择题(第五章)题库信阳师范学院柳春华老师

    (备注:有红色,红色为正确答案 :若无,蓝(绿)色为正确答案) 第五章  虚拟存储器 一.章节练习 1.系统抖动是指(        ). A. 使用机器时,千万屏幕闪烁的现象 B. 刚被调出的页面又 ...

  3. 【第三章:标准单元库 下】静态时序分析圣经翻译计划

    本文由知乎赵俊军授权转载,知乎主页为https://www.zhihu.com/people/zhao-jun-jun-19 3.6 黑盒的接口时序模型 本节将介绍黑盒(任意模块或块)的IO接口时序弧 ...

  4. 《UNIX环境高级编程》笔记 第五章-标准IO库

    1. 流和FILE对象 在第三章的系统调用都是围绕文件描述符fd的.但是标准I/O库函数操作则是围绕流进行的.当使用标准I/O库打开或创建一个文件时,使用一个流与一个文件关联. 当打开一个流时,标准I ...

  5. 数论概论 第五章 习题解答(下) (宋二娃的BLOG)

    5.5  (a)n       算法长度     终止值 21        8                  1 13        10                1 31        ...

  6. 2021.12.19【读书笔记】丨生物信息学与功能基因组学(第五章 高级数据库搜索 下)

    5.5 用类似于BLAST的比对工具快速搜索基因组DNA 需求:随着基因组DNA数据库数量增长,对比对工具要求越来越高 能在基因组DNA中找到外显子 比对时考虑基因组DNA包含的测序错误 有相应的算法 ...

  7. 第五章 Scrapy爬虫框架(5.1 Scrapy框架基础)

    Scrapy是一个高级Web爬虫框架,用于爬取网站并从页面中提取结构化数据.它可以用于数据挖掘.数据监控和自动化测试等多个方面.与之前讲过的Requests库和Selenium库不同,Scrapy更适 ...

  8. Qt-UI 界面工具库简介

    一.关于Qt-UI界面工具库 Qt-UI界面工具库是武汉维仕杰科技有限公司基于Qt上进行扩展开发的控件包和界面工具,并且拥有完全自主的知识产权.得益于丰富的界面开发经验和强大的支持团队,使得Qt-UI ...

  9. 天龙日梅兰竹菊_第三百一十五章 梅兰竹菊

    第三百一十五章梅兰竹菊 自打应下无崖子的承诺以来,楚柏便是一直马不停蹄的赶路! 赶到西夏,在见了李秋水之后,又被李秋水拉着前往[缥缈峰],这一路,风尘仆仆的楚柏,总算是难得空闲下来了: 不得不说! [ ...

最新文章

  1. bootstrap跟vue冲突吗_知道微服务,但你知道微前端吗?
  2. python 的文件读写方法:read readline readlines wirte   writelines
  3. 【强烈推荐】Github star 10K+,周志华机器学习详细公式推导!
  4. C#版本与.NET版本对应关系以及各版本的特性
  5. python有内存处理模块吗_使用Python多处理的高内存使用
  6. 机器学习之数据预处理
  7. 文献综述写作之“结构内容”
  8. 【华为云技术分享】云图说 | 华为云AnyStack on BMS解决方案:助力线下虚拟化业务迁移上云
  9. android动画能超过父容器吗,Android中你不得不知道的动画知识 (一)
  10. 《大话》之 三大工厂
  11. java面向对象(封装-继承-多态)
  12. Alert提示框插件
  13. HDR到底是干什么的?建模的时候有什么用处?
  14. java 通过身份证计算年龄性别
  15. 扒开系统调用的三层皮(上)
  16. SSL基础:23:生成Kubernetes集群证书(OpenSSL方式)
  17. linux中gimp命令截图,Linux利用GIMP截图
  18. jar包打包成exe安装包
  19. linux ps1 主机名 ip,Bash Shell PS1: 自定义你的linux提示符十例
  20. java poi 填充单元格_POI操作excel表格(建立工作薄、创建工作表、将数据填充到单元格中)...

热门文章

  1. 网络安全漏洞分析之重定向漏洞分析
  2. 实用的文字转语音软件介绍
  3. 计算机图形学-二维图形-几何变换
  4. NAO机器人代码编译总结
  5. 遇到冲击波病毒引起的准备关机时怎么办?
  6. 吸附硼酸盐树脂,溴化锂溶液净化除硼CH-99
  7. Mybaits逆向工程
  8. 看完这篇文章学会epub批量转txt
  9. 中国制造VS国际品牌 耳熟能详游戏引擎比拼
  10. office办公---ppt技法及常用网址推荐