Training a neural network means searching for a good set of weights in a landscape with millions of dimensions. Imagine you are blindfolded on a mountain range, and you can only feel the slope directly beneath your feet. Your goal is to find the lowest valley — but you cannot see the terrain ahead, and there are countless valleys, ridges, and plateaus in every direction.Training neural networks = finding good minima in a non-convex loss landscape.Challenges:
Saddle points (more common than local minima in high dimensions — in a 100-million parameter model, a true local minimum requires the loss to curve upward in all 100 million directions simultaneously, which is astronomically unlikely)
Flat regions with vanishing gradients (you are on a plateau and cannot feel which direction is downhill)
Sharp vs flat minima (sharp minima generalize poorly because tiny weight perturbations send you uphill; flat minima are robust)
Ill-conditioned Hessians (the loss surface curves steeply in some directions and gently in others, making a single learning rate suboptimal)
A useful mental model: In high-dimensional spaces, most critical points are saddle points, not local minima. SGD with momentum naturally escapes saddle points because the momentum carries you past the “saddle” even when the gradient is momentarily zero. This is one reason why simple optimizers work surprisingly well in practice.
Problems: Oscillates in steep-walled ravines (zigzags back and forth instead of heading down the valley), painfully slow in flat regions (tiny gradients = tiny steps), and uses the same learning rate for all parameters (but some parameters need big updates while others need small ones).
import torchimport torch.nn as nnclass SGDMomentum: """SGD with momentum from scratch.""" def __init__(self, params, lr=0.01, momentum=0.9): self.params = list(params) self.lr = lr self.momentum = momentum self.velocities = [torch.zeros_like(p) for p in self.params] def step(self): for param, velocity in zip(self.params, self.velocities): if param.grad is None: continue velocity.mul_(self.momentum).add_(param.grad) param.data.add_(velocity, alpha=-self.lr) def zero_grad(self): for param in self.params: if param.grad is not None: param.grad.zero_()
Standard Adam applies L2 regularization incorrectly with adaptive learning rates. The problem: Adam divides gradients by the square root of the second moment (essentially adapting the learning rate per parameter). When you add L2 regularization to the loss, the regularization gradient also gets divided by this adaptive term, meaning frequently-updated parameters get less regularization than rare ones. This is the opposite of what you want.AdamW decouples weight decay from the adaptive update, applying it directly to the weights:
# AdamW update rule -- two separate steps# Step 1: Normal Adam update (adaptive gradient descent)param.data -= self.lr * m_hat / (torch.sqrt(v_hat) + self.eps)# Step 2: Weight decay applied INDEPENDENTLY of gradient historyparam.data -= self.lr * weight_decay * param.data # Separate!
Always use AdamW over Adam for training modern deep networks. The decoupling was shown by Loshchilov and Hutter (2019) to improve generalization across almost every setting. If you see code using Adam with a weight_decay argument, that is the wrong formulation — switch to AdamW.
Pitfall — learning rate and weight decay interaction: When using AdamW, the effective regularization strength is lr * weight_decay. If you double the learning rate, you should halve the weight decay to maintain the same effective regularization. This coupling catches many people off guard when tuning hyperparameters. Some frameworks (like timm) use an absolute weight decay convention to avoid this confusion.
This is the de facto standard schedule for Transformers and modern deep learning. The intuition: at the beginning of training, the model’s weights are random, so gradients are noisy and unreliable. Starting with a high learning rate would send the model careening in random directions. Warmup starts with a tiny learning rate and linearly increases it, giving the optimizer time to accumulate reliable gradient statistics (the first and second moments in Adam) before taking large steps. After warmup, cosine decay gradually reduces the learning rate, which acts like a fine-tuning phase — large steps explore broadly, small steps refine the solution.
import mathclass WarmupCosineScheduler: """Warmup followed by cosine decay -- the standard for Transformers. Phase 1 (warmup): LR increases linearly from 0 to base_lr Phase 2 (decay): LR follows a cosine curve from base_lr to min_lr """ def __init__(self, optimizer, warmup_steps, total_steps, min_lr=0): self.optimizer = optimizer self.warmup_steps = warmup_steps self.total_steps = total_steps self.min_lr = min_lr self.base_lr = optimizer.param_groups[0]['lr'] self.step_count = 0 def step(self): self.step_count += 1 if self.step_count < self.warmup_steps: # Linear warmup: ramp from 0 to base_lr lr = self.base_lr * self.step_count / self.warmup_steps else: # Cosine decay: smoothly decrease from base_lr to min_lr progress = (self.step_count - self.warmup_steps) / (self.total_steps - self.warmup_steps) lr = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (1 + math.cos(math.pi * progress)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr
How long should warmup be? A common rule of thumb: 1-5% of total training steps for large models (LLMs typically use 2000 warmup steps), or 5-10 epochs for vision tasks. Too little warmup and you may get early training instability (loss spikes); too much warmup wastes compute on a suboptimally low learning rate. If you see loss spikes in the first few hundred steps, try increasing warmup.
class SAM: """Sharpness-Aware Minimization.""" def __init__(self, optimizer, rho=0.05): self.optimizer = optimizer self.rho = rho def first_step(self): # Compute perturbation direction grad_norm = torch.sqrt(sum( p.grad.norm() ** 2 for p in self.optimizer.param_groups[0]['params'] if p.grad is not None )) for p in self.optimizer.param_groups[0]['params']: if p.grad is None: continue e_w = self.rho * p.grad / (grad_norm + 1e-12) p.add_(e_w) # Perturb weights p.e_w = e_w def second_step(self): # Remove perturbation and update for p in self.optimizer.param_groups[0]['params']: if hasattr(p, 'e_w'): p.sub_(p.e_w) self.optimizer.step()