Optimization Algorithms
The Optimization Landscape
Training neural networks = finding good minima in a non-convex loss landscape .
Challenges:
Saddle points (more common than local minima in high dimensions)
Flat regions with vanishing gradients
Sharp vs flat minima (generalization)
Ill-conditioned Hessians
Gradient Descent Variants
Vanilla SGD
w t + 1 = w t − η ∇ w L ( w t ) w_{t+1} = w_t - \eta \nabla_w \mathcal{L}(w_t) w t + 1 = w t − η ∇ w L ( w t )
# Vanilla SGD
for param in model.parameters():
param.data -= learning_rate * param.grad
Problems : Oscillates, slow in flat regions, same learning rate for all params.
SGD with Momentum
v t = μ v t − 1 + ∇ w L ( w t ) v_t = \mu v_{t-1} + \nabla_w \mathcal{L}(w_t) v t = μ v t − 1 + ∇ w L ( w t )
w t + 1 = w t − η v t w_{t+1} = w_t - \eta v_t w t + 1 = w t − η v t
import torch
import torch.nn as nn
class SGDMomentum :
"""SGD with momentum from scratch."""
def __init__ ( self , params , lr = 0.01 , momentum = 0.9 ):
self .params = list (params)
self .lr = lr
self .momentum = momentum
self .velocities = [torch.zeros_like(p) for p in self .params]
def step ( self ):
for param, velocity in zip ( self .params, self .velocities):
if param.grad is None :
continue
velocity.mul_( self .momentum).add_(param.grad)
param.data.add_(velocity, alpha =- self .lr)
def zero_grad ( self ):
for param in self .params:
if param.grad is not None :
param.grad.zero_()
Adaptive Learning Rates
RMSprop
Adapt learning rate per-parameter based on gradient history:
v t = β v t − 1 + ( 1 − β ) g t 2 v_t = \beta v_{t-1} + (1 - \beta) g_t^2 v t = β v t − 1 + ( 1 − β ) g t 2
w t + 1 = w t − η v t + ϵ g t w_{t+1} = w_t - \frac{\eta}{\sqrt{v_t + \epsilon}} g_t w t + 1 = w t − v t + ϵ η g t
Adam (Adaptive Moment Estimation)
Combines momentum and adaptive learning rates:
m t = β 1 m t − 1 + ( 1 − β 1 ) g t (first moment) m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \quad \text{(first moment)} m t = β 1 m t − 1 + ( 1 − β 1 ) g t (first moment)
v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 (second moment) v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \quad \text{(second moment)} v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 (second moment)
m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t (bias correction) \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \quad \text{(bias correction)} m ^ t = 1 − β 1 t m t , v ^ t = 1 − β 2 t v t (bias correction)
w t + 1 = w t − η v ^ t + ϵ m ^ t w_{t+1} = w_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t w t + 1 = w t − v ^ t + ϵ η m ^ t
class Adam :
"""Adam optimizer from scratch."""
def __init__ ( self , params , lr = 1e-3 , betas = ( 0.9 , 0.999 ), eps = 1e-8 ):
self .params = list (params)
self .lr = lr
self .beta1, self .beta2 = betas
self .eps = eps
self .t = 0
self .m = [torch.zeros_like(p) for p in self .params] # First moment
self .v = [torch.zeros_like(p) for p in self .params] # Second moment
def step ( self ):
self .t += 1
for i, param in enumerate ( self .params):
if param.grad is None :
continue
g = param.grad
# Update biased first moment estimate
self .m[i] = self .beta1 * self .m[i] + ( 1 - self .beta1) * g
# Update biased second raw moment estimate
self .v[i] = self .beta2 * self .v[i] + ( 1 - self .beta2) * g ** 2
# Compute bias-corrected estimates
m_hat = self .m[i] / ( 1 - self .beta1 ** self .t)
v_hat = self .v[i] / ( 1 - self .beta2 ** self .t)
# Update parameters
param.data -= self .lr * m_hat / (torch.sqrt(v_hat) + self .eps)
AdamW (Decoupled Weight Decay)
Standard Adam applies L2 regularization incorrectly with adaptive learning rates.
AdamW decouples weight decay:
# AdamW update rule
param.data -= self .lr * m_hat / (torch.sqrt(v_hat) + self .eps)
param.data -= self .lr * weight_decay * param.data # Separate!
Always use AdamW over Adam for training modern deep networks.
Learning Rate Schedules
Step Decay
from torch.optim.lr_scheduler import StepLR
scheduler = StepLR(optimizer, step_size = 30 , gamma = 0.1 )
# Decay LR by 0.1 every 30 epochs
Cosine Annealing
from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max = 100 )
# Cosine decay over 100 epochs
class WarmupCosineScheduler :
"""Warmup followed by cosine decay."""
def __init__ ( self , optimizer , warmup_steps , total_steps , min_lr = 0 ):
self .optimizer = optimizer
self .warmup_steps = warmup_steps
self .total_steps = total_steps
self .min_lr = min_lr
self .base_lr = optimizer.param_groups[ 0 ][ 'lr' ]
self .step_count = 0
def step ( self ):
self .step_count += 1
if self .step_count < self .warmup_steps:
# Linear warmup
lr = self .base_lr * self .step_count / self .warmup_steps
else :
# Cosine decay
progress = ( self .step_count - self .warmup_steps) / ( self .total_steps - self .warmup_steps)
lr = self .min_lr + 0.5 * ( self .base_lr - self .min_lr) * ( 1 + math.cos(math.pi * progress))
for param_group in self .optimizer.param_groups:
param_group[ 'lr' ] = lr
Optimizer Comparison
Optimizer Learning Rate Weight Decay Best For SGD + Momentum 0.1 - 0.01 1e-4 CNNs, well-tuned Adam 1e-3 - 1e-4 0 Quick prototyping AdamW 1e-4 - 5e-5 0.01 - 0.1 Transformers, general LAMB 1e-3 - 1e-2 0.01 Large batch training
Modern Optimizers
Lion (Evolved Sign Momentum)
# Simplified Lion update
sign_momentum = torch.sign(beta1 * m + ( 1 - beta1) * g)
param.data -= lr * (sign_momentum + weight_decay * param.data)
m = beta2 * m + ( 1 - beta2) * g
Sharpness-Aware Minimization (SAM)
Seeks flat minima by perturbing weights:
class SAM :
"""Sharpness-Aware Minimization."""
def __init__ ( self , optimizer , rho = 0.05 ):
self .optimizer = optimizer
self .rho = rho
def first_step ( self ):
# Compute perturbation direction
grad_norm = torch.sqrt( sum (
p.grad.norm() ** 2 for p in self .optimizer.param_groups[ 0 ][ 'params' ]
if p.grad is not None
))
for p in self .optimizer.param_groups[ 0 ][ 'params' ]:
if p.grad is None :
continue
e_w = self .rho * p.grad / (grad_norm + 1e-12 )
p.add_(e_w) # Perturb weights
p.e_w = e_w
def second_step ( self ):
# Remove perturbation and update
for p in self .optimizer.param_groups[ 0 ][ 'params' ]:
if hasattr (p, 'e_w' ):
p.sub_(p.e_w)
self .optimizer.step()
Practical Recipes
optimizer = torch.optim.AdamW(
model.parameters(),
lr = 1e-4 ,
weight_decay = 0.05 ,
betas = ( 0.9 , 0.999 )
)
scheduler = CosineAnnealingLR(optimizer, T_max = 300 )
Large Language Models
optimizer = torch.optim.AdamW(
model.parameters(),
lr = 3e-4 ,
weight_decay = 0.1 ,
betas = ( 0.9 , 0.95 )
)
# Warmup for first 2000 steps
scheduler = WarmupCosineScheduler(optimizer, warmup_steps = 2000 , total_steps = 100000 )
Exercises
Exercise 1: Optimizer Shootout
Train CIFAR-10 with SGD, Adam, and AdamW. Compare convergence speed and final accuracy.
Exercise 2: LR Schedule Impact
Compare constant LR vs step decay vs cosine annealing on the same model.
Exercise 3: Warmup Analysis
Train a transformer with and without warmup. Observe gradient norms and loss curves.
What’s Next