Self-Supervised Learning on CIFAR Images with PyTorch: A Step-By-Step Tutorial

After presenting SimCLR, a contrastive self-supervised learning framework, I decided to demonstrate another infamous method, called BYOL. Bootstrap Your Own Latent (BYOL), is a new algorithm for self-supervised learning of image representations. BYOL has two main advantages:

It does not explicitly use negative samples. Instead, it directly minimizes the similarity of representations of the same image under a different augmented view (positive pair). Negative samples are images from the batch other than the positive pair.

As a result, BYOL is claimed to require smaller batch sizes, which makes it an attractive choice.

Below, you can examine the method. Unlike the original paper, I call the online network student and the target network teacher.

byol-overview

Overview of BYOL method. Source: BYOL paper

Online network aka student: compared to SimCLR, there is a second MLP, called predictor, which makes the whole method asymmetric. Asymmetric compared to what? Well, to the teacher model (target network).

Why is that important?

Because the teacher model is updated only through exponential moving average (EMA) from the student’s parameters. Ultimately, at each iteration, a tiny percentage (less than 1%) of the parameters of the student is passed to the teacher. Thus, gradients flow only through the student network. This can be implemented as:

class EMA():

def __init__(self, alpha):

super().__init__()

self.alpha = alpha

def update_average(self, old, new):

if old is None:

return new

return old * self.alpha + (1 – self.alpha) * new

ema = EMA(0.99)

for student_params, teacher_params in zip(student_model.parameters(),teacher_model.parameters()):

old_weight, up_weight = teacher_params.data, student_params.data

teacher_params.data = ema.update_average(old_weight, up_weight)

Another key difference between Simclr and BYOL is the loss function.

Loss function

The predictor MLP is only applied to the student, making the architecture asymmetric. This is a key design choice to avoid mode collapse. Mode collapse here would be to output the same projection for all the inputs.

byol-paper-overview-with-tensors

Overview of BYOL method. Source: BYOL paper

Finally, the authors defined the following mean squared error between the L2-normalized predictions and target projections:

The L2 loss can be implemented as follows. L2 normalization is applied beforehand.

import torch

import torch.nn.functional as F

def loss_fn(x, y):

x = F.normalize(x, dim=-1, p=2)

y = F.normalize(y, dim=-1, p=2)

return 2 – 2 * (x * y).sum(dim=-1)

Code is available on GitHub

Tracking down what’s happening in self-supervised pretraining: KNN accuracy

Nonetheless, the loss in self-supervised learning is not a reliable metric to track. What I found out to be the best way to track what’s happening while training, is to measure the ΚΝΝ accuracy.

The critical advantage of using KNN is that we don’t have to train a linear classifier on top each time, so it’s faster and completely unsupervised.

Note: Measuring KNN only applies to image classification, but you get the idea. For this purpose, I made a class to encapsulate the logic of KNN in our context:

import numpy as np

import torch

from sklearn.model_selection import cross_val_score

from sklearn.neighbors import KNeighborsClassifier

from torch import nn

class KNN():

def __init__(self, model, k, device):

super(KNN, self).__init__() self.k = k

self.device = device

self.model = model.to(device)

self.model.eval()

def extract_features(self, loader):

“””

Infer/Extract features from a trained model

Args:

loader: train or test loader

Returns: 3 tensors of all: input_images, features, labels

“””

x_lst = []

features = []

label_lst = []

with torch.no_grad():

for input_tensor, label in loader:

h = self.model(input_tensor.to(self.device))

features.append(h)

x_lst.append(input_tensor)

label_lst.append(label)

x_total = torch.stack(x_lst)

h_total = torch.stack(features)

label_total = torch.stack(label_lst)

return x_total, h_total, label_total

def knn(self, features, labels, k=1):

“””

Evaluating knn accuracy in feature space.

Calculates only top-1 accuracy (returns 0 for top-5)

Args:

features: [… , dataset_size, feat_dim]

labels: [… , dataset_size]

k: nearest neighbours

Returns: train accuracy, or train and test acc

“””

feature_dim = features.shape[-1]

with torch.no_grad():

features_np = features.cpu().view(-1, feature_dim).numpy()

labels_np = labels.cpu().view(-1).numpy()

self.cls = KNeighborsClassifier(k, metric=”cosine”).fit(features_np, labels_np)

acc = self.eval(features, labels)

return acc

def eval(self, features, labels):

feature_dim = features.shape[-1]

features = features.cpu().view(-1, feature_dim).numpy()

labels = labels.cpu().view(-1).numpy()

acc = 100 * np.mean(cross_val_score(self.cls, features, labels))

return acc

def _find_best_indices(self, h_query, h_ref):

h_query = h_query / h_query.norm(dim=1).view(-1, 1)

h_ref = h_ref / h_ref.norm(dim=1).view(-1, 1)

scores = torch.matmul(h_query, h_ref.t())

score, indices = scores.topk(1, dim=1)

return score, indices

def fit(self, train_loader, test_loader=None):

with torch.no_grad():

x_train, h_train, l_train = self.extract_features(train_loader)

train_acc = self.knn(h_train, l_train, k=self.k)

if test_loader is not None:

x_test, h_test, l_test = self.extract_features(test_loader)

test_acc = self.eval(h_test, l_test)

return train_acc, test_acc

Now we can focus on the method and BYOL model.

Modify resnet: add MLP projection heads

We will start with a base model (resnet18) and modify it for self-supervised learning. The last layer that normally does the classification is replaced with an identity function. The output features of resnet18 will be fed to the MLP projector.

import copy

import torch

from torch import nn

import torch.nn.functional as F

class MLP(nn.Module)

def __init__(self, dim, embedding_size=256, hidden_size=2048, batch_norm_mlp=False):

super().__init__()

norm = nn.BatchNorm1d(hidden_size) if batch_norm_mlp else nn.Identity()

self.net = nn.Sequential(

nn.Linear(dim, hidden_size),

norm,

nn.ReLU(inplace=True),

nn.Linear(hidden_size, embedding_size)

)

def forward(self, x):

return self.net(x)

class AddProjHead(nn.Module):

def __init__(self, model, in_features, layer_name, hidden_size=4096,

embedding_size=256, batch_norm_mlp=True):

super(AddProjHead, self).__init__()

self.backbone = model

setattr(self.backbone, layer_name, nn.Identity())

self.backbone.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

self.backbone.maxpool = torch.nn.Identity()

self.projection = MLP(in_features, embedding_size, hidden_size=hidden_size, batch_norm_mlp=batch_norm_mlp)

def forward(self, x, return_embedding=False):

embedding = self.backbone(x)

if return_embedding:

return embedding

return self.projection(embedding)

I also replaced the first conv layer of resnet18 from 7×7 to 3×3 convolution since we are playing with 32×32 images (CIFAR-10).

Code is available on GitHub. If you are planning to solidify your Pytorch knowledge, there are two amazing books that we highly recommend: Deep learning with PyTorch from Manning Publications and Machine Learning with PyTorch and Scikit-Learn by Sebastian Raschka. You can always use the 35% discount code blaisummer21 for all Manning’s products.

The actual BYOL method

knn-byol-training

Deep Learning in Production Book 📖

Learn how to build, train, deploy, scale and maintain deep learning models. Understand ML infrastructure and MLOps using hands-on examples. Learn more
* Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.

Latest articles

Related articles