TensorFlow——共享变量的使用方法
1.共享变量用途
在构建模型时,需要使用tf.Variable来创建一个变量(也可以理解成节点)。当两个模型一起训练时,一个模型需要使用其他模型创建的变量,比如,对抗网络中的生成器和判别器。如果使用tf.Variable,将会生成一个新的变量,而我们需要使用原来的那个变量。这时就是通过引入get_Variable方法,实现共享变量来解决这个问题。这种方法可以使用多套网络模型来训练一套权重。
2.使用get_Variable获取变量
get_Variable一般会配合Variable_scope一起使用,以实现共享变量。Variable_scope的含义是变量作用域。在某一作用域中的变量可以被设置成共享的方式,被其他网络模型使用。
get_Variable函数的定义如下:
tf.get_Variable(<name>, <shape>, <initializer>)
在TensorFlow里,使用get_Variable时候生成的变量是以指定的name属性为唯一标识,并不是定义的变量名称。使用时一般是通过name属性定位到具体变量,并将其共享到其他的模型中。
import tensorflow as tf import numpy as npvar1 = tf.Variable(1.0, name='first_var') print("var1: ", var1.name)var1 = tf.Variable(2.0, name='first_var') print('var1: ', var1.name)var2 = tf.Variable(3.0) print('var2: ', var2.name)var2 = tf.Variable(4.0) print('var1: ', var2.name)with tf.Session() as sess:sess.run(tf.global_variables_initializer())print("var1=", var1.eval())print("var2=", var2.eval())print()
在上述的代码中,,可以看到内存中有两个var1,并且他们的name是不一样的,对于图来说,后面的var1是生效的。当Variable定义没有指定名字时,系统会自动的加上一个名字Variable:0
3.get_Variable用法演示
import tensorflow as tf import numpy as npget_var1 = tf.get_variable('firat_var_1', [1], initializer=tf.constant_initializer(2)) print("var1: ", get_var1.name)get_var1 = tf.get_variable('firat_var_2', [1], initializer=tf.constant_initializer(3)) print("var1: ", get_var1.name)with tf.Session() as sess:sess.run(tf.global_variables_initializer())print("var1=", get_var1.name)print("var1=", get_var1.eval())
使用不同的name定义变量,当使用相同的name时,会抛出异常,变量名可以相同,但是name是不能相同的。
如果要使用相同的name的话,我们需要使用variable_scope将他们隔开,看如下代码:
import tensorflow as tfwith tf.variable_scope('test_1'):var1 = tf.get_variable('first_var', shape=[2], dtype=tf.float32)with tf.variable_scope('test_2'):var1 = tf.get_variable('first_var', shape=[2], dtype=tf.float32)print("var1:", var1.name) print("var1:", var1.name)
根据程序的运行结果,我们可以发现变量的名字加上了作用域的名称,这样使得我们能够在不同的作用域下面定义name相同的变量,同时,scope还支持嵌套定义,
with tf.variable_scope('test_0'):with tf.variable_scope('test_1'):var1 = tf.get_variable('first_var', shape=[2], dtype=tf.float32)with tf.variable_scope('test_2'):var1 = tf.get_variable('first_var', shape=[2], dtype=tf.float32)print("var1:", var1.name) print("var1:", var1.name)
4.共享作用域
使用作用域中的参数reuse可以实现共享变量功能
在variable_scope里面有一个reuse=True属性,表示使用已经定义过的变量,这时,get_variable将不会在创建新的变量,而是去图中get_variable所创建的变量中找与name相同的变量。
import tensorflow as tfwith tf.variable_scope('test_0'):var1 = tf.get_variable('first_var', shape=[2], dtype=tf.float32)with tf.variable_scope('test_2'):var2 = tf.get_variable('first_var', shape=[2], dtype=tf.float32)with tf.variable_scope('test_0', reuse=True):var3 = tf.get_variable('first_var', shape=[2], dtype=tf.float32)with tf.variable_scope('test_2'):var4 = tf.get_variable('first_var', shape=[2], dtype=tf.float32)print("var1:", var1.name) print("var2:", var2.name)print("var3:", var3.name) print("var4:", var4.name)
在上述的输出结果中,我们可以看到,var1和var3的名字一样,var2和var4的名字一样,则表明他们是同一个变量,如此就实现了变量的共享。在实际应用中,可以将1,2和3,4分别放在不同的模型进行训练,但是他们会作用于同一个模型的学习参数上。
使用anaconda的spyder工具运行时,代码只能运行一次,第二次运行将会报错。可以退出当前的kernel,再重新进入一下,因为tf.get_varibale在创建变量时,会去检查图中是否已经创建过该变量,如果创建过且不是共享的方式,则会报错。
因而可以使用tf.reset_default_graph(),将图里面的变量清空,就可以解决这个问题。
5.初始化共享变量
variable_scope和get_variable都具有初始化的功能。在初始化时,如果没有对当前变量初始化,则TensorFlow会默认使用作用域的初始化,并且作用域的初始化方法也有继承功能。
import tensorflow as tfwith tf.variable_scope('test_0', initializer=tf.constant_initializer(0.15)):var1 = tf.get_variable('first_var', shape=[2], dtype=tf.float32)with tf.variable_scope('test_2'):var2 = tf.get_variable('first_var', shape=[2], dtype=tf.float32)var3 = tf.get_variable('first_var_2', shape=[2], initializer=tf.constant_initializer(0.315))with tf.Session() as sess:sess.run(tf.global_variables_initializer())print("var1:", var1.eval())print("var2:", var2.eval())print("var3:", var3.eval())
当变量没有进行初始化时,会继承它的域的初始化方式,域也会继承它的上一级的域的初始化方式。在多模型训练时,常常可以对模型中的张量进行分区,同时,同一进行初始化。在变量共享方面,可以使用tf.AUTO_REUSE来为reuse属性赋值。tf。AUTO_REUSE可以实现第一次调用variable_scope时,传入reuse的值为false,再次调用时,reuse的值为True。
6.作用域与操作符的受限范围
variable_scope还可以使用with variable_scope as xxxscope的方式定义作用域,当使用这种方式时,将不会在受到外层的scope所限制。
import tensorflow as tfwith tf.variable_scope('test2', initializer=tf.constant_initializer(1.5)) as sp:var1 = tf.get_variable('var1', [2], dtype=tf.float32)print(sp.name) print(var1.name)with tf.variable_scope('test1', dtype=tf.float32, initializer=tf.constant_initializer(5.5)):var2 = tf.get_variable('var1', [2], dtype=tf.float32)with tf.variable_scope(sp) as sp1:var3 = tf.get_variable('var2', [2], dtype=tf.float32)print("var2:", var2.name)print("var3:", var3.name)with tf.Session() as sess:sess.run(tf.global_variables_initializer())print("var2:", var2.eval())print("var2:", var3.eval())
通过with tf.variable_scope(sp) as sp1我们可知其没有收到外层的作用域所限制,初始化的操作时,它的值不是外层作用域的初始化值,而是指定的作用域的初始化的值。
对于操作符而言,不仅收到tf.name_scope限制还收到tf.variable_scope限制。
import tensorflow as tf import numpy as npwith tf.variable_scope('scope'):with tf.name_scope('op'):v = tf.get_variable('var1', [1])x = v + 2.0print("v:", v.name) print('x', x.name)
根据结果,我们可知,通过添加tf.name_scope('op'):作用域时,变量的命名并没有收到限制,只是改变了op的命名,通过tf.name_scope(''):还可以返回到顶层的作用域中。
import tensorflow as tf import numpy as npwith tf.variable_scope('scope'):var1 = tf.get_variable("v", [1])with tf.variable_scope('scope_1'):var2 = tf.get_variable("v", [1])with tf.name_scope(''):var3 = tf.get_variable('var1', [1])x = var3 + 2.0print("var1 ", var1.name) print('var2 ', var2.name) print('var3 ', var3.name) print('x ', x.name)
通过将通过tf.name_scope('')设置为空,对于变量名是没有影响,但是可以看到x的命名,它已经变成了最外层的命名了。
转载于:https://www.cnblogs.com/baby-lily/p/10934131.html
TensorFlow——共享变量的使用方法相关推荐
- Tensorflow中scope命名方法
两篇文章掌握Tensorflow中scope用法: [1]Tensorflow中scope命名方法(本文) [2]Tensorflow中tf.name_scope() 和 tf.variable_sc ...
- 【tensorflow】Sequential 模型方法
深入学习Keras中Sequential模型及方法 - 战争热诚 - bky https://www.cnblogs.com/wj-1314/p/9579490.html Sequential 序贯模 ...
- 【tensorflow】Sequential 模型方法 compile, model.compile
Sequential 顺序模型 API - Keras 中文文档 https://keras.io/zh/models/sequential/ Sequential 序贯模型 序贯模型是函数式模型的简 ...
- 新版本GPU加速的tensorflow库的配置方法
本文介绍在Anaconda环境中,配置可以用GPU运行的Python新版tensorflow库的方法. 在上一篇文章Anaconda配置Python新版本tensorflow库(CPU.GPU ...
- ModuleNotFoundError: No module named ‘tensorflow.compat.v2‘解决方法
ModuleNotFoundError: No module named 'tensorflow.compat.v2'解决方法 原因: tensorflow和keras版本不对齐或者keras版本过高 ...
- python36+centos7离线安装tensorflow与talib的方法
由于应用程序的服务器不能连接外网,导致无法使用pip install tensorflow/TA-Lib的方法: 环境配置:python3.6+centos7.2 通过间接的方法来完成安装:找一个能连 ...
- Anaconda安装tensorflow报错问题解决方法
最近脱离了googlecolab想使用本地的anaconda进行机器学习课题的演练,在安装tensorflow时报错 : UnsatisfiableError: The following speci ...
- python安装tensorflow报错_Anaconda安装tensorflow报错问题解决方法
最近脱离了googlecolab想使用本地的anaconda进行机器学习课题的演练,在安装tensorflow时报错 : UnsatisfiableError: The following speci ...
- tensorflow中的Session方法解释
Session()方法 首先我们需要创建一个Session对象.在不传参数的情况下,该Session的构造器将启动默认的图.之后我们可以通过Session对象的run(op)来执行我们想要的操作.te ...
最新文章
- StructureMap 代码分析之Widget 之Registry 分析 (1)
- ssh调用expect使用以及shell同时传入两个参数调用
- IT-标准化-系列-6.关闭事件跟踪程序
- Linux命令(三) 移动文件 mv
- jzoj3058-火炬手【高精度,暴力】
- 18000 6c java_面向ISO18000-6C协议的无源超高频射频识别标签芯片设计
- 无ide编译java_无IDE编译和运行java
- Python21天打卡Day13-生成器表达式
- 【HDOJ4699】Editor(对顶栈,模拟)
- Mybatis使用技巧
- 【运动控制篇】(7)路径跟踪及组合动作方向
- 使用ItextPdf给PDF文件加文字水印和图片水印
- 4、golang 发送电子邮件
- php edm 系统,edm.php · 那些年我们一起/fanwe - Gitee.com
- 服务器系统需要装显卡驱动吗,显卡驱动需要更新吗,详细教您显卡驱动需要更新吗...
- Java使用模板导出带图片word文档
- 最新版akamai2.0逆向分析爬虫破盾风控绕过tls指纹
- 【pySerial3.4官方文档】1、pySerial
- airpods pro是按压还是触摸_外媒曝AirPods Pro出现广泛故障:触摸和佩戴识别失效...
- 5G 的三大应用场景——ITU-R原文
热门文章
- 创建maven项目多模块项目
- redhat7下对用户账户的管理
- no suitable driver found for jdbc:mysql//localhost:3306/..
- C#调用存储过程详解
- hibernate学习内容
- __FILE__,__LINE__,FUNCTION__实现代码跟踪调试(linux下c语言编程 )(转自IT博客)
- 在JScript中运行应用程序
- (转)用DynamicMethod提升ORM系统转换业务数据的性能
- 年薪 66万+,西澳大学招聘 CV DL Research Fellow(研究员)
- 收藏 | 用 Keras 实现神经网络来解决梯度消失的问题