作用:使用滑动平均模型可以使模型在测试数据上更健壮。即如果在测试过程中,出现了一些噪声数据,滑动平均模型可以很好地应对这些数据,使这些噪声数据不会对模型的变量造成太大的影响。

1.滑动平均模型原理:

在创建滑动平均模型后,滑动平均模型会对每一个变量维护一个影子变量(shadow variable),影子变量的初始值为相应变量的初始值,每当变量更新时,影子变量的值会更新为:

 shadow_variable = decay * shadow_variable + (1-decay)*variable

shadow_variable为 影子变量,variable为待更新的变量,decay为衰减率,衰减率越大模型越稳定,因为从上实在可以看出,衰减率越大,影子变量受变量更新的影响越小。在实际应用中,decay一般会设置成非常接近1的数(如0.999或0.9999)。

[注意!变量的影子变量和变量的滑动平均值是一样的!]

滑动平均可以看作是变量的过去一段时间取值的均值,相比对变量直接赋值而言,滑动平均得到的值在图像上更加平缓光滑,抖动性更小,不会因为某次的异常取值而使得滑动平均值波动很大。

而滑动平均为什么会在测试中被使用呢?

滑动平均可以使模型在测试数据上更健壮(robust)。“采用随机梯度下降算法训练神经网络时,使用滑动平均在很多应用中都可以在一定程度上提高最终模型在测试数据上的表现。”

  对神经网络边的权重 weights 使用滑动平均,得到对应的影子变量 shadow_weights。在训练过程仍然使用原来不带滑动平均的权重 weights,不然无法得到 weights 下一步更新的值,又怎么求下一步 weights 的影子变量 shadow_weights。之后在测试过程中使用 shadow_weights 来代替 weights 作为神经网络边的权重,这样在测试数据上效果更好。因为 shadow_weights 的更新更加平滑,对于随机梯度下降而言,更平滑的更新说明不会偏离最优点很远;

  设decay=0.999,一个更直观的理解,在最后的1000次训练过程中,模型早已经训练完成,正处于抖动阶段,而滑动平均相当于将最后的1000次抖动进行了平均,这样得到的权重会更加robust。在整个训练过程中影子变量并不会对实际需要训练的变量产生影响啊,后面持久化的变量也不是影子变量。 在训练过程中,为参数维护更新一个影子变量,这样影子变量会停留在最终参数的周围保持稳定。 在测试阶段,使用影子变量代替参数,进行测试。

2.定义滑动平均模型:

tensorflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模型。

形式:tf.train.ExponentialMovingAverage(decay, num_updates=None, name="ExponentialMovingAverage")这个

其中有两个较为重要的参数:

decay :必填,为衰减率,用于控制模型更新的速度。

num_updates:选填,默认为none。用于控制衰减率decay的变化,若num_updates为none,则衰减率不变。当使用num_updates后,衰减率就为:

可以看出,num_updates越大,衰减率就越大。num_updates一般会为迭代轮数,所以当迭代轮数越大,模型参数就越稳定。

3.代码

import tensorflow as tfif __name__ == "__main__":#定义一个变量用于计算滑动平均,变量的初始值为0v1 = tf.Variable(5,dtype=tf.float32)#定义一个迭代轮数的变量,动态控制衰减率,并设置为不可训练step = tf.Variable(10,trainable=False)#定义一个滑动平均类,初始化衰减率为0.99和衰减率的变量stepema = tf.train.ExponentialMovingAverage(0.99,step)#定义每次滑动平均所更新的列表maintain_average_op = ema.apply([v1])#初始化上下文会话with tf.Session() as sess:#初始化所有变量init = tf.initialize_all_variables()sess.run(init)#更新v1的滑动平均值'''衰减率为min(0.99,(1+step)/(10+step)=0.1}=0.1'''sess.run(maintain_average_op)#[5.0, 5.0]print(sess.run([v1,ema.average(v1)]))sess.run(tf.assign(v1,4))sess.run(maintain_average_op)#[4.0, 4.5500002],5*(11/20) + 4*(9/20)print(sess.run([v1, ema.average(v1)]))'''在实际中,v1变量很经常是网络中的权重值weights'''

tensorflow中滑动平均模型的说明相关推荐

  1. 深入解析TensorFlow中滑动平均模型与代码实现

    因为本人是自学深度学习的,有什么说的不对的地方望大神指出 指数加权平均算法的原理 TensorFlow中的滑动平均模型使用的是滑动平均(Moving Average)算法,又称为指数加权移动平均算法( ...

  2. 深度学习中滑动平均模型的作用、计算方法及tensorflow代码示例

    滑动平均模型: 用途:用于控制变量的更新幅度,使得模型在训练初期参数更新较快,在接近最优值处参数更新较慢,幅度较小 方式:主要通过不断更新衰减率来控制变量的更新幅度 衰减率计算公式 :     dec ...

  3. TensorFlow模型保存和提取方法(含滑动平均模型)

    一.TensorFlow模型保存和提取方法 1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取.tf.train.Saver对象saver的save方法将Tens ...

  4. tensorflow 滑动平均模型 ExponentialMovingAverage

    ____tz_zs学习笔记 滑动平均模型对于采用GradientDescent或Momentum训练的神经网络的表现都有一定程度上的提升. 原理:在训练神经网络时,不断保持和更新每个参数的滑动平均值, ...

  5. 滑动平均模型(MA)—tensorflow

    在采用梯度下降的方式训练神经网络的时候,我们使用滑动平均模型会在一定的程度上提高最终模型在测试集上的表现. 在TensorFlow中提供了tf.train.ExponentialMovingAvera ...

  6. tensorflow tf.train.ExponentialMovingAverage() (滑动平均模型)(移动平均法 Moving average,MA)(用于平滑数据波动对预测结果的影响)

    tf.train.ExponentialMovingAverage 函数定义 tensorflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模型,他使用指 ...

  7. TensorFlow滑动平均模型作用

    在TensorFlow中经常会提到滑动平均模型,目的是为了控制变量更新的速度,防止变量的突然变化对变量的整体影响. TensorFlow下的 tf.train.ExponentialMovingAve ...

  8. tensorflow中学习率、过拟合、滑动平均的学习

    1. 学习率的设置 我们知道在参数的学习主要是通过反向传播和梯度下降,而其中梯度下降的学习率设置方法是指数衰减. 通过指数衰减的学习率既可以让模型在训练的前期快速接近较优解,又可以保证模型在训练后期不 ...

  9. TensorFlow中EMA的概念和正确使用方法

    目录 EMA介绍 概念 弥补不足:初始数据积累不足的情况 深度学习训练中的作用 实现 典型步骤 一个EMA影子变量的例子 进一步接近真实情景,让w1变动 例2:global_step的trainabl ...

最新文章

  1. 单点登陆_别再问我单点登陆
  2. Bitcoin 中的挖矿算法(3) 挖矿算法代码说明
  3. 含有js的英文单词_JavaScript 常用单词整理
  4. 这还没毕业呢,肩膀就不舒服,唉。。。要是工作了,那该有多累啊
  5. linux 卷文件满,LVM逻辑卷容量的增减
  6. 至读博客朋友的一封信
  7. Linux学习笔记 -- rpm 与 shell 编程
  8. c语言学习宝典怎么样,C语言学习宝典
  9. 安装驱动省心办法:驱动总裁
  10. 前端面试官经验总结 | 前端面试小技巧
  11. English--consonant_摩擦音
  12. LTE下行传输机制--PBCH
  13. 华为服务器告警状态,华为RH2288H V5服务器CPU告警
  14. 点餐系统mysql设计,外卖点餐系统数据库设计.doc
  15. DEVC++第五人格V2.0
  16. 2019年计算机二级获证条件,2019年下半年全国计算机等级考试报考简章
  17. 99_包(package)
  18. vue watch store
  19. 核心技术及创新点怎么写
  20. Pytorch快速搭建Alexnet实现手写英文字母识别+PyQt实现鼠标绘图

热门文章

  1. php设计模式epub,大话设计模式(pdf+epub+mobi+txt+azw3)
  2. Excel启动AutoCAD
  3. 性能监控工具的配置及使用 - 听云-Server
  4. 尚硅谷java多线程
  5. 速码工具箱5.0,二维码生产力工具
  6. php 尾递归,又见尾递归
  7. 【并发编程】程序的启动和终结
  8. emu8086 第一个程序
  9. java编程基础篇-- 编写一个程序,从键盘输入三个整数,求三个整数中的最小值。
  10. [转]【读书笔记】《俞军产品方法论》——产品经理的枕边书