Creating a Transformer model using JAX: A step-by-step guide to building and training your own models

In this tutorial, we will delve into the development of a Neural Network (NN) using JAX, focusing on the Transformer model. As JAX gains popularity, more developer teams are experimenting with it and integrating it into their projects. Despite being less mature than TensorFlow or PyTorch, JAX offers great features for building and training Deep Learning models.

For a solid foundation in JAX basics, refer to my previous article if you haven’t already. The full code can also be found in our Github repository.

One common issue people face when starting with JAX is choosing a framework. The DeepMind team has released several frameworks built on JAX. Some of the most well-known ones include:

– Haiku: A go-to framework for Deep Learning, widely used by Google and DeepMind internal teams.

– Optax: A library for gradient processing and optimization with out-of-the-box optimizers and mathematical operations.

– RLax: A reinforcement learning framework featuring various RL subcomponents and operations.

– Chex: A library of utilities for testing and debugging JAX code.

– Jraph: A Graph Neural Networks library in JAX.

– Flax: A neural network library with various ready-to-use modules, optimizers, and utilities.

– Objax: An ML library that focuses on object-oriented programming and code readability.

– Trax: An end-to-end library for deep learning focusing on Transformers.

– JAXline: A supervised-learning library for distributed JAX training and evaluation.

– ACME: A research framework for reinforcement learning.

– JAX-MD: A framework for molecular dynamics.

– Jaxchem: A niche library emphasizing chemical modeling using JAX.

Choosing the right framework can be challenging, but starting with popular options like Haiku and Flax, which are widely used and have active communities, can be a good starting point. In this article, we will begin with Haiku and see if we need another framework later on.

Are you ready to build a Transformer with JAX and Haiku? Please make sure you have a solid understanding of Transformers before diving in.

Let’s start with the self-attention block.

The self-attention block

First, we need to import JAX and Haiku:

import jax

import jax.numpy as jnp

import haiku as hk

import numpy as np

Luckily, Haiku provides a built-in MultiHeadAttention block that can be extended to create a masked self-attention block. Our block accepts query, key, value, and mask as input and returns the output as a JAX array. The code resembles standard PyTorch or TensorFlow code, where we build a causal mask using np.trill(), nullify all elements of the array above the kth, multiply it with our mask, and pass everything into the hk.MultiHeadAttention module.

class SelfAttention(hk.MultiHeadAttention):

 ”””Self attention with a causal mask applied.”””

 def __call__(

  self,

  query: jnp.ndarray,

  key: Optional[jnp.ndarray] = None,

  value: Optional[jnp.ndarray] = None,

  mask: Optional[jnp.ndarray] = None,

 ) -> jnp.ndarray:

  key = key if key is not None else query

  value = value if value is not None else query

  seq_len = query.shape[1]

  causal_mask = np.tril(np.ones((seq_len, seq_len)))

  mask = mask * causal_mask if mask is not None else causal_mask

  return super().__call__(query, key, value, mask)

This snippet introduces the key principle of Haiku, where all modules should be a subclass of hk.Module. This ensures that they implement __init__ and __call, similar to PyTorch modules where we implement __init__ and a forward function.

To illustrate this clearly, let’s build a simple 2-layer Multilayer Perceptron as an hk.Module, which we will use in the Transformer later.

The linear layer

A simple 2-layer MLP is structured as follows:

class DenseBlock(hk.Module):

 ”””A 2-layer MLP”””

 def __init__(self,

  init_scale: float,

  widening_factor: int = 4,

  name: Optional[str] = None):

  super().__init__(name=name)

  self._init_scale = init_scale

  self._widening_factor = widening_factor

 def __call__(self, x: jnp.ndarray) -> jnp.ndarray:

  hiddens = x.shape[-1]

  initializer = hk.initializers.VarianceScaling(self._init_scale)

  x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)

  x = jax.nn.gelu(x)

  return hk.Linear(hiddens, w_init=initializer)(x)

Points to note:

– Haiku provides various weight initializers under hk.initializers.

– It includes popular layers like hk.Linear.

– Activation functions like relu or softmax are available in JAX’s jax.nn subpackage.

The normalization layer

Layer normalization is essential in the transformer architecture, and Haiku offers common modules that include it.

def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray:

 ”””Apply a unique LayerNorm to x with default settings.”””

 return hk.LayerNorm(axis=-1,

  create_scale=True,

  create_offset=True,

  name=name)(x)

The transformer

Now, let’s look at the Transformer model, which utilizes our predefined modules. In the `__init__` function, basic variables such as the number of layers, attention heads, and dropout rate are defined. The `__call__` function composes a list of blocks using a for loop.

Each block includes:

Finally, a final normalization layer is added at the end.

class Transformer(hk.Module):

 ”””A transformer stack.”””

 def __init__(self,

  num_heads: int,

  num_layers: int,

  dropout_rate: float,

  name: Optional[str] = None):

  super().__init__(name=name)

  self._num_layers = num_layers

  self._num_heads = num_heads

  self._dropout_rate = dropout_rate

 def __call__(self,

  h: jnp.ndarray,

  mask: Optional[jnp.ndarray],

  is_training: bool) -> jnp.ndarray:

  ”””Connects the transformer.”””

  init_scale = 2. / self._num_layers

  dropout_rate = self._dropout_rate if is_training else 0.

  if mask is not None:

   mask = mask[:, None, None, :]

  for i in range(self._num_layers):

   h_norm = layer_norm(h, name=f’h{i}_ln_1′)

   h_attn = SelfAttention(

    num_heads=self._num_heads,

    key_size=64,

    w_init_scale=init_scale,

    name=f’h{i}_attn’)(h_norm, mask=mask)

   h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)

   h = h + h_attn

   h_norm = layer_norm(h, name=f’h{i}_ln_2′)

   h_dense = DenseBlock(init_scale, name=f’h{i}_mlp’)(h_norm)

   h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)

   h = h + h_dense

  h = layer_norm(h, name=’ln_f’)

  return h

Building a Neural Network with JAX is relatively straightforward and powerful.

The embeddings layer

For completeness, let’s include the embeddings layer. Haiku offers an embedding layer that generates tokens from input sentences. These tokens are combined with positional embeddings to create the final input.

def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int):

 tokens = data[‘obs’]

 input_mask = jnp.greater(tokens, 0)

 seq_length = tokens.shape[1]

 embed_init = hk.initializers.TruncatedNormal(stddev=0.02)

 token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)

 token_embs = token_embedding_map(tokens)

 positional_embeddings = hk.get_parameter(

  ’pos_embs’, [seq_length, d_model], init=embed_init)

 input_embeddings = token_embs + positional_embeddings

 return input_embeddings, input_mask

The use of hk.get_parameter allows access to the trainable parameters of a module. The API serves the purpose of converting the code into a pure function using hk.transform.

Why pure functions?

JAX’s power lies in its function transformations – vectorization, parallelization, and just-in-time compilation. To transform a function, it must be pure, meaning:

  1. The function returns the same result for identical inputs.
  2. The function has no side effects.

Haiku simplifies this transformation with hk.transform, enabling the use of automatic differentiation, parallelization, and other features. Let’s continue with the training process of our Transformer model.

The forward pass

A typical forward pass involves:

  1. Computing input embedding from input data.
  2. Running data through Transformer blocks.
  3. Returning the output.

These steps can be easily assembled using JAX as shown below:

def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,

  num_layers: int, dropout_rate: float):

 ”””Create the model’s forward pass.”””

 def forward_fn(data: Mapping[str, jnp.ndarray],

  is_training: bool = True) -> jnp.ndarray:

  ”””Forward pass.”””

  input_embeddings, input_mask = embeddings(data, vocab_size)

  transformer = Transformer(

   num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)

  output_embeddings = transformer(input_embeddings, input_mask, is_training)

  return hk.Linear(vocab_size)(output_embeddings)

 return forward_fn

The structured code facilitates forward pass implementation by integrating input embedding, transformer blocks, and output in a coherent manner.

The loss function

The loss function, a cross-entropy function accounting for the mask, utilizes features like one_hot and log_softmax provided by JAX.

def lm_loss_fn(forward_fn,

 vocab_size: int,

 params,

 rng,

 data: Mapping[str, jnp.ndarray],

 is_training: bool = True) -> jnp.ndarray:

 ”””Compute the loss on data with respect to parameters.”””

 logits = forward_fn(params, rng, data, is_training)

 targets = jax.nn.one_hot(data[‘target’], vocab_size)

 assert logits.shape == targets.shape

 mask = jnp.greater(data[‘obs’], 0)

 loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)

 loss = jnp.sum(loss * mask) / jnp.sum(mask)

 return loss

Implementing the loss function encapsulates the crucial step of calculating the model’s loss against the parameters with due consideration for the mask.

The training loop

As Jax and Haiku lack built-in optimization functionalities, Optax comes into play for gradient processing. Optax’s key component is the GradientTransformation, defined by two functions – __init__ and __update__, for state initialization and gradient transformation, respectively. The process also involves using Python’s functools.partial for creating new functions with fewer arguments.

The GradientUpdater class accepts the model, loss function, and optimizer. The model is a pure forward_fn function transformed by hk.transform. The loss function is created using partial with fixed forward_fn and vocab_size parameters. The optimizer is a set of optimization transformations.

A snippet of the GradientUpdater:

class GradientUpdater:

 ”””A stateless abstraction around an init_fn/update_fn pair.”””

 def __init__(self, net_init, loss_fn,

  optimizer: optax.GradientTransformation):

  self._net_init = net_init

  self._loss_fn = loss_fn

  self._opt = optimizer

 @functools.partial(jax.jit, static_argnums=0)

 def init(self, master_rng, data):

  out_rng, init_rng = jax.random.split(master_rng)

  params = self._net_init(init_rng, data)

  opt_state = self._opt.init(params)

  out = dict(

   step=np.array(0),

   rng=out_rng,

   opt_state=opt_state,

   params=params,

  )

  return out

 @functools.partial(jax.jit, static_argnums=0)

 def update(self, state: Mapping[str, Any], data: Mapping[str, jnp.ndarray]):

  rng, new_rng = jax.random.split(state[‘rng’])

  params = state[‘params’]

  loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)

  updates, opt_state = self._opt.update(g, state[‘opt_state’])

  params = optax.apply_updates(params, updates)

  new_state = {

   ’step’: state[‘step’] + 1,

   ’rng’: new_rng,

   ’opt_state’: opt_state,

   ’params’: params,

  }

  metrics = {

   ’step’: state[‘step’],

   ’loss’: loss,

  }

  return new_state, metrics

The updater initializes the optimizer state and provides loss updates while enhancing readability and efficiency with JAX’s functionalities.

Finally, the training loop combines all the ideas and codes discussed thus far in the Transformer model training process.

Latest articles

Related articles