Clip_by_norm 函数理解
1. 梯度裁剪场景
先看示例:
optimizer = tf.train.AdamOptimizer(self.learning_rate)
gradients, v = zip(*optimizer.compute_gradients(self.pretrain_loss))
gradients, _ = tf.clip_by_global_norm(gradients, self.grad_clip)
updates = optimizer.apply_gradients(zip(gradients, v), global_step=self.global_step)
梯度裁剪的最直接目的就是防止梯度爆炸,手段就是控制梯度的最大范式。
原型:tf.clip_by_global_norm
tf.clip_by_global_norm(
t_list,#常输入梯度
clip_norm,#裁剪率
use_norm=None,#使用已经计算规约
name=None )
返回值:
list_clipped: 裁剪后的梯度列表
global_norm:全局的规约数
下面示例计算过程:
2. 手动裁剪
根据计算原理:t_list[i] * clip_norm / max(global_norm, clip_norm)
2.1 产生列表数
#生成0-9之间的数组成的列表
init_t_list = np.asarray([0,1,2,3,4,5,6,7,8,9])
2.2 求L2值的两种方法
#方式1:使用np自带的函数
l2 = np.linalg.norm(init_t_list)
print(l2)
#方式2:手写实现方式
l2_ = np.sqrt(np.sum(np.square(init_t_list)))
print(l2_)
2.3 求梯度裁剪后的值
#假设裁剪规约数等于5.0
clip_norm = 5.0
#求裁剪后的值
t_list = init_t_list * clip_norm / max(l2, clip_norm)
print(t_list)
#裁剪后L2值
t_list_l2 = np.linalg.norm(t_list)
print(t_list_l2)
2.4 输出结果
16.8819430161
16.8819430161
[ 0. 0.29617444 0.59234888 0.88852332 1.18469776 1.48087219
1.77704663 2.07322107 2.36939551 2.66556995]
5.0
从结果中,得到裁剪后的梯度的L2值是5.0,这正是意义所在,说明对输入的梯度进行了L2范数的上限的限制,防止梯度过大。
3. tf代码裁剪
import tensorflow as tf
#产生一个梯度,此处与上面的保持一致为0-9组成的序列
gradients = [float(i) for i in range(10)]#[0,1,2,3,4,5,6,7,8,9]
#创建回话
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
gradients, global_norm = tf.clip_by_global_norm(gradients, clip_norm)
print(sess.run(gradients))
print(sess.run(global_norm))
3.1 结果
[0.0, 0.29617444, 0.59234887, 0.88852334, 1.1846977, 1.4808722, 1.7770467, 2.073221, 2.3693955, 2.66557]
16.8819
两种方式得到的最终结果一致。
注意,api中解释到: However, it is slower than clip_by_norm() because all the parameters must be
ready before the clipping operation can be performed.
参考论文:
On the difficulty of training Recurrent Neural Networks
Clip_by_norm 函数理解相关推荐
- nodejs回调函数理解
回调实例 问题:想要得到一秒后 计算出的结果 //错误写法function add(x,y) {console.log(1);setTimeout(function () {console.log(2 ...
- ML之MIC:利用有无噪音的正余弦函数理解相关性指标的不同(多图绘制Pearson系数、最大信息系数MIC)
ML之MIC:利用有无噪音的正余弦函数理解相关性指标的不同(多图绘制Pearson系数.最大信息系数MIC) 目录 利用有无噪音的正余弦函数理解相关性指标的不同(多图绘制Pearson系数.最大信息系 ...
- 高频交易配对交易学习——Copulas函数理解
Copulas函数理解 https://github.com/MalteKurz/VineCopulaCPP
- Pytorch中tensor.view().permute().contiguous()函数理解
Pytorch中tensor.view().permute().contiguous()函数理解 yolov3中有一行这样的代码,在此记录一下三个函数的含义 # 例子中batch_size为整型,le ...
- pytorch中repeat()函数理解
pytorch中repeat()函数理解 最近在学习过程中遇到了repeat()函数的使用,这里记录一下自己对这个函数的理解. 情况1:repeat参数个数与tensor维数一致时 a = torch ...
- SQLServer STUFF 函数理解
SQLServer CAST -- 转换数据类型 逗号表示分割 . STUFF 函数理解 -- 第一个就是字符串 FOR XML PATH('') 必须用 , 第二个参数 负数或0空字符串, ...
- Java回调函数理解和应用
#Java回调函数理解和应用 所谓回调:就是A类中调用B类中的某个方法C,然后B类中反过来调用A类中的方法D,D这个方法就叫回调方法,这样子说你是不是有点晕晕的. 在未理解之前,我也是一脸懵逼,等我理 ...
- pytorch 中 contiguous() 函数理解
pytorch 中 contiguous() 函数理解 文章目录 pytorch 中 contiguous() 函数理解 引言 使用 contiguous() 后记 文章抄自 Pytorch中cont ...
- Android回调函数理解
Android回调函数理解,比如我用一个activity去做显示下载进度的一个进度条,但是下载是另外一个B类来做的,这个时候我Activity获取下载的进度就可以提供一个回调接口,然后让下载类来回调就 ...
- softmax函数理解
该节课中提到了一种叫作softmax的函数,因为之前对这个概念不了解,所以本篇就这个函数进行整理,如下: 维基给出的解释:softmax函数,也称指数归一化函数,它是一种logistic函数的归一化形 ...
最新文章
- 信息管理系统(Servlet+jsp+mvc+jdbc)
- 批处理命令 For循环命令具体解释!
- 和php交互的过程_JavaScript学习笔记(二十三) 服务器PHP
- mysql bin.000047_mysql-bin.0000X 日志文件处理
- show部分书...
- Linux命令大全完整版
- android 竖屏优先,android 强制设置横屏 判断是横屏还是竖屏
- vb局域网连接mysql_VB 用代码进行局域网内数据库的连接
- laravel + xampp 除了根目录其他路由都是404的解决方法
- 服务器机柜可放多大显示器,一个标准服务器机柜究竟能够放多少服务器
- 如何提升电脑开机速度?
- KEIL MDK5 更好用 更简洁 的ARM开发环境
- 从零开始的Docker [ 7 ] --- 顶级 Volumes,数据卷, 系统限制sysctls
- 线束音视频传输连接器FAKRA与HSD区别?
- 使用SmartUpload组件上传文件,自己踩过的坑
- 【高等数学】通过俩条空间直线求得公垂线的求法
- hive -- return code 2 from org.apache.hadoop.hive.ql.exec.mr.MapRedTask
- shell经典,shell十三问
- 智慧城市地下综合管廊环境监控系统
- 绑核原理linux,DPDK性能影响因素之绑核原理