Skip to main content
Optimization Algorithms

Optimization Algorithms

The Optimization Landscape

Training neural networks = finding good minima in a non-convex loss landscape. Challenges:
  • Saddle points (more common than local minima in high dimensions)
  • Flat regions with vanishing gradients
  • Sharp vs flat minima (generalization)
  • Ill-conditioned Hessians

Gradient Descent Variants

Vanilla SGD

wt+1=wtηwL(wt)w_{t+1} = w_t - \eta \nabla_w \mathcal{L}(w_t)
# Vanilla SGD
for param in model.parameters():
    param.data -= learning_rate * param.grad
Problems: Oscillates, slow in flat regions, same learning rate for all params.

SGD with Momentum

vt=μvt1+wL(wt)v_t = \mu v_{t-1} + \nabla_w \mathcal{L}(w_t) wt+1=wtηvtw_{t+1} = w_t - \eta v_t
import torch
import torch.nn as nn

class SGDMomentum:
    """SGD with momentum from scratch."""
    
    def __init__(self, params, lr=0.01, momentum=0.9):
        self.params = list(params)
        self.lr = lr
        self.momentum = momentum
        self.velocities = [torch.zeros_like(p) for p in self.params]
    
    def step(self):
        for param, velocity in zip(self.params, self.velocities):
            if param.grad is None:
                continue
            velocity.mul_(self.momentum).add_(param.grad)
            param.data.add_(velocity, alpha=-self.lr)
    
    def zero_grad(self):
        for param in self.params:
            if param.grad is not None:
                param.grad.zero_()

Adaptive Learning Rates

RMSprop

Adapt learning rate per-parameter based on gradient history: vt=βvt1+(1β)gt2v_t = \beta v_{t-1} + (1 - \beta) g_t^2 wt+1=wtηvt+ϵgtw_{t+1} = w_t - \frac{\eta}{\sqrt{v_t + \epsilon}} g_t

Adam (Adaptive Moment Estimation)

Combines momentum and adaptive learning rates: mt=β1mt1+(1β1)gt(first moment)m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \quad \text{(first moment)} vt=β2vt1+(1β2)gt2(second moment)v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \quad \text{(second moment)} m^t=mt1β1t,v^t=vt1β2t(bias correction)\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \quad \text{(bias correction)} wt+1=wtηv^t+ϵm^tw_{t+1} = w_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t
class Adam:
    """Adam optimizer from scratch."""
    
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
        self.params = list(params)
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.t = 0
        
        self.m = [torch.zeros_like(p) for p in self.params]  # First moment
        self.v = [torch.zeros_like(p) for p in self.params]  # Second moment
    
    def step(self):
        self.t += 1
        
        for i, param in enumerate(self.params):
            if param.grad is None:
                continue
            
            g = param.grad
            
            # Update biased first moment estimate
            self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * g
            # Update biased second raw moment estimate
            self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * g ** 2
            
            # Compute bias-corrected estimates
            m_hat = self.m[i] / (1 - self.beta1 ** self.t)
            v_hat = self.v[i] / (1 - self.beta2 ** self.t)
            
            # Update parameters
            param.data -= self.lr * m_hat / (torch.sqrt(v_hat) + self.eps)

AdamW (Decoupled Weight Decay)

Standard Adam applies L2 regularization incorrectly with adaptive learning rates. AdamW decouples weight decay:
# AdamW update rule
param.data -= self.lr * m_hat / (torch.sqrt(v_hat) + self.eps)
param.data -= self.lr * weight_decay * param.data  # Separate!
Always use AdamW over Adam for training modern deep networks.

Learning Rate Schedules

Step Decay

from torch.optim.lr_scheduler import StepLR

scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
# Decay LR by 0.1 every 30 epochs

Cosine Annealing

from torch.optim.lr_scheduler import CosineAnnealingLR

scheduler = CosineAnnealingLR(optimizer, T_max=100)
# Cosine decay over 100 epochs

Warmup + Cosine (Transformer Standard)

class WarmupCosineScheduler:
    """Warmup followed by cosine decay."""
    
    def __init__(self, optimizer, warmup_steps, total_steps, min_lr=0):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.min_lr = min_lr
        self.base_lr = optimizer.param_groups[0]['lr']
        self.step_count = 0
    
    def step(self):
        self.step_count += 1
        
        if self.step_count < self.warmup_steps:
            # Linear warmup
            lr = self.base_lr * self.step_count / self.warmup_steps
        else:
            # Cosine decay
            progress = (self.step_count - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            lr = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (1 + math.cos(math.pi * progress))
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

Optimizer Comparison

OptimizerLearning RateWeight DecayBest For
SGD + Momentum0.1 - 0.011e-4CNNs, well-tuned
Adam1e-3 - 1e-40Quick prototyping
AdamW1e-4 - 5e-50.01 - 0.1Transformers, general
LAMB1e-3 - 1e-20.01Large batch training

Modern Optimizers

Lion (Evolved Sign Momentum)

# Simplified Lion update
sign_momentum = torch.sign(beta1 * m + (1 - beta1) * g)
param.data -= lr * (sign_momentum + weight_decay * param.data)
m = beta2 * m + (1 - beta2) * g

Sharpness-Aware Minimization (SAM)

Seeks flat minima by perturbing weights:
class SAM:
    """Sharpness-Aware Minimization."""
    
    def __init__(self, optimizer, rho=0.05):
        self.optimizer = optimizer
        self.rho = rho
    
    def first_step(self):
        # Compute perturbation direction
        grad_norm = torch.sqrt(sum(
            p.grad.norm() ** 2 for p in self.optimizer.param_groups[0]['params']
            if p.grad is not None
        ))
        
        for p in self.optimizer.param_groups[0]['params']:
            if p.grad is None:
                continue
            e_w = self.rho * p.grad / (grad_norm + 1e-12)
            p.add_(e_w)  # Perturb weights
            p.e_w = e_w
    
    def second_step(self):
        # Remove perturbation and update
        for p in self.optimizer.param_groups[0]['params']:
            if hasattr(p, 'e_w'):
                p.sub_(p.e_w)
        self.optimizer.step()

Practical Recipes

Vision Transformers

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=0.05,
    betas=(0.9, 0.999)
)

scheduler = CosineAnnealingLR(optimizer, T_max=300)

Large Language Models

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=0.1,
    betas=(0.9, 0.95)
)

# Warmup for first 2000 steps
scheduler = WarmupCosineScheduler(optimizer, warmup_steps=2000, total_steps=100000)

Exercises

Train CIFAR-10 with SGD, Adam, and AdamW. Compare convergence speed and final accuracy.
Compare constant LR vs step decay vs cosine annealing on the same model.
Train a transformer with and without warmup. Observe gradient norms and loss curves.

What’s Next