文 / David Budden 与 Matteo Hessel

DeepMind 工程师通过构建工具、对算法进行拓展和创造具有挑战性的虚拟和物理环境来训练和测试人工智能 (AI) 系统,加速我们的研究。作为这项工作的一部分,我们在持续评估机器学习新的库和框架。

近来,我们发现由 Google Research 团队开发的机器学习框架 JAX 为越来越多的项目提供良好支持。JAX 与我们的工程理念产生了很好的共鸣,并在去年被我们的研究社区广泛使用。本文将分享我们的 JAX 使用经验,来说明我们认为它有助于我们 AI 研究的原因,并概述我们正在为支持各地研究人员而建立的生态系统。

  • Google Research
    https://research.google/

  • JAX
    https://github.com/google/jax#jax-autograd-and-xla-

为什么选择 JAX?

JAX 是为高性能数字计算(尤其是机器学习研究)而设计的 Python 库。其用于数值计算的 API 基于 NumPy 这样一个用于科学计算的函数库所构建。得益于 Python 和 NumPy 较高的使用率和知名度,使得 JAX 简洁灵活、易于使用。

  • NumPy
    https://www.nature.com/articles/s41586-020-2649-2

除了其 NumPy API 之外,JAX 还具有一个用于可组合函数的转换的扩展系统,在以下几方面帮助机器学习研究:

  • 微分:梯度优化是 ML 的基础。通过 grad、hessian、jacfwd 和 jacrev 等方法实现了函数转换,JAX 为任意数值函数的正向和反向自动微分提供了原生支持。

  • 自动微分
    https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

  • 向量化:在 ML 研究中,我们经常将一个函数应用于大量数据中,例如计算一个批次数据的损失,或在微分独立学习时评估每个样本的梯度。JAX 通过 vmap 转换实现自动向量化,简化了这种形式的编程。又例如,研究人员在实现新算法时,无需推理批处理。JAX 还提供相关 pmap 转换来支持大规模数据并行,在数据过大时精妙地分配单个加速器内存。

  • 评估每个样本的梯度
    https://arxiv.org/abs/2010.09063

  • JIT 编译:XLA 被用于在 GPU 和 Cloud TPU 加速器上进行及时 (JIT) 编译和执行 JAX 程序。JIT 编译结合 JAX 中与 NumPy 一致的 API,使没有高性能计算经验的研究人员也可以轻松扩展研究至一个或多个加速器上。

  • XLA
    https://tensorflow.google.cn/xla

  • Cloud TPU
    https://cloud.google.com/tpu

我们发现,JAX 帮助新型算法和架构的研究进行快速实验,为近期发表的多篇论文奠定了基础。要了解详情,请参考我们在 NeurIPS 虚拟大会上举办的 JAX 圆桌会议。

  • NeurIPS
    https://neurips.cc/

DeepMind 中的 JAX

对前沿 AI 研究的支持意味着能在快速原型设计与快速迭代间保持平衡的同时,兼顾在传统生产环境中成规模部署的能力。而这一切带来挑战的原因为研究领域发展十分迅速且难以预测。往往一项新的研究突破能在任意时刻改变整个领域发展的方向与需求。在这种瞬息万变的环境中,我们工程团队的核心使命便是确保在研究项目中可以有效复用现有的经验与代码。

一种成熟的方法是模块化:我们将每个研究项目中开发的最重要和最关键的代码块提取至经过测试且高效的组件中。这使得研究人员能够专注研究的同时受益于我们的核心库所实现的算法部分的代码重用、错误修复和性能提升。我们还发现,应该确保每个库都有明确定义的范围,并确保库之间在能够互相调用的同时保证相互独立。增量更新,即使用版本特性时不会受制于其余部分,对于为研究人员提供最大的灵活性并持续支持其选择正确的工作工具至关重要。

JAX 生态系统开发中的其他考虑因素包括确保其与现有 TensorFlow 库(如 Sonnet 和 TRFL)的设计(尽可能)保持一致。我们还构建了(在相关时)尽可能接近其基础数学的组件,以实现自我描述,并最大程度地减少“从纸面到代码”的思维跳转。最后,我们选择将我们的库开源,以促进分享研究成果,并鼓励更广泛的社区探索 JAX 生态系统。

  • TensorFlow 库
    https://tensorflow.google.cn/guide

  • Sonnet
    https://deepmind.com/blog/article/open-sourcing-sonnet

  • TRFL
    https://deepmind.com/blog/article/trfl

最后,我们选择将我们的库开源,以促进分享研究成果,并鼓励更广泛的社区探索 JAX 生态系统。

  • 开源
    https://github.com/deepmind

当今生态系统

Haiku 

可组合函数转换的 JAX 编程模型可能会使对有状态对象的处理复杂化,例如具有可训练参数的神经网络。Haiku 神经网络库允许用户使用常见的面向对象的编程模型,同时利用强劲而便利的 JAX 纯功能范式。

Haiku 的活跃用户包括 DeepMind 和 Google 的数百名研究员,Haiku 也已在多个外部项目(如 Coax、DeepChem、NumPyro)中得到采用。它以 Sonnet 的 API 为基础。Sonnet 是我们在 TensorFlow 中基于模块的神经网络编程模型,我们希望尽可能简化从 Sonnet 到 Haiku 的移植。

  • Sonnet
    https://github.com/deepmind/sonnet

在 GitHub 上了解更多信息。

  • https://github.com/deepmind/dm-haiku

Optax 

梯度优化是 ML 的基础。Optax 提供了梯度转换库以及允许在单行代码中实现许多标准优化器(例如 RMSProp 或 Adam)的合成算子(例如链)。

Optax 的合成性质自然支持在自定义优化器中重组相同的基本成分。此外,它还提供了许多用于随机梯度估算和二阶优化的实用工具。

许多 Optax 用户已经采用 Haiku,但根据我们的增量购买理念,任何以 JAX 树结构表示参数的库都可获得支持(例如 Elegy、Flax 和 Stax)。请在此处查看关于这一丰富多样的 JAX 库生态系统的更多信息。

  • 此处
    https://github.com/google/jax#neural-network-libraries

在 GitHub 上了解更多信息。

  • https://github.com/deepmind/optax

RLax

我们许多最成功的项目都位于深度学习与强化学习 (RL) 的交汇处,也就是深度强化学习。RLax 库为构建 RL 代理提供了实用的构建块。

  • 深度强化学习
    https://deepmind.com/blog/article/deep-reinforcement-learning

RLax 中的组件涵盖了广泛的算法和概念:TD 学习、政策梯度、actor-critic、MAP、近端政策优化、非线性价值转换、一般价值函数和许多探索方法。

虽然提供了一些介绍性的示例代理,但 RLax 并不是用于构建和部署完整 RL 代理系统的框架。Acme 是基于 RLax 组件构建的全功能代理框架示例。

  • 示例代理
    https://github.com/deepmind/rlax/tree/master/examples

  • Acme
    https://deepmind.com/research/publications/Acme

在 GitHub 上了解更多信息。

  • https://github.com/deepmind/rlax

Chex

测试对于软件可靠性至关重要,研究代码也不例外。只有保证研究代码正确,才能从研究实验中得出科学结论。Chex 测试实用工具集合可支持库作者验证通用构建块是否正确耐用,还可支持最终用户检查其实验代码。

Chex 提供了多种实用工具,包括 JAX 感知单元测试、JAX 数据类型的属性断言、mock 和 fake 以及多设备测试环境。Chex 广泛用于 DeepMind 的整个 JAX 生态系统以及 Coax 和 MineRL 等外部项目。

  • Coax
    https://github.com/microsoft/coax

  • MineRL
    https://github.com/dzorlu/minerl

在 GitHub 上了解更多信息。

  • https://github.com/deepmind/chex

Jraph

图神经网络 (GNN) 是一个激动人心的研究领域,包括许多大有前途的应用。例如,我们最近在 Google 地图中的交通预测工作和物理模拟方面的工作。Jraph(发音同“giraffe”)是一个轻量级库,支持在 JAX 中使用 GNN。

  • 交通预测
    https://deepmind.com/blog/article/traffic-prediction-with-advanced-graph-neural-networks

  • 物理模拟
    https://www.youtube.com/watch?v=2Bw5f4vYL98

Jraph 提供了标准化的图数据结构,用于处理图的一组实用程序,以及易于分叉和可扩展的图神经网络模型的“zoo”。包括其他关键特性:有效利用硬件加速器的 GraphTuples 批处理,通过填充和遮蔽对可变形图的 JIT 编译支持,以及在输入分区上定义的损失。与 Optax 和我们的其他库一样,Jraph 对用户的神经网络库选择没有任何限制。

从我们丰富的示例中详细了解如何使用库。

  • 示例
    https://github.com/deepmind/jraph/tree/master/jraph/examples

在 GitHub 上了解更多信息。

  • https://github.com/deepmind/jraph

我们的 JAX 生态系统正在不断发展,我们希望 ML 研究社区能够探索我们的库和 JAX 的潜力,从而加速自己的研究。

  • 我们的库
    https://deepmind.com/research?filters=%7B%22collection%22:%5B%22OpenSource%22%5D%7D

引用 DeepMind JAX 生态系统

如果您发现 DeepMind JAX 生态系统有助于您的工作,请使用此引用(托管在 GitHub 上)。

  • 此引用
    https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt

更多 AI 相关阅读:

  • 步履不停:TensorFlow 2.4 新功能一览!

  • URL2Video 流水线:网页端自动创作视频的实现

  • 新一代端侧声音过滤方案:VoiceFilter-Lite

  • 通过 Performer 架构再探注意力机制

  • 衡量预训练 NLP 模型中的性别相关性

 点击屏末 | 阅读原文 | 探索 JAX 库

不断发展的 JAX:加速 AI 研究的利器相关推荐

  1. Facebook开源Torchnet,加速AI研究

    OpenStack Days China将于7月14-15日在北京国家会议中心举办,届时包括OpenStack基金会的Jonathan Bryce.Mark Collier.Alan Clark等大牛 ...

  2. 专访普林斯顿大学贡三元教授:做 AI 研究要有价值观,数学更是「制胜法宝」

    https://mp.weixin.qq.com/s?__biz=MzI5NTIxNTg0OA==&mid=2247495153&idx=1&sn=71d58ac0b3dc50 ...

  3. Yann LeCun:未来几十年AI研究的最大挑战是「预测世界模型」

    来源:机器之心 本文约4000字,建议阅读8分钟 本文为你介绍一种叫做分层 JEPA(联合嵌入预测架构)的架构. LeCun 认为,构造自主 AI 需要预测世界模型,而世界模型必须能够执行多模态预测, ...

  4. 图灵奖获得者Yann LeCun:未来几十年AI研究的最大挑战是「预测世界模型」

    来源:机器之心 LeCun 认为,构造自主 AI 需要预测世界模型,而世界模型必须能够执行多模态预测,对应的解决方案是一种叫做分层 JEPA(联合嵌入预测架构)的架构.该架构可以通过堆叠的方式进行更抽 ...

  5. IEEE Fellow、AI大牛田奇加入华为云!他为何而来?“加速AI基础研究落地”

    金磊 发自 凹非寺 量子位 报道 | 公众号 QbitAI 云+AI正在驱动的数字化.智能化变革趋势,现如今与新基建相辅相成,早已深入人心. 作为面向企业.行业和产业的基础变革,华为云入局时,最被外界 ...

  6. Github1.3万星,迅猛发展的JAX对比TensorFlow、PyTorch

    点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:机器之心 AI博士笔记系列推荐 周志华<机器学习>手推笔记正式开源!可打印版本附pdf下载链接 J ...

  7. Github 1.3万星,迅猛发展的JAX对比TensorFlow、PyTorch

    ©作者 | 机器之心编辑部 来源 | 机器之心 在机器学习领域,大家可能对 TensorFlow 和 PyTorch 已经耳熟能详,但除了这两个框架,一些新生力量也不容小觑,它就是谷歌推出的 JAX. ...

  8. 助AI研究社群发出内建18种预先训练模型工具

    社群发布协助重现机器学习研究的工具PyTorch Hub,目前为测试版,透过简单的API和工作流程,提供开发者基本的模型,来重现机器学习相关的研究,PyTorch Hub包含多种预先训练的模型repo ...

  9. 聚焦场景 共建生态 加速AI落地——2018中国人工智能应用与生态峰会成功举办

    上千规模的参会人员,上百位的加盟AI专家,上百个的AI产品与应用现场展示,众多直播平台的现场直播,精准的"场景应用 智能平台"会议主题,十几位大咖的干货分享,精彩的观点碰撞,注定了 ...

最新文章

  1. 深度丨AI界的七大未解之谜:OpenAI丢出一组AI研究课题
  2. OpenAI详细解析:攻击者是如何使用「对抗样本」攻击机器学习的
  3. EXCEL自定义的应用
  4. sql语句技巧,不敢独享,特此呈上
  5. POJ 1723 Soldiers (中位数)
  6. jboss eap 7.0_创建委托登录模块(用于JBoss EAP 6.1)
  7. mysql:多表查询方式
  8. 如何更新google chrome浏览器
  9. macpro如何清理磁盘空间_在MacBook上,释放磁盘空间的7种方法
  10. 2021-03-23美团面试
  11. HUAWEI 机试题:最长元音字串的长度
  12. 关于CISC和RISC的一些总结
  13. 一种云化busybox demolets的设想和一种根本降低编程实践难度的设想:免部署无语法编程
  14. 微信小程序-组件样式覆盖
  15. Rosalind第五题:计算GC内容
  16. FutureTask源码解析
  17. 计算机理论导引 试卷,计算机理论导引实验————ADFA的可判定性
  18. linux磁盘无法识别移动硬盘
  19. 那些让你惊掉下巴到肚皮上的python冷知识(二)
  20. Python爬虫爬取链家网上的房源信息练习

热门文章

  1. 部署Faster-RCNN TensorFlow版本
  2. C#中ref、out类型参数的区别和params类型参数的用法
  3. 【量化交易】资产配置决策
  4. 月薪3万的一道面试题---看看你的IQ
  5. 使用开源代码搭建资产管理系统
  6. 苏轼与江西交通的不解之缘
  7. Python中imread()函数
  8. Java 单向链表和单向循环链表的代码实现
  9. 水壶问题解法和原理解析
  10. JAVA数据结构与算法之斐波那契查找(黄金分割点)