Google JAX
開發者 | |
---|---|
首次发布 | 2019年10月31日[1] |
当前版本 |
|
预览版本 | v0.3.13(2022年5月16日 | )
源代码库 | github |
编程语言 | Python, C++ |
操作系统 | Linux, macOS, Windows |
平台 | Python, NumPy |
类型 | 机器学习 |
许可协议 | Apache 2.0 |
网站 | jax |
Google JAX,是Google开发的用于变换数值函数的Python机器学习框架[3][4][5]。它结合了修改版本的Autograd(自动通过函数的微分获得其梯度函数)[6],和TensorFlow的XLA(加速线性代数)[7]。它被设计为尽可能的遵从NumPy的结构和工作流程,并协同工作于各种现存的框架如TensorFlow和PyTorch[8][9]。
主要功能
[编辑]JAX的主要功能是[3]:
- grad:自动微分,
- jit:即时编译,
- vmap:自动向量化,
- pmap:SPMD编程。
grad
[编辑]下面的代码演示grad
函数的自动微分。
# 导入库 from jax import grad import jax.numpy as jnp # 定义logistic函数 def logistic(x): return jnp.exp(x) / (jnp.exp(x) + 1) # 获得logistic函数的梯度函数 grad_logistic = grad(logistic) # 求值logistic函数在x = 1处的梯度 grad_log_out = grad_logistic(1.0) print(grad_log_out)
最终的输出为:
0.19661194
jit
[编辑]下面的代码演示jit
函数的优化。
# 导入库 from jax import jit import jax.numpy as jnp # 定义cube函数 def cube(x): return x * x * x # 生成数据 x = jnp.ones((10000, 10000)) # 创建cube函数的jit版本 jit_cube = jit(cube) # 应用cube函数和jit_cube函数于相同数据来比较其速度 cube(x) jit_cube(x)
可见jit_cube
的运行时间显著的短于cube
。
vmap
[编辑]下面的代码展示vmap
函数的通过SIMD的向量化。
# 导入库 from functools import partial from jax import vmap import jax.numpy as jnp # 定义函数 def grads(self, inputs): in_grad_partial = partial(self._net_grads, self._net_params) grad_vmap = vmap(in_grad_partial) rich_grads = grad_vmap(inputs) flat_grads = np.asarray(self._flatten_batch(rich_grads)) assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0] return flat_grads
pmap
[编辑]下面的代码展示pmap
函数的对矩阵乘法的并行化。
# 从JAX导入pmap和random;导入JAX NumPy from jax import pmap, random import jax.numpy as jnp # 生成2个维度为5000 x 6000的随机数矩阵,每设备一个 random_keys = random.split(random.PRNGKey(0), 2) matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys) # 没有数据传输,并行的在每个CPU/GPU上进行局部矩阵乘法 outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices) # 没有数据传输,并行的在每个CPU/GPU上分别求取这两个矩阵的均值 means = pmap(jnp.mean)(outputs) print(means)
最终的输出为:
[1.1566595 1.1805978]
使用JAX的库
[编辑]一些Python库使用JAX作为后端,这包括:
- Flax,最初由Google Brain开发的高层人工神经网络库[10]。
- Equinox,将参数化函数(包括人工神经网络)表示为PyTree的库。它由Patrick Kidger创建[11]。
- Diffrax,用于求微分方程的数值解的库,比如解常微分方程和随机微分方程[12]。
- Optax,DeepMind开发的用于梯度处理和最优化的库[13]。
- Lineax,用于解线性方程组和线性最小二乘法[14]。
- RLax,DeepMind开发的用于强化学习的库[15]
- jraph,DeepMind开发的图神经网络库[16]。
- jaxtyping,用于为阵列或张量的形状和数据类型增加类型标注的库[17]。
- NumPyro,概率编程库[18]。
- Brax,物理引擎[19]。
参见
[编辑]引用
[编辑]- ^ jax-v0.1.49.
- ^ https://github.com/google/jax/releases/tag/jax-v0.4.24.
- ^ 3.0 3.1 Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao, JAX: Autograd and XLA, Astrophysics Source Code Library (Google), 2022-06-18 [2022-06-18], Bibcode:2021ascl.soft11002B, (原始内容存档于2022-06-18)
- ^ Frostig, Roy; Johnson, Matthew James; Leary, Chris. Compiling machine learning programs via high-level tracing (PDF). MLsys. 2018-02-02: 1–3. (原始内容存档 (PDF)于2022-06-21).
- ^ Using JAX to accelerate our research. www.deepmind.com. [2022-06-18]. (原始内容存档于2022-06-18) (英语).
- ^ autograd. [2023-09-23]. (原始内容存档于2022-07-18).
- ^ XLA. [2023-09-23]. (原始内容存档于2022-09-01).
- ^ Lynley, Matthew. Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta. Business Insider. [2022-06-21]. (原始内容存档于2022-06-21) (美国英语).
- ^ Why is Google's JAX so popular?. Analytics India Magazine. 2022-04-25 [2022-06-18]. (原始内容存档于2022-06-18) (美国英语).
- ^ Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29 [2022-07-29], (原始内容存档于2022-09-03)
- ^ Kidger, Patrick, Equinox, 2022-07-29 [2022-07-29], (原始内容存档于2023-09-19)
- ^ Kidger, Patrick, Diffrax, 2023-08-05 [2023-08-08], (原始内容存档于2023-08-10)
- ^ Optax, DeepMind, 2022-07-28 [2022-07-29], (原始内容存档于2023-06-07)
- ^ Lineax, Google, 2023-08-08 [2023-08-08], (原始内容存档于2023-08-10)
- ^ RLax, DeepMind, 2022-07-29 [2022-07-29], (原始内容存档于2023-04-26)
- ^ Jraph - A library for graph neural networks in jax., DeepMind, 2023-08-08 [2023-08-08], (原始内容存档于2022-11-23)
- ^ jaxtyping, Google, 2023-08-08 [2023-08-08], (原始内容存档于2023-08-10)
- ^ NumPyro - Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU. [2022-08-31]. (原始内容存档于2022-08-31).
- ^ Brax - Massively parallel rigidbody physics simulation on accelerator hardware. [2022-08-31]. (原始内容存档于2022-08-31).
外部链接
[编辑]- Documentationː jax
.readthedocs .io - Colab (Jupyter/iPython) Quickstart Guideː colab
.research .google .com /github /google /jax /blob /main /docs /notebooks /quickstart .ipynb - TensorFlow's XLAː www
.tensorflow .org /xla (Accelerated Linear Algebra) - YouTube上的Intro to JAX: Accelerating Machine Learning research
- Original paperː mlsys
.org /Conferences /doc /2018 /146 .pdf