Skip to main content
Self-Supervised Learning

Self-Supervised Learning: Learning Without Labels

The Promise of Self-Supervision

Labeled data is expensive. Unlabeled data is abundant. Self-supervised learning bridges this gap by creating pretext tasks from unlabeled data.
Supervised Learning:    Image → Human Label → Representation
Self-Supervised:        Image → Generated Task → Representation
Key Insight: Design tasks where the labels are FREE and force the model to learn useful representations.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import numpy as np
from typing import List, Tuple, Optional

torch.manual_seed(42)

Pretext Tasks

Classic Pretext Tasks

class PretextTasks:
    """Classic self-supervised pretext tasks."""
    
    @staticmethod
    def rotation_prediction(image: torch.Tensor) -> Tuple[torch.Tensor, int]:
        """
        Rotate image and predict rotation angle.
        Forces model to understand object orientation.
        """
        rotation = np.random.choice([0, 1, 2, 3])  # 0°, 90°, 180°, 270°
        rotated = torch.rot90(image, rotation, dims=[-2, -1])
        return rotated, rotation
    
    @staticmethod
    def jigsaw_puzzle(image: torch.Tensor, grid_size: int = 3) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Shuffle image patches and predict correct order.
        Forces model to understand spatial relationships.
        """
        _, h, w = image.shape
        patch_h, patch_w = h // grid_size, w // grid_size
        
        # Extract patches
        patches = []
        for i in range(grid_size):
            for j in range(grid_size):
                patch = image[:, i*patch_h:(i+1)*patch_h, j*patch_w:(j+1)*patch_w]
                patches.append(patch)
        
        # Shuffle patches
        n_patches = grid_size * grid_size
        permutation = torch.randperm(n_patches)
        shuffled_patches = [patches[i] for i in permutation]
        
        # Reconstruct shuffled image
        rows = []
        for i in range(grid_size):
            row_patches = shuffled_patches[i*grid_size:(i+1)*grid_size]
            row = torch.cat(row_patches, dim=2)  # Concat horizontally
            rows.append(row)
        shuffled_image = torch.cat(rows, dim=1)  # Concat vertically
        
        return shuffled_image, permutation
    
    @staticmethod
    def colorization(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Convert to grayscale and predict colors.
        Forces model to understand semantic content.
        """
        # Assume image is RGB [3, H, W]
        grayscale = 0.299 * image[0] + 0.587 * image[1] + 0.114 * image[2]
        grayscale = grayscale.unsqueeze(0).repeat(3, 1, 1)
        
        return grayscale, image
    
    @staticmethod
    def inpainting(image: torch.Tensor, mask_ratio: float = 0.3) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Mask parts of image and predict masked content.
        Forces model to understand context.
        """
        _, h, w = image.shape
        
        # Create random mask
        mask = torch.rand(1, h, w) > mask_ratio
        
        # Apply mask
        masked_image = image * mask.float()
        
        return masked_image, image, mask


# Rotation prediction model
class RotationNet(nn.Module):
    """Predict rotation angle of image."""
    
    def __init__(self, backbone: nn.Module):
        super().__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(backbone.feature_dim, 4)  # 4 rotation angles
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)


# Example backbone
class SimpleBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_dim = 512
        
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        
        self.fc = nn.Linear(256, self.feature_dim)
    
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

Contrastive Learning

The Contrastive Learning Framework

Goal: Pull together representations of similar samples (positives), push apart representations of dissimilar samples (negatives).

SimCLR (Simple Contrastive Learning)

class SimCLRTransform:
    """Data augmentation for SimCLR."""
    
    def __init__(self, size: int = 224):
        self.transform = T.Compose([
            T.RandomResizedCrop(size, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(),
            T.RandomApply([T.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([T.GaussianBlur(kernel_size=23)], p=0.5),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __call__(self, x):
        """Generate two augmented views of the same image."""
        return self.transform(x), self.transform(x)


class SimCLR(nn.Module):
    """SimCLR: A Simple Framework for Contrastive Learning."""
    
    def __init__(
        self,
        backbone: nn.Module,
        projection_dim: int = 128,
        hidden_dim: int = 2048,
        temperature: float = 0.5
    ):
        super().__init__()
        
        self.backbone = backbone
        self.temperature = temperature
        
        # Projection head (MLP)
        feature_dim = backbone.feature_dim
        self.projector = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, projection_dim)
        )
    
    def forward(self, x1: torch.Tensor, x2: torch.Tensor):
        """
        Args:
            x1: First augmented view [batch_size, C, H, W]
            x2: Second augmented view [batch_size, C, H, W]
        
        Returns:
            loss: NT-Xent contrastive loss
        """
        batch_size = x1.size(0)
        
        # Extract features and project
        z1 = self.projector(self.backbone(x1))  # [batch, projection_dim]
        z2 = self.projector(self.backbone(x2))
        
        # Normalize projections
        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)
        
        # Compute similarity matrix
        representations = torch.cat([z1, z2], dim=0)  # [2*batch, projection_dim]
        similarity_matrix = torch.mm(representations, representations.t())
        
        # Create labels: positive pairs are (i, i+batch_size)
        labels = torch.arange(batch_size, device=x1.device)
        labels = torch.cat([labels + batch_size, labels], dim=0)
        
        # Mask out self-similarity
        mask = torch.eye(2 * batch_size, device=x1.device).bool()
        similarity_matrix.masked_fill_(mask, float('-inf'))
        
        # Scale by temperature
        similarity_matrix = similarity_matrix / self.temperature
        
        # NT-Xent loss (InfoNCE)
        loss = F.cross_entropy(similarity_matrix, labels)
        
        return loss
    
    def get_representations(self, x: torch.Tensor) -> torch.Tensor:
        """Get learned representations for downstream tasks."""
        return self.backbone(x)


# Train SimCLR
def train_simclr(model, dataloader, optimizer, epochs=100):
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        
        for images, _ in dataloader:  # We don't need labels!
            # Get two augmented views
            x1, x2 = SimCLRTransform()(images)
            x1, x2 = x1.cuda(), x2.cuda()
            
            optimizer.zero_grad()
            loss = model(x1, x2)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}: Loss = {total_loss/len(dataloader):.4f}")

MoCo (Momentum Contrast)

class MoCo(nn.Module):
    """
    Momentum Contrast (MoCo v2).
    Uses a momentum encoder and a large dictionary of negative samples.
    """
    
    def __init__(
        self,
        backbone: nn.Module,
        projection_dim: int = 128,
        queue_size: int = 65536,
        momentum: float = 0.999,
        temperature: float = 0.07
    ):
        super().__init__()
        
        self.queue_size = queue_size
        self.momentum = momentum
        self.temperature = temperature
        
        # Query encoder
        self.encoder_q = backbone
        self.projector_q = nn.Sequential(
            nn.Linear(backbone.feature_dim, 2048),
            nn.ReLU(),
            nn.Linear(2048, projection_dim)
        )
        
        # Key encoder (momentum updated)
        self.encoder_k = self._copy_encoder(backbone)
        self.projector_k = nn.Sequential(
            nn.Linear(backbone.feature_dim, 2048),
            nn.ReLU(),
            nn.Linear(2048, projection_dim)
        )
        self._copy_params(self.projector_q, self.projector_k)
        
        # Freeze key encoder
        for param in self.encoder_k.parameters():
            param.requires_grad = False
        for param in self.projector_k.parameters():
            param.requires_grad = False
        
        # Queue for negative samples
        self.register_buffer("queue", F.normalize(torch.randn(projection_dim, queue_size), dim=0))
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
    
    def _copy_encoder(self, encoder):
        import copy
        return copy.deepcopy(encoder)
    
    def _copy_params(self, src, dst):
        for param_src, param_dst in zip(src.parameters(), dst.parameters()):
            param_dst.data.copy_(param_src.data)
    
    @torch.no_grad()
    def _momentum_update(self):
        """Update key encoder with momentum."""
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = self.momentum * param_k.data + (1 - self.momentum) * param_q.data
        for param_q, param_k in zip(self.projector_q.parameters(), self.projector_k.parameters()):
            param_k.data = self.momentum * param_k.data + (1 - self.momentum) * param_q.data
    
    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        """Update the queue with new key embeddings."""
        batch_size = keys.size(0)
        ptr = int(self.queue_ptr)
        
        # Replace oldest entries
        if ptr + batch_size > self.queue_size:
            self.queue[:, ptr:] = keys[:self.queue_size - ptr].T
            self.queue[:, :batch_size - (self.queue_size - ptr)] = keys[self.queue_size - ptr:].T
        else:
            self.queue[:, ptr:ptr + batch_size] = keys.T
        
        ptr = (ptr + batch_size) % self.queue_size
        self.queue_ptr[0] = ptr
    
    def forward(self, x_q, x_k):
        """
        Args:
            x_q: Query images
            x_k: Key images (different augmentation of same images)
        """
        batch_size = x_q.size(0)
        
        # Compute query features
        q = self.projector_q(self.encoder_q(x_q))
        q = F.normalize(q, dim=1)
        
        # Compute key features (no gradient)
        with torch.no_grad():
            self._momentum_update()
            k = self.projector_k(self.encoder_k(x_k))
            k = F.normalize(k, dim=1)
        
        # Positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        
        # Negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
        
        # Logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1) / self.temperature
        
        # Labels: positive is at index 0
        labels = torch.zeros(batch_size, dtype=torch.long, device=x_q.device)
        
        # Update queue
        self._dequeue_and_enqueue(k)
        
        return F.cross_entropy(logits, labels)

Non-Contrastive Methods

BYOL (Bootstrap Your Own Latent)

BYOL proves you don’t need negative samples at all!
class BYOL(nn.Module):
    """
    Bootstrap Your Own Latent.
    No negative samples needed - learns by predicting target network.
    """
    
    def __init__(
        self,
        backbone: nn.Module,
        projection_dim: int = 256,
        hidden_dim: int = 4096,
        moving_average_decay: float = 0.99
    ):
        super().__init__()
        
        self.moving_average_decay = moving_average_decay
        
        # Online network
        self.online_encoder = backbone
        self.online_projector = self._make_projector(backbone.feature_dim, hidden_dim, projection_dim)
        self.predictor = self._make_projector(projection_dim, hidden_dim, projection_dim)
        
        # Target network (momentum updated)
        self.target_encoder = self._copy_encoder(backbone)
        self.target_projector = self._make_projector(backbone.feature_dim, hidden_dim, projection_dim)
        
        # Freeze target network
        for param in self.target_encoder.parameters():
            param.requires_grad = False
        for param in self.target_projector.parameters():
            param.requires_grad = False
    
    def _make_projector(self, in_dim, hidden_dim, out_dim):
        return nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    
    def _copy_encoder(self, encoder):
        import copy
        return copy.deepcopy(encoder)
    
    @torch.no_grad()
    def update_target_network(self):
        """EMA update of target network."""
        tau = self.moving_average_decay
        
        for online, target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            target.data = tau * target.data + (1 - tau) * online.data
        
        for online, target in zip(self.online_projector.parameters(), self.target_projector.parameters()):
            target.data = tau * target.data + (1 - tau) * online.data
    
    def forward(self, x1, x2):
        """
        Args:
            x1, x2: Two augmented views of the same images
        """
        # Online network forward
        online_feat1 = self.online_encoder(x1)
        online_proj1 = self.online_projector(online_feat1)
        online_pred1 = self.predictor(online_proj1)
        
        online_feat2 = self.online_encoder(x2)
        online_proj2 = self.online_projector(online_feat2)
        online_pred2 = self.predictor(online_proj2)
        
        # Target network forward (no gradient)
        with torch.no_grad():
            target_proj1 = self.target_projector(self.target_encoder(x1))
            target_proj2 = self.target_projector(self.target_encoder(x2))
            
            # Stop gradient
            target_proj1 = target_proj1.detach()
            target_proj2 = target_proj2.detach()
        
        # Compute loss: predict one view from the other
        loss1 = self._loss_fn(online_pred1, target_proj2)
        loss2 = self._loss_fn(online_pred2, target_proj1)
        
        return (loss1 + loss2) / 2
    
    def _loss_fn(self, x, y):
        """Normalized MSE loss."""
        x = F.normalize(x, dim=-1)
        y = F.normalize(y, dim=-1)
        return 2 - 2 * (x * y).sum(dim=-1).mean()

SimSiam (Simple Siamese)

class SimSiam(nn.Module):
    """
    Simple Siamese networks.
    Even simpler than BYOL - no momentum encoder, just stop-gradient.
    """
    
    def __init__(
        self,
        backbone: nn.Module,
        projection_dim: int = 2048,
        prediction_dim: int = 512
    ):
        super().__init__()
        
        self.encoder = backbone
        
        # Projection MLP
        self.projector = nn.Sequential(
            nn.Linear(backbone.feature_dim, projection_dim),
            nn.BatchNorm1d(projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim),
            nn.BatchNorm1d(projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim),
            nn.BatchNorm1d(projection_dim)
        )
        
        # Prediction MLP
        self.predictor = nn.Sequential(
            nn.Linear(projection_dim, prediction_dim),
            nn.BatchNorm1d(prediction_dim),
            nn.ReLU(),
            nn.Linear(prediction_dim, projection_dim)
        )
    
    def forward(self, x1, x2):
        """Forward pass with two views."""
        
        # Compute projections
        z1 = self.projector(self.encoder(x1))
        z2 = self.projector(self.encoder(x2))
        
        # Compute predictions
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)
        
        # Compute loss with stop-gradient
        loss1 = self._negative_cosine_similarity(p1, z2.detach())
        loss2 = self._negative_cosine_similarity(p2, z1.detach())
        
        return (loss1 + loss2) / 2
    
    def _negative_cosine_similarity(self, p, z):
        """Negative cosine similarity."""
        p = F.normalize(p, dim=1)
        z = F.normalize(z, dim=1)
        return -(p * z).sum(dim=1).mean()

Masked Modeling

Masked Autoencoders (MAE)

class MaskedAutoencoder(nn.Module):
    """
    Masked Autoencoder for Vision (MAE).
    Mask random patches and reconstruct them.
    """
    
    def __init__(
        self,
        image_size: int = 224,
        patch_size: int = 16,
        embed_dim: int = 768,
        encoder_depth: int = 12,
        encoder_heads: int = 12,
        decoder_embed_dim: int = 512,
        decoder_depth: int = 8,
        decoder_heads: int = 16,
        mask_ratio: float = 0.75
    ):
        super().__init__()
        
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio
        num_patches = (image_size // patch_size) ** 2
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim) * 0.02)
        
        # Encoder (processes visible patches only)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=encoder_heads,
                dim_feedforward=embed_dim * 4,
                batch_first=True
            ),
            num_layers=encoder_depth
        )
        
        # Decoder
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed = nn.Parameter(torch.randn(1, num_patches, decoder_embed_dim) * 0.02)
        
        self.decoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=decoder_embed_dim,
                nhead=decoder_heads,
                dim_feedforward=decoder_embed_dim * 4,
                batch_first=True
            ),
            num_layers=decoder_depth
        )
        
        # Prediction head (reconstruct pixels)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * 3)
    
    def random_masking(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Randomly mask patches.
        
        Args:
            x: [batch, num_patches, embed_dim]
        
        Returns:
            x_masked: visible patches
            mask: binary mask (1 = masked)
            ids_restore: indices to restore original order
        """
        N, L, D = x.shape
        len_keep = int(L * (1 - self.mask_ratio))
        
        # Random noise for shuffling
        noise = torch.rand(N, L, device=x.device)
        
        # Sort noise to get shuffle indices
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        
        # Keep first len_keep patches (after shuffling)
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, D))
        
        # Generate binary mask: 0 = keep, 1 = masked
        mask = torch.ones(N, L, device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)
        
        return x_masked, mask, ids_restore
    
    def forward(self, images: torch.Tensor):
        """
        Args:
            images: [batch, 3, H, W]
        
        Returns:
            loss: Reconstruction loss on masked patches
            pred: Predicted pixel values
            mask: Binary mask
        """
        # Patchify
        patches = self.patch_embed(images)  # [B, embed_dim, H/P, W/P]
        patches = patches.flatten(2).transpose(1, 2)  # [B, num_patches, embed_dim]
        
        # Add positional embedding
        patches = patches + self.pos_embed
        
        # Random masking
        patches_masked, mask, ids_restore = self.random_masking(patches)
        
        # Encode visible patches
        latent = self.encoder(patches_masked)
        
        # Project to decoder dimension
        latent = self.decoder_embed(latent)
        
        # Append mask tokens
        N, _, D = latent.shape
        num_patches = mask.shape[1]
        mask_tokens = self.mask_token.expand(N, num_patches - latent.shape[1], -1)
        
        # Unshuffle: put mask tokens in correct positions
        x_ = torch.cat([latent, mask_tokens], dim=1)
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).expand(-1, -1, D))
        
        # Add decoder positional embedding
        x_ = x_ + self.decoder_pos_embed
        
        # Decode
        x_ = self.decoder(x_)
        
        # Predict pixels
        pred = self.decoder_pred(x_)  # [B, num_patches, patch_size^2 * 3]
        
        # Compute loss only on masked patches
        target = self._patchify(images)
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # Mean over patch pixels
        loss = (loss * mask).sum() / mask.sum()  # Mean over masked patches
        
        return loss, pred, mask
    
    def _patchify(self, images: torch.Tensor) -> torch.Tensor:
        """Convert images to patches."""
        p = self.patch_size
        B, C, H, W = images.shape
        h, w = H // p, W // p
        
        x = images.reshape(B, C, h, p, w, p)
        x = x.permute(0, 2, 4, 3, 5, 1)  # [B, h, w, p, p, C]
        x = x.reshape(B, h * w, p * p * C)
        
        return x


# Test MAE
mae = MaskedAutoencoder(image_size=224, patch_size=16, embed_dim=768)
images = torch.randn(4, 3, 224, 224)
loss, pred, mask = mae(images)
print(f"MAE Loss: {loss.item():.4f}")
print(f"Mask ratio: {mask.float().mean().item():.2f}")

BERT-style Masked Language Modeling

class MaskedLanguageModel(nn.Module):
    """BERT-style masked language modeling for text."""
    
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int = 768,
        num_layers: int = 12,
        num_heads: int = 12,
        max_seq_len: int = 512,
        mask_token_id: int = 103,  # [MASK] token
        mask_ratio: float = 0.15
    ):
        super().__init__()
        
        self.mask_token_id = mask_token_id
        self.mask_ratio = mask_ratio
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Embedding(max_seq_len, embed_dim)
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads,
                dim_feedforward=embed_dim * 4,
                batch_first=True
            ),
            num_layers=num_layers
        )
        
        self.output = nn.Linear(embed_dim, vocab_size)
    
    def mask_tokens(self, input_ids: torch.Tensor):
        """Apply BERT-style masking."""
        labels = input_ids.clone()
        
        # Probability matrix for masking
        probability_matrix = torch.full(input_ids.shape, self.mask_ratio)
        
        # Create masked indices
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # Only compute loss on masked tokens
        
        # 80% -> [MASK], 10% -> random, 10% -> original
        indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
        input_ids[indices_replaced] = self.mask_token_id
        
        indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.embedding.weight), input_ids.shape)
        input_ids[indices_random] = random_words[indices_random]
        
        return input_ids, labels
    
    def forward(self, input_ids: torch.Tensor):
        # Apply masking
        masked_ids, labels = self.mask_tokens(input_ids.clone())
        
        # Get embeddings
        seq_len = masked_ids.size(1)
        positions = torch.arange(seq_len, device=masked_ids.device).unsqueeze(0)
        
        x = self.embedding(masked_ids) + self.pos_embedding(positions)
        
        # Transformer forward
        x = self.transformer(x)
        
        # Predict masked tokens
        logits = self.output(x)
        
        # Compute loss
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            ignore_index=-100
        )
        
        return loss, logits

Evaluation & Downstream Tasks

Linear Probing

class LinearProbe(nn.Module):
    """Linear evaluation of learned representations."""
    
    def __init__(self, encoder: nn.Module, num_classes: int):
        super().__init__()
        
        self.encoder = encoder
        
        # Freeze encoder
        for param in self.encoder.parameters():
            param.requires_grad = False
        
        self.classifier = nn.Linear(encoder.feature_dim, num_classes)
    
    def forward(self, x):
        with torch.no_grad():
            features = self.encoder(x)
        return self.classifier(features)


def linear_evaluation(pretrained_encoder, train_loader, val_loader, num_classes, epochs=100):
    """Evaluate representations with linear probing."""
    
    probe = LinearProbe(pretrained_encoder, num_classes).cuda()
    optimizer = torch.optim.Adam(probe.classifier.parameters(), lr=0.001)
    
    best_acc = 0
    
    for epoch in range(epochs):
        # Train
        probe.train()
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            
            optimizer.zero_grad()
            outputs = probe(images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
        
        # Evaluate
        probe.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.cuda(), labels.cuda()
                outputs = probe(images)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        acc = 100 * correct / total
        if acc > best_acc:
            best_acc = acc
        
        print(f"Epoch {epoch+1}: Accuracy = {acc:.2f}%")
    
    return best_acc

Fine-tuning

def finetune_evaluation(pretrained_encoder, train_loader, val_loader, num_classes, epochs=100):
    """Fine-tune entire model on downstream task."""
    
    model = nn.Sequential(
        pretrained_encoder,
        nn.Linear(pretrained_encoder.feature_dim, num_classes)
    ).cuda()
    
    # Different learning rates for encoder and classifier
    optimizer = torch.optim.AdamW([
        {'params': pretrained_encoder.parameters(), 'lr': 1e-4},
        {'params': model[-1].parameters(), 'lr': 1e-3}
    ])
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    
    for epoch in range(epochs):
        model.train()
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
        
        scheduler.step()
    
    # Final evaluation
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return 100 * correct / total

Practical Considerations

Data Augmentation for SSL

class SSLAugmentations:
    """Strong augmentations for self-supervised learning."""
    
    @staticmethod
    def simclr_augmentation(size=224):
        """SimCLR augmentation strategy."""
        return T.Compose([
            T.RandomResizedCrop(size, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomApply([
                T.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2)
            ], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([T.GaussianBlur(kernel_size=23)], p=0.5),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    @staticmethod
    def byol_augmentation(size=224):
        """BYOL augmentation (asymmetric)."""
        base = [
            T.RandomResizedCrop(size, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomApply([
                T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)
            ], p=0.8),
            T.RandomGrayscale(p=0.2),
        ]
        
        view1 = T.Compose(base + [
            T.RandomApply([T.GaussianBlur(kernel_size=23)], p=1.0),  # Always blur
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        view2 = T.Compose(base + [
            T.RandomApply([T.GaussianBlur(kernel_size=23)], p=0.1),  # Rarely blur
            T.RandomSolarize(threshold=128, p=0.2),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        return view1, view2

Training Tips

def ssl_training_tips():
    """Key tips for successful self-supervised training."""
    
    tips = """
    ╔════════════════════════════════════════════════════════════════╗
    ║           SELF-SUPERVISED LEARNING: BEST PRACTICES            ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                ║
    ║  1. BATCH SIZE                                                 ║
    ║     • Contrastive methods (SimCLR, MoCo): Large batches (4096+)║
    ║     • Non-contrastive (BYOL, SimSiam): Smaller batches OK      ║
    ║     • Use gradient accumulation if GPU memory limited          ║
    ║                                                                ║
    ║  2. LEARNING RATE                                              ║
    ║     • Base LR scales with batch size: lr = base_lr * batch/256 ║
    ║     • Use cosine schedule with warmup                          ║
    ║     • LARS optimizer for very large batches                    ║
    ║                                                                ║
    ║  3. AUGMENTATION                                               ║
    ║     • Stronger augmentation = better representations           ║
    ║     • Color jittering is crucial                               ║
    ║     • Multi-crop (DINO style) improves efficiency              ║
    ║                                                                ║
    ║  4. ARCHITECTURE                                               ║
    ║     • Projection head: 2-3 layer MLP with BN                   ║
    ║     • Predictor (BYOL/SimSiam): Crucial for preventing collapse║
    ║     • Larger models generally learn better representations     ║
    ║                                                                ║
    ║  5. TRAINING DURATION                                          ║
    ║     • SSL needs longer training than supervised (800+ epochs)  ║
    ║     • Early stopping based on downstream performance           ║
    ║                                                                ║
    ║  6. EVALUATION                                                 ║
    ║     • Linear probe: Freeze encoder, train linear classifier    ║
    ║     • k-NN: Nearest neighbor classification (no training)      ║
    ║     • Fine-tuning: Update entire model on downstream task      ║
    ║                                                                ║
    ╚════════════════════════════════════════════════════════════════╝
    """
    print(tips)

ssl_training_tips()

Exercises

Implement Swapping Assignments between Views:
class SwAV(nn.Module):
    def forward(self, views):
        # 1. Compute features for all views
        # 2. Compute cluster assignments (Sinkhorn-Knopp)
        # 3. Swap: predict assignment of view1 from view2
Implement DINO-style multi-crop augmentation:
def multi_crop(image, n_global=2, n_local=6):
    # Global crops: 224x224, high coverage
    # Local crops: 96x96, low coverage
    # Train with all crops simultaneously
Train SimCLR, BYOL, and MAE on CIFAR-10 and compare:
  • Training time
  • Linear probe accuracy
  • Fine-tuning accuracy
  • Representation quality (t-SNE)

What’s Next?