深度学习框架Tensorflow学习与应用实战

1.训练模型参数(权值矩阵、偏置值)并保存到指定文件夹下

#保存训练好的模型参数(权重矩阵、偏置值等)到指定文件夹中
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data#载入数据集,one_hot:将标签转换为只有一位为1,其它为0,会自动从网上下载数据集到当前目录
mnist=input_data.read_data_sets("MNIST_data",one_hot=True)#每个批次的大小
batch_size=100
#计算一共有多少个批次(整除)
n_batch=mnist.train.num_examples//batch_size#该神经网络只输入层和输出层,输入层包含784个神经元,输出层包含10个神经元
#定义两个placeholder,将28*28数字图片偏平为规格为784的向量
x=tf.placeholder(tf.float32,[None,784])
#标签结果
y=tf.placeholder(tf.float32,[None,10])#创建一个简单的神经网络
#权值初始化为0,  784x10
W=tf.Variable(tf.zeros([784,10]))
#偏置值
b=tf.Variable(tf.zeros([10]))
#softmax将输出转化为概率值
prediction=tf.nn.softmax(tf.matmul(x,W)+b)#二次代价函数,差的平方的平均值
#loss=tf.reduce_mean(tf.square(y-prediction))
#交叉熵代价函数的平均值
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))#梯度下降法
train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss)#初始化变量
init=tf.global_variables_initializer()#tf.argmax(y,1)返回1的位置(真实值),tf.argmax(prediction,1)(预测值)返回概率值最大的位置,比较位置是否相等,若想等返回true,不等返回false,存放在布尔列表中
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#tf.argmax()返回一维张量中最大值的位置
#求准确率
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))saver=tf.train.Saver()with tf.Session() as sess:sess.run(init)# 对所有图片迭代11次for epoch in range(11):#对所有图片分批训练一次for batch in range(n_batch):#获取一批(100个)样本图片,batch_xs:图片信息,batch_ys:图片标签batch_xs,batch_ys=mnist.train.next_batch(batch_size)#利用训练图片信息及对应标签,梯度下降法训练模型,得到权重W及bsess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})#利用测试集进行测试该迭代时模型的准确率acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})#打印迭代次数及对应准确率print("Iter "+str(epoch)+",Testing Accuracy "+str(acc))#保存训练好的模型(权值、偏置值等信息)saver.save(sess,'net/my_net.ckpt')

查看当前目录net文件夹下结果:

2.载入训练好的模型参数并用于测试集

#载入已经训练好的模型参数测试
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data#载入数据集,one_hot:将标签转换为只有一位为1,其它为0,会自动从网上下载数据集到当前目录
mnist=input_data.read_data_sets("MNIST_data",one_hot=True)#每个批次的大小
batch_size=100
#计算一共有多少个批次(整除)
n_batch=mnist.train.num_examples//batch_size#该神经网络只输入层和输出层,输入层包含784个神经元,输出层包含10个神经元
#定义两个placeholder,将28*28数字图片偏平为规格为784的向量
x=tf.placeholder(tf.float32,[None,784])
#标签结果
y=tf.placeholder(tf.float32,[None,10])#创建一个简单的神经网络
#权值初始化为0,  784x10
W=tf.Variable(tf.zeros([784,10]))
#偏置值
b=tf.Variable(tf.zeros([10]))
#softmax将输出转化为概率值
prediction=tf.nn.softmax(tf.matmul(x,W)+b)#二次代价函数,差的平方的平均值
#loss=tf.reduce_mean(tf.square(y-prediction))
#交叉熵代价函数的平均值
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))#梯度下降法
train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss)#初始化变量
init=tf.global_variables_initializer()#tf.argmax(y,1)返回1的位置(真实值),tf.argmax(prediction,1)(预测值)返回概率值最大的位置,比较位置是否相等,若想等返回true,不等返回false,存放在布尔列表中
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#tf.argmax()返回一维张量中最大值的位置
#求准确率
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))saver=tf.train.Saver()with tf.Session() as sess:sess.run(init)#打印准确率(不准确),因为此时权重矩阵为0,偏置值为0print("未载入模型时准确率:"+str(sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})))#载入已经训练好的模型参数(权值矩阵、偏置值)saver.restore(sess,'net/my_net.ckpt')#打印准确率,此时权值矩阵、偏置值均不为0print("载入模型后准确率:"+str(sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})))

运行结果:

Tensorflow-saver模型参数保存及载入相关推荐

  1. TF:利用TF的train.Saver将训练好的W、b模型文件保存+新建载入刚训练好模型(用于以后预测新的数据)

    TF:利用TF的train.Saver将训练好的W.b模型文件保存+新建载入刚训练好模型(用于以后预测新的数据) 目录 输出结果 代码设计 输出结果 代码设计 import tensorflow as ...

  2. [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)

    [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式) 个人网站–> http://www.yansongsong.cn TensorFl ...

  3. pytorch多卡并行模型的保存与载入

    pytorch多卡并行模型的保存与载入 当模型是在数据并行方式在多卡上进行训练的训练和保存,那么载入的时候也是一样需要是多卡.并且,load_state_dict()函数的调用要放在DataParal ...

  4. TensorFlow:模型的保存与恢复(Saver)

    目录 前言 1 实例化对象 2 保存训练过程中或者训练好的, 模型图及权重参数 2.1保存训练模型 2.2 查看保存 3. 重载模型的图及权重参数(模型恢复)     前言 我们经常在训练完一个模型之 ...

  5. tensorflow之pb文件保存与载入

    pb是protocol(协议) buffer(缓冲)的缩写.TensorFlow训练模型后存成的pb文件,是一种表示模型(神经网络)结构的二进制文件,将图中的变量保存成为常量,便于调用,一般无法将pb ...

  6. TensorFlow(1)-模型相关基础概念

    TensorFlow-1 1.Graph对象 2.Session对象 3.Variabels变量 4. placeholders与feed_dict 5. tf.train.Saver() 模型参数保 ...

  7. tensorflow 1.0 学习:模型的保存与恢复(Saver)

    将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...

  8. Pytorch中参数和模型的保存与读取

    Tensor变量的存取(包括parameter) 对于普通Tensor变量的存取,如下代码所示: import torch import torch.nn as nn x = torch.ones(3 ...

  9. paddlepaddle(六)模型保存与载入

    目录 1.API分类 1.1基础API 1.2高级API 2.训练调优场景的模型&参数保存载入 2.1动态图参数保存载入 2.2静态图参数保存载入 3.训练部署场景的模型参数保存载入 3.1 ...

  10. 【待更新】GPU 保存模型参数,GPU 加载模型参数

    GPU 保存模型参数,GPU 加载模型参数 保存 # 模型 device = torch.device('cuda') net = KGCN(num_user, num_entity, num_rel ...

最新文章

  1. yaf_dispatcher.c 的 yaf_dispatcher_fix_default函数
  2. linux centos grub grub2 加密、清除
  3. 应用分析:CIO须注意SOA使用中的五大隐患
  4. android 填满手机磁盘空间方法
  5. winform point数组带数值_带你学够浪:Go语言基础系列 - 8分钟学复合类型
  6. xampp mysqli_query and后的条件不行_Java笔记不用!null作为判空条件
  7. 使用phpstorm+wamp实现php代码实时调试审计
  8. CREO 6.0 - 基础 - 01 - 零件 - 零件的装配 - 零件的移动、偏转、角度角度设定
  9. Linux——SUID、SGID、SBIT简介
  10. pytorch拼接与拆分
  11. 安川焊接机器人做圆弧运动编程_安川焊接机器人编程
  12. linux xp双系统引导修复工具,Ubuntu与XP双系统引导修复备忘
  13. 【刘晓燕长难句分析】1.简单句
  14. 完全平方数-动态规划
  15. 和你走在南京种满梧桐的大街小巷
  16. 游戏的分类及相关热点
  17. Android 快速集成阿里云OSS服务2020
  18. 001-2019-0124 前端Html
  19. win10切换输入法快捷键_电脑小白必学的5个Win10技巧
  20. PHP正则表达式提取html超链接中的h…

热门文章

  1. ECMAScript标准命名
  2. 西部世界:币本位是什么?
  3. 金仓数据库字段_金仓数据库认证工程师(KCE)考试试题_含答案_
  4. CImage::Loda 方法加载图片失败,因为vs2013中该方法不支持中文变量
  5. CorelDRAW 12快捷键
  6. XMLHTTP的ReadyState与Statu详解
  7. 中国“神威•太湖之光”蝉联世界超算冠军
  8. WPF 通过Image控件实现多张图片的播放
  9. 现代C++的文艺复兴
  10. java怎么用蓝牙传_[技巧]蓝牙传输JAVA简易教程(图文及小常识)