来源:DeepHub IMBA
本文约3300字,建议阅读10+分钟
本文中,我们了解了 JAX 是什么,并了解了它的一些基本概念。

JAX 是一个由 Google 开发的用于优化科学计算Python 库:

  • 它可以被视为 GPU 和 TPU 上运行的NumPy , jax.numpy提供了与numpy非常相似API接口。

  • 它与 NumPy API 非常相似,几乎任何可以用 numpy 完成的事情都可以用 jax.numpy 完成。

  • 由于使用XLA(一种加速线性代数计算的编译器)将Python和JAX代码JIT编译成优化的内核,可以在不同设备(例如gpu和tpu)上运行。而优化的内核是为高吞吐量设备(例如gpu和tpu)进行编译,它与主程序分离但可以被主程序调用。JIT编译可以用jax.jit()触发。

  • 它对自动微分有很好的支持,对机器学习研究很有用。可以使用 jax.grad() 触发自动区分。

  • JAX 鼓励函数式编程,因为它是面向函数的。与 NumPy 数组不同,JAX 数组始终是不可变的。

  • JAX提供了一些在编写数字处理时非常有用的程序转换,例如JIT . JAX()用于JIT编译和加速代码,JIT .grad()用于求导,以及JIT .vmap()用于自动向量化或批处理。

  • JAX 可以进行异步调度。所以需要调用 .block_until_ready() 以确保计算已经实际发生。

JAX 使用 JIT 编译有两种方式:

  • 自动:在执行 JAX 函数的库调用时,默认情况下 JIT 编译会在后台进行。

  • 手动:您可以使用 jax.jit() 手动请求对自己的 Python 函数进行 JIT 编译。

JAX 使用示例

我们可以使用 pip 安装库。

pip install jax

导入需要的包,这里我们也继续使用 NumPy ,这样可以执行一些基准测试。

import jax
import jax.numpy as jnp
from jax import random
from jax import grad, jit
import numpy as npkey = random.PRNGKey(0)

与 import numpy as np 类似,我们可以 import jax.numpy as jnp 并将代码中的所有 np 替换为 jnp 。如果 NumPy 代码是用函数式编程风格编写的,那么新的 JAX 代码就可以直接使用。但是,如果有可用的GPU,JAX则可以直接使用。

JAX 中随机数的生成方式与 NumPy 不同。JAX需要创建一个 jax.random.PRNGKey 。我们稍后会看到如何使用它。

我们在 Google Colab 上做一个简单的基准测试,这样我们就可以轻松访问 GPU 和 TPU。我们首先初始化一个包含 25M 元素的随机矩阵,然后将其乘以它的转置。使用针对 CPU 优化的 NumPy,矩阵乘法平均需要 1.61 秒。

# runs on CPU - numpy
size = 5000
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)
# 1 loop, best of 5: 1.61 s per loop

在 CPU 上使用 JAX 执行相同的操作平均需要大约 3.49 秒。

# runs on CPU - JAX
size = 5000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
# 1 loop, best of 5: 3.49 s per loop

在 CPU 上运行时,JAX 通常比 NumPy 慢,因为 NumPy 已针对CPU进行了非常多的优化。但是,当使用加速器时这种情况会发生变化,所以让我们尝试使用 GPU 进行矩阵乘法。

# runs on GPU
size = 5000
x = random.normal(key, (size, size), dtype=jnp.float32)
%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time
%time jnp.dot(x_jax, x_jax.T).block_until_ready() # 2. measure JAX compilation time
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # 3. measure JAX running time
# 1. CPU times: user 102 µs, sys: 42 µs, total: 144 µs
#   Wall time: 155 µs
# 2. CPU times: user 1.3 s, sys: 195 ms, total: 1.5 s
#   Wall time: 2.16 s
# 3. 10 loops, best of 5: 68.9 ms per loop

从示例中可以看出,要进行公平的基准比较,我们需要使用 JAX 测量不同的步骤:

设备传输时间:将矩阵传输到 GPU 所经过的时间。耗时 0.155 毫秒。
编译时间:JIT 编译经过的时间。耗时 2.16 秒。
运行时间:有效的代码运行时间。耗时 68.9 毫秒。

在 GPU 上使用 JAX 进行单个矩阵乘法的总耗时约为 2.23 秒,高于 NumPy 的总时间 1.61 秒。但是对于每个额外的矩阵乘法,JAX 只需要 68.9 毫秒,而 NumPy 需要 1.61 秒,快了 22 倍多!因此,如果多次执行线性代数运算,那么使用 JAX 是有意义的。

让我们测试使用 TPU 进行矩阵乘法。

# runs on TPU
size = 5000
x = random.normal(key, (size, size), dtype=jnp.float32)
%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time
%time jnp.dot(x_jax, x_jax.T).block_until_ready() # 2. measure JAX compilation time
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # 3. measure JAX running time
# 1. CPU times: user 131 µs, sys: 72 µs, total: 203 µs
#   Wall time: 164 µs
# 2. CPU times: user 190 ms, sys: 302 ms, total: 492 ms
#   Wall time: 837 ms
# 3. 100 loops, best of 5: 16.5 ms per loop

忽略设备传输时间和编译时间,每个矩阵乘法平均需要 16.5 毫秒:GPU 相比快了4倍,与 CPU 的 NumPy相比快了88倍。需要说明的是,当乘以不同大小的矩阵时,获得相同的加速效果也不同:相乘的矩阵越大,GPU可以优化操作的越多,加速也越大。

为了在 Google Colab 上复制上述基准,需要运行以下代码让 JAX 知道有可用的 TPU。

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

让我们看看 XLA 编译器。

XLA

XLA 是 JAX(和其他库,例如 TensorFlow,TPU的Pytorch)使用的线性代数的编译器,它通过创建自定义优化内核来保证最快的在程序中运行线性代数运算。XLA 最大的好处是可以让我们在应用中自定义内核,该部分使用线性代数运算,以便它可以进行最多的优化。

XLA 最重要的优化是融合,即可以在同一个内核中进行多个线性代数运算,将中间输出保存到 GPU 寄存器中,而不将它们具体化到内存中。这可以显著增加我们的“计算强度”,即所做的工作量与负载和存储数量的比例。融合还可以让我们完全省略仅在内存中shuffle 的操作(例如reshape)。

下面我们看看如何使用 XLA 和 jax.jit 手动触发 JIT 编译。

使用 jax.jit 进行即时编译

这里有一些新的基准来测试 jax.jit 的性能。我们定义了两个实现 SELU(Scaled Exponential Linear Unit)的函数:一个使用 NumPy,一个使用 JAX。暂时先不考虑 jax.jitat

def selu_np(x, alpha=1.67, lmbda=1.05):
return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)def selu_jax(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

然后,我们使用 NumPy 在 1M 个元素的向量上运行它。

# runs on the CPU - numpy
x = np.random.normal(size=(1000000,)).astype(np.float32)
%timeit selu_np(x)
# 100 loops, best of 5: 7.6 ms per loop

平均需要 7.6 毫秒。现在让我们在 CPU 上使用 JAX。

# runs on the CPU - JAX
x = random.normal(key, (1000000,))
%time selu_jax(x).block_until_ready() # 1. measure JAX compilation time
%timeit selu_jax(x).block_until_ready() # 2. measure JAX runtime
# 1. CPU times: user 124 ms, sys: 5.01 ms, total: 129 ms
#   Wall time: 124 ms
# 2. 100 loops, best of 5: 4.8 ms per loop

现在平均需要 4.8 毫秒,在这种情况下比 NumPy 快。下一个测试是在 GPU 上使用 JAX。

# runs on the GPU
x = random.normal(key, (1000000,))
%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time
%time selu_jax(x_jax).block_until_ready() # 2. measure JAX compilation time
%timeit selu_jax(x_jax).block_until_ready() # 3. measure JAX runtime
# 1. CPU times: user 103 µs, sys: 0 ns, total: 103 µs
#   Wall time: 109 µs
# 2. CPU times: user 148 ms, sys: 9.09 ms, total: 157 ms
#   Wall time: 447 ms
# 3. 1000 loops, best of 5: 1.21 ms per loop

函数运行时间为1.21毫秒。下面我们用 jax.jit 测试它,触发 JIT 编译器使用 XLA 将 SELU 函数编译到优化的 GPU 内核中,同时优化函数内部的所有操作。

# runs on the GPU
x = random.normal(key, (1000000,))
selu_jax_jit = jit(selu_jax)
%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time
%time selu_jax_jit(x_jax).block_until_ready() # 2. measure JAX compilation time
%timeit selu_jax_jit(x_jax).block_until_ready() # 3. measure JAX runtime
# 1. CPU times: user 70 µs, sys: 28 µs, total: 98 µs
#   Wall time: 104 µs
# 2. CPU times: user 66.6 ms, sys: 1.18 ms, total: 67.8 ms
#   Wall time: 122 ms
# 3. 10000 loops, best of 5: 130 µs per loop

使用编译内核,函数运行时间为0.13毫秒!

让我们回顾一下不同的运行时间:

  • CPU 上的 NumPy:7.6 毫秒。

  • CPU 上的 JAX:4.8 毫秒(x1.58 加速)。

  • 没有 JIT 的 GPU 上的 JAX:1.21 毫秒(x6.28 加速)。

  • 带有 JIT 的 GPU 上的 JAX:0.13 毫秒(x58.46 加速)。

使用 JIT 编译避免从 GPU 寄存器中移动数据这样给我们带来了非常大的加速。一般来说在不同类型的内存之间移动数据与代码执行相比非常慢,因此在实际使用时应该尽量避免!

将 SELU 函数应用于不同大小的向量时,您可能会获得不同的结果。矢量越大,加速器越能优化操作,加速也越大。

除了执行 selu_jax_jit = jit(selu_jax) 之外,还可以使用 @jit 装饰器对函数进行 JIT 编译,如下所示。

@jit
def selu_jax_jit(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

JIT 编译可以加速,为什么我们不能全部都这样做呢?因为并非所有代码都可以 JIT 编译,JIT要求数组形状是静态的并且在编译时已知。另外就是引入jax.jit 也会带来一些开销。因此通常只有编译的函数比较复杂并且需要多次运行才能节省时间。但是这在机器学习中很常见,例如我们倾编译一个大而复杂的模型,然后运行它进行数百万次训练、损失函数和指标的计算。

使用 jax.grad 自动微分

另一个 JAX 转换是使用 jit.grad() 函数的自动微分。

借助 Autograd ,JAX 可以自动对原生 Python 和 NumPy 代码进行微分。并且支持 Python 的大部分特性,包括循环、if、递归和闭包。

下面看看一个带有 jit.grad() 的代码示例,我们计算一个自定义的包含 JAX 函数的Python 函数的导数。

def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
# [0.25, 0.19661197, 0.10499357]

总结

在本文中,我们了解了 JAX 是什么,并了解了它的一些基本概念:NumPy 接口、JIT 编译、XLA、优化内核、程序转换、自动微分和函数式编程。在 JAX 之上,开源社区为机器学习构建了更多高级库,例如 Flax 和 Haiku。有兴趣的可以搜索查看。

编辑:黄继彦

猜您喜欢:

 戳我,查看GAN的系列专辑~!

一顿午饭外卖,成为CV视觉前沿弄潮儿!

CVPR 2022 | 25+方向、最新50篇GAN论文

 ICCV 2021 | 35个主题GAN论文汇总

超110篇!CVPR 2021最全GAN论文梳理

超100篇!CVPR 2020最全GAN论文梳理

拆解组新的GAN:解耦表征MixNMatch

StarGAN第2版:多域多样性图像生成

附下载 | 《可解释的机器学习》中文版

附下载 |《TensorFlow 2.0 深度学习算法实战》

附下载 |《计算机视觉中的数学方法》分享

《基于深度学习的表面缺陷检测方法综述》

《零样本图像分类综述: 十年进展》

《基于深度神经网络的少样本学习综述》

JAX介绍和快速入门示例相关推荐

  1. Castle Active Record for .NET2.0快速入门示例

    一.创建Web工程 创建一个Web站点或者Web应用程序,添加对Castle.ActiveRecord.dll的引用. 二.创建需要持久化的业务实体 在.NET2.0下,由于引入了泛型,创建业务实体比 ...

  2. TIA博途中变长数组的介绍与使用入门示例

    TIA博途中变长数组的介绍与使用入门示例 使用变长数组的注意事项: 接口参数Array[*] of -,实参使用指定类型数组: 可以使用的范围:FC的Input.InOut,FB的InOut: FC或 ...

  3. 第1讲韩顺平 PHP视频教程 zend framework zend介绍 zend快速入门 韩顺平php视频教程ppt 笔记心得

    韩顺平 PHP视频教程  zend framework  PPT笔记心得 本教程贯穿了两个软件公司常用的两个项目,数据采集系统和购物车,共分20讲从本质上深入浅出的分析了zend framework运 ...

  4. Helm基本介绍及快速入门

    文章目录 Helm基本介绍及快速入门 一.Helm基本介绍 Helm简介 Helm 相关组件及概念 二.Helm部署 Helm客户端安装 校验是否安装成功 三.Helm 使用 使用仓库(helm re ...

  5. c++做界面_Adobe Photoshop基本介绍②,快速入门界面

    Adobe Photoshop基本介绍②,快速入门界面.此篇文章意在讲解Adobe Photoshop界面的介绍.在本篇文章,up主会使用Adobe Photoshop CC 2018 作为讲解.(以 ...

  6. SpringCloud——Gateway(介绍、快速入门、网关集群)

    介绍 网关是微服务最边缘的服务,直接暴露给用户,用来做用户和微服务的桥梁. Gateway是Spring官方提供的用来代替zuul的网关组件 核心逻辑:路由转发 + 执行过滤器链 三大核心概念 Rou ...

  7. [JavaWeb-HTML]HTML概念介绍和快速入门

    HTML 1. 概念:是最基础的网页开发语言* Hyper Text Markup Language 超文本标记语言* 超文本:* 超文本是用超链接的方法,将各种不同空间的文字信息组织在一起的网状文本 ...

  8. HOJ 系统常用功能介绍 部署快速入门 c++ python java编程语言在线自动评测 信息奥赛一本通 USACO GESP 洛谷 蓝桥 CSP NOIP题库

    技术支持微 makytony 服务器配置需求 腾讯云 2H4G 5M 60GB 轻量应用服务器  承载大约 200~400人使用,经过压力测试,评测并发速度可满足130人左右的在线比赛. 系统镜像选 ...

  9. 【JAVA】Dozer 介绍及快速入门教程

    文章目录 概述 使用 安装 入门 XML 映射 注解映射 SpringBoot 集成 结语 概述 Dozer 是什么? Dozer 是 Java Bean 到 Java Bean 的映射器,他以递归的 ...

最新文章

  1. 动态代理(JDK的动态代理)
  2. Windows Thin PC中文化
  3. html不可选择的按钮,HTML功能无法使用按钮
  4. Facebook:苹果谷歌支持HTML5会死啊
  5. mysql 行列转换 动态_mysql 行列动态转换的实现(列联表,交叉表)
  6. 显示器接口VGA、DVI、HDMI、DP
  7. flask如何连接mysql数据库_flask连接mysql数据库
  8. spring核心:bean工厂的装配 3
  9. matlab ct投影数据,CT_projection_and_reconstruction
  10. 【图】二分图最大权匹配
  11. java panel边框_java – 如何在jPanel上设置边框?
  12. javassist组件分享利用javassist动态创建一个类
  13. ES6-模块导入导出
  14. .axf文件_「嵌入式笔记」hex文件、bin文件、axf文件的区别?
  15. 爱可生 mysql监控_actiontech-zabbix-mysql-monitor
  16. Oracle Comment添加表备注和列备注添加和查询comment on table or culumn
  17. java -- 百度API 接口使用
  18. 一个简单的拼音输入法,实现常用汉字的输入
  19. 复制宝贝到淘宝店铺,主图和标题不做修改,是否会被封店?
  20. 测试主管面试必问合集:get 与 post 的区别

热门文章

  1. close_wait状态和time_wait状态
  2. mysql的所有聚合函数_MySQL常用聚合函数详解
  3. jpa SqlQuery casewhen用法
  4. 以太网控制芯片W5300与W5100差异对比
  5. uC-OS2 V2.93 STM32L476 移植:系统移植篇
  6. 上海人工智能实验室面试题
  7. Java script事件详解
  8. Word文档最后一页页码与总页码不一致怎么解决?
  9. System.Runtime.InteropServices.COMException:“没有注册类 (异常来自 HRESULT:0x80040154 (REGDB_E_CLASSNOTREG))”
  10. 我的世界服务器被中断,我的世界说服务器正在维护,我的世界服务器出现这个问题怎么办?...