TensorFlow目前保存的模型文件主要有两种,ckpt与pb,二者之间的异同请见

https://zhuanlan.zhihu.com/p/32887066

下面,我以mnist手写数据集用softmax回归为例,说明如何对训练好的模型进行保存与恢复。

1. 训练模型并保存为模型文件

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as npmnist = input_data.read_data_sets('MNIST_data', one_hot=True)
sess = tf.InteractiveSession()x = tf.placeholder("float", shape=[None, 784], name='input_x')  # 输入图像占位符
y_ = tf.placeholder("float", shape=[None, 10])  # 标签类别占位符# 模型参数一般用Variable来表示
W = tf.Variable(tf.zeros([784, 10]), name='w')  # 权重W是一个784x10的矩阵(因为我们有784个特征和10个输出值)
b = tf.Variable(tf.zeros([10]), name='b')  # 偏置b是一个10维的向量(因为我们有10个分类)sess.run(tf.initialize_all_variables())  # 变量需要通过seesion初始化后,才能在session中使用
# 使用Tensorflow提供的回归模型softmax,y代表输出,把向量化后的图片x和权重矩阵W相乘,加上偏置b,然后计算每个分类的softmax概率值
y = tf.nn.softmax(tf.matmul(x, W) + b, name='predict')cross_entropy = - tf.reduce_sum(y_ * tf.log(y))  # 计算交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  # 梯度下降算法以0.01的学习速率最小化交叉熵# tf.argmax返回某个tensor对象在某一维上的其数据最大值所在的索引值
# 下面这行返回一组布尔值如[True, False, True, True]
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))# 把布尔值转换成浮点数,然后取平均值,[True, False, True, True] 会变成 [1,0,1,1] ,取平均值后得到 0.75
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))for i in range(1000):batch = mnist.train.next_batch(50)  # 每一步迭代加载50个训练样本,然后执行一次train_stepsess.run(train_step, feed_dict={x: batch[0], y_: batch[1]})if i % 100 == 0:print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))  # 模型在测试数据集上面的正确率# 随便从测试集中取一个例子做测试
print(sess.run(y, feed_dict={x: np.expand_dims(mnist.test.images[15], axis=0)}))
print(sess.run(tf.argmax(sess.run(y, feed_dict={x: np.expand_dims(mnist.test.images[15], axis=0)}), axis=1)))   # 预测结果
print(mnist.test.labels[15])    # 标签值

对上述代码不熟悉的请参考:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_pros.html

a.保存为ckpt格式的模型文件

saver = tf.train.Saver()
saver.save(sess, "save_path/file_name")

生成的模型文件如下:

b.保存为pb格式的模型文件

builder = tf.saved_model.builder.SavedModelBuilder('./model2')
builder.add_meta_graph_and_variables(sess, ["mytag"])
builder.save()

生成的模型文件如下:

运行结果:(第6个元素最大,表示数字5,说明预测正确)

0.2847
0.8778
0.8945
0.8972
0.9031
0.9015
0.9109
0.9007
0.8901
0.9061
[[  2.59213091e-04   1.70691292e-05   1.03438069e-04   1.55748194e-022.95701193e-05   9.70679998e-01   7.14014686e-06   7.19119780e-051.32500082e-02   6.82865766e-06]]
[5]
[ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.]

2. 模型文件的恢复与使用

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
import os
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_filemnist = input_data.read_data_sets('MNIST_data', one_hot=True)# pb模型的恢复
def restore_model_pb():sess = tf.Session()tf.saved_model.loader.load(sess, ['mytag'], os.getcwd() + '\model2')input_x = sess.graph.get_tensor_by_name('input_x:0')op = sess.graph.get_tensor_by_name('predict:0')print(sess.run(op, feed_dict={input_x: np.expand_dims(mnist.test.images[15], axis=0)}))sess.close()# ckpt模型的恢复
def restore_model_ckpt():sess = tf.Session()# 加载模型结构saver = tf.train.import_meta_graph('./save_path/file_name.meta')# 只需要指定目录就可以恢复所有变量信息saver.restore(sess, tf.train.latest_checkpoint('./save_path'))# 直接获取保存的变量print(sess.run('w:0'))input_x = sess.graph.get_tensor_by_name('input_x:0')# # 获取需要进行计算的operatorop = sess.graph.get_tensor_by_name('predict:0')print(sess.run(op, feed_dict={input_x: np.expand_dims(mnist.test.images[15], axis=0)}))sess.close()restore_model_pb()
# 打印所有变量的值
# print_tensors_in_checkpoint_file("save_path/file_name", None, True)

运行结果:

[[  2.59213091e-04   1.70691292e-05   1.03438069e-04   1.55748194e-022.95701193e-05   9.70679998e-01   7.14014686e-06   7.19119780e-051.32500082e-02   6.82865766e-06]]

java中调用以上模型文件请参考:java调用tensorflow模型文件

TensorFlow 模型的保存与恢复相关推荐

  1. Tensorflow【实战Google深度学习框架】TensorFlow模型的保存与恢复加载

    我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载. 总结一下Tensorflow常用的模型保存 ...

  2. 简单完整地讲解tensorflow模型的保存和恢复

    http://blog.csdn.net/liangyihuai/article/details/78515913 在本教程主要讲到: 1. 什么是Tensorflow模型? 2. 如何保存Tenso ...

  3. Tensorflow模型的保存与恢复的细节

    翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...

  4. tensorflow——模型的保存和恢复tf.trian.saver()

    保存 1创建saver对象,确定save哪些:saver=tf.trian.Saver(),不填写参数的话默认全部 2指定在哪个session中保存,以及保存路径:saver.save(sess, ' ...

  5. [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)

    [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式) 个人网站–> http://www.yansongsong.cn TensorFl ...

  6. 基于Python的模型的保存、恢复、继续训练

    资源下载地址:https://download.csdn.net/download/sheziqiong/86774566 资源下载地址:https://download.csdn.net/downl ...

  7. tensorflow 1.0 学习:模型的保存与恢复(Saver)

    将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...

  8. tensorflow 模型的保存和加载

    为了让训练结果可以复用,需要将训练得到的神经网络模型持久化,也就是把模型的参数保存下来,并保证可以持久化后的模型文件中还原出保存的模型. 1. 保存模型 tensorflow提供了一个API可以方便的 ...

  9. TensorFlow:模型的保存与恢复(Saver)

    目录 前言 1 实例化对象 2 保存训练过程中或者训练好的, 模型图及权重参数 2.1保存训练模型 2.2 查看保存 3. 重载模型的图及权重参数(模型恢复)     前言 我们经常在训练完一个模型之 ...

最新文章

  1. BZOJ 1176: [Balkan2007]Mokia( CDQ分治 + 树状数组 )
  2. 【MongoDB】5.MongoDB与java的简单结合
  3. np.c_与np.r_
  4. 胡凌:隐私的终结——大数据时代的个体生活危机
  5. jmeter更改java内存,jmeter内存溢出解决方法
  6. AngularJS XMLHttpRequest
  7. java重写重定向_JavaWeb请求转发与请求重定向理解
  8. 计算机硬件外围设备介绍,天津2012年自考“计算机外围设备使用与维护”课程考试大纲...
  9. Window/linux(Ubuntu)使用反编译工具jad
  10. Link-State协议的PRC计算详解
  11. 原地怠速油耗最大吗?为什么有人说汽车宁可跑起来也不要原地怠速?
  12. Android 中 C++ Thread线程用法
  13. 《深入浅出Python机器学习》读书笔记 第二章 基于Python语言的环境配置
  14. widi软件|widi音频转换软件
  15. 游戏运营的工作中是做什么
  16. mysql查询 NULL
  17. 演讲的思路锻炼,逆向思维需要刻意练习吗?
  18. ie上传文件到ftp服务器,通过浏览器上传文件到ftp
  19. LeetCode 739. 每日温度 | Python
  20. jQuery插件使用-瀑布流

热门文章

  1. 线程同步机制synchronized中锁的判断以及锁的作用范围
  2. 云借阅图书管理系统[基于SSM框架的项目]
  3. 三种方式实现阻塞队列(简单版)
  4. 关于我对Oracle的一些认知
  5. 电动车动力性计算MATLAB程序,matlab计算汽车动力性经济性(已编好程序).pdf
  6. java基础之idea工具使用
  7. 大数据智慧出行开发第一周:智慧出行底层数据架构剖析纵览全局
  8. Redis 集合 有序集合 python操作集合
  9. qq邮箱服务器host是什么
  10. 设计模式——装饰模式