Skip to main content

Documentation Index

Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt

Use this file to discover all available pages before exploring further.

Training at Scale

Training Strategies at Scale

The Scale Challenge

Here is the central tension of modern deep learning: the best models are enormous, but GPU memory and compute are finite. A 70B-parameter model in FP32 requires 280 GB just for the weights — that is more than 3x the memory of an A100 80GB GPU. And that is just the parameters; you also need memory for activations, gradients, and optimizer states (Adam stores two additional copies of every parameter). Modern models are huge:
  • GPT-3: 175B parameters (~700 GB in FP32, impossible on any single GPU)
  • LLaMA 70B: 70B parameters (~280 GB in FP32)
  • Vision models: trained on billions of images, requiring weeks of compute
How do we train these efficiently? The answer is a toolkit of techniques — mixed precision, gradient accumulation, checkpointing, and various parallelism strategies — that trade off between memory, speed, and implementation complexity. A senior ML engineer needs to know when to reach for each tool.
Think of it like packing for a trip. A single suitcase (one GPU) can only hold so much. Mixed precision is like using vacuum bags to compress your clothes (same content, half the space). Gradient accumulation is like making multiple trips to the car. Model parallelism is like using multiple suitcases. FSDP/ZeRO is like having each family member carry a different part of a modular tent — nobody carries the full thing, but you can assemble it when needed.

Mixed Precision Training (FP16/BF16)

The single highest-impact optimization for most practitioners. The idea: most operations in a neural network do not need 32-bit floating point precision. Matrix multiplications, convolutions, and attention computations work just fine in 16-bit, which is both 2x smaller (less memory) and 2x faster (modern GPUs have dedicated 16-bit tensor cores that run at double the throughput). The “mixed” part means you keep a master copy of the weights in FP32 for the optimizer update (where precision matters for accumulating tiny gradient changes), but run the forward and backward passes in FP16/BF16.
import torch
from torch.cuda.amp import autocast, GradScaler

model = Model().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# GradScaler prevents FP16 underflow by scaling the loss up
# before backward, then scaling gradients back down before the step.
# Only needed for FP16; BF16 has enough dynamic range to skip this.
scaler = GradScaler()

for batch in dataloader:
    x, y = batch
    x, y = x.cuda(), y.cuda()
    
    optimizer.zero_grad()
    
    # Forward pass in mixed precision -- PyTorch automatically
    # casts operations to FP16 where safe and keeps FP32 where needed
    # (e.g., softmax, layer norm, loss computation)
    with autocast():
        output = model(x)
        loss = criterion(output, y)
    
    # Backward pass with loss scaling to prevent gradient underflow
    scaler.scale(loss).backward()
    scaler.step(optimizer)  # Unscales gradients, then steps
    scaler.update()         # Adjusts scale factor for next iteration

Precision Comparison

TypeBitsRangeUse Case
FP3232±3.4e38Accumulation, critical ops
FP1616±65504Most compute, may need scaling
BF1616±3.4e38Modern GPUs, no scaling needed
FP88LimitedInference, cutting edge
BF16 is preferred on modern GPUs (A100, H100) — same range as FP32, no loss scaling needed.

Gradient Accumulation

When you cannot fit the batch size you want into GPU memory, gradient accumulation lets you simulate it by splitting the batch across multiple forward-backward passes and accumulating the gradients before taking a single optimizer step. Mathematically, the result is identical to running the full batch — the gradients are the same, just computed in pieces. Think of it like filling a swimming pool with a garden hose instead of a fire hose. It takes longer per “batch” but uses far less water pressure (memory) at any given moment.
accumulation_steps = 4  # 4 mini-batches of 64 = effective batch of 256
effective_batch_size = batch_size * accumulation_steps

optimizer.zero_grad()
for i, batch in enumerate(dataloader):
    x, y = batch
    
    with autocast():
        output = model(x)
        # CRITICAL: divide loss by accumulation_steps so the accumulated
        # gradient is the same as if we had run the full batch at once.
        # Forgetting this division is a common bug that silently scales
        # your effective learning rate by accumulation_steps.
        loss = criterion(output, y) / accumulation_steps
    
    scaler.scale(loss).backward()  # Gradients accumulate in .grad
    
    # Only step the optimizer every accumulation_steps mini-batches
    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()  # Reset gradients for next accumulation cycle
Pitfall — BatchNorm with gradient accumulation: BatchNorm computes statistics per mini-batch, not per accumulated batch. With accumulation_steps=4, BatchNorm sees a batch of 64 even though your effective batch is 256. This means BN statistics are noisier than expected. If you rely on BatchNorm, either use SyncBatchNorm across accumulation steps (complex) or switch to GroupNorm/LayerNorm which are unaffected.

Gradient Checkpointing

During the forward pass, PyTorch saves every intermediate activation for use during backpropagation. For a 100-layer Transformer, this means storing 100 full-size activation tensors — which can dwarf the model parameters in memory usage. Gradient checkpointing (also called activation checkpointing) trades compute for memory: instead of saving activations, it discards them during forward and recomputes them on-the-fly during backward. The analogy: instead of photocopying every page of a book as you read it (so you can refer back later), you just remember the page numbers and re-read the pages when you need them. Slower, but requires almost no shelf space.
from torch.utils.checkpoint import checkpoint

class CheckpointedBlock(nn.Module):
    """Wraps a block with gradient checkpointing.
    
    During forward: runs the block normally but does NOT save activations.
    During backward: re-runs the forward pass to recompute activations on the fly.
    """
    def __init__(self, block):
        super().__init__()
        self.block = block
    
    def forward(self, x):
        # use_reentrant=False is the recommended mode in modern PyTorch --
        # it handles edge cases with multiple autograd threads correctly.
        return checkpoint(self.block, x, use_reentrant=False)

# Apply to model -- typically checkpoint every Transformer block
model.layers = nn.ModuleList([
    CheckpointedBlock(layer) for layer in model.layers
])
Memory savings: ~2-3x at cost of ~30% slower training (one extra forward pass per block).
Selective checkpointing: You do not have to checkpoint every layer. In practice, checkpointing every 2nd or 3rd block gives most of the memory benefit (~60-70%) with less speed penalty (~10-15%). Profile your specific model to find the sweet spot. The layers with the largest activations (usually the attention layers in Transformers) benefit most from checkpointing.

Data Parallel Training

DataParallel (Single Machine, Multi-GPU)

model = nn.DataParallel(model)  # Simple but inefficient
output = model(input)

DistributedDataParallel (Preferred)

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def train(rank, world_size):
    setup(rank, world_size)
    
    model = Model().to(rank)
    model = DDP(model, device_ids=[rank])
    
    # Training loop (each GPU gets different data)
    sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
    
    for epoch in range(epochs):
        sampler.set_epoch(epoch)  # Shuffle differently each epoch
        for batch in dataloader:
            # Normal training code
            ...

# Launch with:
# torchrun --nproc_per_node=4 train.py

Model Parallelism

Split model across GPUs when it doesn’t fit on one:

Pipeline Parallelism

# Split model into stages
class PipelinedModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.stage1 = nn.Sequential(...).to('cuda:0')
        self.stage2 = nn.Sequential(...).to('cuda:1')
        self.stage3 = nn.Sequential(...).to('cuda:2')
    
    def forward(self, x):
        x = self.stage1(x.to('cuda:0'))
        x = self.stage2(x.to('cuda:1'))
        x = self.stage3(x.to('cuda:2'))
        return x

Tensor Parallelism

Split individual layers across GPUs (used in LLM training):
# Column-parallel linear layer
class ColumnParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size):
        super().__init__()
        assert out_features % world_size == 0
        self.local_out = out_features // world_size
        self.weight = nn.Parameter(torch.randn(self.local_out, in_features))
    
    def forward(self, x):
        local_out = F.linear(x, self.weight)
        # All-gather to combine outputs
        return all_gather(local_out)

FSDP (Fully Sharded Data Parallel)

Shard model parameters, gradients, and optimizer states across GPUs:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# Wrap model with FSDP
model = FSDP(
    model,
    auto_wrap_policy=transformer_auto_wrap_policy,
    mixed_precision=MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
    ),
)
ParallelismMemory SavingsCommunication
DDPNoneAllReduce gradients
FSDP (ZeRO-3)~Nx (N=GPUs)AllGather + ReduceScatter

DeepSpeed Integration

import deepspeed

model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    config={
        "train_batch_size": 256,
        "gradient_accumulation_steps": 4,
        "fp16": {"enabled": True},
        "zero_optimization": {
            "stage": 3,  # ZeRO Stage 3
            "offload_param": {"device": "cpu"},
            "offload_optimizer": {"device": "cpu"},
        },
    }
)

for batch in dataloader:
    loss = model_engine(batch)
    model_engine.backward(loss)
    model_engine.step()

Efficient Data Loading

A surprisingly common bottleneck: your GPU sits idle waiting for data. If your data loading is slower than your model’s forward-backward pass, you are data-bound — upgrading your GPU will not help at all. The fix is to overlap data loading with computation using multiple worker processes.
from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=8,           # Parallel data loading -- use 4-8 workers per GPU.
                             # Too many workers wastes CPU/RAM; too few starves the GPU.
    pin_memory=True,         # Pre-allocates batches in page-locked (pinned) memory,
                             # which allows async CPU-to-GPU transfer (2-3x faster).
    prefetch_factor=2,       # Each worker prefetches 2 batches ahead.
    persistent_workers=True, # Keep worker processes alive between epochs --
                             # avoids the overhead of spawning new processes each epoch.
)
Pitfall — num_workers on Windows: On Windows, multiprocessing uses spawn instead of fork, which is significantly slower to start workers. If you see long pauses at epoch boundaries, set persistent_workers=True (as above) and consider reducing num_workers to 4. Also, ensure your dataset code is picklable — lambdas in transforms will crash the data loader with a cryptic pickling error.

Training Recipe Summary

ScaleTechniqueMemorySpeed
1 GPUMixed precision2x2x
1 GPU+ Gradient checkpoint4x0.7x
1 GPU+ Gradient accumSimulates bigger batch-
Multi-GPUDDP1xNx
Multi-GPUFSDP/ZeRONx~Nx
Multi-node+ CPU offloadEven largerSlower

Exercises

Implement mixed precision training for a CNN. Measure memory usage and speed improvement.
Simulate a 256 batch size with 64 actual batch size using 4 accumulation steps.
Set up DistributedDataParallel training on 2 GPUs. Compare throughput with DataParallel.

What’s Next

Module 20: Transfer Learning & Fine-tuning

Leverage pretrained models effectively — feature extraction, fine-tuning, and adaptation.