TensorFlow模型持久化
模型持久化的目的在于可以使模型训练后的结果重复使用,节省重复训练模型的时间。
模型保存
train.Saver类是TensorFlow提供的用于保存和还原模型的API,使用非常简单。
import tensorflow as tf# 声明两个变量并计算其加和
a = tf.Variable(tf.constant([1.0, 2.0], shape=[2]), name='a')
b = tf.Variable(tf.constant([3.0, 4.0], shape=[2]), name='b')
result = a + b# 初始化全部变量的操作
init_op = tf.global_variables_initializer()
# 定义 Saver 类对象用于保存模型
saver = tf.train.Saver()with tf.Session() as sess:sess.run(init_op)saver.save(sess, "./model/model.ckpt")
上面的代码实现了一个简单的TensorFlow模型持久化的功能。
save()函数的sess参数用于指定要保存的模型会话,save_path参数用于指定路径。
通过Saver类的save()函数将TensorFlow模型保存到一个指定路径下的model.ckpt文件中。
(TensorFlow模型一般会保存在文件名为.ckpt的文件中,可以省略后缀名,但是好的编程习惯是对其加以指定)
虽然上面的程序只制定了一个文件路径,但是在这个文件目录下回出现4个文件:
- checkpoint文件是一个文本文件,保存了一个目录下所有的模型文件列表。该文件会被自动更新,当有更多模型被保存到model目录下时,文件内容会更新为最新的训练模型。
- model.ckpt.data-00000-of-00001文件是一个二进制文件,保存了TensorFlow中每一个变量的取值。
- model.ckpt.index文件是一个二进制文件,保存了每一个变量的名称,是一个string-string的table,其中table的key值为tensor名,value值为BundleEntryProto。
- model.ckpt.meta文件是一个二进制文件,保存了计算图的结构。
将一个模型文件分成多个文件保存的原因是TensorFlow会将模型的计算图结构以及参数的取值分开来保存。
模型加载
TensorFlow也提供了相应的函数来加载保存的模型。
with tf.Session() as sess:saver.restore(sess, "./model/model.ckpt")print(sess.run(result))
输出:
加载模型的代码和保存模型的代码相似,但是省略了初始化全部变量的过程。
使用restore()函数需要在模型参数恢复前定义计算图上的所有运算,并且变量名需要与模型存在的变量名保持一致,这样就可以将变量的值通过已保存的模型加载进来。
有时我们可能不希望重复定义计算图上的计算,太繁琐了,TensorFlow提供了import_meta_graph()函数加载模型的计算图。
import_meta_graph()函数的输入参数为.meta文件的路径,返回一个Saver类实例,再调用这个实例的restore()函数就可以恢复参数了。
saver = tf.train.import_meta_graph("./model/model.ckpt.meta")with tf.Session() as sess:saver.restore(sess, "./model/model.ckpt")# 获取默认计算图上指定节点处的张量print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
输出:
.ckpt.meta文件保存了计算图的结构,通过import_meta_graph()函数将计算图导入到程序中并传递给saver,之后在会话中通过restore()函数对该计算图中变量的值进行加载。
get_tensor_by_name()函数用于获取指定节点处的张量(add:0 表示add节点的第一个输出)。
TensorFlow模型持久化相关推荐
- Tensorflow模型持久化与恢复
Tensorflow模型 简单点说,一个tensorflow模型包含了神经网络的结构(graph)和通过训练得到的一系列神经网络的参数. 神经网络的结构(graph)即神经网络的节点(nodes)及其 ...
- tensorflow模型持久化方法
#测试模型持久化 v1 = tf.Variable(tf.constant(1.,shape=[2,2]),name='v1') v2 = tf.Variable(tf.constant(1.,sha ...
- 【TensorFlow】TensorFlow从浅入深系列之十三 -- 教你深入理解模型持久化(模型保存、模型加载)
本文是<TensorFlow从浅入深>系列之第13篇 TensorFlow从浅入深系列之一 -- 教你如何设置学习率(指数衰减法) TensorFlow从浅入深系列之二 -- 教你通过思维 ...
- 5.2 TensorFlow:模型的加载,存储,实例
背景 之前已经写过TensorFlow图与模型的加载与存储了,写的很详细,但是或闻有人没看懂,所以在附上一个关于模型加载与存储的例子,CODE是我偶然看到了,就记下来了.其中模型很巧妙,比之前nump ...
- ONNX系列四 --- 使用ONNX使TensorFlow模型可移植
目录 TensorFlow简介 安装和导入转换器 快速浏览模型 将TensorFlow模型转换为ONNX 摘要和后续步骤 参考文献 下载源547.1 KB 系列文章列表如下: ONNX系列一 --- ...
- TensorFlow模型保存和提取方法
2019独角兽企业重金招聘Python工程师标准>>> 一.TensorFlow模型保存和提取方法 1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存 ...
- TensorFlow模型保存和提取方法(含滑动平均模型)
一.TensorFlow模型保存和提取方法 1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取.tf.train.Saver对象saver的save方法将Tens ...
- tensorflow模型固化
1 tensorflow模型固化 1.1 训练时直接固化成pb文件 import tensorflow as tf from tensorflow.python.framework import gr ...
- 干货 | tensorflow模型导出与OpenCV DNN中使用
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自|OpenCV学堂 OpenCV DNN模块 Deep N ...
最新文章
- 全图表征学习算法之无监督学习和基于卷积神经网络的监督学习
- 如何使我的Python程序休眠50毫秒?
- input 属性和用法
- fastjson反序列化多层嵌套泛型类与java中的Type类型
- 企业微信 添加白名单_企业微信群为什么只能加200人?企业微信群怎么申请扩容?...
- 【年终总结】有三AI至今在人脸图像算法领域都分享了哪些内容?
- c语言考试算法,c语言考试常用算法docx.docx
- 【Node】常用基础 API 整理
- 网络:NAT使用场景
- SAP CRM呼叫中心polling javascript - icf_notify_poll.js
- 令人叫绝的EXCEL函数功能
- SAP License:MM自动过账科目特殊库存杂谈
- 中国替扎尼定行业市场供需与战略研究报告
- 取消IDEA保存文件,默认删除行尾空格
- 判断是否离开当前页面
- zz 传苹果平板电脑的UI界面将具备“快速学习”功能
- 救急的戴尔Latitude 10商用平板电脑
- NBIOT的BC26使用
- 详细的ico图标制作与Qt修改exe图标方法
- 像京东等大厂为什么不通过减薪来代替裁员,降低成本?
热门文章
- 36. Valid Sudoku
- Python基础之(面向对象初识)
- 2018美团笔试字符串问题
- springmvc 中controller与jsp传值
- java security 详解_Spring Security入门教程 通俗易懂 超详细 【内含案例】
- linux 信号量锁 内核,Linux内核中锁机制之信号量、读写信号量
- 上海教师中级职称英语计算机考试,计算机教师如果考过了软考中级对职称评定有用吗,学校会承认嘛,有人懂吗,求助...
- Java黑皮书课后题第4章:*4.2(几何:最大圆距离)最大圆面积是指球面上两个点间的距离。编写一个程序,提示用户以度为单位输入地球上两个点的经纬度,显示其最大圆距离值
- WEBBASE篇: 第八篇, JavaScript知识2
- HDU1425 A Chess Game