Self-Supervised Learning: Learning Without Labels
The Promise of Self-Supervision
Labeled data is expensive. Unlabeled data is abundant. Self-supervised learning bridges this gap by creating pretext tasks from unlabeled data.Copy
Supervised Learning: Image → Human Label → Representation
Self-Supervised: Image → Generated Task → Representation
Copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import numpy as np
from typing import List, Tuple, Optional
torch.manual_seed(42)
Pretext Tasks
Classic Pretext Tasks
Copy
class PretextTasks:
"""Classic self-supervised pretext tasks."""
@staticmethod
def rotation_prediction(image: torch.Tensor) -> Tuple[torch.Tensor, int]:
"""
Rotate image and predict rotation angle.
Forces model to understand object orientation.
"""
rotation = np.random.choice([0, 1, 2, 3]) # 0°, 90°, 180°, 270°
rotated = torch.rot90(image, rotation, dims=[-2, -1])
return rotated, rotation
@staticmethod
def jigsaw_puzzle(image: torch.Tensor, grid_size: int = 3) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Shuffle image patches and predict correct order.
Forces model to understand spatial relationships.
"""
_, h, w = image.shape
patch_h, patch_w = h // grid_size, w // grid_size
# Extract patches
patches = []
for i in range(grid_size):
for j in range(grid_size):
patch = image[:, i*patch_h:(i+1)*patch_h, j*patch_w:(j+1)*patch_w]
patches.append(patch)
# Shuffle patches
n_patches = grid_size * grid_size
permutation = torch.randperm(n_patches)
shuffled_patches = [patches[i] for i in permutation]
# Reconstruct shuffled image
rows = []
for i in range(grid_size):
row_patches = shuffled_patches[i*grid_size:(i+1)*grid_size]
row = torch.cat(row_patches, dim=2) # Concat horizontally
rows.append(row)
shuffled_image = torch.cat(rows, dim=1) # Concat vertically
return shuffled_image, permutation
@staticmethod
def colorization(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert to grayscale and predict colors.
Forces model to understand semantic content.
"""
# Assume image is RGB [3, H, W]
grayscale = 0.299 * image[0] + 0.587 * image[1] + 0.114 * image[2]
grayscale = grayscale.unsqueeze(0).repeat(3, 1, 1)
return grayscale, image
@staticmethod
def inpainting(image: torch.Tensor, mask_ratio: float = 0.3) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Mask parts of image and predict masked content.
Forces model to understand context.
"""
_, h, w = image.shape
# Create random mask
mask = torch.rand(1, h, w) > mask_ratio
# Apply mask
masked_image = image * mask.float()
return masked_image, image, mask
# Rotation prediction model
class RotationNet(nn.Module):
"""Predict rotation angle of image."""
def __init__(self, backbone: nn.Module):
super().__init__()
self.backbone = backbone
self.classifier = nn.Linear(backbone.feature_dim, 4) # 4 rotation angles
def forward(self, x):
features = self.backbone(x)
return self.classifier(features)
# Example backbone
class SimpleBackbone(nn.Module):
def __init__(self):
super().__init__()
self.feature_dim = 512
self.conv = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
self.fc = nn.Linear(256, self.feature_dim)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
return self.fc(x)
Contrastive Learning
The Contrastive Learning Framework
Goal: Pull together representations of similar samples (positives), push apart representations of dissimilar samples (negatives).SimCLR (Simple Contrastive Learning)
Copy
class SimCLRTransform:
"""Data augmentation for SimCLR."""
def __init__(self, size: int = 224):
self.transform = T.Compose([
T.RandomResizedCrop(size, scale=(0.2, 1.0)),
T.RandomHorizontalFlip(),
T.RandomApply([T.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
T.RandomGrayscale(p=0.2),
T.RandomApply([T.GaussianBlur(kernel_size=23)], p=0.5),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __call__(self, x):
"""Generate two augmented views of the same image."""
return self.transform(x), self.transform(x)
class SimCLR(nn.Module):
"""SimCLR: A Simple Framework for Contrastive Learning."""
def __init__(
self,
backbone: nn.Module,
projection_dim: int = 128,
hidden_dim: int = 2048,
temperature: float = 0.5
):
super().__init__()
self.backbone = backbone
self.temperature = temperature
# Projection head (MLP)
feature_dim = backbone.feature_dim
self.projector = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, projection_dim)
)
def forward(self, x1: torch.Tensor, x2: torch.Tensor):
"""
Args:
x1: First augmented view [batch_size, C, H, W]
x2: Second augmented view [batch_size, C, H, W]
Returns:
loss: NT-Xent contrastive loss
"""
batch_size = x1.size(0)
# Extract features and project
z1 = self.projector(self.backbone(x1)) # [batch, projection_dim]
z2 = self.projector(self.backbone(x2))
# Normalize projections
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
# Compute similarity matrix
representations = torch.cat([z1, z2], dim=0) # [2*batch, projection_dim]
similarity_matrix = torch.mm(representations, representations.t())
# Create labels: positive pairs are (i, i+batch_size)
labels = torch.arange(batch_size, device=x1.device)
labels = torch.cat([labels + batch_size, labels], dim=0)
# Mask out self-similarity
mask = torch.eye(2 * batch_size, device=x1.device).bool()
similarity_matrix.masked_fill_(mask, float('-inf'))
# Scale by temperature
similarity_matrix = similarity_matrix / self.temperature
# NT-Xent loss (InfoNCE)
loss = F.cross_entropy(similarity_matrix, labels)
return loss
def get_representations(self, x: torch.Tensor) -> torch.Tensor:
"""Get learned representations for downstream tasks."""
return self.backbone(x)
# Train SimCLR
def train_simclr(model, dataloader, optimizer, epochs=100):
model.train()
for epoch in range(epochs):
total_loss = 0
for images, _ in dataloader: # We don't need labels!
# Get two augmented views
x1, x2 = SimCLRTransform()(images)
x1, x2 = x1.cuda(), x2.cuda()
optimizer.zero_grad()
loss = model(x1, x2)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}: Loss = {total_loss/len(dataloader):.4f}")
MoCo (Momentum Contrast)
Copy
class MoCo(nn.Module):
"""
Momentum Contrast (MoCo v2).
Uses a momentum encoder and a large dictionary of negative samples.
"""
def __init__(
self,
backbone: nn.Module,
projection_dim: int = 128,
queue_size: int = 65536,
momentum: float = 0.999,
temperature: float = 0.07
):
super().__init__()
self.queue_size = queue_size
self.momentum = momentum
self.temperature = temperature
# Query encoder
self.encoder_q = backbone
self.projector_q = nn.Sequential(
nn.Linear(backbone.feature_dim, 2048),
nn.ReLU(),
nn.Linear(2048, projection_dim)
)
# Key encoder (momentum updated)
self.encoder_k = self._copy_encoder(backbone)
self.projector_k = nn.Sequential(
nn.Linear(backbone.feature_dim, 2048),
nn.ReLU(),
nn.Linear(2048, projection_dim)
)
self._copy_params(self.projector_q, self.projector_k)
# Freeze key encoder
for param in self.encoder_k.parameters():
param.requires_grad = False
for param in self.projector_k.parameters():
param.requires_grad = False
# Queue for negative samples
self.register_buffer("queue", F.normalize(torch.randn(projection_dim, queue_size), dim=0))
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
def _copy_encoder(self, encoder):
import copy
return copy.deepcopy(encoder)
def _copy_params(self, src, dst):
for param_src, param_dst in zip(src.parameters(), dst.parameters()):
param_dst.data.copy_(param_src.data)
@torch.no_grad()
def _momentum_update(self):
"""Update key encoder with momentum."""
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data = self.momentum * param_k.data + (1 - self.momentum) * param_q.data
for param_q, param_k in zip(self.projector_q.parameters(), self.projector_k.parameters()):
param_k.data = self.momentum * param_k.data + (1 - self.momentum) * param_q.data
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
"""Update the queue with new key embeddings."""
batch_size = keys.size(0)
ptr = int(self.queue_ptr)
# Replace oldest entries
if ptr + batch_size > self.queue_size:
self.queue[:, ptr:] = keys[:self.queue_size - ptr].T
self.queue[:, :batch_size - (self.queue_size - ptr)] = keys[self.queue_size - ptr:].T
else:
self.queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.queue_size
self.queue_ptr[0] = ptr
def forward(self, x_q, x_k):
"""
Args:
x_q: Query images
x_k: Key images (different augmentation of same images)
"""
batch_size = x_q.size(0)
# Compute query features
q = self.projector_q(self.encoder_q(x_q))
q = F.normalize(q, dim=1)
# Compute key features (no gradient)
with torch.no_grad():
self._momentum_update()
k = self.projector_k(self.encoder_k(x_k))
k = F.normalize(k, dim=1)
# Positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# Negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
# Logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1) / self.temperature
# Labels: positive is at index 0
labels = torch.zeros(batch_size, dtype=torch.long, device=x_q.device)
# Update queue
self._dequeue_and_enqueue(k)
return F.cross_entropy(logits, labels)
Non-Contrastive Methods
BYOL (Bootstrap Your Own Latent)
BYOL proves you don’t need negative samples at all!Copy
class BYOL(nn.Module):
"""
Bootstrap Your Own Latent.
No negative samples needed - learns by predicting target network.
"""
def __init__(
self,
backbone: nn.Module,
projection_dim: int = 256,
hidden_dim: int = 4096,
moving_average_decay: float = 0.99
):
super().__init__()
self.moving_average_decay = moving_average_decay
# Online network
self.online_encoder = backbone
self.online_projector = self._make_projector(backbone.feature_dim, hidden_dim, projection_dim)
self.predictor = self._make_projector(projection_dim, hidden_dim, projection_dim)
# Target network (momentum updated)
self.target_encoder = self._copy_encoder(backbone)
self.target_projector = self._make_projector(backbone.feature_dim, hidden_dim, projection_dim)
# Freeze target network
for param in self.target_encoder.parameters():
param.requires_grad = False
for param in self.target_projector.parameters():
param.requires_grad = False
def _make_projector(self, in_dim, hidden_dim, out_dim):
return nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, out_dim)
)
def _copy_encoder(self, encoder):
import copy
return copy.deepcopy(encoder)
@torch.no_grad()
def update_target_network(self):
"""EMA update of target network."""
tau = self.moving_average_decay
for online, target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
target.data = tau * target.data + (1 - tau) * online.data
for online, target in zip(self.online_projector.parameters(), self.target_projector.parameters()):
target.data = tau * target.data + (1 - tau) * online.data
def forward(self, x1, x2):
"""
Args:
x1, x2: Two augmented views of the same images
"""
# Online network forward
online_feat1 = self.online_encoder(x1)
online_proj1 = self.online_projector(online_feat1)
online_pred1 = self.predictor(online_proj1)
online_feat2 = self.online_encoder(x2)
online_proj2 = self.online_projector(online_feat2)
online_pred2 = self.predictor(online_proj2)
# Target network forward (no gradient)
with torch.no_grad():
target_proj1 = self.target_projector(self.target_encoder(x1))
target_proj2 = self.target_projector(self.target_encoder(x2))
# Stop gradient
target_proj1 = target_proj1.detach()
target_proj2 = target_proj2.detach()
# Compute loss: predict one view from the other
loss1 = self._loss_fn(online_pred1, target_proj2)
loss2 = self._loss_fn(online_pred2, target_proj1)
return (loss1 + loss2) / 2
def _loss_fn(self, x, y):
"""Normalized MSE loss."""
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return 2 - 2 * (x * y).sum(dim=-1).mean()
SimSiam (Simple Siamese)
Copy
class SimSiam(nn.Module):
"""
Simple Siamese networks.
Even simpler than BYOL - no momentum encoder, just stop-gradient.
"""
def __init__(
self,
backbone: nn.Module,
projection_dim: int = 2048,
prediction_dim: int = 512
):
super().__init__()
self.encoder = backbone
# Projection MLP
self.projector = nn.Sequential(
nn.Linear(backbone.feature_dim, projection_dim),
nn.BatchNorm1d(projection_dim),
nn.ReLU(),
nn.Linear(projection_dim, projection_dim),
nn.BatchNorm1d(projection_dim),
nn.ReLU(),
nn.Linear(projection_dim, projection_dim),
nn.BatchNorm1d(projection_dim)
)
# Prediction MLP
self.predictor = nn.Sequential(
nn.Linear(projection_dim, prediction_dim),
nn.BatchNorm1d(prediction_dim),
nn.ReLU(),
nn.Linear(prediction_dim, projection_dim)
)
def forward(self, x1, x2):
"""Forward pass with two views."""
# Compute projections
z1 = self.projector(self.encoder(x1))
z2 = self.projector(self.encoder(x2))
# Compute predictions
p1 = self.predictor(z1)
p2 = self.predictor(z2)
# Compute loss with stop-gradient
loss1 = self._negative_cosine_similarity(p1, z2.detach())
loss2 = self._negative_cosine_similarity(p2, z1.detach())
return (loss1 + loss2) / 2
def _negative_cosine_similarity(self, p, z):
"""Negative cosine similarity."""
p = F.normalize(p, dim=1)
z = F.normalize(z, dim=1)
return -(p * z).sum(dim=1).mean()
Masked Modeling
Masked Autoencoders (MAE)
Copy
class MaskedAutoencoder(nn.Module):
"""
Masked Autoencoder for Vision (MAE).
Mask random patches and reconstruct them.
"""
def __init__(
self,
image_size: int = 224,
patch_size: int = 16,
embed_dim: int = 768,
encoder_depth: int = 12,
encoder_heads: int = 12,
decoder_embed_dim: int = 512,
decoder_depth: int = 8,
decoder_heads: int = 16,
mask_ratio: float = 0.75
):
super().__init__()
self.patch_size = patch_size
self.mask_ratio = mask_ratio
num_patches = (image_size // patch_size) ** 2
# Patch embedding
self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim) * 0.02)
# Encoder (processes visible patches only)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=encoder_heads,
dim_feedforward=embed_dim * 4,
batch_first=True
),
num_layers=encoder_depth
)
# Decoder
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(torch.randn(1, num_patches, decoder_embed_dim) * 0.02)
self.decoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=decoder_embed_dim,
nhead=decoder_heads,
dim_feedforward=decoder_embed_dim * 4,
batch_first=True
),
num_layers=decoder_depth
)
# Prediction head (reconstruct pixels)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * 3)
def random_masking(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Randomly mask patches.
Args:
x: [batch, num_patches, embed_dim]
Returns:
x_masked: visible patches
mask: binary mask (1 = masked)
ids_restore: indices to restore original order
"""
N, L, D = x.shape
len_keep = int(L * (1 - self.mask_ratio))
# Random noise for shuffling
noise = torch.rand(N, L, device=x.device)
# Sort noise to get shuffle indices
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# Keep first len_keep patches (after shuffling)
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, D))
# Generate binary mask: 0 = keep, 1 = masked
mask = torch.ones(N, L, device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward(self, images: torch.Tensor):
"""
Args:
images: [batch, 3, H, W]
Returns:
loss: Reconstruction loss on masked patches
pred: Predicted pixel values
mask: Binary mask
"""
# Patchify
patches = self.patch_embed(images) # [B, embed_dim, H/P, W/P]
patches = patches.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
# Add positional embedding
patches = patches + self.pos_embed
# Random masking
patches_masked, mask, ids_restore = self.random_masking(patches)
# Encode visible patches
latent = self.encoder(patches_masked)
# Project to decoder dimension
latent = self.decoder_embed(latent)
# Append mask tokens
N, _, D = latent.shape
num_patches = mask.shape[1]
mask_tokens = self.mask_token.expand(N, num_patches - latent.shape[1], -1)
# Unshuffle: put mask tokens in correct positions
x_ = torch.cat([latent, mask_tokens], dim=1)
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).expand(-1, -1, D))
# Add decoder positional embedding
x_ = x_ + self.decoder_pos_embed
# Decode
x_ = self.decoder(x_)
# Predict pixels
pred = self.decoder_pred(x_) # [B, num_patches, patch_size^2 * 3]
# Compute loss only on masked patches
target = self._patchify(images)
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # Mean over patch pixels
loss = (loss * mask).sum() / mask.sum() # Mean over masked patches
return loss, pred, mask
def _patchify(self, images: torch.Tensor) -> torch.Tensor:
"""Convert images to patches."""
p = self.patch_size
B, C, H, W = images.shape
h, w = H // p, W // p
x = images.reshape(B, C, h, p, w, p)
x = x.permute(0, 2, 4, 3, 5, 1) # [B, h, w, p, p, C]
x = x.reshape(B, h * w, p * p * C)
return x
# Test MAE
mae = MaskedAutoencoder(image_size=224, patch_size=16, embed_dim=768)
images = torch.randn(4, 3, 224, 224)
loss, pred, mask = mae(images)
print(f"MAE Loss: {loss.item():.4f}")
print(f"Mask ratio: {mask.float().mean().item():.2f}")
BERT-style Masked Language Modeling
Copy
class MaskedLanguageModel(nn.Module):
"""BERT-style masked language modeling for text."""
def __init__(
self,
vocab_size: int,
embed_dim: int = 768,
num_layers: int = 12,
num_heads: int = 12,
max_seq_len: int = 512,
mask_token_id: int = 103, # [MASK] token
mask_ratio: float = 0.15
):
super().__init__()
self.mask_token_id = mask_token_id
self.mask_ratio = mask_ratio
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.pos_embedding = nn.Embedding(max_seq_len, embed_dim)
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=embed_dim * 4,
batch_first=True
),
num_layers=num_layers
)
self.output = nn.Linear(embed_dim, vocab_size)
def mask_tokens(self, input_ids: torch.Tensor):
"""Apply BERT-style masking."""
labels = input_ids.clone()
# Probability matrix for masking
probability_matrix = torch.full(input_ids.shape, self.mask_ratio)
# Create masked indices
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # Only compute loss on masked tokens
# 80% -> [MASK], 10% -> random, 10% -> original
indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
input_ids[indices_replaced] = self.mask_token_id
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.embedding.weight), input_ids.shape)
input_ids[indices_random] = random_words[indices_random]
return input_ids, labels
def forward(self, input_ids: torch.Tensor):
# Apply masking
masked_ids, labels = self.mask_tokens(input_ids.clone())
# Get embeddings
seq_len = masked_ids.size(1)
positions = torch.arange(seq_len, device=masked_ids.device).unsqueeze(0)
x = self.embedding(masked_ids) + self.pos_embedding(positions)
# Transformer forward
x = self.transformer(x)
# Predict masked tokens
logits = self.output(x)
# Compute loss
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-100
)
return loss, logits
Evaluation & Downstream Tasks
Linear Probing
Copy
class LinearProbe(nn.Module):
"""Linear evaluation of learned representations."""
def __init__(self, encoder: nn.Module, num_classes: int):
super().__init__()
self.encoder = encoder
# Freeze encoder
for param in self.encoder.parameters():
param.requires_grad = False
self.classifier = nn.Linear(encoder.feature_dim, num_classes)
def forward(self, x):
with torch.no_grad():
features = self.encoder(x)
return self.classifier(features)
def linear_evaluation(pretrained_encoder, train_loader, val_loader, num_classes, epochs=100):
"""Evaluate representations with linear probing."""
probe = LinearProbe(pretrained_encoder, num_classes).cuda()
optimizer = torch.optim.Adam(probe.classifier.parameters(), lr=0.001)
best_acc = 0
for epoch in range(epochs):
# Train
probe.train()
for images, labels in train_loader:
images, labels = images.cuda(), labels.cuda()
optimizer.zero_grad()
outputs = probe(images)
loss = F.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
# Evaluate
probe.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.cuda(), labels.cuda()
outputs = probe(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
acc = 100 * correct / total
if acc > best_acc:
best_acc = acc
print(f"Epoch {epoch+1}: Accuracy = {acc:.2f}%")
return best_acc
Fine-tuning
Copy
def finetune_evaluation(pretrained_encoder, train_loader, val_loader, num_classes, epochs=100):
"""Fine-tune entire model on downstream task."""
model = nn.Sequential(
pretrained_encoder,
nn.Linear(pretrained_encoder.feature_dim, num_classes)
).cuda()
# Different learning rates for encoder and classifier
optimizer = torch.optim.AdamW([
{'params': pretrained_encoder.parameters(), 'lr': 1e-4},
{'params': model[-1].parameters(), 'lr': 1e-3}
])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
for epoch in range(epochs):
model.train()
for images, labels in train_loader:
images, labels = images.cuda(), labels.cuda()
optimizer.zero_grad()
outputs = model(images)
loss = F.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
scheduler.step()
# Final evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.cuda(), labels.cuda()
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return 100 * correct / total
Practical Considerations
Data Augmentation for SSL
Copy
class SSLAugmentations:
"""Strong augmentations for self-supervised learning."""
@staticmethod
def simclr_augmentation(size=224):
"""SimCLR augmentation strategy."""
return T.Compose([
T.RandomResizedCrop(size, scale=(0.2, 1.0)),
T.RandomHorizontalFlip(p=0.5),
T.RandomApply([
T.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2)
], p=0.8),
T.RandomGrayscale(p=0.2),
T.RandomApply([T.GaussianBlur(kernel_size=23)], p=0.5),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
@staticmethod
def byol_augmentation(size=224):
"""BYOL augmentation (asymmetric)."""
base = [
T.RandomResizedCrop(size, scale=(0.2, 1.0)),
T.RandomHorizontalFlip(p=0.5),
T.RandomApply([
T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)
], p=0.8),
T.RandomGrayscale(p=0.2),
]
view1 = T.Compose(base + [
T.RandomApply([T.GaussianBlur(kernel_size=23)], p=1.0), # Always blur
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
view2 = T.Compose(base + [
T.RandomApply([T.GaussianBlur(kernel_size=23)], p=0.1), # Rarely blur
T.RandomSolarize(threshold=128, p=0.2),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return view1, view2
Training Tips
Copy
def ssl_training_tips():
"""Key tips for successful self-supervised training."""
tips = """
╔════════════════════════════════════════════════════════════════╗
║ SELF-SUPERVISED LEARNING: BEST PRACTICES ║
╠════════════════════════════════════════════════════════════════╣
║ ║
║ 1. BATCH SIZE ║
║ • Contrastive methods (SimCLR, MoCo): Large batches (4096+)║
║ • Non-contrastive (BYOL, SimSiam): Smaller batches OK ║
║ • Use gradient accumulation if GPU memory limited ║
║ ║
║ 2. LEARNING RATE ║
║ • Base LR scales with batch size: lr = base_lr * batch/256 ║
║ • Use cosine schedule with warmup ║
║ • LARS optimizer for very large batches ║
║ ║
║ 3. AUGMENTATION ║
║ • Stronger augmentation = better representations ║
║ • Color jittering is crucial ║
║ • Multi-crop (DINO style) improves efficiency ║
║ ║
║ 4. ARCHITECTURE ║
║ • Projection head: 2-3 layer MLP with BN ║
║ • Predictor (BYOL/SimSiam): Crucial for preventing collapse║
║ • Larger models generally learn better representations ║
║ ║
║ 5. TRAINING DURATION ║
║ • SSL needs longer training than supervised (800+ epochs) ║
║ • Early stopping based on downstream performance ║
║ ║
║ 6. EVALUATION ║
║ • Linear probe: Freeze encoder, train linear classifier ║
║ • k-NN: Nearest neighbor classification (no training) ║
║ • Fine-tuning: Update entire model on downstream task ║
║ ║
╚════════════════════════════════════════════════════════════════╝
"""
print(tips)
ssl_training_tips()
Exercises
Exercise 1: Implement SwAV
Exercise 1: Implement SwAV
Implement Swapping Assignments between Views:
Copy
class SwAV(nn.Module):
def forward(self, views):
# 1. Compute features for all views
# 2. Compute cluster assignments (Sinkhorn-Knopp)
# 3. Swap: predict assignment of view1 from view2
Exercise 2: Add Multi-Crop
Exercise 2: Add Multi-Crop
Implement DINO-style multi-crop augmentation:
Copy
def multi_crop(image, n_global=2, n_local=6):
# Global crops: 224x224, high coverage
# Local crops: 96x96, low coverage
# Train with all crops simultaneously
Exercise 3: Compare Methods
Exercise 3: Compare Methods
Train SimCLR, BYOL, and MAE on CIFAR-10 and compare:
- Training time
- Linear probe accuracy
- Fine-tuning accuracy
- Representation quality (t-SNE)