tensorflow 的 BatchNormalization的常见的坑

1. 在训练时设置批标准化的参数 training = True, 执行测试时设置  training = False

2. 在训练时,要用 tf.control_dependencies 添加对批标准化的“update_ops”的依赖控制,先执行批标准化的“update_ops”,然后后再进行当前次的迭代训练。

注:1)批标准化的“update_ops”主要是指两个滑动平均操作,一个用于计算均值,一个用于计算方差。

2) tensorflow中有多个实现批标准化的函数,有的会自动把“update_ops”放入tf.GraphKeys.UPDATE_OPS 这个collection中,有的则不会,比如tf.keras.layers.BatchNormalization, 此时需要手动把“update_ops”放入tf.GraphKeys.UPDATE_OPS中。

3)如果在训练时没有添加 tf.control_dependencies:如果不在使用时添加tf.control_dependencies函数,即在训练时(training=True)每批次时只会计算当批次的mean和var,并传递给tf.nn.batch_normalization进行归一化,由于mean_update和variance_update在计算图中并不在上述操作的依赖路径上,因为并不会主动完成,也就是说,在训练时mean_update和variance_update并不会被使用到,其值一直是初始值。因此在测试阶段(training=False)使用这两个作为mean和variance并进行归一化操作,这样就会出现错误。而如果使用tf.control_dependencies函数,会在训练阶段每次训练操作执行前被动地去执行mean_update和variance_update,因此moving_mean和moving_variance会被不断更新,在测试时使用该参数也就不会出现错误。

3. 批标准化层会带来4个新的变量,但是这四个变量不是trainable的,而且这四个变量在测试阶段要会用到,因此,在保存模型时要设置 saver = tf.train.Saver(var_list=tf.global_variables()), 即保存所有全局变量,而不是仅仅保存可训练的变量。

4. BN层模型pb转化为tflite时, 我暂时没遇到,具体可参考:BN层模型pb转化为tflite时

5. 训练时不要设置batchsize=1,当batch_size = 1时,batch_norm实际上是instance_norm.由于训练时的batch_size太小,导致滑动平均值不稳定,因为使用滑动平均值去测试效果不好.

6. 批标准化的位置放在激活函数之前


1. tf.control_dependencies

该函数的作用是:在执行其辖区的命令之前,先确保辖区命令依赖的节点被率先执行。辖区命令依赖的节点被作为参数放在tf.control_dependencies的参数位置。请看下面一个例子。

import tensorflow as tfa = tf.Variable(1)
b = tf.Variable(2)
update_op = tf.assign(a, 10) # 赋值操作,改变变量 a_1 的数值# 添加依赖控制 # 该命令的参数是被依赖的计算节点,只有当被依赖的节点被执行后才能执行该命令辖区内的操作,
with tf.control_dependencies([update_op]): # 添加了对 update_op 的依赖,执行辖区命令时,要先执行 update_op 节点
# with tf.control_dependencies([]):        # 未添加对 update_op 的依赖,执行辖区命令时,不会调用执行 update_op 节点c = tf.add(a, b)with tf.Session() as sess:sess.run(tf.global_variables_initializer())a,b,c= sess.run([a,b,c])print("a={}  b={}  c={}".format(a,b,c))"""
实验结果:
1)当添加了对 update_op 的依赖时(执行 with tf.control_dependencies([update_op]) ),输出结果为:
a=10  b=2  c=122)当未添加对 update_op 的依赖时(执行 with tf.control_dependencies([]):),输出结果为:
a=1  b=2  c=3结果分析:
虽然在图的搭建过程中定义了 update_op = tf.assign(a, 10) 节点,但是tensorflow的图的运行并不是按命令行的顺序执行的。
如果c = tf.add(a, b) 操作没有对节点 update_op 添加依赖,那么在执行 c = tf.add(a, b)时,update_op 不会被主动调用。tf.control_dependencies 就是解决了这个问题。"""

2. tf.GraphKeys.UPDATE_OPS

tf.GraphKeys.UPDATE_OPS 是一个tensorflow的计算图中内置的一个集合,该集合会保存一些需要在训练操作之前完成的操作,一般配合tf.control_dependencies函数使用

比如,在批标准化中,存在两个变量:mean和variance,这两个值是随着迭代的进行不断更新的,因此在执行每次迭代之前,需要先计算mean和variance。通过下面一个例子介绍tf.layers.batch_normalization产生的mean和variance。

import tensorflow as tfinput = tf.ones([1, 2, 2, 3])
output = tf.layers.batch_normalization(input, training=True)
#output = tf.layers.batch_normalization(input, training=False)# 打印batch_normalization中的两个操作
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
print(update_ops)# with tf.control_dependencies(update_ops):
#   train_op = optimizer.minimize(loss)# with tf.Session() as sess:
#   sess.run(tf.global_variables_initializer())
#   saver = tf.train.Saver()
#   saver.save(sess, "batch_norm_layer/Model")"""
打印:
[<tf.Operation 'batch_normalization/AssignMovingAvg' type=AssignSub>, <tf.Operation 'batch_normalization/AssignMovingAvg_1' type=AssignSub>]
这两个变量即均值和方差。当tf.layers.batch_normalization的参数training=False时,打印内容为:[]
因为此时不需要对均值和方差进行更新。"""

3. tensorflow 的批标准化

tensorflow可以通过多种途径进行批标准化,有多个候选函数:

  • tf.nn.batch_normalization
  • tf.layers.batch_normalization(input,training=True)
  • tf.layers.BatchNormalization()(input,training=True)
  • tf.contrib.layers.batch_norm(input, is_training=True)
  • tf.keras.layers.BatchNormalization()(input, training=True)

在第二节的例子中,我们使用了 tf.layers.batch_normalization, 该函数可以自动创建 update_ops(即均值和方差变量)并加入tf.GraphKeys.UPDATE_OPS中,在训练时,我们直接添加对 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 的依赖即可,但是是不是其他的批标准化函数也有这种自动操作呢?很遗憾,答案是否定的,tf.keras.layers.BatchNormalization 就没有。实验如下:

import tensorflow as tfinput = tf.ones([1, 2, 2, 3])
# 主要改变了批标准化函数
output = tf.keras.layers.BatchNormalization()(input, training=True)# 打印batch_normalization中的两个操作
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
print(update_ops)# with tf.control_dependencies(update_ops):
#   train_op = optimizer.minimize(loss)# with tf.Session() as sess:
#   sess.run(tf.global_variables_initializer())
#   saver = tf.train.Saver()
#   saver.save(sess, "batch_norm_layer/Model")"""
打印:
[]
"""

可见,tf.keras.layers.BatchNormalization并没有把均值与方差的计算节点放入tf.GraphKeys.UPDATE_OPS中,因此,此时即使在训练时添加了对 tf.GraphKeys.UPDATE_OPS 的依赖,也并没有执行批标准化中均值与方差的更新。

下面将每个批标准化函数是否会把 更新变量 均值和方差放入tf.GraphKeys.UPDATE_OPS中进行统计:

  • tf.nn.batch_normalization  : 低阶API,不建议使用
  • tf.layers.batch_normalization(input,training=True) :会自动将 update_ops 添加到 tf.GraphKeys.UPDATE_OPS 中
  • tf.layers.BatchNormalization()(input,training=True) :会自动将 update_ops 添加到 tf.GraphKeys.UPDATE_OPS 中
  • tf.contrib.layers.batch_norm(input, is_training=True):会自动将 update_ops 添加到 tf.GraphKeys.UPDATE_OPS 中
  • tf.keras.layers.BatchNormalization()(input, training=True):不会自动将 update_ops 添加到 tf.GraphKeys.UPDATE_OPS 中

当以上函数的 training 参数被设置为False时,以上函数均不会将 update_ops 添加到 tf.GraphKeys.UPDATE_OPS 中。

4. 检测批标准化函数是否将更新变量添加到了tf.GraphKeys.UPDATE_OPS 中

鉴于第三节中不同批标准化函数的不通特性,在利用批标准化训练时,最好检测一下批标准化的更新变量有没有被添加到tf.GraphKeys.UPDATE_OPS 中。事实上,与上一节的代码一样,只需要在图上增加以下两行命令即可:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
print(update_ops)"""
如果打印了以下内容,就说明批标准化的更新变量被添加到 tf.GraphKeys.UPDATE_OPS中了
[<tf.Operation 'batch_normalization/AssignMovingAvg' type=AssignSub>, <tf.Operation 'batch_normalization/AssignMovingAvg_1' type=AssignSub>]否则,如果打印了
[]
则说明,批标准化的更新变量没有被添加到 tf.GraphKeys.UPDATE_OPS中
"""

5. 如何将 tf.keras.layers.BatchNormalization 的更新变量添加到tf.GraphKeys.UPDATE_OPS 中

如上所述,无论tf.keras.layers.BatchNormalization的参数 training = True 还是 training = False,  该函数都不会把”更新变量“添加到tf.GraphKeys.UPDATE_OPS 中,这会对批标准化的功效产生影响。

在应用tf.keras.layers.BatchNormalization进行批标准化时,我们可以通过以下两种方式把批标准化中的”更新变量“添加到tf.GraphKeys.UPDATE_OPS中。

方式1:定义每一层tf.keras.layers.BatchNormalization时,把函数的定义与应用分开,用函数的 .updates 功能提取当前批标准化层的”更新变量“,具体示例如下:

import tensorflow as tfinput = tf.ones([1, 2, 2, 3])# step1: 把批标准化函数的定义与使用分开
bn_op = tf.keras.layers.BatchNormalization()
output = bn_op(input,training=True)# step2: 用批标准化函数自带的.updates提取函数中定义的”update_ops“
bn_update_ops = bn_op.updates# step3: 把当前层的”update_ops添加到tf.GraphKeys.UPDATE_OPS中“
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,bn_update_ops)# step4: 提取整个网络的所有 update_ops
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
print(update_ops)

很明显,该方式在每一次使用批标准化函数时都要提取一次”update_ops“,很繁琐。有没有简单一点的方法?当然有,那就是方法2.

方式2:定义完整个网络结构后,或者定义完神经网络的整个图结构以后,用以下命令检索图中所有的”update_ops“,示例如下:

import tensorflow as tf# step1: 定义整个神经网络的图解钩
input = tf.ones([1, 2, 2, 3])
output = tf.nn.relu(tf.keras.layers.BatchNormalization()(input,training=True))# step2: 获取整个图的所有节点
ops = tf.get_default_graph().get_operations()# step3: 从所有节点中筛选与批标准化有关的“update_ops”,
# 注1:AssignMovingAvg 的语义是滑动平均赋值,因为 批标准化中的均值与方差的计算方式都是 滑动平均赋值,所以在节点的命名中会包含这个说明。
#    x.type=="AssignSubVariableOp" 表示该节点的类型是“AssignSubVariableOp”--次变量赋值操作,
#                                  我的理解是,批标准化中的均值与方差不是真正的变量,算是“次变量”
#                                  所以 均值与方差的赋值节点属性是次变量赋值操作
# 注2:这里的筛选条件筛选的只是与批标准化相关的“update_ops”
bn_update_ops = [x for x in ops if ("AssignMovingAvg" in x.name and x.type=="AssignSubVariableOp")]
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,bn_update_ops)# step4: 因为神经图网络中可能还存在其他的“update_ops”,因此再次从tf.GraphKeys.UPDATE_OPS中提取“update_ops”
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
print(update_ops)"""
update_ops的格式是:
[[<tf.Operation 'batch_normalization/AssignMovingAvg/AssignSubVariableOp' type=AssignSubVariableOp>, <tf.Operation 'batch_normalization/AssignMovingAvg_1/AssignSubVariableOp' type=AssignSubVariableOp>]]
"""

6. 查看批标准化带来了哪几个变量

import tensorflow as tf# step1: 定义整个神经网络的图解钩
input = tf.ones([1, 2, 2, 3])
output = tf.nn.relu(tf.keras.layers.BatchNormalization()(input,training=True))# step2: 获取所有计算节点
ops = tf.get_default_graph().get_operations()# step3: 筛选批标准化的计算节点“均值更新”节点与“方差更新”节点
bn_update_ops = [x for x in ops if ("AssignMovingAvg" in x.name and x.type=="AssignSubVariableOp")]
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,bn_update_ops)# step4: 获取所有的“更新节点”并打印
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
print("update_ops:",update_ops)# 打印所有变量
var_list = tf.global_variables()
for v in var_list:print("var:",v)with tf.Session() as sess:sess.run(tf.global_variables_initializer())saver = tf.train.Saver(var_list=tf.global_variables())saver.save(sess, "batch_norm_layer/Model")"""
打印:
update_ops: [[<tf.Operation 'batch_normalization/AssignMovingAvg/AssignSubVariableOp' type=AssignSubVariableOp>, <tf.Operation 'batch_normalization/AssignMovingAvg_1/AssignSubVariableOp' type=AssignSubVariableOp>]]
var: <tf.Variable 'batch_normalization/gamma:0' shape=(3,) dtype=float32>
var: <tf.Variable 'batch_normalization/beta:0' shape=(3,) dtype=float32>
var: <tf.Variable 'batch_normalization/moving_mean:0' shape=(3,) dtype=float32>
var: <tf.Variable 'batch_normalization/moving_variance:0' shape=(3,) dtype=float32>分析:
可见,每个批标准化操作带来:(1)两个滑动平均计算节点,(2)4个变量:gamma,beta,moving_mean,moving_variance.这4个变量是不可训练的,因此在保存模型的时候,var_list要设置为tf.global_variables()因为这些变量在推断时还需要用。
"""

7. 训练时没有使用依赖控制tf.control_dependencies会怎样?

如以上分析,批标准化中的更新节点并不会主动被调用,如果没有添加依赖控制,或在以来控制中没有添加批标准化的更新节点的话,会导致每个批标准化层的均值与方差永远保持不变,即保持初始化状态。

参考:

tensorflow中的batch_norm以及tf.control_dependencies和tf.GraphKeys.UPDATE_OPS的探究

TensorFlow 中 Batch Normalization API 的一些坑

正确使用Tensorflow Batch_normalization

TensorFlow中batch norm原理,使用事项与踩坑

对tensorflow 的BatchNormalization的坑的理解与测试相关推荐

  1. Tensorflow张量和维度概念的理解

    Tensorflow张量和维度概念的理解 理解tensorflow张量的概念:张量就是一个数据存储容器,一种数据结构,是人为定义的.因为在计算机内存中哪里有什么2维空间3维空间,都是一块块连续的内存区 ...

  2. cudnn 安装失败_Win10下安装tensorflow环境的一些坑

    2020更新: Attention!由于tensorflow更新频繁,特别是现在的2.x版本,改动较大,以下内容是基于tf-1.14版本的.本文基本内容如下: 导入Numpy报错问题解决 CUDA,c ...

  3. 深度学习TensorFlow取名由来,张量的理解

    以下部分为CSDN博主「麦地与诗人」的原创文章,转载请附上原文出处链接及本声明. 原文链接:https://blog.csdn.net/YPP0229/article/details/94321792 ...

  4. 【深度学习框架】Tensorflow Session.run()函数的进一步理解

    在tensorflow中session.run()用来将数据传入计算图,计算并返回出给定变量/placeholder的结果. 在看论文代码的时候遇到一段复杂的feed_dict, 本文记录了对sess ...

  5. Tensorflow on Spark爬坑指南

    北京 上海巡回站 | NVIDIA DLI深度学习培训 2018年1月26/1月12日 NVIDIA 深度学习学院 带你快速进入火热的DL领域 阅读全文                        ...

  6. Tensorflow yolov3 Intel Realsense D435 双摄像头下测试python多线程(假的多线程)self.predict()函数运行时间(191204)

    测试代码: # -*- coding: utf-8 -*- """ @File : test-191204-两个摄像头调用多线程识别.py @Time : 2019/12 ...

  7. MobileNet、GhostNet理解及测试

      MobileNet由谷歌于2017年提出,它是一种能够用在端侧设备上的轻量级网络.GhostNet是华为诺亚实验室开发的一款轻量级网络,论文发表在CVPR2020.我认为学习论文还是要实际跑一跑代 ...

  8. tensorflow版PSENet 文本检测模型训练和测试

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx psenet核心是为了解决基于分割的算法不能区分相邻文本的问题,以及对任意形状文本的检测问题. ...

  9. 卡尔曼滤波(Kalman Filter)原理理解和测试

    Kalman Filter学原理学习 1. Kalman Filter 历史 Kalman滤波器的历史,最早要追溯到17世纪,Roger Cotes开始研究最小均方问题.但由于缺少实际案例的支撑(那个 ...

最新文章

  1. 微信小程序request合法域名怎么配置啊
  2. nagios+mysql+ndo2安装总结
  3. PAT甲级1132 Cut Integer:[C++题解]
  4. linux 中查找文件,并且将目标文件按时间顺序排序
  5. 挂载(mount)深入理解
  6. 智慧停车场管理系统、停车位、停车费、停车场系统、寻车、抬杆、入位车、出位车、车流量统计、停车、收费、缴费、预警管理、业务统计、报警统计、运维管理、报警系统、异常页面、数据配置、智慧停车原型、停车场
  7. 数据库中的左连接和右连接的区别
  8. c 程序设计语言试卷,C语言程序设计试题及答案
  9. DPDK ip分片与重组的设计实现
  10. 吹响数字经济时代的冲锋号 2021宝德X86生态伙伴大会在深召开
  11. 【Usaco2009 gold 】拯救奶牛
  12. python--字符串
  13. 【移动应用开发】实验2Android UI
  14. asp.net mvc 项目使用Quartz.net添加定时任务
  15. MySQL不会丢失数据的秘密,就藏在它的 7种日志里
  16. JS逆向:狐妖小红娘漫画扒取
  17. 如何在LaTex当中给表格命名
  18. 【生成模型】变分自编码器(VAE)及图变分自编码器(VGAE)
  19. 项目申请PPT经验总结
  20. html5.js百度网盘,HTML5 Canvas+js仿百度网盘扫描文件过程加载动画

热门文章

  1. 迭代近邻算法Iterative Closest Point, ICP
  2. mui获取css参数,Mui-获取时间-调用手机api
  3. c语言编程变色,【图片】(原创)用纯C变了个变色输出字符的程序。。。【c语言吧】_百度贴吧...
  4. linux 神奇命令,Linux 命令神器:lsof 入门
  5. 快捷键截屏_Windows10自带截屏快捷键使用方法大全
  6. VS2010调用python编写的代码error:cannot open file 'python27_d.lib'.
  7. 鸿蒙os到底是什么,聊聊鸿蒙OS到底是什么!
  8. c ++ 打印二进制_C / C ++中的二进制搜索树
  9. python多重继承_Python多重继承
  10. java组合与继承始示例_Java 9功能与示例