Roy Frostig - JAX: accelerating machine learning research by composing function transformations in Python
JAX is a system for high-performance machine learning research and numerical computing. It offers the familiarity of Python+NumPy together with hardware acceleration, plus a set of composable function transformations: automatic differentiation, automatic batching, end-to-end compilation (via XLA), parallelizing over multiple accelerators, and more. JAX's core strength is its guarantee that these user-wielded transformations can be composed arbitrarily, so that programmers can write math (e.g. a loss function) and transform it into pieces of an ML program (e.g. a vectorized, compiled, batch gradient function for that loss).
JAX had its open-source release in December 2018 (https://github.com/google/jax). It's used by researchers for a wide range of applications, from studying training dynamics of neural networks, to probabilistic programming, to scientific applications in physics and biology.