最新的
原文:https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

JAX快速入门


首先解答一个问题:JAX是什么?

简单的说就是GPU加速、支持自动微分(autodiff)的numpy。众所周知,numpy是Python下的基础数值运算库,得到广泛应用。用Python搞科学计算或机器学习,没人离得开它。但是numpy不支持GPU或其他硬件加速器,也没有对backpropagation的内置支持,再加上Python本身的速度限制,所以很少有人会在生产环境下直接用numpy训练或部署深度学习模型。这也是为什么会出现Theano, TensorFlow, Caffe等深度学习框架的原因。但是numpy有其独特的优势:底层、灵活、调试方便、API稳定且为大家所熟悉(与MATLAB一脉相承),深受研究者的青睐。JAX的主要出发点就是将numpy的以上优势与硬件加速结合。现在已经开源的JAX ( https://github.com/google/jax) 就是通过GPU (CUDA)来实现硬件加速。出自:https://www.zhihu.com/question/306496943/answer/557876584

小宋说:JAX 其实就是一个支持加速器(GPU 和 TPU)的科学计算库(numpy, scipy)和神经网络库(提供relu,sigmoid, conv 等),相较于PyTorch与TensorFlow更加灵活,通用性更佳。这也是笔者推荐学习和做这个翻译工作的原因,带着大家一起去学习掌握这个框架。

由于笔者非英语专业,有些内荣难免翻译有误,欢迎大家批评指正。对于有些笔者不确定的翻译,采用下划线加括号引用原词的方式来补充,例如:自动微分differentiation

官方定义:JAX是CPU,GPU和TPU上的NumPy,具有出色的自动差分differentiation),可用于高性能机器学习研究。

作为更新版本的Autograd,JAX可以自动微分本机Python和NumPy代码。它可以通过Python的大部分功能(包括循环,if,递归和闭包)进行微分,甚至可以采用派生类的派生类。它支持反向模式和正向模式微分,并且两者可以任意顺序组成。

新功能是JAX使用 XLA 在诸如GPU和TPU的加速器上编译和运行您的NumPy代码。默认情况下,编译是在后台进行的,而库调用将得到及时的编译和执行。但是,JAX甚至允许您使用单功能API即时将自己的Python函数编译为XLA优化的内核。编译和自动微分可以任意组合,因此您无需离开Python即可表达复杂的算法并获得最佳性能。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

乘法矩阵

在以下示例中,我们将生成随机数据。NumPy和JAX之间的一大区别是生成随机数的方式。有关更多详细信息,请参见JAX中的Common Gotchas。

key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.372111    0.2642311  -0.18252774 -0.7368198  -0.44030386 -0.15214427-0.6713536  -0.59086424  0.73168874  0.56730247]

乘以两个大矩阵。

size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

489 ms ± 3.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

我们补充说,block_until_ready因为默认情况下JAX使用异步执行(请参见异步调度)。

JAX NumPy函数可在常规NumPy数组上使用。

import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

488 ms ± 942 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

这样比较慢,因为它每次都必须将数据传输到GPU。您可以使用来确保NDArray由设备内存支持device_put()

from jax import device_putx = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

487 ms ± 9.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

的输出device_put()仍然像NDArray一样,但是它仅在需要打印,绘图,保存(printing, plotting, saving)到磁盘,分支等需要它们的值时才将值复制回CPU。的行为device_put()等效于函数,但是速度更快。jit(lambda x: x)

如果您有GPU(或TPU!),这些调用将在加速器上运行,并且可能比在CPU上快得多。

x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)

235 ms ± 546 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

JAX不仅仅是一个由GPU支持的NumPy。它还带有一些程序转换,这些转换在编写数字代码时很有用。目前,主要有三个:

  • jit(),以加快您的代码

  • grad(),用于求梯度(derivatives)

  • vmap(),用于自动矢量化或批处理。

让我们一一介绍。我们还将最终以有趣的方式编写这些内容。

利用jit()加快功能

JAX在GPU上透明运行(如果没有,则在CPU上运行,而TPU即将推出!)。但是,在上面的示例中,JAX一次将内核分配给GPU一次操作。如果我们有一系列操作,则可以使用@jit装饰器使用XLA一起编译多个操作。让我们尝试一下。

def selu(x, alpha=1.67, lmbda=1.05):return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

4.4 ms ± 107 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

我们可以使用加快速度@jit,它将在第一次selu调用jit-compile并将其之后缓存。

selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

860 µs ± 27.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

通过 grad()计算梯度

除了评估数值函数外,我们还希望对其进行转换。一种转变是自动微分。在JAX中,就像在Autograd中一样,您可以使用grad()函数来计算梯度。

def sum_logistic(x):return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]

让我们以极限微分(finite differences)验证我们的结果是正确的。

def first_finite_differences(f, x):eps = 1e-3return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)for v in jnp.eye(len(x))])print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1964569  0.10502338]

求解梯度可以通过简单调用grad()grad()jit()可以任意混合。在上面的示例中,我们先抖动sum_logistic然后取其派生词。我们继续深入学习实验:

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.035325594

对于更高级的autodiff,可以将其jax.vjp()用于反向模式矢量雅各比积和jax.jvp()正向模式雅可比矢量积。两者可以彼此任意组合,也可以与其他JAX转换任意组合。这是组合它们以构成有效计算完整的Hessian矩阵的函数的一种方法:

from jax import jacfwd, jacrev
def hessian(fun):return jit(jacfwd(jacrev(fun)))

自动向量化 vmap()

JAX在其API中还有另一种转换,您可能会发现它有用:vmap()向量化映射。它具有沿数组轴映射函数的熟悉语义( familiar semantics),但不是将循环保留在外部,而是将循环推入函数的原始操作中以提高性能。当与组合时jit(),它的速度可以与手动添加批处理尺寸一样快。

我们将使用一个简单的示例,并使用将矩阵向量乘积提升为矩阵矩阵乘积vmap()。尽管在这种特定情况下很容易手动完成此操作,但是相同的技术可以应用于更复杂的功能。

mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))def apply_matrix(v):return jnp.dot(mat, v)

给定诸如之类的功能apply_matrix,我们可以在Python中循环执行批处理维度,但是这样做的性能通常很差。

def naively_batched_apply_matrix(v_batched):return jnp.stack([apply_matrix(v) for v in v_batched])print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched

4.43 ms ± 9.91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

我们知道如何手动批处理此操作。在这种情况下,jnp.dot透明地处理额外的批次尺寸。

@jit
def batched_apply_matrix(v_batched):return jnp.dot(v_batched, mat.T)print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched

51.9 µs ± 1.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

但是,假设没有批处理支持,我们的功能更加复杂。我们可以用来vmap()自动添加批处理支持。

@jit
def vmap_batched_apply_matrix(v_batched):return vmap(apply_matrix)(v_batched)print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap

79.7 µs ± 249 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

当然,vmap()可以与任意组成jit()grad()和任何其它JAX变换。

这只是JAX可以做的事情。我们很高兴看到您的操作!

『JAX中文文档』JAX快速入门相关推荐

  1. 以太坊智能合约开发,Web3.js API 中文文档 ethereum web3.js入门说明

    以太坊智能合约开发,Web3.js API 中文文档 ethereum web3.js入门说明 为了让你的Ðapp运行上以太坊,一种选择是使用web3.js library提供的web3.对象.底层实 ...

  2. keras中文文档学习笔记—快速上手keras

    keras的核心数据结构是"model",其中最主要的是Sequential模型: Sequential模型调用 from keras.model import Sequentia ...

  3. php sequelize,Sequelize 中文文档 v4 - Getting started - 入门

    Getting started - 入门 此系列文章的应用示例已发布于 GitHub: sequelize-docs-Zh-CN. 可以 Fork 帮助改进或 Star 关注更新. 欢迎 Star. ...

  4. 最新 | Python 官方中文文档正式发布!

    点击上方"AI有道",选择"置顶"公众号 重磅干货,第一时间送达 千呼万唤始出来!Python 官方文档终于发布中文版了!受英语困扰的小伙伴终于可以更轻松地阅读 ...

  5. web3.js 中文文档 入门

    web3.js 中文文档 v1.3.4 入门(Getting Started) web3.js是包含以太坊生态系统功能的模块集合. web3-eth用于以太坊区块链和智能合约. web3-shh是针对 ...

  6. Hugo中文文档 快速开始

    Hugo中文文档 快速开始 安装Hugo 1. 二进制安装(推荐:简单.快速) 到 Hugo Releases 下载对应的操作系统版本的Hugo二进制文件(hugo或者hugo.exe) Mac下直接 ...

  7. Bootstrap 一篇就够 快速入门使用(中文文档)

    目录 一.Bootstrap 简介 什么是 Bootstrap? 历史 为什么使用 Bootstrap? Bootstrap 包的内容 在线实例 Bootstrap 实例 更多实例 Bootstrap ...

  8. Babel 是什么?· Babel 中文文档

    Babel 是一个 JavaScript 编译器 Babel 是一个工具链,主要用于将 ECMAScript 2015+ 版本的代码转换为向后兼容的 JavaScript 语法,以便能够运行在当前和旧 ...

  9. springboot中文文档_登顶 Github 的 Spring Boot 仓库!艿艿写的最肝系列

    源码精品专栏 中文详细注释的开源项目 RPC 框架 Dubbo 源码解析 网络应用框架 Netty 源码解析 消息中间件 RocketMQ 源码解析 数据库中间件 Sharding-JDBC 和 My ...

最新文章

  1. linux 隐藏权限,Linux权限位,s权限,t权限,及隐藏权限
  2. Net 消息中间件 知识小结
  3. Interview:算法岗位面试—10.30上午上海某信息公司(偏图算法)技术面试之单链表反转、给定整型数组和目标值 二分法查找+下午上海某金融公司(AI岗位,上市)CTO和主管技术面试之Xcepti
  4. matlab中predictor怎么填,在MATLAB中求解非線性有限元
  5. boost::container模块实现普通容器的程序
  6. HDU 1525 Euclid's Game
  7. 如何使用系统自带的日志转储功能logroate.存放应用日志
  8. java 什么时候用递归_如果要用Java实现算法,一定慎用递归
  9. 数据结构之栈实现中缀转后缀并计算结果
  10. 持续集成及部署利器:Go
  11. 如何在Apache官网下载Apache shiro
  12. 【全网最强C语言学习】C语言入门(工具)——库函数字典MSDN
  13. C语言CGI编程入门(一)
  14. python爬虫百度网盘_python爬取百度云网盘资源
  15. 一句话,读懂首席架构师、CTO和技术总监的区别
  16. python 气泡图 聚类_R可视化 | 气泡图
  17. Linux设置密码dictionary,Linux中修改密码出现it is based on a dictionary word解决方法
  18. 5GNR漫谈14:TM一致性测试
  19. 计算机工程与设计 北大核心,计算机工程与设计 统计源期刊北大核心期刊
  20. 博文推荐|深度解析如何在 Pulsar 中实现隔离

热门文章

  1. css叠层_css z-index层重叠顺序
  2. UE4 添加多人联机功能
  3. warmp启动图标黄色
  4. Web实时通信Socket.IO兼容浏览器版本IE7IE8IE9IE10
  5. 亚马逊防关联的三不原则|亚马逊电商--店铺运营防关联
  6. https证书中的泛域名和多域名通配符
  7. RabbitMQ订单超时(面试问答)
  8. c++ vector的内存释放
  9. 结构体PLUS(计算结构体类型大小及位段)~~
  10. Kafka实现消息生产和消费