Skip to main content
Regularization Techniques

Regularization for Deep Networks

The Overfitting Problem

Deep networks have millions of parameters — they can memorize training data perfectly while failing on new examples. Regularization constrains the model, improving generalization.

Weight Decay (L2 Regularization)

Add penalty on weight magnitude to loss: Ltotal=Ltask+λ2iwi2\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}} + \frac{\lambda}{2} \sum_i w_i^2 Effect: Pushes weights toward zero, preventing extreme values.
import torch.optim as optim

# Apply weight decay in optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=0.01  # L2 penalty
)
AdamW decouples weight decay from gradient updates — use it over Adam with L2 reg.

Dropout

Randomly zero activations during training:
import torch
import torch.nn as nn

class DropoutFromScratch(nn.Module):
    """Dropout implementation from scratch."""
    
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p
    
    def forward(self, x):
        if self.training:
            mask = (torch.rand_like(x) > self.p).float()
            return x * mask / (1 - self.p)  # Scale to maintain expectation
        return x
Why it works: Forces network to learn redundant representations; acts like an ensemble.
Layer TypeTypical Dropout Rate
Fully connected0.3 - 0.5
After attention0.1 - 0.3
Embedding0.0 - 0.1

Data Augmentation

The most effective regularizer: artificially expand training set.
from torchvision import transforms

# Standard augmentation pipeline
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

Advanced Augmentations

# CutOut: Random rectangular mask
class Cutout:
    def __init__(self, size=16):
        self.size = size
    
    def __call__(self, img):
        h, w = img.shape[1:]
        y = torch.randint(h, (1,)).item()
        x = torch.randint(w, (1,)).item()
        
        y1 = max(0, y - self.size // 2)
        y2 = min(h, y + self.size // 2)
        x1 = max(0, x - self.size // 2)
        x2 = min(w, x + self.size // 2)
        
        img[:, y1:y2, x1:x2] = 0
        return img

# MixUp: Blend two samples
def mixup(x, y, alpha=0.2):
    lam = torch.distributions.Beta(alpha, alpha).sample()
    batch_size = x.size(0)
    index = torch.randperm(batch_size)
    
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    
    return mixed_x, y_a, y_b, lam

# Training with MixUp
x, y_a, y_b, lam = mixup(x, y)
loss = lam * criterion(model(x), y_a) + (1 - lam) * criterion(model(x), y_b)

Label Smoothing

Soften hard labels to prevent overconfidence: ysmooth=(1α)yhard+αKy_{\text{smooth}} = (1 - \alpha) \cdot y_{\text{hard}} + \frac{\alpha}{K}
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
    
    def forward(self, pred, target):
        n_classes = pred.size(-1)
        log_probs = torch.log_softmax(pred, dim=-1)
        
        # Smooth labels
        targets = torch.zeros_like(log_probs).scatter_(
            1, target.unsqueeze(1), 1
        )
        targets = (1 - self.smoothing) * targets + self.smoothing / n_classes
        
        loss = (-targets * log_probs).sum(dim=-1).mean()
        return loss

Early Stopping

Monitor validation loss; stop when it stops improving:
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.should_stop = False
    
    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            torch.save(model.state_dict(), 'best_model.pt')
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        return self.should_stop

# Usage
early_stopping = EarlyStopping(patience=10)
for epoch in range(max_epochs):
    train(...)
    val_loss = validate(...)
    if early_stopping(val_loss, model):
        print("Early stopping triggered!")
        break

Comparison of Regularization Techniques

TechniqueEffectWhen to Use
Weight DecayPenalize large weightsAlways (0.01-0.1)
DropoutRandom deactivationDense layers, attention
Data AugmentationExpand training dataAlways for vision
Label SmoothingSoften targetsClassification
Early StoppingPrevent overtrainingAlways
Stochastic DepthDrop whole layersVery deep networks

Exercises

Train a network with dropout rates 0, 0.1, 0.3, 0.5, 0.7. Plot train vs val accuracy for each.
Compare model performance with: no augmentation, basic flips, full augmentation pipeline.
Implement CutMix (rectangular patches from different images) and compare with MixUp.

What’s Next