JAX is an emerging player in the Machine Learning (ML) realm, offering a new approach to programming that is intuitive, structured, and clean. Despite being fundamentally different, it has the potential to replace established frameworks like Tensorflow and PyTorch.
As one of my friends put it, we used to have all the Aces, Kings, and Queens, but now we have JAX.
This article delves into the world of JAX, exploring what it is and why it stands out among other libraries. Using code snippets, we showcase the power of JAX and highlight some key features.
If you’re curious, let’s dive in.
What is Jax?
Jax is a Python library designed for high-performance ML research. It serves as a numerical computing library, similar to Numpy, with some key enhancements. Developed by Google, it is utilized internally by Google and Deepmind teams.
Source: JAX documentation
Install JAX
Before delving into the main advantages of JAX, it’s recommended to install it in your Python environment or Google Colab to follow along and run the code yourself. The full code will be provided at the end of the article.
To install JAX, use pip from the command line:
$ pip install –upgrade jax jaxlib
If you want to support GPU, you’ll need CUDA and cuDNN installed, followed by running the command:
$ pip install –upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
For troubleshooting, refer to the official Github instructions.
Now let’s import JAX alongside Numpy. Numpy will be used to compare different use cases.
import jax
import jax.numpy as jnp
import numpy as np
JAX basics
JAX’s primary focus is on performing numerical operations expressively and with high performance, similar to Numpy. However, the key difference lies behind the scenes.
The DeviceArray
JAX enables running the same program on hardware accelerators like GPUs and TPUs seamlessly. This is achieved through DeviceArray, which replaces Numpy’s standard array, keeping values in the accelerator and fetching them only when needed.
We can utilize DeviceArrays similar to standard arrays, passing them to other libraries, plotting graphs, differentiating, and more. Moreover, JAX supports the majority of Numpy’s API, resulting in almost identical code between JAX and Numpy.
JAX also excels in speed, offering faster calculations. To demonstrate this, consider creating two arrays of size (1000, 1000) with Numpy and JAX, and calculating their inner product.
Let’s time these operations:
x = np.random.rand(1000,1000)
y = jnp.array(x)
%timeit -n 1 -r 1 np.dot(x,x)
%timeit -n 1 -r 1 jnp.dot(y,y).block_until_ready()
Impressive, isn’t it? The speed improvement is significant, especially on GPUs. Note the use of the block_until_ready() function due to JAX’s asynchronous nature, requiring waiting for completion to accurately measure time.
But wait, there’s more to JAX…
Why JAX?
If speed and automatic GPU support aren’t convincing enough, let’s dive deeper. JAX can be viewed as a set of function transformations on regular Python and Numpy code.
One such transformation is differentiation. Does JAX support automatic differentiation?
Of course!
Auto differentiation with grad() function
JAX can differentiate through various Python and NumPy functions, including loops, branches, recursions, and more. This capability is invaluable for Deep Learning applications, making backpropagation effortless with the grad() function.
For instance, consider defining a simple quadratic function and computing its derivative at point 1.0.
To validate the result, manually compute the derivative as well.
from jax import grad
def f(x):
return 3*x**2 + 2*x + 5
def f_prime(x):
return 6*x +2
grad(f)(1.0)
f_prime(1.0)
Surprisingly, JAX performs analytical gradient solve under the hood, relying on the function’s form and the chain rule. For deeper insights into automatic differentiation, refer to the official documentation.
Accelerated Linear Algebra (XLA compiler)
JAX’s speed is partly attributed to Accelerated Linear Algebra, specifically XLA. XLA is a domain-specific compiler for linear algebra widely used in Tensorflow.
To optimize matrix operations, code is compiled into computation kernels tailored to code nature for extensive optimization.
These optimizations include:
Just in time compilation (jit)
JIT compilation is essential for leveraging XLA’s power, necessitating code compilation into XLA kernels. This is where jit comes into play.
Just-in-time compilation involves compiling code at runtime rather than before execution, maximizing XLA’s capabilities.
To use XLA and jit, you can employ the jit() function or the @jit annotation.
from jax import jit
x = np.random.rand(1000,1000)
y = jnp.array(x)
def f(x):
g = jit(f)
%timeit -n 5 -r 5 f(y).block_until_ready()
%timeit -n 5 -r 5 g(y).block_until_ready()
jit can be combined with grad transformation for accelerated backpropagation, enhancing performance significantly.
Though jit has limitations, especially with “if” branches, it proves immensely useful for most deep learning scenarios.
Replicate computation across devices with pmap
pmap allows replicating computations across multiple cores or devices, enabling parallel execution. It automatically distributes computation across devices and handles communication between them.
To inspect available devices, you can use jax.devices().
from jax import pmap
def f(x):
f(np.arange(4))
pmap(f)(np.arange(4))
DeviceArray transforms into ShardedDeviceArray, facilitating parallel execution.
pmap also enables collective communication between devices. For instance, performing a “reduce” operation involves summing values across devices. This operation is executed as follows:
from functools import partial
from jax.lax import psum
@partial(pmap, axis_name=”i”)
def normalize(x):
normalize(np.arange(8.))
With pmap, you can define custom computation patterns to optimize device usage, akin to CUDA for individual cores but for separate devices.
Automatic vectorization with vmap
vmap transforms functions into vectorized forms, enabling them to accept batches of data points. This vectorization improves speed and memory utilization significantly.
For a deeper dive into JAX’s features, consider exploring its official documentation.
Pseudo-Random number generator
JAX’s random number generator operates differently from Numpy’s, requiring an explicit PRNG state as the first argument for all random functions. This approach ensures robust vectorization and parallel computation across devices.
from jax import random
key = random.PRNGKey(5)
random.uniform(key)
Asynchronous dispatch
JAX employs asynchronous dispatch, returning DeviceArray as a future to the Python program without waiting for operations to complete. This asynchronous execution allows queuing operations for hardware accelerators without halting Python code.
Profiling JAX and Device memory profiler
JAX supports profiling via Tensorboard and Nvidia’s Nsight, for GPU code analysis. Additionally, JAX offers a built-in Device Memory Profiler to visualize GPU and TPU execution, aiding in optimizing memory usage and performance.
For further exploration, experiment with profiling tools and enhance your understanding of JAX’s capabilities.
Conclusion
This article provided an overview of JAX’s advantages and showcased simple code snippets illustrating basic syntax and intricacies. For comprehensive code examples, refer to the provided Colab notebook or Github repository.
Stay tuned for upcoming articles exploring building, training, and scaling deep neural networks with JAX, along with insights into frameworks leveraging JAX.
If you enjoyed this article, feel free to share it on social media.
References
Deep Learning in Production Book 📖
Learn how to build, train, deploy, scale and maintain deep learning models. Understand ML infrastructure and MLOps using hands-on examples. Learn more
* Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.