If you are reading transformer papers, you may have noticed Positional Embeddings (PE). They may seem reasonable. However, when you try to implement them, it becomes really confusing!
The answer is simple: if you want to implement transformer-related papers, it is very important to get a good grasp of positional embeddings.
It turns out that sinusoidal positional encodings are not enough for computer vision problems. Images are highly structured and we want to incorporate some strong sense of position (order) inside the multi-head self-attention (MHSA) block.
To this end, I will introduce some theory as well as my re-implementation of positional embeddings.
The code contains einsum operations. Read my past article if you are not comfortable with it. The code is also available.
Positional encodings vs positional embeddings
In the vanilla transformer, positional encodings are added before the first MHSA block model. Let’s start by clarifying this: positional embeddings are not related to the sinusoidal positional encodings. It’s highly similar to word or patch embeddings, but here we embed the position.
Each position of the sequence will be mapped to a trainable vector of size dimdimdim.
Moreover, positional embeddings are trainable as opposed to encodings that are fixed.
Here is a rough illustration of how this works:
“`
pos_emb1D = torch.nn.Parameter(torch.randn(max_seq_tokens, dim))
input_to_transformer_mhsa = input_embedding + pos_emb1D[:current_seq_tokens, :]
out = transformer(input_to_transformer_mhsa)
“`
By now you are probably wondering what PE learn. Me too!
A fully-connected graph with four vertices and sixteen directed bonds. Image from Gregory Berkolaiko. Source: ResearchGate.
You can think of each attention weight ϵij\epsilon_{ij}ϵij as an arrow.
ϵij=xiWQ(xjWK)Td\epsilon_{ij} =\frac{x_i W^Q(x_jW^K)^T}{\sqrt{d}}
The index iii will indicate the query and the index jjj the key and the value.
Source: Ramachandran et al. Stand-Alone Self-Attention in Vision Models. Each individual output element comes a single query element indexed by iii. The query element qiq_iqi will be associated to all the elements of the input sequences, indeed by jjj.
PE aim to inject some positional information in this computation. So we consider the positions pij∈Rdp_{ij} \in R^dpij∈Rd of the Keys with respect to the query element.
ϵij=xiWQ(xjWK)T+xiWQ(pijK)Td\epsilon_{ij} =\frac{x_i W^Q(x_jW^K)^T + x_i W^Q(p_{ij}^K)^T}{\sqrt{d}}
The added xiWQ(pijK)Tx_i W^Q(p_{ij}^K)^TxiWQ(pijK)T term represents the distance of the query element to a particular sequence position.
A great thing with PE is that we can have shared representations across heads, introducing minimal overhead. For a sequence of length nnn and hhh attention heads with head dimension ddd, this reduces the space complexity from O(hn2d)O(h n^2 d )O(hn2d) to O(n2d)O(n^2 d)O(n2d.
Let’s further divide Positional Embeddings (PE) into two categories.
Absolute VS relative positional embeddings
It is often the case that additional positional info is added to the query (Q) representation in the MSHA block. There are two main approaches here:
Absolute positions: every input token at position iii will be associated with a trainable embedding vector that will indicate the row of the matrix RRR with shape [tokens, dim]. RRR is a trainable matrix, initialized in N(0,1)N(0,1)N(0,1). It will slightly alter the representation based on the position.
att=softmax(1dim(QKT+QR))
att = softmax(\frac{1}{\sqrt{dim}} (Q K^T + Q R))
The above code does nothing more than what we have already illustrated in the diagram. Implementation of Relative PE Since we have solved the difficult issue from converting relative to absolute embeddings, relative PE is not more difficult than the absolute PE.
import torch
import torch.nn as nn
from einops import rearrange
def rel_pos_emb_1d(q, rel_emb, shared_heads):
“”” Same functionality as RelPosEmb1D
Args:
q: a 4d tensor of shape [batch, heads, tokens, dim]
rel_emb: a 2D or 3D tensor
of shape [ 2*tokens-1 , dim] or [ heads, 2*tokens-1 , dim]
“””
if shared_heads:
emb = torch.einsum(‘b h t d, r d -> b h t r’, q, rel_emb)
else:
emb = torch.einsum(‘b h t d, h r d -> b h t r’, q, rel_emb)
return relative_to_absolute(emb)
class RelPosEmb1DAISummer(nn.Module):
def __init__(self, tokens, dim_head, heads=None):
“””
Output: [batch head tokens tokens]
Args:
tokens: the number of the tokens of the seq
dim_head: the size of the last dimension of q
heads: if None representation is shared across heads.
else the number of heads must be provided
“””
super().__init__()
scale = dim_head ** -0.5
self.shared_heads = heads if heads is not None else True
if self.shared_heads:
self.rel_pos_emb = nn.Parameter(torch.randn(2 * tokens – 1, dim_head) * scale)
else:
self.rel_pos_emb = nn.Parameter(torch.randn(heads, 2 * tokens – 1, dim_head) * scale)
def forward(self, q):
return rel_pos_emb_1d(q, self.rel_pos_emb, self.shared_heads)
I am just adding the relative_to_absolute in the function. It is interesting to see how we can extend it to 2D grids.
Two-dimensional Relative PE
The paper “Stand-Alone Self-Attention in Vision Models” extended the idea to 2D relative PE. Relative attention starts by defining the relative distance of two tokens. However, this time the tokens are pixels that correspond to rows hhh and columns www of an image: tokens=h∗wtokens = h*wtokens=h∗w
Thus, it would make more sense to factorize (decompose) the tokens across dimensions hhh and www, so each token receives two independent distances: a row offset and a column offset. The following picture demonstrates this perfectly.
This image depicts an example of relative distances in a 2D grid. Notice that the relative distances are computed based on the yellow-highlighted pixel. Red indicates the row offset, while blue indicates the column offset.
Even though the MHSA will work on a sequence of pixels=tokens, we will provide each pixel with 2 relative distances from the 2D grid.
The above image is taken from Prajit Ramachandran et al. (2019)
I hope this helps you understand the concept of positional embeddings and how they can be implemented in different scenarios. Feel free to explore further references for a deeper dive into the topic.