Skip to main content

Documentation Index

Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt

Use this file to discover all available pages before exploring further.

Self-Supervised Learning

Self-Supervised Learning: Learning Without Labels

The Promise of Self-Supervision

Here is the fundamental asymmetry of modern AI: labeling a single ImageNet image takes about 60 seconds of human effort; generating an unlabeled image takes a camera shutter click. The internet produces billions of unlabeled images and text documents daily, while labeled datasets require expensive human annotation campaigns. Self-supervised learning bridges this gap by creating pretext tasks from unlabeled data — clever problems where the “answer” is already embedded in the data itself.
Supervised Learning:    Image -> Human Label -> Representation
Self-Supervised:        Image -> Generated Task -> Representation
Key Insight: Design tasks where the labels are FREE and force the model to learn useful representations. For example, take an image, create two randomly augmented versions, and train the model to recognize that both came from the same source. No human labels needed — but the model must learn to understand objects, textures, and spatial structure to succeed. This approach now powers the foundation of most state-of-the-art vision models (DINOv2, MAE) and all large language models (GPT, LLaMA) — they all use self-supervised pretraining as their first stage.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import numpy as np
from typing import List, Tuple, Optional

torch.manual_seed(42)

Pretext Tasks

Classic Pretext Tasks

These were the first generation of self-supervised methods, and while they have been largely superseded by contrastive and masked approaches, understanding them builds intuition for the core principle: design a task where the labels come for free, and where solving the task forces useful feature learning.
Pretext tasks are not created equal: Rotation prediction forces the model to understand object orientation (useful), but colorization may cause the model to memorize statistical color priors rather than learning semantic features (less useful). The quality of the downstream representations depends entirely on how well the pretext task aligns with the features you actually need. Always evaluate with linear probing on your target task, not just pretext accuracy.
class PretextTasks:
    """Classic self-supervised pretext tasks."""
    
    @staticmethod
    def rotation_prediction(image: torch.Tensor) -> Tuple[torch.Tensor, int]:
        """
        Rotate image and predict rotation angle.
        Forces model to understand object orientation.
        """
        rotation = np.random.choice([0, 1, 2, 3])  # 0°, 90°, 180°, 270°
        rotated = torch.rot90(image, rotation, dims=[-2, -1])
        return rotated, rotation
    
    @staticmethod
    def jigsaw_puzzle(image: torch.Tensor, grid_size: int = 3) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Shuffle image patches and predict correct order.
        Forces model to understand spatial relationships.
        """
        _, h, w = image.shape
        patch_h, patch_w = h // grid_size, w // grid_size
        
        # Extract patches
        patches = []
        for i in range(grid_size):
            for j in range(grid_size):
                patch = image[:, i*patch_h:(i+1)*patch_h, j*patch_w:(j+1)*patch_w]
                patches.append(patch)
        
        # Shuffle patches
        n_patches = grid_size * grid_size
        permutation = torch.randperm(n_patches)
        shuffled_patches = [patches[i] for i in permutation]
        
        # Reconstruct shuffled image
        rows = []
        for i in range(grid_size):
            row_patches = shuffled_patches[i*grid_size:(i+1)*grid_size]
            row = torch.cat(row_patches, dim=2)  # Concat horizontally
            rows.append(row)
        shuffled_image = torch.cat(rows, dim=1)  # Concat vertically
        
        return shuffled_image, permutation
    
    @staticmethod
    def colorization(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Convert to grayscale and predict colors.
        Forces model to understand semantic content.
        """
        # Assume image is RGB [3, H, W]
        grayscale = 0.299 * image[0] + 0.587 * image[1] + 0.114 * image[2]
        grayscale = grayscale.unsqueeze(0).repeat(3, 1, 1)
        
        return grayscale, image
    
    @staticmethod
    def inpainting(image: torch.Tensor, mask_ratio: float = 0.3) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Mask parts of image and predict masked content.
        Forces model to understand context.
        """
        _, h, w = image.shape
        
        # Create random mask
        mask = torch.rand(1, h, w) > mask_ratio
        
        # Apply mask
        masked_image = image * mask.float()
        
        return masked_image, image, mask


# Rotation prediction model
class RotationNet(nn.Module):
    """Predict rotation angle of image."""
    
    def __init__(self, backbone: nn.Module):
        super().__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(backbone.feature_dim, 4)  # 4 rotation angles
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)


# Example backbone
class SimpleBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_dim = 512
        
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        
        self.fc = nn.Linear(256, self.feature_dim)
    
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

Contrastive Learning

The Contrastive Learning Framework

Goal: Pull together representations of similar samples (positives), push apart representations of dissimilar samples (negatives). The analogy is organizing a photo album: pictures of the same person should be grouped together regardless of pose, lighting, or background, while pictures of different people should be clearly separated. The mathematical insight: if your model maps two augmentations of the same image to nearby points in embedding space, it must have learned features that are invariant to the augmentation (crops, rotations, color changes) but sensitive to the actual content (objects, structure, semantics). And those invariant, semantic features are exactly what you need for downstream tasks like classification.

SimCLR (Simple Contrastive Learning)

SimCLR (Chen et al., 2020) showed that contrastive learning could match supervised pretraining with a surprisingly simple recipe: strong augmentation + large batches + a projection head. The augmentation policy is critical — SimCLR found that the combination of random cropping and color distortion is especially powerful because it forces the model to match views that share neither spatial location nor color information, leaving only semantic content.
class SimCLRTransform:
    """Data augmentation for SimCLR.
    
    The augmentation pipeline is the MOST important hyperparameter in SimCLR.
    Each component serves a specific purpose:
    - RandomResizedCrop: forces spatial invariance (same object, different framing)
    - ColorJitter: forces color invariance (same object, different lighting)
    - GaussianBlur: forces texture sensitivity over high-frequency noise
    - RandomGrayscale: prevents relying on color alone for object identity
    """
    
    def __init__(self, size: int = 224):
        self.transform = T.Compose([
            T.RandomResizedCrop(size, scale=(0.2, 1.0)),   # Aggressive crops
            T.RandomHorizontalFlip(),
            T.RandomApply([T.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),  # Strong color distortion
            T.RandomGrayscale(p=0.2),                       # Force shape-based features
            T.RandomApply([T.GaussianBlur(kernel_size=23)], p=0.5),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __call__(self, x):
        """Generate two augmented views of the same image."""
        return self.transform(x), self.transform(x)


class SimCLR(nn.Module):
    """SimCLR: A Simple Framework for Contrastive Learning."""
    
    def __init__(
        self,
        backbone: nn.Module,
        projection_dim: int = 128,
        hidden_dim: int = 2048,
        temperature: float = 0.5
    ):
        super().__init__()
        
        self.backbone = backbone
        self.temperature = temperature
        
        # Projection head (MLP)
        feature_dim = backbone.feature_dim
        self.projector = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, projection_dim)
        )
    
    def forward(self, x1: torch.Tensor, x2: torch.Tensor):
        """
        Args:
            x1: First augmented view [batch_size, C, H, W]
            x2: Second augmented view [batch_size, C, H, W]
        
        Returns:
            loss: NT-Xent contrastive loss
        """
        batch_size = x1.size(0)
        
        # Extract features and project
        z1 = self.projector(self.backbone(x1))  # [batch, projection_dim]
        z2 = self.projector(self.backbone(x2))
        
        # Normalize projections
        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)
        
        # Compute similarity matrix
        representations = torch.cat([z1, z2], dim=0)  # [2*batch, projection_dim]
        similarity_matrix = torch.mm(representations, representations.t())
        
        # Create labels: positive pairs are (i, i+batch_size)
        labels = torch.arange(batch_size, device=x1.device)
        labels = torch.cat([labels + batch_size, labels], dim=0)
        
        # Mask out self-similarity
        mask = torch.eye(2 * batch_size, device=x1.device).bool()
        similarity_matrix.masked_fill_(mask, float('-inf'))
        
        # Scale by temperature
        similarity_matrix = similarity_matrix / self.temperature
        
        # NT-Xent loss (InfoNCE)
        loss = F.cross_entropy(similarity_matrix, labels)
        
        return loss
    
    def get_representations(self, x: torch.Tensor) -> torch.Tensor:
        """Get learned representations for downstream tasks."""
        return self.backbone(x)


# Train SimCLR
def train_simclr(model, dataloader, optimizer, epochs=100):
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        
        for images, _ in dataloader:  # We don't need labels!
            # Get two augmented views
            x1, x2 = SimCLRTransform()(images)
            x1, x2 = x1.cuda(), x2.cuda()
            
            optimizer.zero_grad()
            loss = model(x1, x2)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}: Loss = {total_loss/len(dataloader):.4f}")

MoCo (Momentum Contrast)

SimCLR has a practical limitation: it needs very large batch sizes (4096+) to have enough negative samples, which requires multiple GPUs with huge memory. MoCo solves this elegantly with two innovations: (1) a momentum-updated encoder that provides consistent key representations, and (2) a queue of past key embeddings that serves as a large, diverse negative pool without needing a massive batch. Think of it like this: SimCLR compares every image to every other image in the current batch. MoCo compares each image to a much larger “memory bank” of past representations, giving it many more negatives to contrast against without the GPU memory cost.
class MoCo(nn.Module):
    """
    Momentum Contrast (MoCo v2).
    Uses a momentum encoder and a large dictionary of negative samples.
    Key insight: decouple the negative pool size from the batch size.
    """
    
    def __init__(
        self,
        backbone: nn.Module,
        projection_dim: int = 128,
        queue_size: int = 65536,
        momentum: float = 0.999,
        temperature: float = 0.07
    ):
        super().__init__()
        
        self.queue_size = queue_size
        self.momentum = momentum
        self.temperature = temperature
        
        # Query encoder
        self.encoder_q = backbone
        self.projector_q = nn.Sequential(
            nn.Linear(backbone.feature_dim, 2048),
            nn.ReLU(),
            nn.Linear(2048, projection_dim)
        )
        
        # Key encoder (momentum updated)
        self.encoder_k = self._copy_encoder(backbone)
        self.projector_k = nn.Sequential(
            nn.Linear(backbone.feature_dim, 2048),
            nn.ReLU(),
            nn.Linear(2048, projection_dim)
        )
        self._copy_params(self.projector_q, self.projector_k)
        
        # Freeze key encoder
        for param in self.encoder_k.parameters():
            param.requires_grad = False
        for param in self.projector_k.parameters():
            param.requires_grad = False
        
        # Queue for negative samples
        self.register_buffer("queue", F.normalize(torch.randn(projection_dim, queue_size), dim=0))
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
    
    def _copy_encoder(self, encoder):
        import copy
        return copy.deepcopy(encoder)
    
    def _copy_params(self, src, dst):
        for param_src, param_dst in zip(src.parameters(), dst.parameters()):
            param_dst.data.copy_(param_src.data)
    
    @torch.no_grad()
    def _momentum_update(self):
        """Update key encoder with momentum."""
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = self.momentum * param_k.data + (1 - self.momentum) * param_q.data
        for param_q, param_k in zip(self.projector_q.parameters(), self.projector_k.parameters()):
            param_k.data = self.momentum * param_k.data + (1 - self.momentum) * param_q.data
    
    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        """Update the queue with new key embeddings."""
        batch_size = keys.size(0)
        ptr = int(self.queue_ptr)
        
        # Replace oldest entries
        if ptr + batch_size > self.queue_size:
            self.queue[:, ptr:] = keys[:self.queue_size - ptr].T
            self.queue[:, :batch_size - (self.queue_size - ptr)] = keys[self.queue_size - ptr:].T
        else:
            self.queue[:, ptr:ptr + batch_size] = keys.T
        
        ptr = (ptr + batch_size) % self.queue_size
        self.queue_ptr[0] = ptr
    
    def forward(self, x_q, x_k):
        """
        Args:
            x_q: Query images
            x_k: Key images (different augmentation of same images)
        """
        batch_size = x_q.size(0)
        
        # Compute query features
        q = self.projector_q(self.encoder_q(x_q))
        q = F.normalize(q, dim=1)
        
        # Compute key features (no gradient)
        with torch.no_grad():
            self._momentum_update()
            k = self.projector_k(self.encoder_k(x_k))
            k = F.normalize(k, dim=1)
        
        # Positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        
        # Negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
        
        # Logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1) / self.temperature
        
        # Labels: positive is at index 0
        labels = torch.zeros(batch_size, dtype=torch.long, device=x_q.device)
        
        # Update queue
        self._dequeue_and_enqueue(k)
        
        return F.cross_entropy(logits, labels)

Non-Contrastive Methods

BYOL (Bootstrap Your Own Latent)

BYOL dropped a bombshell in 2020: you do not need negative samples at all. This was shocking because the entire field assumed negatives were essential to prevent “collapse” (where the model maps everything to the same point). BYOL uses an asymmetric architecture with a predictor head and a momentum-updated target network, which creates enough asymmetry to prevent collapse without any negatives. The mechanism is still not fully understood theoretically, which makes it one of the more intriguing results in recent representation learning.
class BYOL(nn.Module):
    """
    Bootstrap Your Own Latent.
    No negative samples needed - learns by predicting target network.
    """
    
    def __init__(
        self,
        backbone: nn.Module,
        projection_dim: int = 256,
        hidden_dim: int = 4096,
        moving_average_decay: float = 0.99
    ):
        super().__init__()
        
        self.moving_average_decay = moving_average_decay
        
        # Online network
        self.online_encoder = backbone
        self.online_projector = self._make_projector(backbone.feature_dim, hidden_dim, projection_dim)
        self.predictor = self._make_projector(projection_dim, hidden_dim, projection_dim)
        
        # Target network (momentum updated)
        self.target_encoder = self._copy_encoder(backbone)
        self.target_projector = self._make_projector(backbone.feature_dim, hidden_dim, projection_dim)
        
        # Freeze target network
        for param in self.target_encoder.parameters():
            param.requires_grad = False
        for param in self.target_projector.parameters():
            param.requires_grad = False
    
    def _make_projector(self, in_dim, hidden_dim, out_dim):
        return nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    
    def _copy_encoder(self, encoder):
        import copy
        return copy.deepcopy(encoder)
    
    @torch.no_grad()
    def update_target_network(self):
        """EMA update of target network."""
        tau = self.moving_average_decay
        
        for online, target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            target.data = tau * target.data + (1 - tau) * online.data
        
        for online, target in zip(self.online_projector.parameters(), self.target_projector.parameters()):
            target.data = tau * target.data + (1 - tau) * online.data
    
    def forward(self, x1, x2):
        """
        Args:
            x1, x2: Two augmented views of the same images
        """
        # Online network forward
        online_feat1 = self.online_encoder(x1)
        online_proj1 = self.online_projector(online_feat1)
        online_pred1 = self.predictor(online_proj1)
        
        online_feat2 = self.online_encoder(x2)
        online_proj2 = self.online_projector(online_feat2)
        online_pred2 = self.predictor(online_proj2)
        
        # Target network forward (no gradient)
        with torch.no_grad():
            target_proj1 = self.target_projector(self.target_encoder(x1))
            target_proj2 = self.target_projector(self.target_encoder(x2))
            
            # Stop gradient
            target_proj1 = target_proj1.detach()
            target_proj2 = target_proj2.detach()
        
        # Compute loss: predict one view from the other
        loss1 = self._loss_fn(online_pred1, target_proj2)
        loss2 = self._loss_fn(online_pred2, target_proj1)
        
        return (loss1 + loss2) / 2
    
    def _loss_fn(self, x, y):
        """Normalized MSE loss."""
        x = F.normalize(x, dim=-1)
        y = F.normalize(y, dim=-1)
        return 2 - 2 * (x * y).sum(dim=-1).mean()

SimSiam (Simple Siamese)

SimSiam distills the non-contrastive approach to its absolute minimum: no negative samples, no momentum encoder, no large batches. The only thing preventing collapse is the stop-gradient operation on one branch and the asymmetric predictor head on the other. The paper by Xinlei Chen and Kaiming He (2021) is remarkable for showing that this minimal recipe works — and for providing an analysis suggesting that SimSiam implicitly performs an alternating optimization similar to Expectation-Maximization.
When to use SimSiam vs. BYOL vs. SimCLR: SimSiam is the easiest to implement and works with small batches (256), making it the best starting point for self-supervised experiments on a single GPU. BYOL adds the momentum encoder for slightly better performance on large-scale benchmarks. SimCLR requires multi-GPU setups with batch sizes of 4096+ to provide enough negatives. Start with SimSiam; scale up only if you hit a quality ceiling.
class SimSiam(nn.Module):
    """
    Simple Siamese networks.
    Even simpler than BYOL - no momentum encoder, just stop-gradient.
    """
    
    def __init__(
        self,
        backbone: nn.Module,
        projection_dim: int = 2048,
        prediction_dim: int = 512
    ):
        super().__init__()
        
        self.encoder = backbone
        
        # Projection MLP
        self.projector = nn.Sequential(
            nn.Linear(backbone.feature_dim, projection_dim),
            nn.BatchNorm1d(projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim),
            nn.BatchNorm1d(projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim),
            nn.BatchNorm1d(projection_dim)
        )
        
        # Prediction MLP
        self.predictor = nn.Sequential(
            nn.Linear(projection_dim, prediction_dim),
            nn.BatchNorm1d(prediction_dim),
            nn.ReLU(),
            nn.Linear(prediction_dim, projection_dim)
        )
    
    def forward(self, x1, x2):
        """Forward pass with two views."""
        
        # Compute projections
        z1 = self.projector(self.encoder(x1))
        z2 = self.projector(self.encoder(x2))
        
        # Compute predictions
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)
        
        # Compute loss with stop-gradient
        loss1 = self._negative_cosine_similarity(p1, z2.detach())
        loss2 = self._negative_cosine_similarity(p2, z1.detach())
        
        return (loss1 + loss2) / 2
    
    def _negative_cosine_similarity(self, p, z):
        """Negative cosine similarity."""
        p = F.normalize(p, dim=1)
        z = F.normalize(z, dim=1)
        return -(p * z).sum(dim=1).mean()

Masked Modeling

Masked Autoencoders (MAE)

Masked Autoencoders brought the “masked language modeling” paradigm from NLP (BERT) to vision, and the results were striking. The core idea: mask 75% of image patches (far more aggressive than BERT’s 15% token masking) and train the model to reconstruct the missing pixels. Why such aggressive masking? Because images have high spatial redundancy — neighboring patches are highly correlated, so the model can “cheat” by interpolating from nearby visible patches unless you force it to reason about large missing regions. The efficiency benefit is also remarkable: since the encoder only processes the 25% visible patches, MAE pretraining is 3-4x faster than contrastive methods that must process full images through two augmented views.
The decoder is throwaway: After pretraining, only the encoder is used for downstream tasks. The decoder exists solely to provide a training signal. Do not waste time optimizing decoder architecture — a lightweight transformer decoder (much smaller than the encoder) works well. Kaiming He’s original MAE uses an 8-layer decoder for a 12-layer encoder.
class MaskedAutoencoder(nn.Module):
    """
    Masked Autoencoder for Vision (MAE).
    Mask random patches and reconstruct them.
    """
    
    def __init__(
        self,
        image_size: int = 224,
        patch_size: int = 16,
        embed_dim: int = 768,
        encoder_depth: int = 12,
        encoder_heads: int = 12,
        decoder_embed_dim: int = 512,
        decoder_depth: int = 8,
        decoder_heads: int = 16,
        mask_ratio: float = 0.75
    ):
        super().__init__()
        
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio
        num_patches = (image_size // patch_size) ** 2
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim) * 0.02)
        
        # Encoder (processes visible patches only)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=encoder_heads,
                dim_feedforward=embed_dim * 4,
                batch_first=True
            ),
            num_layers=encoder_depth
        )
        
        # Decoder
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed = nn.Parameter(torch.randn(1, num_patches, decoder_embed_dim) * 0.02)
        
        self.decoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=decoder_embed_dim,
                nhead=decoder_heads,
                dim_feedforward=decoder_embed_dim * 4,
                batch_first=True
            ),
            num_layers=decoder_depth
        )
        
        # Prediction head (reconstruct pixels)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * 3)
    
    def random_masking(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Randomly mask patches.
        
        Args:
            x: [batch, num_patches, embed_dim]
        
        Returns:
            x_masked: visible patches
            mask: binary mask (1 = masked)
            ids_restore: indices to restore original order
        """
        N, L, D = x.shape
        len_keep = int(L * (1 - self.mask_ratio))
        
        # Random noise for shuffling
        noise = torch.rand(N, L, device=x.device)
        
        # Sort noise to get shuffle indices
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        
        # Keep first len_keep patches (after shuffling)
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, D))
        
        # Generate binary mask: 0 = keep, 1 = masked
        mask = torch.ones(N, L, device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)
        
        return x_masked, mask, ids_restore
    
    def forward(self, images: torch.Tensor):
        """
        Args:
            images: [batch, 3, H, W]
        
        Returns:
            loss: Reconstruction loss on masked patches
            pred: Predicted pixel values
            mask: Binary mask
        """
        # Patchify
        patches = self.patch_embed(images)  # [B, embed_dim, H/P, W/P]
        patches = patches.flatten(2).transpose(1, 2)  # [B, num_patches, embed_dim]
        
        # Add positional embedding
        patches = patches + self.pos_embed
        
        # Random masking
        patches_masked, mask, ids_restore = self.random_masking(patches)
        
        # Encode visible patches
        latent = self.encoder(patches_masked)
        
        # Project to decoder dimension
        latent = self.decoder_embed(latent)
        
        # Append mask tokens
        N, _, D = latent.shape
        num_patches = mask.shape[1]
        mask_tokens = self.mask_token.expand(N, num_patches - latent.shape[1], -1)
        
        # Unshuffle: put mask tokens in correct positions
        x_ = torch.cat([latent, mask_tokens], dim=1)
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).expand(-1, -1, D))
        
        # Add decoder positional embedding
        x_ = x_ + self.decoder_pos_embed
        
        # Decode
        x_ = self.decoder(x_)
        
        # Predict pixels
        pred = self.decoder_pred(x_)  # [B, num_patches, patch_size^2 * 3]
        
        # Compute loss only on masked patches
        target = self._patchify(images)
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # Mean over patch pixels
        loss = (loss * mask).sum() / mask.sum()  # Mean over masked patches
        
        return loss, pred, mask
    
    def _patchify(self, images: torch.Tensor) -> torch.Tensor:
        """Convert images to patches."""
        p = self.patch_size
        B, C, H, W = images.shape
        h, w = H // p, W // p
        
        x = images.reshape(B, C, h, p, w, p)
        x = x.permute(0, 2, 4, 3, 5, 1)  # [B, h, w, p, p, C]
        x = x.reshape(B, h * w, p * p * C)
        
        return x


# Test MAE
mae = MaskedAutoencoder(image_size=224, patch_size=16, embed_dim=768)
images = torch.randn(4, 3, 224, 224)
loss, pred, mask = mae(images)
print(f"MAE Loss: {loss.item():.4f}")
print(f"Mask ratio: {mask.float().mean().item():.2f}")

BERT-style Masked Language Modeling

class MaskedLanguageModel(nn.Module):
    """BERT-style masked language modeling for text."""
    
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int = 768,
        num_layers: int = 12,
        num_heads: int = 12,
        max_seq_len: int = 512,
        mask_token_id: int = 103,  # [MASK] token
        mask_ratio: float = 0.15
    ):
        super().__init__()
        
        self.mask_token_id = mask_token_id
        self.mask_ratio = mask_ratio
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Embedding(max_seq_len, embed_dim)
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads,
                dim_feedforward=embed_dim * 4,
                batch_first=True
            ),
            num_layers=num_layers
        )
        
        self.output = nn.Linear(embed_dim, vocab_size)
    
    def mask_tokens(self, input_ids: torch.Tensor):
        """Apply BERT-style masking."""
        labels = input_ids.clone()
        
        # Probability matrix for masking
        probability_matrix = torch.full(input_ids.shape, self.mask_ratio)
        
        # Create masked indices
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # Only compute loss on masked tokens
        
        # 80% -> [MASK], 10% -> random, 10% -> original
        indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
        input_ids[indices_replaced] = self.mask_token_id
        
        indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.embedding.weight), input_ids.shape)
        input_ids[indices_random] = random_words[indices_random]
        
        return input_ids, labels
    
    def forward(self, input_ids: torch.Tensor):
        # Apply masking
        masked_ids, labels = self.mask_tokens(input_ids.clone())
        
        # Get embeddings
        seq_len = masked_ids.size(1)
        positions = torch.arange(seq_len, device=masked_ids.device).unsqueeze(0)
        
        x = self.embedding(masked_ids) + self.pos_embedding(positions)
        
        # Transformer forward
        x = self.transformer(x)
        
        # Predict masked tokens
        logits = self.output(x)
        
        # Compute loss
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            ignore_index=-100
        )
        
        return loss, logits

Evaluation & Downstream Tasks

Linear Probing

Linear probing is the standard litmus test for representation quality: freeze the pretrained encoder entirely and train only a single linear layer on top for a downstream classification task. The idea is that if a linear classifier can achieve high accuracy, the representations must already be linearly separable — meaning the encoder has organized its feature space in a semantically meaningful way without any task-specific supervision.
Linear probe vs. fine-tuning: A high linear probe accuracy means the pretrained features are excellent “out of the box.” A large gap between linear probe and fine-tuning accuracy means the features need adaptation — they captured useful information but organized it in a way that is not directly aligned with your task. In practice, always report both numbers, as they measure different things: feature quality (linear probe) vs. feature adaptability (fine-tuning).
class LinearProbe(nn.Module):
    """Linear evaluation of learned representations."""
    
    def __init__(self, encoder: nn.Module, num_classes: int):
        super().__init__()
        
        self.encoder = encoder
        
        # Freeze encoder
        for param in self.encoder.parameters():
            param.requires_grad = False
        
        self.classifier = nn.Linear(encoder.feature_dim, num_classes)
    
    def forward(self, x):
        with torch.no_grad():
            features = self.encoder(x)
        return self.classifier(features)


def linear_evaluation(pretrained_encoder, train_loader, val_loader, num_classes, epochs=100):
    """Evaluate representations with linear probing."""
    
    probe = LinearProbe(pretrained_encoder, num_classes).cuda()
    optimizer = torch.optim.Adam(probe.classifier.parameters(), lr=0.001)
    
    best_acc = 0
    
    for epoch in range(epochs):
        # Train
        probe.train()
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            
            optimizer.zero_grad()
            outputs = probe(images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
        
        # Evaluate
        probe.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.cuda(), labels.cuda()
                outputs = probe(images)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        acc = 100 * correct / total
        if acc > best_acc:
            best_acc = acc
        
        print(f"Epoch {epoch+1}: Accuracy = {acc:.2f}%")
    
    return best_acc

Fine-tuning

def finetune_evaluation(pretrained_encoder, train_loader, val_loader, num_classes, epochs=100):
    """Fine-tune entire model on downstream task."""
    
    model = nn.Sequential(
        pretrained_encoder,
        nn.Linear(pretrained_encoder.feature_dim, num_classes)
    ).cuda()
    
    # Different learning rates for encoder and classifier
    optimizer = torch.optim.AdamW([
        {'params': pretrained_encoder.parameters(), 'lr': 1e-4},
        {'params': model[-1].parameters(), 'lr': 1e-3}
    ])
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    
    for epoch in range(epochs):
        model.train()
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
        
        scheduler.step()
    
    # Final evaluation
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return 100 * correct / total

Practical Considerations

Data Augmentation for SSL

class SSLAugmentations:
    """Strong augmentations for self-supervised learning."""
    
    @staticmethod
    def simclr_augmentation(size=224):
        """SimCLR augmentation strategy."""
        return T.Compose([
            T.RandomResizedCrop(size, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomApply([
                T.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2)
            ], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([T.GaussianBlur(kernel_size=23)], p=0.5),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    @staticmethod
    def byol_augmentation(size=224):
        """BYOL augmentation (asymmetric)."""
        base = [
            T.RandomResizedCrop(size, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomApply([
                T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)
            ], p=0.8),
            T.RandomGrayscale(p=0.2),
        ]
        
        view1 = T.Compose(base + [
            T.RandomApply([T.GaussianBlur(kernel_size=23)], p=1.0),  # Always blur
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        view2 = T.Compose(base + [
            T.RandomApply([T.GaussianBlur(kernel_size=23)], p=0.1),  # Rarely blur
            T.RandomSolarize(threshold=128, p=0.2),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        return view1, view2

Training Tips

def ssl_training_tips():
    """Key tips for successful self-supervised training."""
    
    tips = """
    ╔════════════════════════════════════════════════════════════════╗
    ║           SELF-SUPERVISED LEARNING: BEST PRACTICES            ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                ║
    ║  1. BATCH SIZE                                                 ║
    ║     • Contrastive methods (SimCLR, MoCo): Large batches (4096+)║
    ║     • Non-contrastive (BYOL, SimSiam): Smaller batches OK      ║
    ║     • Use gradient accumulation if GPU memory limited          ║
    ║                                                                ║
    ║  2. LEARNING RATE                                              ║
    ║     • Base LR scales with batch size: lr = base_lr * batch/256 ║
    ║     • Use cosine schedule with warmup                          ║
    ║     • LARS optimizer for very large batches                    ║
    ║                                                                ║
    ║  3. AUGMENTATION                                               ║
    ║     • Stronger augmentation = better representations           ║
    ║     • Color jittering is crucial                               ║
    ║     • Multi-crop (DINO style) improves efficiency              ║
    ║                                                                ║
    ║  4. ARCHITECTURE                                               ║
    ║     • Projection head: 2-3 layer MLP with BN                   ║
    ║     • Predictor (BYOL/SimSiam): Crucial for preventing collapse║
    ║     • Larger models generally learn better representations     ║
    ║                                                                ║
    ║  5. TRAINING DURATION                                          ║
    ║     • SSL needs longer training than supervised (800+ epochs)  ║
    ║     • Early stopping based on downstream performance           ║
    ║                                                                ║
    ║  6. EVALUATION                                                 ║
    ║     • Linear probe: Freeze encoder, train linear classifier    ║
    ║     • k-NN: Nearest neighbor classification (no training)      ║
    ║     • Fine-tuning: Update entire model on downstream task      ║
    ║                                                                ║
    ╚════════════════════════════════════════════════════════════════╝
    """
    print(tips)

ssl_training_tips()

Exercises

Implement Swapping Assignments between Views:
class SwAV(nn.Module):
    def forward(self, views):
        # 1. Compute features for all views
        # 2. Compute cluster assignments (Sinkhorn-Knopp)
        # 3. Swap: predict assignment of view1 from view2
Implement DINO-style multi-crop augmentation:
def multi_crop(image, n_global=2, n_local=6):
    # Global crops: 224x224, high coverage
    # Local crops: 96x96, low coverage
    # Train with all crops simultaneously
Train SimCLR, BYOL, and MAE on CIFAR-10 and compare:
  • Training time
  • Linear probe accuracy
  • Fine-tuning accuracy
  • Representation quality (t-SNE)

What’s Next?

Reinforcement Learning for DL

RLHF, PPO, and preference optimization

Neural Architecture Search

Automated architecture discovery

Interview Deep-Dive

Strong Answer:SimCLR learns representations by pulling together two augmented views of the same image while pushing apart views of different images. It requires large batch sizes (4096-8192) for sufficient negatives, and captures primarily high-level semantic features — object identity and category — but can miss low-level visual details like texture and spatial structure.MAE masks 75% of image patches and trains the model to reconstruct the missing pixels. It needs no negatives or large batches. MAE representations are richer in spatial and structural information because pixel reconstruction forces understanding of geometry, texture, and local context.Choose SimCLR when your downstream task is classification or retrieval where discriminative semantic embeddings matter. SimCLR representations transfer better to linear probing because they are more discriminative by construction.Choose MAE when your downstream task requires dense prediction (segmentation, detection, depth estimation) or when pretraining a ViT from scratch on smaller datasets. MAE’s reconstruction objective provides a very strong learning signal that overcomes ViT’s lack of inductive bias, while SimCLR on small datasets produces weaker representations due to insufficient negative diversity.The current trend: DINOv2 combines self-distillation and masked image modeling, getting the best of both approaches.Follow-up: Why does SimCLR need such large batch sizes, and how does BYOL avoid this requirement?SimCLR’s contrastive loss needs negatives to define what to push apart. With small batches, most negatives are “easy” (very different images). Large batches increase the probability of “hard” negatives that force finer-grained representations.BYOL sidesteps negatives entirely using two networks: an online network and a target network (exponential moving average of the online). The online network predicts the target’s representation of a different view. The asymmetry — a predictor MLP exists only in the online network, and the target updates slowly via EMA — prevents collapse without needing negatives. BYOL works well with batch sizes as small as 256, making it accessible without massive GPU clusters, with accuracy within 0.5% of SimCLR on ImageNet linear probe.
Strong Answer:Every major LLM (GPT, LLaMA, Mistral) uses causal language modeling — predicting the next token given all previous tokens. The “labels” are just the next word in existing text, requiring no human annotation.Why it works so well: next-token prediction is an extraordinarily rich objective that implicitly requires learning syntax, semantics, world knowledge, reasoning, and theory of mind. To predict the next word after “The capital of France is,” the model must encode geographic knowledge. To complete code, it must understand programming logic. Every possible downstream task is effectively a subtask of next-token prediction.The information-theoretic argument: natural language has about 1-2 bits of entropy per character. A model achieving near-human perplexity has necessarily compressed vast world knowledge into its parameters, because that knowledge is required for accurate prediction. This is not just pattern matching — it is implicit multi-task learning.The practical insight: data quality matters enormously. A model trained on high-quality data with fewer tokens typically outperforms one trained on low-quality data with more tokens. Teams like Meta invest months in data curation, deduplication, and quality filtering before training begins.Follow-up: BERT uses masked language modeling instead of causal language modeling. Why did GPT-style causal models win for generative tasks?Causal models are autoregressive — they generate tokens left-to-right, matching how text generation works in practice. BERT is bidirectional (sees context from both sides of masked tokens), giving stronger understanding but making it unsuitable for generation since you cannot condition on future tokens that do not exist yet.The deeper reason causal models won: generation is more commercially valuable and more general than understanding. A model that generates coherent text implicitly understands context, but an understanding model cannot easily generate. This asymmetry drove industry convergence on autoregressive architectures despite BERT’s earlier publication.
Strong Answer:Phase one — self-supervised pretraining on the 10M unlabeled images using MAE. MAE works well without huge batch sizes, and for domain-specific data, the spatial and structural features from reconstruction are more useful than SimCLR’s purely semantic features. Pretrain a ViT-Base for 200-400 epochs — MAE’s 75% masking makes each epoch cheap.Phase two — supervised fine-tuning on the 5K labeled images. Use aggressive augmentation (RandAugment, CutMix, MixUp), low learning rate (1e-4 with cosine decay), label smoothing (0.1), and weight decay (0.05) to prevent overfitting. Fine-tune the full model, not just a linear probe, since 5K images is sufficient for end-to-end tuning.Phase three — semi-supervised refinement. Use the fine-tuned model to pseudo-label the most confident 50% of the 10M unlabeled images. Retrain on real labels plus pseudo-labels for 2-5% additional accuracy.If the domain is very different from ImageNet (microscopy, satellite, sonar), domain-specific pretraining significantly outperforms ImageNet pretraining. If similar to natural images, the gap narrows but domain pretraining still wins.Follow-up: How do you validate that self-supervised pretraining learned useful representations before investing in expensive fine-tuning?Three quick checks under one hour total. First, linear probing: freeze backbone, attach a linear classifier, train on 500 of the 5K labels — accuracy significantly above random indicates useful features. Second, nearest-neighbor retrieval: embed all 5K labeled images, check if nearest neighbors share labels — high recall at k=5 indicates good semantic organization. Third, t-SNE visualization of labeled set embeddings — clusters corresponding to meaningful categories confirm structural learning. These checks provide high confidence before committing to full fine-tuning.