L2-Regularization 实现的话,需要把所有的参数放在一个集合内,最后计算loss时,再减去加权值。

相比自己乱搞,代码一团糟,Tensorflow 提供了更优美的实现方法。

一、tf.GraphKeys : 多个包含Variables(Tensor)集合

(1)GLOBAL_VARIABLES:使用tf.get_variable()时,默认会将vairable放入这个集合。

我们熟悉的tf.global_variables_initializer()就是初始化这个集合内的Variables。

import tensorflow as tf
sess=tf.Session()
a=tf.get_variable("a",[3,3,32,64],initializer=tf.random_normal_initializer())
b=tf.get_variable("b",[64],initializer=tf.random_normal_initializer())
#collections=None等价于 collection=[tf.GraphKeys.GLOBAL_VARIABLES]gv= tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)          #tf.get_collection(collection_name)返回某个collection的列表
for var in gv: print(var is a)print(var.get_shape())

Tips: tf.GraphKeys.GLOBAL_VARIABLES == "variable"。即其保存的是一个字符串。

(2)自定义集合

想个集合的名字,然后在tf.get_variable时,把集合名字传给 collection 就好了。

import tensorflow as tf
sess=tf.Session()
a=tf.get_variable("a",shape=[10],collections=["mycollection"])  #不把GLOBAL_VARIABLES加进去,那么就不在那个集合里了。
keys=tf.get_collection("mycollection")
for key in keys:print(key.name)

二、L2正则化

先看看tf.contrib.layers.l2_regularizer(weight_decay)都执行了什么:
import tensorflow as tf
sess=tf.Session()
weight_decay=0.1
tmp=tf.constant([0,1,2,3],dtype=tf.float32)
"""
l2_reg=tf.contrib.layers.l2_regularizer(weight_decay)
a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp)
"""
#**上面代码的等价代码
a=tf.get_variable("I_am_a",initializer=tmp)
a2=tf.reduce_sum(a*a)*weight_decay/2;
a3=tf.get_variable(a.name.split(":")[0]+"/Regularizer/l2_regularizer",initializer=a2)
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,a2)
#**
sess.run(tf.global_variables_initializer())
keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
for key in keys:print("%s : %s" %(key.name,sess.run(key)))

我们很容易可以模拟出tf.contrib.layers.l2_regularizer都做了什么,不过会让代码变丑。

以下比较完整实现L2 正则化。
import tensorflow as tf
sess=tf.Session()
weight_decay=0.1                                                #(1)定义weight_decay
l2_reg=tf.contrib.layers.l2_regularizer(weight_decay)           #(2)定义l2_regularizer()
tmp=tf.constant([0,1,2,3],dtype=tf.float32)
a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp)  #(3)创建variable,l2_regularizer复制给regularizer参数。#目测REXXX_LOSSES集合
#regularizer定义会将a加入REGULARIZATION_LOSSES集合
print("Global Set:")
keys = tf.get_collection("variables")
for key in keys:print(key.name)
print("Regular Set:")
keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
for key in keys:print(key.name)
print("--------------------")
sess.run(tf.global_variables_initializer())
print(sess.run(a))
reg_set=tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)   #(4)则REGULARIAZTION_LOSSES集合会包含所有被weight_decay后的参数和,将其相加
l2_loss=tf.add_n(reg_set)
print("loss=%s" %(sess.run(l2_loss)))
"""
此处输出0.7,即:weight_decay*sigmal(w*2)/2=0.1*(0*0+1*1+2*2+3*3)/2=0.7
其实代码自己写也很方便,用API看着比较正规。
在网络模型中,直接将l2_loss加入loss就好了。(loss变大,执行train自然会decay)
"""

[Tensorflow]L2正则化和collection【tf.GraphKeys】相关推荐

  1. (4)[Tensorflow]L2正则化和collection【tf.GraphKeys】

    L2-Regularization 实现的话,需要把所有的参数放在一个集合内,最后计算loss时,再减去加权值. 相比自己乱搞,代码一团糟,Tensorflow 提供了更优美的实现方法. 1. tf. ...

  2. L2正则化和collection,tf.GraphKeys

    L2-Regularization 实现的话,需要把所有的参数放在一个集合内,最后计算loss时,再减去加权值. 相比自己乱搞,代码一团糟,Tensorflow 提供了更优美的实现方法. 一.tf.G ...

  3. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

  4. tensorflow学习笔记:tf.control_dependencies,tf.GraphKeys.UPDATE_OPS,tf.get_collection

    tf.control_dependencies(control_inputs): control_dependencies(control_inputs) ARGS: control_inputs:在 ...

  5. 【转】tensorflow中的batch_norm以及tf.control_dependencies和tf.GraphKeys.UPDATE_OPS的探究

    笔者近来在tensorflow中使用batch_norm时,由于事先不熟悉其内部的原理,因此将其错误使用,从而出现了结果与预想不一致的结果.事后对其进行了一定的调查与研究,在此进行一些总结. 一.错误 ...

  6. tensorflow教程——tf.GraphKeys

    GraphKeys tf.GraphKeys包含所有graph collection中的标准集合名,有点像Python里的build-in fuction. 首先要了解graph collection ...

  7. L2正则化—tensorflow实现

    L2正则化是一种减少过拟合的方法,在损失函数中加入刻画模型复杂程度的指标.假设损失函数是 J(θ) J(\theta),则优化的是 J(θ)+λR(w) J(\theta)+\lambda R(w), ...

  8. tensorflow中的正则化函数在_『TensorFlow』正则化添加方法整理

    一.基础正则化函数 tf.contrib.layers.l1_regularizer(scale, scope=None) 返回一个用来执行L1正则化的函数,函数的签名是func(weights). ...

  9. l2正则化python_TensorFlow keras卷积神经网络 添加L2正则化方式

    我就废话不多说了,大家还是直接看代码吧! model = keras.models.Sequential([ #卷积层1 keras.layers.Conv2D(32,kernel_size=5,st ...

最新文章

  1. 最新Java培训-NIO实战教程
  2. 百度地图之根据地图上的点确定地图的放缩比例
  3. 基于Web Services建立Asp与Asp.Net之间Session数据桥的应用研究
  4. go语言查询某个值是否在数组中_go语言中的数组
  5. java中List Set Map使用
  6. 欧几里得算法及其扩展
  7. 网络电话---异常处理01
  8. 思科网络工程师面试题
  9. 购车指南首次买车必看系列之(二): 产权篇
  10. 用户的虚拟地址 linux 0 4gb,Linux驱动虚拟地址和物理地址的映射
  11. 录入查询学生成绩C语言,学生成绩录入查询系统C语言程序
  12. Wireshark分析实际报文理解SSL(TLS)协议
  13. 一专多能、刻意练习和终身成长
  14. 华为ensp配置pap认证
  15. mysql 查询当前年份
  16. 内核自带的基于GPIO的LED驱动学习(三)
  17. 我觉得好看的规章制度文章
  18. c语言智能家居安防系统,智能家居之安防智能控制系统
  19. Federated Learning in Mobile Edge Networks: AComprehensive Survey(翻译)
  20. 有效解决VS 无法启动 IIS EXPRESS Web 服务器。(ID为xxxxx的进程当前未运行)

热门文章

  1. 郑大网教育计算机2017,郑州大学-“2017中国大学生计算机设计大赛河南省级赛”在郑州大学举行...
  2. 高级查询组件下拉框联动(三)
  3. [笔记] 关于KAG3中宏参数的类型
  4. php下载地址转换工具,PHP迅雷、快车、旋风下载专用链转换代码
  5. 《挖掘管理价值:企业软件项目管理实战》一2.3 需求分析过程
  6. 网络直播课程:神马是敏捷?(直播时间:2014-7-14 20:00-21:00)
  7. root用户无法启动vscode的解决方法
  8. 总结一下强大的ES6符号
  9. 电脑记账最简单的方法
  10. GET /static/css/bootstrap.min.js.map HTTP/1.1“ 404GET /static/css/bootstrap.min.css.map HTTP/1.1“404