Here is the central tension of modern deep learning: the best models are enormous, but GPU memory and compute are finite. A 70B-parameter model in FP32 requires 280 GB just for the weights — that is more than 3x the memory of an A100 80GB GPU. And that is just the parameters; you also need memory for activations, gradients, and optimizer states (Adam stores two additional copies of every parameter).Modern models are huge:
GPT-3: 175B parameters (~700 GB in FP32, impossible on any single GPU)
LLaMA 70B: 70B parameters (~280 GB in FP32)
Vision models: trained on billions of images, requiring weeks of compute
How do we train these efficiently? The answer is a toolkit of techniques — mixed precision, gradient accumulation, checkpointing, and various parallelism strategies — that trade off between memory, speed, and implementation complexity. A senior ML engineer needs to know when to reach for each tool.
Think of it like packing for a trip. A single suitcase (one GPU) can only hold so much. Mixed precision is like using vacuum bags to compress your clothes (same content, half the space). Gradient accumulation is like making multiple trips to the car. Model parallelism is like using multiple suitcases. FSDP/ZeRO is like having each family member carry a different part of a modular tent — nobody carries the full thing, but you can assemble it when needed.
The single highest-impact optimization for most practitioners. The idea: most operations in a neural network do not need 32-bit floating point precision. Matrix multiplications, convolutions, and attention computations work just fine in 16-bit, which is both 2x smaller (less memory) and 2x faster (modern GPUs have dedicated 16-bit tensor cores that run at double the throughput).The “mixed” part means you keep a master copy of the weights in FP32 for the optimizer update (where precision matters for accumulating tiny gradient changes), but run the forward and backward passes in FP16/BF16.
import torchfrom torch.cuda.amp import autocast, GradScalermodel = Model().cuda()optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)# GradScaler prevents FP16 underflow by scaling the loss up# before backward, then scaling gradients back down before the step.# Only needed for FP16; BF16 has enough dynamic range to skip this.scaler = GradScaler()for batch in dataloader: x, y = batch x, y = x.cuda(), y.cuda() optimizer.zero_grad() # Forward pass in mixed precision -- PyTorch automatically # casts operations to FP16 where safe and keeps FP32 where needed # (e.g., softmax, layer norm, loss computation) with autocast(): output = model(x) loss = criterion(output, y) # Backward pass with loss scaling to prevent gradient underflow scaler.scale(loss).backward() scaler.step(optimizer) # Unscales gradients, then steps scaler.update() # Adjusts scale factor for next iteration
When you cannot fit the batch size you want into GPU memory, gradient accumulation lets you simulate it by splitting the batch across multiple forward-backward passes and accumulating the gradients before taking a single optimizer step. Mathematically, the result is identical to running the full batch — the gradients are the same, just computed in pieces.Think of it like filling a swimming pool with a garden hose instead of a fire hose. It takes longer per “batch” but uses far less water pressure (memory) at any given moment.
accumulation_steps = 4 # 4 mini-batches of 64 = effective batch of 256effective_batch_size = batch_size * accumulation_stepsoptimizer.zero_grad()for i, batch in enumerate(dataloader): x, y = batch with autocast(): output = model(x) # CRITICAL: divide loss by accumulation_steps so the accumulated # gradient is the same as if we had run the full batch at once. # Forgetting this division is a common bug that silently scales # your effective learning rate by accumulation_steps. loss = criterion(output, y) / accumulation_steps scaler.scale(loss).backward() # Gradients accumulate in .grad # Only step the optimizer every accumulation_steps mini-batches if (i + 1) % accumulation_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad() # Reset gradients for next accumulation cycle
Pitfall — BatchNorm with gradient accumulation: BatchNorm computes statistics per mini-batch, not per accumulated batch. With accumulation_steps=4, BatchNorm sees a batch of 64 even though your effective batch is 256. This means BN statistics are noisier than expected. If you rely on BatchNorm, either use SyncBatchNorm across accumulation steps (complex) or switch to GroupNorm/LayerNorm which are unaffected.
During the forward pass, PyTorch saves every intermediate activation for use during backpropagation. For a 100-layer Transformer, this means storing 100 full-size activation tensors — which can dwarf the model parameters in memory usage. Gradient checkpointing (also called activation checkpointing) trades compute for memory: instead of saving activations, it discards them during forward and recomputes them on-the-fly during backward.The analogy: instead of photocopying every page of a book as you read it (so you can refer back later), you just remember the page numbers and re-read the pages when you need them. Slower, but requires almost no shelf space.
from torch.utils.checkpoint import checkpointclass CheckpointedBlock(nn.Module): """Wraps a block with gradient checkpointing. During forward: runs the block normally but does NOT save activations. During backward: re-runs the forward pass to recompute activations on the fly. """ def __init__(self, block): super().__init__() self.block = block def forward(self, x): # use_reentrant=False is the recommended mode in modern PyTorch -- # it handles edge cases with multiple autograd threads correctly. return checkpoint(self.block, x, use_reentrant=False)# Apply to model -- typically checkpoint every Transformer blockmodel.layers = nn.ModuleList([ CheckpointedBlock(layer) for layer in model.layers])
Memory savings: ~2-3x at cost of ~30% slower training (one extra forward pass per block).
Selective checkpointing: You do not have to checkpoint every layer. In practice, checkpointing every 2nd or 3rd block gives most of the memory benefit (~60-70%) with less speed penalty (~10-15%). Profile your specific model to find the sweet spot. The layers with the largest activations (usually the attention layers in Transformers) benefit most from checkpointing.
A surprisingly common bottleneck: your GPU sits idle waiting for data. If your data loading is slower than your model’s forward-backward pass, you are data-bound — upgrading your GPU will not help at all. The fix is to overlap data loading with computation using multiple worker processes.
from torch.utils.data import DataLoaderdataloader = DataLoader( dataset, batch_size=64, shuffle=True, num_workers=8, # Parallel data loading -- use 4-8 workers per GPU. # Too many workers wastes CPU/RAM; too few starves the GPU. pin_memory=True, # Pre-allocates batches in page-locked (pinned) memory, # which allows async CPU-to-GPU transfer (2-3x faster). prefetch_factor=2, # Each worker prefetches 2 batches ahead. persistent_workers=True, # Keep worker processes alive between epochs -- # avoids the overhead of spawning new processes each epoch.)
Pitfall — num_workers on Windows: On Windows, multiprocessing uses spawn instead of fork, which is significantly slower to start workers. If you see long pauses at epoch boundaries, set persistent_workers=True (as above) and consider reducing num_workers to 4. Also, ensure your dataset code is picklable — lambdas in transforms will crash the data loader with a cryptic pickling error.