In this tutorial, we will delve into the development of a Neural Network (NN) using JAX, focusing on the Transformer model. As JAX gains popularity, more developer teams are experimenting with it and integrating it into their projects. Despite being less mature than TensorFlow or PyTorch, JAX offers great features for building and training Deep Learning models.
For a solid foundation in JAX basics, refer to my previous article if you haven’t already. The full code can also be found in our Github repository.
One common issue people face when starting with JAX is choosing a framework. The DeepMind team has released several frameworks built on JAX. Some of the most well-known ones include:
– Haiku: A go-to framework for Deep Learning, widely used by Google and DeepMind internal teams.
– Optax: A library for gradient processing and optimization with out-of-the-box optimizers and mathematical operations.
– RLax: A reinforcement learning framework featuring various RL subcomponents and operations.
– Chex: A library of utilities for testing and debugging JAX code.
– Jraph: A Graph Neural Networks library in JAX.
– Flax: A neural network library with various ready-to-use modules, optimizers, and utilities.
– Objax: An ML library that focuses on object-oriented programming and code readability.
– Trax: An end-to-end library for deep learning focusing on Transformers.
– JAXline: A supervised-learning library for distributed JAX training and evaluation.
– ACME: A research framework for reinforcement learning.
– JAX-MD: A framework for molecular dynamics.
– Jaxchem: A niche library emphasizing chemical modeling using JAX.
Choosing the right framework can be challenging, but starting with popular options like Haiku and Flax, which are widely used and have active communities, can be a good starting point. In this article, we will begin with Haiku and see if we need another framework later on.
Are you ready to build a Transformer with JAX and Haiku? Please make sure you have a solid understanding of Transformers before diving in.
Let’s start with the self-attention block.
The self-attention block
First, we need to import JAX and Haiku:
import jax
import jax.numpy as jnp
import haiku as hk
import numpy as np
Luckily, Haiku provides a built-in MultiHeadAttention block that can be extended to create a masked self-attention block. Our block accepts query, key, value, and mask as input and returns the output as a JAX array. The code resembles standard PyTorch or TensorFlow code, where we build a causal mask using np.trill(), nullify all elements of the array above the kth, multiply it with our mask, and pass everything into the hk.MultiHeadAttention module.
class SelfAttention(hk.MultiHeadAttention):
”””Self attention with a causal mask applied.”””
def __call__(
self,
query: jnp.ndarray,
key: Optional[jnp.ndarray] = None,
value: Optional[jnp.ndarray] = None,
mask: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
key = key if key is not None else query
value = value if value is not None else query
seq_len = query.shape[1]
causal_mask = np.tril(np.ones((seq_len, seq_len)))
mask = mask * causal_mask if mask is not None else causal_mask
return super().__call__(query, key, value, mask)
This snippet introduces the key principle of Haiku, where all modules should be a subclass of hk.Module. This ensures that they implement __init__ and __call, similar to PyTorch modules where we implement __init__ and a forward function.
To illustrate this clearly, let’s build a simple 2-layer Multilayer Perceptron as an hk.Module, which we will use in the Transformer later.
The linear layer
A simple 2-layer MLP is structured as follows:
class DenseBlock(hk.Module):
”””A 2-layer MLP”””
def __init__(self,
init_scale: float,
widening_factor: int = 4,
name: Optional[str] = None):
super().__init__(name=name)
self._init_scale = init_scale
self._widening_factor = widening_factor
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
hiddens = x.shape[-1]
initializer = hk.initializers.VarianceScaling(self._init_scale)
x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)
x = jax.nn.gelu(x)
return hk.Linear(hiddens, w_init=initializer)(x)
Points to note:
– Haiku provides various weight initializers under hk.initializers.
– It includes popular layers like hk.Linear.
– Activation functions like relu or softmax are available in JAX’s jax.nn subpackage.
The normalization layer
Layer normalization is essential in the transformer architecture, and Haiku offers common modules that include it.
def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray:
”””Apply a unique LayerNorm to x with default settings.”””
return hk.LayerNorm(axis=-1,
create_scale=True,
create_offset=True,
name=name)(x)
The transformer
Now, let’s look at the Transformer model, which utilizes our predefined modules. In the `__init__` function, basic variables such as the number of layers, attention heads, and dropout rate are defined. The `__call__` function composes a list of blocks using a for loop.
Each block includes:
Finally, a final normalization layer is added at the end.
class Transformer(hk.Module):
”””A transformer stack.”””
def __init__(self,
num_heads: int,
num_layers: int,
dropout_rate: float,
name: Optional[str] = None):
super().__init__(name=name)
self._num_layers = num_layers
self._num_heads = num_heads
self._dropout_rate = dropout_rate
def __call__(self,
h: jnp.ndarray,
mask: Optional[jnp.ndarray],
is_training: bool) -> jnp.ndarray:
”””Connects the transformer.”””
init_scale = 2. / self._num_layers
dropout_rate = self._dropout_rate if is_training else 0.
if mask is not None:
mask = mask[:, None, None, :]
for i in range(self._num_layers):
h_norm = layer_norm(h, name=f’h{i}_ln_1′)
h_attn = SelfAttention(
num_heads=self._num_heads,
key_size=64,
w_init_scale=init_scale,
name=f’h{i}_attn’)(h_norm, mask=mask)
h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
h = h + h_attn
h_norm = layer_norm(h, name=f’h{i}_ln_2′)
h_dense = DenseBlock(init_scale, name=f’h{i}_mlp’)(h_norm)
h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
h = h + h_dense
h = layer_norm(h, name=’ln_f’)
return h
Building a Neural Network with JAX is relatively straightforward and powerful.
The embeddings layer
For completeness, let’s include the embeddings layer. Haiku offers an embedding layer that generates tokens from input sentences. These tokens are combined with positional embeddings to create the final input.
def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int):
tokens = data[‘obs’]
input_mask = jnp.greater(tokens, 0)
seq_length = tokens.shape[1]
embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
token_embs = token_embedding_map(tokens)
positional_embeddings = hk.get_parameter(
’pos_embs’, [seq_length, d_model], init=embed_init)
input_embeddings = token_embs + positional_embeddings
return input_embeddings, input_mask
The use of hk.get_parameter allows access to the trainable parameters of a module. The API serves the purpose of converting the code into a pure function using hk.transform.
Why pure functions?
JAX’s power lies in its function transformations – vectorization, parallelization, and just-in-time compilation. To transform a function, it must be pure, meaning:
- The function returns the same result for identical inputs.
- The function has no side effects.
Haiku simplifies this transformation with hk.transform, enabling the use of automatic differentiation, parallelization, and other features. Let’s continue with the training process of our Transformer model.
The forward pass
A typical forward pass involves:
- Computing input embedding from input data.
- Running data through Transformer blocks.
- Returning the output.
These steps can be easily assembled using JAX as shown below:
def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,
num_layers: int, dropout_rate: float):
”””Create the model’s forward pass.”””
def forward_fn(data: Mapping[str, jnp.ndarray],
is_training: bool = True) -> jnp.ndarray:
”””Forward pass.”””
input_embeddings, input_mask = embeddings(data, vocab_size)
transformer = Transformer(
num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)
output_embeddings = transformer(input_embeddings, input_mask, is_training)
return hk.Linear(vocab_size)(output_embeddings)
return forward_fn
The structured code facilitates forward pass implementation by integrating input embedding, transformer blocks, and output in a coherent manner.
The loss function
The loss function, a cross-entropy function accounting for the mask, utilizes features like one_hot and log_softmax provided by JAX.
def lm_loss_fn(forward_fn,
vocab_size: int,
params,
rng,
data: Mapping[str, jnp.ndarray],
is_training: bool = True) -> jnp.ndarray:
”””Compute the loss on data with respect to parameters.”””
logits = forward_fn(params, rng, data, is_training)
targets = jax.nn.one_hot(data[‘target’], vocab_size)
assert logits.shape == targets.shape
mask = jnp.greater(data[‘obs’], 0)
loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)
loss = jnp.sum(loss * mask) / jnp.sum(mask)
return loss
Implementing the loss function encapsulates the crucial step of calculating the model’s loss against the parameters with due consideration for the mask.
The training loop
As Jax and Haiku lack built-in optimization functionalities, Optax comes into play for gradient processing. Optax’s key component is the GradientTransformation, defined by two functions – __init__ and __update__, for state initialization and gradient transformation, respectively. The process also involves using Python’s functools.partial for creating new functions with fewer arguments.
The GradientUpdater class accepts the model, loss function, and optimizer. The model is a pure forward_fn function transformed by hk.transform. The loss function is created using partial with fixed forward_fn and vocab_size parameters. The optimizer is a set of optimization transformations.
A snippet of the GradientUpdater:
class GradientUpdater:
”””A stateless abstraction around an init_fn/update_fn pair.”””
def __init__(self, net_init, loss_fn,
optimizer: optax.GradientTransformation):
self._net_init = net_init
self._loss_fn = loss_fn
self._opt = optimizer
@functools.partial(jax.jit, static_argnums=0)
def init(self, master_rng, data):
out_rng, init_rng = jax.random.split(master_rng)
params = self._net_init(init_rng, data)
opt_state = self._opt.init(params)
out = dict(
step=np.array(0),
rng=out_rng,
opt_state=opt_state,
params=params,
)
return out
@functools.partial(jax.jit, static_argnums=0)
def update(self, state: Mapping[str, Any], data: Mapping[str, jnp.ndarray]):
rng, new_rng = jax.random.split(state[‘rng’])
params = state[‘params’]
loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)
updates, opt_state = self._opt.update(g, state[‘opt_state’])
params = optax.apply_updates(params, updates)
new_state = {
’step’: state[‘step’] + 1,
’rng’: new_rng,
’opt_state’: opt_state,
’params’: params,
}
metrics = {
’step’: state[‘step’],
’loss’: loss,
}
return new_state, metrics
The updater initializes the optimizer state and provides loss updates while enhancing readability and efficiency with JAX’s functionalities.
Finally, the training loop combines all the ideas and codes discussed thus far in the Transformer model training process.