tensorflow saver_TensorFlow: Model Persistence
TensorFlow提供了一个非常简单的API来保存和还原神经网络模型。这个API就是tf.train.Saver类。以下代码给出了保存TensorFlow计算图的方法:
import tensorflow as tf# 声明两个变量并计算它们的和
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2init_op = tf.global_variables_initializer()
saver = tf.train.Saver()with tf.Session() as sess:sess.run(init_op)# 将模型保存到Saved_model目录下saver.save(sess, "Saved_model/model.ckpt")
上面的代码实现了持久化一个简单的TensorFlow模型的功能。虽然上述程序只指定了一个文件路径,但这个目录下会出现多个文件。原书中说会生成3个文件,分别是
1) model.ckpt.meta —— 保存了TensorFlow计算图的结构
2) model.ckpt —— 保存了TensorFlow程序中每个变量的取值
3) checkpoint —— 保存了一个目录下所有的模型文件列表
但我运行的结果是出现了4个文件,不知和系统是否有关。我用的是OpenSUSE(Linux)系统。
![](http://img-02.proxy.5ce.com/view/image?&type=2&guid=a2c1ab50-e02e-eb11-8da9-e4434bdf6706&url=https://pic2.zhimg.com/v2-aada72d3849e15a8441c3362a7299831_b.jpg)
加载已保存的模型:
import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2init_op = tf.global_variables_initializer()
saver = tf.train.Saver()with tf.Session() as sess:saver.restore(sess, "Saved_model/model.ckpt")print(sess.run(result))
运行结果如下:
[ 3.]
这段加在模型的代码和前面保存模型的代码几乎是一样的——也是先定义了TensorFlow计算图上的所有运算,并声明了一个tf.train.Saver类。两段代码唯一的区别是,在加在模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。
如果不希望重复定义图上的运算,也可以直接加在已经持久化的图,代码如下:
import tensorflow as tf# 直接加载持久化的图
saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")with tf.Session() as sess:saver.restore(sess, "Saved_model/model.ckpt")print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [3.]# 运行结果:
# [ 3.]
- 保存滑动平均模型
import tensorflow as tfv = tf.Variable(0, dtype=tf.float32, name="v")
# 【勘误】原书的代码中用的是tf.all_variables(),版本太老了;系统提示改用tf.global_variables
# 在没有声明滑动平均模型时只有一个变量v,所以下面的语句只会输出“v:0”
for variables in tf.global_variables():print(variables.name)# 运行结果:
# v:0# ----------------------分割线----------------------ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
# 在声明了滑动平均模型之后,TensorFlow会自动生成一个影子变量 v/ExponentialMovingAverage
for variables in tf.global_variables():print(variables.name)# 运行结果:
# v:0
# v/ExponentialMovingAverage:0# ----------------------分割线----------------------saver = tf.train.Saver()
with tf.Session() as sess:#【勘误】原书是tf.initialze_all_variables(),版本太老,现在推荐的是tf.global_variables_initializer()init_op = tf.global_variables_initializer()sess.run(init_op)sess.run(tf.assign(v, 10))sess.run(maintain_averages_op)saver.save(sess, "Saved_model/model2.ckpt")print(sess.run([v, ema.average(v)]))# 运行结果:
# [10.0, 0.099999905]# ----------------------分割线----------------------import tensorflow as tfv1 = tf.Variable(0, dtype=tf.float32, name="v1")
v2 = tf.Variable(0, dtype=tf.float32, name="v2")# 通过变量重命名,将原模型的v赋值给v1,原模型的v/ExponentialMovingAverage赋值给v2
saver = tf.train.Saver({"v": v1, "v/ExponentialMovingAverage": v2})
with tf.Session() as sess:saver.restore(sess, "Saved_model/model2.ckpt")print(sess.run([v1, v2]))# 运行结果:
# [10.0, 0.099999905]
- variables_to_restore函数的使用样例
为了方便加载时重命名滑动平均变量,tf.train.ExponentialMovingAverage类提供了variables_to_restore函数来生成tf.train.Saver类所需要的变量重命名字典
import tensorflow as tfv = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.99)
print(ema.variables_to_restore())# 运行结果:
# {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}# 注意整个saver和上一段代码片中的saver的区别,这里就不用以变量重命名的方式载入ema了
saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:saver.restore(sess, "Saved_model/model2.ckpt")print(sess.run(v))# 运行结果:
# 0.0999999
- 保存为pb格式
使用tf.train.Saver会保存运行TensorFlow程序所需要的全部信息,然而有时并不需要某些信息。比如在测试或离线预测时,只需要知道如何从神经网络的输入层经过前向传播计算得到输出层即可,而不需要类似于变量初始化、模型保存等辅助节点的信息。在第6章介绍迁移学习时,会遇到类似的情况。而且,将变量取值和计算图结构分成不同的文件存储有时也不方便,于是TensorFlow提供了convert_variables_to_constants函数,通过这个函数可以将计算图中的变量及其取值通过常量的方式保存,这样整个TensorFlow计算图可以统一存放在一个文件中。下面的程序提供了一个样例:
import tensorflow as tf
from tensorflow.python.framework import graph_utilv1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2init_op = tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init_op)# 导出当前计算图中的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程graph_def = tf.get_default_graph().as_graph_def()# 将图中的变量及其取值转化为常量,同时将图中不必要的节点去掉# 如果只关心程序中定义的某些计算时,和这些计算无关的节点就没有必要导出并保存了# 在下面一行代码中,最后一个参数['add']给出了需要保存的节点名称# 注意add节点是上面定义的两个变量相加的操作,其后面没有:0# 而张量的名称后面有:0,表示的是某个计算节点的第一个输出output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:f.write(output_graph_def.SerializeToString())# 运行结果:
# Converted 2 variables to const ops.
![](http://img-03.proxy.5ce.com/view/image?&type=2&guid=a2c1ab50-e02e-eb11-8da9-e4434bdf6706&url=https://pic4.zhimg.com/v2-166413c1a97090fe2654b072606c3757_b.jpg)
- 加载pb格式的文件
import tensorflow as tf
from tensorflow.python.platform import gfilewith tf.Session() as sess:model_filename = "Saved_model/combined_model.pb"# 读取保存的模型文件,并将文件解析成对应的GraphDef Protocol Buffer:with gfile.FastGFile(model_filename, 'rb') as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())# 将graph_def中保存的图加在到当前的图中,return_elements=["add:0"]给出了返回的张量的名称。在保存的时候给出的是计算节点的名称,所以为"add"# 在加载的时候给出的是张量的名称,所以是"add:0"result = tf.import_graph_def(graph_def, return_elements=["add:0"])print(sess.run(result))# 运行结果:
# [array([ 3.], dtype=float32)]
tensorflow saver_TensorFlow: Model Persistence相关推荐
- Tensorflow 获取model中的变量列表,用于模型加载等
目录 前言 1. 用tensorflow自带的工具 2. 用tensorflow.contrib.slim. 3. 从保存的model中提取var_list 4. 其他 前言 在加载预训练的网络模型时 ...
- TensorFlow及model的安装
最近在学习深度学习的卷积神经网络,采用的学习框架是tensorflow框架.现在主要介绍一下自己再安装过程中遇到的问题及安装方法. 一.基于python3.5的安装 安装tensorflow可以基于p ...
- tensorflow中model.compile()
model.compile()用来配置模型的优化器.损失函数,评估指标等 里面的具体参数有: compile(optimizer='rmsprop',loss=None,metrics=None,lo ...
- 【tensorflow】model.fit() fit函数
[转载并参考]model.fit() fit函数_a1111h的博客-CSDN博客 https://blog.csdn.net/a1111h/article/details/82148497 fit( ...
- tensorflow model几种模型文件
转载:https://blog.eson.org/pub/3da24a26/ 模型文件相关 checkpoint (.ckpt): variable的序列化存储,常用于保存和还原模型参数.保存方式是变 ...
- tensorflow 1.x Saver(保存与加载模型) 预测
20201231 tensorflow 1.X 模型保存 https://blog.csdn.net/qq_35290785/article/details/89646248 保存模型 saver=t ...
- TensorRT 3:更快的TensorFlow推理和Volta支持
TensorRT 3:更快的TensorFlow推理和Volta支持 TensorRT 3: Faster TensorFlow Inference and Volta Support 英伟达Tens ...
- 最喜欢随机森林?周志华团队 DF21 后,TensorFlow 开源决策森林库 TF-DF
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 转自 | 机器之心 TensorFlow 决策森林 (TF-DF) ...
- ubuntu16.04 cuda9.0 cudnn Tensorflow GPU 1.10.0
Ubuntu14.04升级到Ubuntu16.04 查看目前版本 lsb_release -a apt-get update && apt-get dist-upgrade reboo ...
最新文章
- 【c语言】蓝桥杯入门训练 圆的面积
- [Swift]LeetCode388. 文件的最长绝对路径 | Longest Absolute File Path
- 电脑配置清单_2020电脑配置清单AMD指南
- 洛谷1091合唱队形
- 大数据产品开发流程规范_华为内部资料流出!揭秘华为数据湖:3大特点、6个标准、入湖流程...
- c++11线程必须要懂得同步技术
- 深度学习常用框架和基础模型
- c语言编程星号输出图形的步骤,使用C语言打印不同星号图案
- Flutter 开源社交电商项目Flutter_Mycommunity_App
- zabbix如何网站监控web
- PNP与NPN的区别与判断(一)
- 如何解决Worm.Win32.AutoRun.bqn(文件夹改exe病毒)
- 【转载】手机UC浏览器缓存视频合并方法
- 微信小程序腾讯服务器地址要购买吗,微信小程序JavaScript SDK
- Harbor安装(待补充)
- 大厂代码规范及个人本学期的代码规范
- 建站过程中如何防止被骗
- UILabel的使用
- win7旗舰版蓝屏代码说明
- Tesla数据标注系统解析