来源:机器之心 本文约4600字,建议阅读10+分钟
你有在使用JAX吗?
近年来,谷歌于 2018 年推出的 JAX 迎来了迅猛发展,很多研究者对其寄予厚望,希望它可以取代 TensorFlow 等众多深度学习框架。但 JAX 是否真的适合所有人使用呢?这篇文章对 JAX 的方方面面展开了深入探讨,希望可以给研究者选择深度学习框架时提供有益的参考。
即时编译(Just-in-Time Compilation) 自动并行化(Automatic Parallelization) 自动向量化(Automatic Vectorization) 自动微分(Automatic Differentiation)
NumPy 加速器。NumPy 是使用 Python 进行科学计算的基础包之一,但它仅与 CPU 兼容。JAX 提供了 NumPy 的实现(具有几乎相同的 API),可以非常轻松地在 GPU 和 TPU 上运行。对于许多用户而言,仅此一项功能就足以证明使用 JAX 的合理性; XLA。XLA(Accelerated Linear Algebra)是专为线性代数设计的全程序优化编译器。JAX 建立在 XLA 之上,显著提高了计算速度上限; JIT。JAX 允许用户使用 XLA 将自己的函数转换为即时编译(JIT)版本。这意味着可以通过在计算函数中添加一个简单的函数装饰器(decorator)来将计算速度提高几个数量级; Auto-differentiation。JAX 将 Autograd(自动区分原生 Python 代码和 NumPy 代码)和 XLA 结合在一起,它的自动微分能力在科学计算的许多领域都至关重要。JAX 提供了几个强大的自动微分工具; 深度学习。虽然 JAX 本身不是深度学习框架,但它的确为深度学习提供了一个很好的基础。很多构建在 JAX 之上的库旨在提供深度学习功能,包括 Flax、Haiku 和 Elegy。甚至在最近的一些 PyTorch 与 TensorFlow 文章中强调了 JAX 作为一个值得关注的「框架」,并推荐其用于基于 TPU 的深度学习研究。JAX 对 Hessians 的高效计算也与深度学习相关,因为它们使高阶优化技术更加可行; 通用可微分编程范式(General Differentiable Programming Paradigm )。虽然我们可以使用 JAX 来构建和训练深度学习模型,但它也为通用可微编程提供了一个框架。这意味着 JAX 可以通过使用基于模型的机器学习方法来解决问题,从而可以利用数十年研究建立起的给定领域的先验知识。
Grad() 进行自动微分; Vmap() 自动向量化; Pmap() 并行化计算; Jit() 将函数转换为即时编译版本。
JAX 仍然被官方认为是一个实验性框架。JAX 是一个相对「年轻」的项目。目前,JAX 仍被视为一个研究项目,而不是成熟的谷歌产品,因此如果用户正在考虑迁移到 JAX,请记住这一点; 使用 JAX 一定要勤勉。调试的时间成本,或者更严重的是,未跟踪副作用(untracked side effects)的风险可能导致那些没有扎实掌握函数式编程的用户不适用 JAX。在开始将它用于正式项目之前,请确保自己了解使用 JAX 的常见缺陷; JAX 没有针对 CPU 计算进行优化。鉴于 JAX 是以「加速器优先」的方式开发的,因此每个操作的分派并未针对 JAX 进行完全优化。在某些情况下,NumPy 实际上可能比 JAX 更快,尤其是对于小型程序而言,这是因为 JAX 引入了开销; JAX 与 Windows 不兼容。目前在 Windows 上不支持 JAX。如果用户使用 Windows 系统但仍想尝试 JAX,可以使用 Colab 或将其安装在虚拟机(VM)上。