前言

本文主要翻译自JAX在github上的一篇文档(Authors: Matteo Hessel & Rosalia Schneider),同时增加了部分个人理解。
原文链接如下:
https://github.com/google/jax/blob/main/docs/jax-101/05-random-numbers.md

关于伪随机数的生成,pseudo random number generation:
伪随机数并非真正的随机数,伪随机数是根据一定的算法,依据初始值(种子,key)生成的数值,当算法不变时,生成的结果不变。

在各个方面,JAX力求和Numpy保持一致,而在伪随机数方面是一个例外。下面将具体介绍一下JAX和Numpy之间关于伪随机数的区别。

Numpy中的伪随机数

numpy中的伪随机数通过np.random生成,伪随机数的状态是全局统一的(原文:In NumPy, pseudo random number generation is based on a global state)。
怎么理解这句话呢?我理解在Numpy中,所有的随机数状态均是基于一种算法生成的,即所有的随机数的状态均在一个序列当中。
这里再介绍一下状态,即state。在Numpy当中,有一个方法是np.random.get_state(),在官方文档中,解释为:Return a tuple representing the internal state of the generator。即返回一个代表生成器内部状态的元组。我理解这个状态和种子的概念是类似的,在同一状态下,得到的随机数是相同的。
我们看看这个状态具体是个什么东东:

def print_truncated_random_state():"""To avoid spamming the outputs, print only part of the state."""full_random_state = np.random.get_state()print(str(full_random_state)[:460], '...')print_truncated_random_state()('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,3904844661,  676747479, 2085143622, 1056793272, 3812477442,2168787041,  275552121, 2696932952, 3432054210, 1657102335,3518946594,  962584079, 1051271004, 3806145045, 1414436097,2032348584, 1661738718, 1116708477, 2562755208, 3176189976,696824676, 2399811678, 3992505346,  569184356, 2626558620,136797809, 4273176064,  296167901, 343 ...

插曲:可以看到,输出结果中有"MT19937" 个东东,这个是个什么东西了?
查了一下:MT19937表示一个伪随机数生成算法。
梅森旋转算法(Mersenne twister)是一个伪随机数发生算法。由松本真和西村拓士[1]在1997年开发,基于有限二进制字段上的矩阵线性递归F2F_2F2​。可以快速产生高质量的伪随机数,修正了古典随机数发生算法的很多缺陷。
Mersenne Twister这个名字来自周期长度取自梅森质数的这样一个事实。这个算法通常使用两个相近的变体,不同之处在于使用了不同的梅森素数。一个更新的和更常用的是MT19937, 32位字长。还有一个变种是64位版的MT19937-64。对于一个k位的长度,Mersenne Twister会在[0,2k−1][0,2^k-1][0,2k−1]的区间之间生成离散型均匀分布的随机数。

在每次调用Numpy时,state都会更新:

np.random.seed(0)
print_truncated_random_state()('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,2481403966, 4042607538,  337614300, 3232553940, 1018809052,3202401494, 1775180719, 3192392114,  594215549,  184016991,829906058,  610491522, 3879932251, 3139825610,  297902587,4075895579, 2943625357, 3530655617, 1423771745, 2135928312,2891506774, 1066338622,  135451537,  933040465, 2759011858,2273819758, 3545703099, 2516396728, 127 ...
_ = np.random.uniform()
print_truncated_random_state()('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,3904844661,  676747479, 2085143622, 1056793272, 3812477442,2168787041,  275552121, 2696932952, 3432054210, 1657102335,3518946594,  962584079, 1051271004, 3806145045, 1414436097,2032348584, 1661738718, 1116708477, 2562755208, 3176189976,696824676, 2399811678, 3992505346,  569184356, 2626558620,136797809, 4273176064,  296167901, 343 ...

同时我们也可以把状态保存下来,后面可以读取这个状态来返回相同的值。

state=np.random.get_state()  #获取并保存state
np.random.uniform()
0.6027633760716439
np.random.set_state(state) #读取已保存的状态
np.random.uniform()
0.6027633760716439  ---两个随机数是一样的

在Numpy里面,不仅一次可以获取一个随机数,还可以获取一个随机的向量或是张量。

np.random.seed(0)
print(np.random.uniform(size=3))
[0.5488135  0.71518937 0.60276338]

Numpy中一个比较有意思的东西是,它提供了一个顺序等价保证,怎么理解呢?就是同样是3个数字,分3次取3个数和一次取一个包含3个元素的向量得到的结果是一样的。
如下:

np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))
individually: [0.5488135  0.71518937 0.60276338]
np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
all at once:  [0.5488135  0.71518937 0.60276338]

是不是很神奇?这个如果在搞明白为什么会这样,估计是要研究下Numpy的random.uniform的实现代码了,我估计np.random.uniform(size=3)也是循环生成的。

JAX中的伪随机数

JAX中的伪随机数与Numpy中有很大不同,Numpy中的随机数设计满足不了JAX的需求。JAX要求具备以下特点:

  1. 可重复,reproducible ——个人理解意思为可复现,即重复操作时结果是一样的(就是种子的意思)
  2. 可并行,parallelizable
  3. 向量化,vectorisable
    下面我们来具体说明。
    首先我们看看一个全局状随机数的含义。上代码:
import numpy as np
np.random.seed(0)
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()
print(foo())
------------------
1.9791922366721637

上面的代码中,方法foo表示两个服从均匀分布的标量之和。
在假设方法bar和方法baz是按一定顺序执行的情况下,计算结果才能满足JAX的三点要求的第一条,即reproducible。
啥意思呢?每次运行代码时,必须保证bar与baz的执行顺序相同,得到的结果才相同。如果第一次先执行bar再执行baz,第二次先执行baz再执行bar,这两次的结果是不一样的。
这个现象在Numpy里似乎是无关紧要,不是啥问题,本来也是会按序执行的。但是这个东东在JAX里就不行了。。。
为啥呢?因为JAX是支持并行的!
这段代码如果想在JAX里复现,那就得强制按顺序执行,但是bar和baz两个方法没有依赖,在编译时,会被搞成并行执行的。这个就违背了JAX的第二点要求:parallelizable!

为了解决这个问题,在JAX中,不再使用全局状态,也就是不再设置全局的种子了。随机函数显式的、明确的消费随机状态,这个状态在JAX一般用“key”表示。
怎么理解呢?我们先看看这个key具体是什么:

from jax import random
key = random.PRNGKey(42)
print(key)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
------------------
[ 0 42]

key是一个shape为2的数组[0,42]。
其实这个key也就是numpy中的seed啦,只是在numpy中,seed只需要设置一次,但是在JAX里,只要用到了random的方法,需要明确的指定key,就是每次调用random的方法都要单独设置key,下文会见到。
random方法使用key,但是并不会改变它。同样,当一个random方法消费相同的key时,得到的结果也是一样的。

print(random.normal(key))
print(random.normal(key))
-0.18471184
-0.18471184

需要注意的是:
当不同的随机函数使用相同的key时,得到的结果是存在相关性的,这在一般情况不是我们想要的,我们希望要的东西是独立的。(PS:相关和独立就是概率里的那个概念)
不要重复使用key,不要重复使用key,不要重复使用key。。。

那么问题来了,一个key不能重复使用,但是又要求每个random方法都要明确的指定key,还要结果可重复,还让不让人玩儿了??

简单~!JAX在设计里肯定考虑这个了,不然早就混不下去了。怎么解决呢,就是把一个key给掰成几瓣!看代码:

from jax import random
key = random.PRNGKey(42)
print("old key", key)
new_key, subkey = random.split(key)
del key  # The old key is discarded -- we must never use it again.
normal_sample = random.normal(subkey)
print(r"    \---SPLIT --> new key   ", new_key)
print(r"             \--> new subkey", subkey, "--> normal", normal_sample)
del subkey  # The subkey is also discarded after use.
# Note: you don't actually need to `del` keys -- that's just for emphasis.
# Not reusing the same values is enough.
key = new_key  # If we wanted to do this again, we would use new_key as the key.
----------------------------------------------------------------------------------
old key [ 0 42]\---SPLIT --> new key    [2465931498 3679230171]\--> new subkey [255383827 267815257] --> normal 1.3694694

首先:split()是一个确定性函数(输入相同时,输出总是一样的),可以把一个key分裂成多个相互独立的key,并且还能满足伪随机的性质。上面的代码中,我们通过把一个老K,分裂成了两个小K。分裂之后的小K,还可以继续发展“下线”,继续向下分裂。
总之,只要保证不同的random方法使用不同的key就可以了。还有就是,做为母K,分裂之后,就不要再用了,是什么原因原文没有介绍,我也没查到。

无论是谁叫Key还是叫subkey都是不重要的,它们都是在相同状态下的伪随机数。上面那个例子,一般写成如下的形式,这时,老的key会被自动discarded,key被赋值分裂后的伪随机数。

key, subkey = random.split(key)

当然了,split不只是能分裂出两个子key,你想要几个都可以。

key, *forty_two_subkeys = random.split(key, num=43)

Numpy和JAX的random模型的另一个区别是顺序等价保证( the sequential equivalence guarantee),就是上面提到的执行顺序的问题。
JAX也是和Numpy一样,可以生成一个多维随机向量,但是JAX并不提供 the sequential equivalence guarantee,因为那样的话会影响在SIMD(单指令多数据结构)硬件上的向量化,也就是前面提到的第3点:vectorisable。
再看看前面一个例子在JAX中情况:

key = random.PRNGKey(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)key = random.PRNGKey(42)
print("all at once: ", random.normal(key, shape=(3,)))individually: [-0.04838839  0.10796146 -1.2226542 ]
all at once:  [ 0.18693541 -1.2806507  -1.5593133 ]

我们没办法再得到两个相同的结果了。

同时需要注意的是,这里面我们把母key也用了,原文是说因为在别的地方没使用,所以没有违反只使用一次的原则。但是我个人觉得这个违反了分裂就丢弃的原则了啊!!

以上就是全文了,由于个人水平有限,理解的不一定正确,欢迎大家一起讨论~~

Numpy和JAX中的随机数相关推荐

  1. python numpy.random模块中提供啦大量的随机数相关的函数

    1. numpy中产生随机数的方法 1)rand() 产生[0,1]的浮点随机数,括号里面的参数可以指定产生数组的形状 2)randn() 产生标准正太分布随机数,参数含义与random相同 3)ra ...

  2. 『JAX中文文档』JAX快速入门

    最新的 原文:https://jax.readthedocs.io/en/latest/notebooks/quickstart.html JAX快速入门 首先解答一个问题:JAX是什么? 简单的说就 ...

  3. 一篇搞懂Python中的随机数

    在 python 中生成随机样本的所有你需要的示例列表 长按关注<Python学研大本营>,加入读者群,分享更多精彩 扫码关注<Python学研大本营>,加入读者群,分享更多精 ...

  4. Java中的随机数生成器:Random,ThreadLocalRandom,SecureRandom

    Java中的随机数生成器:Random,ThreadLocalRandom,SecureRandom 文中的 Random即:java.util.Random, ThreadLocalRandom 即 ...

  5. java中随机数怎么定义类_浅析Java中的随机数类

    Java中的随机数是否可以重复?Java中产生的随机数能否可以用来产生数据库主键?带着这个问题,我们做了一系列测试. 1.测试一: 使用不带参数的Random()构造函数 * @author Carl ...

  6. python找出值为nan_Python Numpy:找到list中的np.nan值方法

    这个问题源于在训练机器学习的一个模型时,使用训练数据时提示prepare的数据中存在np.nan 报错信息如下: ValueError: np.nan is an invalid document, ...

  7. numpy找到数组中符合条件的数

    numpy找到数组中符合条件的数 import numpy as nparr = np.array([1, 1, 1, 134, 45, 3, 46, 45, 65, 3, 23424, 234, 1 ...

  8. java 随机数生成实现_Java中生成随机数的实现方法总结

    搜索热词 在实际开发工作中经常需要用到随机数.如有些系统中创建用户后会给用户一个随机的初始化密码.这个密码由于是随机的,为此往往只有用户自己知道.他们获取了这个随机密码之后,需要马上去系统中更改.这就 ...

  9. css 加随机数 引用_在CSS中生成随机数

    Python部落(python.freelycode.com)组织翻译,禁止转载,欢迎转发. Robin Rendle 于2017年1月11日 前几天,我遇到了一个特别有趣的问题.我想用random ...

最新文章

  1. 模态对话框的父窗口设置
  2. 国际机器人联合会:全球工业机器人2019报告
  3. 第13届景驰-埃森哲杯广东工业大学ACM程序设计大赛 L-回旋星空
  4. css样式表和选择器
  5. 目标检测第7步:如何在Windows 10下,配置Pycharm中的YOLOv5(5.0)虚拟环境?
  6. 电商项目实战项目需求以及技术选型
  7. Visual Studio Code 10 月 Python 扩展更新
  8. android listview 分页
  9. Inception v1
  10. 神舟刷蓝天w650dbios_神舟z6kp5D1记录一次艰难的刷蓝天bios,总算成功了
  11. 六年级上册计算机教材分析,人教版六年级上册数学教材分析
  12. 6种方法轻松将PDF转换为Word文档,办公必备!
  13. 【历史上的今天】6 月 6 日:世界 IPv6 启动纪念日;《俄罗斯方块》发布;小红书成立
  14. 趋势交易大师php,大道至简——多级别均线共振交易系统
  15. Tableau 实现percentile分类计算功能
  16. 1分钟搞定两个电脑之间谷歌收藏夹的迁移,不用账号!不用下载!
  17. CentOS7(Linux)在VMware Workstation上的 安装使用教程
  18. cindy POSA2读书笔记(二)
  19. nodemcu刷鸿蒙系统,ESP01S刷入NodeMCU固件
  20. antd 每次打开modal 初始化数据

热门文章

  1. 歌手列表快速导航入口
  2. 目标检测算法——Faster R-CNN
  3. 1.Premiere Pro CS6界面介绍
  4. html实现鼠标拖拽按钮,JS实现鼠标按下拖拽效果
  5. XPath实战之爬取豆瓣电影
  6. HTML---背景样式
  7. 《图解易经:一本终于可以读懂的易…
  8. ElasticSearch的search.max_buckets值1000限制问题
  9. 微信支付提示支付验证签名失败
  10. Outlook Express 收件箱修复