• TF笔记:小trick之gumbel softmax

    • 0. 引言
    • 1. gumbel softmax
    • 2. tf代码实现
    • 3. 参考链接

0. 引言

故事的起因在于我们在实际工作中遇到的一个小的需求,即我们在模型定义当中需要用到argmax的信息,因此,我们就快速地写下了如下一段代码:

import tensorflow as tfdef get_argmax(x):h = get_shape_list(x)[-1]y = tf.one_hot(tf.argmax(x, axis=-1), h)return y

由此,我们就可以找到tensor当中每一行的最大元素,并使用onehot向量将其表示出来。

但是,在实际的使用中,我们发现了一个问题,即这样定义的模型能够正常工作,但是其训练出来的模型特征表征却和我们的预期大相径庭。

原因相比大多数读者也都注意到了,即我们在这种函数定义当中,由于使用了argmax,使得梯度回传被中断了,这就导致了模型训练失败,无法达到预期的目标。

而要解决这里argmax导致的梯度回传中断的问题,gumbel softmax方法就是一种常用的方法,下面,我们就来对其进行一些简单的介绍。

1. gumbel softmax

gumbel softmax方法的本质在于说用一个连续可导的函数来模拟argmax函数的结果表达,使得其可以在不截断梯度回传的情况下完成argmax函数的功能。

argmax函数的函数曲线可以通过狄拉克函数(δ(x)\delta(x)δ(x))进行描述,即:

argmax(v⃗)=∑ini∗δ(i−u)argmax(\vec{v}) = \sum_{i}^{n}{i * \delta(i-u)} argmax(v)=i∑n​i∗δ(i−u)

其中,uuu为向量v⃗\vec{v}v中最大元素的下标。

如果用one-hot向量进行argmax的表达的话,即有其中任一元素的值为δ(i−u)\delta(i-u)δ(i−u)。

由此,我们只需要使用一个连续可导的函数来模拟δ(x−u)\delta(x-u)δ(x−u)函数即可,而对于这个问题,gumbel softmax采用的方式是基于softmax函数进行参数调制的方式进行实现。

基础的softmax函数的表达式如下:

σ(x⃗)=exi∑jexj\sigma(\vec{x}) = \frac{e^{x_i}}{\sum_j e^{x_j}} σ(x)=∑j​exj​exi​​

而gumbel softmax函数事实上就是在softmax的基础上加上参数调制。

我们给出gumbel softmax的函数表达式如下:

σ′(x)=exi/δ∑jexj/δ\sigma'(x) = \frac{e^{x_i / \delta}}{\sum_j e^{x_j / \delta}} σ′(x)=∑j​exj​/δexi​/δ​

其中,delta为一个小量。

2. tf代码实现

基于此,我们可以比较快速地写出gumbel softmax函数的tf代码了。

import tensorflow as tfdef gumbel_softmax(x, delta=1e-3, axis=None):return tf.nn.softmax(x/delta, axis=axis)

emmm,简单过头了……

嘛,那啥,simple is best!

3. 参考链接

  1. 漫谈重参数:从正态分布到Gumbel Softmax

TF笔记:小trick之gumbel softmax相关推荐

  1. 让Transformer的推理速度提高4.5倍,这个小trick还能给你省十几万

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 丰色 发自 凹非寺 量子位 报道 | 公众号 QbitAI 最近,N ...

  2. 语法上的小trick

    语法上的小trick 构造函数 虽然不写构造函数也是可以的,但是可能会开翻车,所以还是写上吧.: 提供三种写法: ​ 使用的时候只用: 注意,这里的A[i]=gg(3,3,3)的"gg&qu ...

  3. 通过一个小Trick实现shader的像素识别/统计操作

    2018/12/14日补充:后来发现compute shader里用AppendStructuredBuffer可以解决这类问题,请看这里:https://www.cnblogs.com/hont/p ...

  4. 位运算相关题目-一些小trick 1bit代表独立数字 求只出现一次的数字 无进位n进制数 n(-n) Boyer-Moore 投票算法 n(n-1)

    二进制位方法 集合的每个元素,都有可以选或不选,用二进制的位来表示,0表示不选,1表示选自.0x1 << nums.size()-1 的每一位就代表了集合中每个元素都选用.这里由于集合中每 ...

  5. 浅谈CTF中各种花式绕过的小trick

    文章目录 浅谈CTF中各种花式绕过的小trick 前言 md5加密bypass 弱比较绕过 方法一:0e绕过 方法二:数组绕过 强比较绕过 方法:数组绕过 md5碰撞绕过 方法:使用Fastcoll生 ...

  6. 【Linux高效小trick】快速查看Linux进程的开始和运行时间

    写在前面 前面介绍了,怎么杀死Linux的僵尸进程,为GPU释放更多的内存,做想做的事,文章链接如下: [Linux高效小trick]Linux下杀死僵尸进程,释放GPU内存,让代码全速运行~ 今天再 ...

  7. 会议论文投稿小trick

    所谓"初生牛犊不怕虎"."无知者无畏",鄙人今年斗胆向IJCAI 2019(人工智能顶级会议,全称International Joint Conferences ...

  8. Gumbel Max与Gumbel Softmax演示动画

    Gumbel Max以及Gumbel Softmax的理论证明见: 漫谈重参数:从正态分布到Gumbel Softmax 我用js写了一个利用Gumbel Max来对离散分布进行重参数化的过程,地址: ...

  9. EXCEL学习笔记——小技巧

    EXCEL学习笔记--小技巧(持续更新) 我赌五毛:八成的EXCEL使用者连SUM()函数的帮助都没阅读过.我敢再赌五毛:九成的EXCEL使用者没使用过我下文中九成的技巧.写本文的初衷是能让EXCEL ...

最新文章

  1. Git客户端图文详解如何安装配置GitHub操作流程攻略
  2. C++Primer Plus (第六版)阅读笔记 + 源码分析【第三章:处理数据】
  3. mysql qps计算方法_mysql计算 TPS,QPS 的方式
  4. 图像处理之添加文字水印
  5. mysql数据库迁移到另一台电脑上
  6. 第八期:实操:两台路由器,如何分别通过WAN和LAN口连接?
  7. Two sum(给定一个无重复数组和目标值,查找数组中和为目标值的两个数,并输出其下标)...
  8. python爬虫实战:《星球大战》豆瓣影评分析
  9. 使用Fraps获取3D程序的FPS
  10. 抖音品质建设 - iOS启动优化《原理篇》
  11. DEVICE_ATTR_RW 宏分析
  12. python淘宝秒拍_(python)下载秒拍美拍视频
  13. 逻辑设计基础_芯片设计--TCAM概述
  14. 玩转微信营销和推广的10种方法和技巧
  15. mac环境下cocos2dx引擎3.x版本的创建工程步骤
  16. 迪杰斯特拉(Dijkstra)
  17. 柔性上肢康复机器人研究中的VR技术
  18. 计算机辅助教学领域的先驱者,探索“三大构成”教学模式发展之路
  19. 标准应用促边缘云成熟度提升
  20. H5实现APP下载功能

热门文章

  1. 使用git上传远程的时候出现用户名密码错误 emote: Incorrect username or password ( access token )
  2. yuv420p 详解_视频格式YUV详解
  3. initrd是什么?
  4. 厦门商标注册之快闪家族
  5. 基于springboot+vue的便利店库存管理系统
  6. 毕业论文内容框架指导(程序设计类)
  7. 雷石服务器不显示加密盘,磁盘加锁专家加锁后的磁盘不见了,怎么办?
  8. RTKLIB学习总结(四)rnx2rtkp.c、Option文件读取
  9. python基于PHP+MySQL的美食网站的设计与实现
  10. 未能找到服务器主机名,未能找到主机名服务器