将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。

模型保存,先要创建一个Saver对象:如

saver=tf.train.Saver()

在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:

saver=tf.train.Saver(max_to_keep=0)

但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐。

当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即

saver=tf.train.Saver(max_to_keep=1)

创建完saver对象后,就可以保存训练好的模型了,如:

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

第一个参数sess,这个就不用说了。第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。

saver.save(sess,'my-model', global_step=0) ==>      filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

看一个mnist实例:

# -*- coding:utf-8 -*-

"""

Created on SunJun  4 10:29:48 2017

@author:Administrator

"""

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist =input_data.read_data_sets("MNIST_data/", one_hot=False)

x =tf.placeholder(tf.float32, [None, 784])

y_=tf.placeholder(tf.int32,[None,])

dense1 =tf.layers.dense(inputs=x,

units=1024,

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.nn.l2_loss)

dense2=tf.layers.dense(inputs=dense1,

units=512,

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.nn.l2_loss)

logits=tf.layers.dense(inputs=dense2,

units=10,

activation=None,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)

train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

correct_prediction= tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)

acc=tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

saver=tf.train.Saver(max_to_keep=1)

for i in range(100):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_op, feed_dict={x: batch_xs,y_: batch_ys})

val_loss,val_acc=sess.run([loss,acc],feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

sess.close()

代码中红色部分就是保存模型的代码,虽然我在每训练完一代的时候,都进行了保存,但后一次保存的模型会覆盖前一次的,最终只会保存最后一次。因此我们可以节省时间,将保存代码放到循环之外(仅适用max_to_keep=1,否则还是需要放在循环内).

在实验中,最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。

saver=tf.train.Saver(max_to_keep=1)

max_acc=0

for i in range(100):

batch_xs, batch_ys =mnist.train.next_batch(100)

sess.run(train_op, feed_dict={x: batch_xs,y_: batch_ys})

val_loss,val_acc=sess.run([loss,acc],feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

ifval_acc>max_acc:

max_acc=val_acc

saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

sess.close()

如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。

saver=tf.train.Saver(max_to_keep=3)

max_acc=0

f=open('ckpt/acc.txt','w')

for i in range(100):

batch_xs, batch_ys =mnist.train.next_batch(100)

sess.run(train_op, feed_dict={x: batch_xs,y_: batch_ys})

val_loss,val_acc=sess.run([loss,acc],feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

f.write(str(i+1)+', val_acc:'+str(val_acc)+'\n')

if val_acc>max_acc:

max_acc=val_acc

saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

f.close()

sess.close()

模型的恢复用的是restore()函数,它需要两个参数restore(sess,save_path),save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:

model_file=tf.train.latest_checkpoint('ckpt/')

saver.restore(sess,model_file)

则程序后半段代码我们可以改为:

sess=tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

is_train=False

saver=tf.train.Saver(max_to_keep=3)

#训练阶段

if is_train:

max_acc=0

f=open('ckpt/acc.txt','w')

for i in range(100):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_op, feed_dict={x:batch_xs, y_: batch_ys})

val_loss,val_acc=sess.run([loss,acc],feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

f.write(str(i+1)+', val_acc:'+str(val_acc)+'\n')

if val_acc>max_acc:

max_acc=val_acc

saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

f.close()

#验证阶段

else:

model_file=tf.train.latest_checkpoint('ckpt/')

saver.restore(sess,model_file)

val_loss,val_acc=sess.run([loss,acc],feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))

sess.close()

标红的地方,就是与保存、恢复模型相关的代码。用一个bool型变量is_train来控制训练和验证两个阶段。

整个源程序:

# -*- coding:utf-8 -*-"""Created on SunJun  4 10:29:48 2017@author:Administrator"""import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist =input_data.read_data_sets("MNIST_data/", one_hot=False)x =tf.placeholder(tf.float32, [None, 784])y_=tf.placeholder(tf.int32,[None,])dense1 =tf.layers.dense(inputs=x,units=1024,activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss)dense2=tf.layers.dense(inputs=dense1,units=512,activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss)logits=tf.layers.dense(inputs=dense2,units=10,activation=None,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss)loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)correct_prediction= tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)   acc=tf.reduce_mean(tf.cast(correct_prediction, tf.float32))sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer())is_train=Truesaver=tf.train.Saver(max_to_keep=3)#训练阶段if is_train:max_acc=0f=open('ckpt/acc.txt','w')for i in range(100):batch_xs, batch_ys =mnist.train.next_batch(100)sess.run(train_op, feed_dict={x:batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc],feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)f.close()#验证阶段else:model_file=tf.train.latest_checkpoint('ckpt/')saver.restore(sess,model_file)val_loss,val_acc=sess.run([loss,acc],feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))sess.close()

tf.train.Saver相关推荐

  1. tf.train.Saver函数的用法之保存全部变量和模型

    用于保存模型,以后再用就可以直接导入模型进行计算,方便. 例如: [python] view plaincopy import tensorflow as tf; import numpy as np ...

  2. tensorflow tf.train.Saver.restore() (用于下次训练时恢复模型)

    # 保存当前的Session到文件目录tf.train.Saver().save(sess, 'net/my_net.ckpt') # 然后在下次训练时恢复模型: tf.train.Saver().r ...

  3. tf.train.Saver中max_to_keep设置无效

    [问题描述]: saver = tf.train.Saver(max_to_keep=2) saver.save(sess, args.model_save_dir + '/lm_pretrain.c ...

  4. tf第七讲:模型保存与加载(tf.train.Saver()tf.saved_model)及fine_tune(梯度冻结)

      大家好,我是爱编程的喵喵.双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中.从事机器学习以及相关的前后端开发工作.曾在阿里云.科大讯飞.CCF等比赛获得多次Top名次.现 ...

  5. tf.train.Saver,和模型参数微调

    https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/5-06-save/ https://blog.csdn.net/ ...

  6. TF:利用TF的train.Saver将训练好的W、b模型文件保存+新建载入刚训练好模型(用于以后预测新的数据)

    TF:利用TF的train.Saver将训练好的W.b模型文件保存+新建载入刚训练好模型(用于以后预测新的数据) 目录 输出结果 代码设计 输出结果 代码设计 import tensorflow as ...

  7. tensorflow tf.train.ExponentialMovingAverage().variables_to_restore()函数 (用于加载模型时将影子变量直接映射到变量本身)

    variables_to_restore函数,是TensorFlow为滑动平均值提供.之前,也介绍过通过使用滑动平均值可以让神经网络模型更加的健壮.我们也知道,其实在TensorFlow中,变量的滑动 ...

  8. tensorflow || 滑动平均的理解--tf.train.ExponentialMovingAverage

    1 滑动平均的理解 滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving average),可以用来估计变 ...

  9. 【Tensorflow教程笔记】常用模块 tf.train.Checkpoint :变量的保存与恢复

    基础 TensorFlow 基础 TensorFlow 模型建立与训练 基础示例:多层感知机(MLP) 卷积神经网络(CNN) 循环神经网络(RNN) 深度强化学习(DRL) Keras Pipeli ...

最新文章

  1. Asynctask源码分析
  2. python对文件的读操作方法有哪些-python--文件的读写操作
  3. flex布局:子子元素过大撑开了设定flex:1的子元素的解决方案
  4. 【ntp】虚拟机时间莫名异常
  5. 厦门one_“断轴”频发,李想承认理想ONE存缺陷!曾声明悬架非常安全
  6. java 中的fork join框架
  7. [html] 谈谈你对input元素中readonly和disabled属性的理解
  8. delphi 解析一维条码_科普帖:一般商用条码扫描器全知道,只需三把枪
  9. 36.软件安装:RPM,SRPM和YUM功能
  10. R语言数据最大最小归一化
  11. 对偶式与反函数_.数字逻辑.对偶式与反函数.数字逻辑下,对偶式与反函数和原函数的关系是什么?...
  12. 儿童讲堂 - 量词的解释
  13. laravel-集合对象的销毁forget,重组values(),pluck ()方法
  14. iphone 信号对应设备_如何访问iPhone的现场测试模式(并查看实际信号强度)
  15. 个人邮箱|如何群发邮件?3秒教你搞定
  16. mongoDB百度脑图总结
  17. 【总结,持续更新】java常见的线程不安全,你以为的线程安全
  18. 问题 J: 古罗马数字2
  19. 安卓手机文件系统 roots recovery bootimg
  20. D语言与C++做映射时需要注意的事情

热门文章

  1. python输入print跳到documentation-习题 48: 更复杂的用户输入
  2. 均质机工作原理动画_3D动画演示:有刷直流电机的工作原理
  3. 飞畅科技——视频光端机用光模块的选型详解
  4. [渝粤教育] 西南科技大学 政府经济学 在线考试复习资料
  5. 【渝粤题库】陕西师范大学202321投资银行学 作业(专升本)
  6. 六种常用的物联网通信协议
  7. matlab 大于并且小于,Matlab:将大于(小于)1(-1)的元素转换为1(-1)的序列
  8. 【数字信号处理】希尔伯特变换系列1之相位处理(含MATLAB代码)
  9. Atom使用方法(快捷键,插件,汉化)
  10. php中的空转为什么意思,php 长期更