和这篇文章对比https://blog.csdn.net/fanzonghao/article/details/81023730

不希望重复定义图上的运算,也就是在模型恢复过程中,不想sess.run(init)首先看路径

lineRegulation_model.py定义线性回归类:

import tensorflow as tf
"""
类定义一些公共量,方便模型载入用
"""
class LineRegModel:def __init__(self):with tf.variable_scope('var'):self.a_val=tf.Variable(tf.random_normal(shape=[1]),name='a_val')self.b_val = tf.Variable(tf.random_normal(shape=[1]),name='b_val')self.x_input=tf.placeholder(dtype=tf.float32,name='input_placeholder')self.y_label = tf.placeholder(dtype=tf.float32,name='result_placeholder')self.y_output = tf.add(tf.multiply(self.x_input,self.a_val),self.b_val,name='output')self.loss=tf.reduce_mean(tf.pow(self.y_output-self.y_label,2))def get_saver(self):return tf.train.Saver()def get_op(self):return tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)

model_train.py定义模型训练过程

import tensorflow as tf
import numpy as np
from save_and_restore2 import global_variable
from save_and_restore2 import  lineRegulation_model as model
import os
if not os.path.exists('./model'):os.makedirs('./model')
"""
训练模型
"""
train_x=np.random.rand(5)
train_y=train_x*5+3
model=model.LineRegModel()#类要加括号
a_val=model.a_val
b_val=model.b_val
x_input=model.x_input
y_label=model.y_label
y_output=model.y_output
loss=model.loss
optimizer=model.get_op()
saver=model.get_saver()
if __name__ == '__main__':init=tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)flag=Trueepoch=0while flag:epoch+=1cost,_=sess.run([loss,optimizer],feed_dict={x_input:train_x,y_label:train_y})if cost<1e-6:flag=Falseprint('a={},b={}'.format(a_val.eval(sess),b_val.eval(sess)))print('epoch={}'.format(epoch))print(a_val)# print(a_val.op)saver.save(sess,global_variable.save_path)print('model save finish')

print(a_val)的形式

print(a_val.op)的形式

model_restore.py恢复模型 ,利用恢复图在恢复权重的方式,可实现更细节的模型恢复

import tensorflow as tf
from save_and_restore import global_variable,lineRegulation_model as model
"""
恢复模型图文件
"""
saver=tf.train.import_meta_graph('./model/weight.meta')
#读取placeholder和最终的输出结果
graph=tf.get_default_graph()
a_val=graph.get_tensor_by_name('var/a_val:0')
b_val=graph.get_tensor_by_name('var/b_val:0')input_placeholder=graph.get_tensor_by_name('input_placeholder:0')
labels_placeholder=graph.get_tensor_by_name('result_placeholder:0')
y_output=graph.get_tensor_by_name('output:0')with tf.Session() as sess:#具体权重的恢复saver.restore(sess,'./model/weight')result=sess.run(y_output,feed_dict={input_placeholder:[1]})print(result)print(sess.run(a_val))print(sess.run(b_val))

简单的线性回归实现模型的存储和读取相关推荐

  1. 数学建模优化模型简单例题_数学建模之优化模型:存储模型

    点击上方「蓝字」关注我们 最近,为申报市级精品课程,我为我校"数学建模与科学计算"课程录制了讲课视频,下面是3.1节优化模型的第一个例子:存储模型.敬请大家批评指正! 优化模型是数 ...

  2. python计算均方根误差_如何在Python中创建线性回归机器学习模型?「入门篇」

    线性回归和逻辑回归是当今很受欢迎的两种机器学习模型. 本文将教你如何使用 scikit-learn 库在Python中创建.训练和测试你的第一个线性.逻辑回归机器学习模型,本文适合大部分的新人小白. ...

  3. 通过简单的线性回归理解机器学习的基本原理

    在本文中,我将使用一个简单的线性回归模型来解释一些机器学习(ML)的基本原理.线性回归虽然不是机器学习中最强大的模型,但由于容易熟悉并且可解释性好,所以仍然被广泛使用.简单地说,线性回归用于估计连续或 ...

  4. 简单多元线性回归(梯度下降算法与矩阵法)

    from:https://www.cnblogs.com/shibalang/p/4859645.html 多元线性回归是最简单的机器学习模型,通过给定的训练数据集,拟合出一个线性模型,进而对新数据做 ...

  5. 梯度下降的线性回归用python_运用TensorFlow进行简单实现线性回归、梯度下降示例...

    线性回归属于监督学习,因此方法和监督学习应该是一样的,先给定一个训练集,根据这个训练集学习出一个线性函数,然后测试这个函数训练的好不好(即此函数是否足够拟合训练集数据),挑选出最好的函数(cost f ...

  6. UA MATH571A 一元线性回归I 模型设定与估计

    UA MATH571A 一元线性回归I 模型设定与估计 模型设定 最小二乘法(Method of Least Square) Coefficients Mean Response and Residu ...

  7. 一个简单例子:贫血模型or领域模型

    转:一个简单例子:贫血模型or领域模型 贫血模型 我们首先用贫血模型来实现.所谓贫血模型就是模型对象之间存在完整的关联(可能存在多余的关联),但是对象除了get和set方外外几乎就没有其它的方法,整个 ...

  8. 4、python简单线性回归代码案例(完整)_python 实现一个简单的线性回归案例

    #!/usr/bin/env python # -*- coding: utf-8 -*- # @File : 自实现一个线性回归.py # @Author: 赵路仓 # @Date : 2020/4 ...

  9. 我的第一篇博文——简单的C/S模型

    这几天在学习Linux环境下的基础socket编程,作为一个小实验,自己编写了一个最基本简单的C/S模型,然而并没有像我想当然的那样一次性成功.一些错误来源于概念的偏差,而一些来源于对细节的忽略.总的 ...

最新文章

  1. ad16自动布线设置规则_PCB设计的十大误区——那些年,我们一起遵守的规则
  2. 学习率对神经网络迭代次数和准确率的影响以及近似数学表达式
  3. 孙正义举债豪购ARM的3个理由:潜伏物联网时代
  4. 列名无效如何解决_XSKY ClickHouse如何实现存算分离
  5. angular使用sass的scss语法
  6. 这几天都是在公司慢待
  7. php正则检查QQ,PHP 正则匹配手机号的QQ号
  8. 【源码阅读】Java集合之一 - ArrayList源码深度解读
  9. Prototype使用Form操作表单
  10. STM32库中自定义的数据类型
  11. poj 匈牙利二分匹配 模板 poj题目
  12. QTableView效率优化3 - 自定义Model的内容补充
  13. hsql转换oracle,Oracle To Hsql
  14. matlab绘制奈奎图,matlab画奈奎斯特图
  15. 面对传销,该怎么处理
  16. Ceph Recovery分析
  17. Destroying assets is not permitted to avoid data loss.
  18. 主引导记录(MBR)、硬盘分区表(DPT)、扩展引导记录(EBR)
  19. 【MATLAB】MATLAB矩阵的表示
  20. 【机器学习】最大均值差异MMD详解

热门文章

  1. 从前馈到反馈:解析循环神经网络(RNN)及其tricks
  2. Docker系列之二:基于容器的自动构建
  3. SQL 数据分析常用语句
  4. 带你根据源码了解View的事件触发流程,主要讲解为什么子View返回true,ViewGroup就无法接收到事件的过程
  5. 【HTML/CSS】CSS权重、继承及引入方式
  6. 知识图谱最新权威综述论文解读:知识表示学习部分
  7. String源码分析
  8. day24 反射\元类
  9. 对汉诺塔递归算法的理解(图解,附完整代码实现)
  10. USACO 1.2 Milking Cows (枚举)