Skip to main content
Training at Scale

Training Strategies at Scale

The Scale Challenge

Modern models are huge:
  • GPT-3: 175B parameters
  • LLaMA 70B: 70B parameters
  • Vision models: Billions of images
How do we train these efficiently?

Mixed Precision Training (FP16/BF16)

Use lower precision for faster computation while maintaining accuracy.
import torch
from torch.cuda.amp import autocast, GradScaler

model = Model().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler()

for batch in dataloader:
    x, y = batch
    x, y = x.cuda(), y.cuda()
    
    optimizer.zero_grad()
    
    # Forward pass in mixed precision
    with autocast():
        output = model(x)
        loss = criterion(output, y)
    
    # Backward pass with scaling
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Precision Comparison

TypeBitsRangeUse Case
FP3232±3.4e38Accumulation, critical ops
FP1616±65504Most compute, may need scaling
BF1616±3.4e38Modern GPUs, no scaling needed
FP88LimitedInference, cutting edge
BF16 is preferred on modern GPUs (A100, H100) — same range as FP32, no loss scaling needed.

Gradient Accumulation

Simulate larger batch sizes without more GPU memory:
accumulation_steps = 4
effective_batch_size = batch_size * accumulation_steps

optimizer.zero_grad()
for i, batch in enumerate(dataloader):
    x, y = batch
    
    with autocast():
        output = model(x)
        loss = criterion(output, y) / accumulation_steps  # Scale loss
    
    scaler.scale(loss).backward()
    
    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

Gradient Checkpointing

Trade compute for memory — recompute activations during backward:
from torch.utils.checkpoint import checkpoint

class CheckpointedBlock(nn.Module):
    def __init__(self, block):
        super().__init__()
        self.block = block
    
    def forward(self, x):
        return checkpoint(self.block, x, use_reentrant=False)

# Apply to model
model.layers = nn.ModuleList([
    CheckpointedBlock(layer) for layer in model.layers
])
Memory savings: ~2-3x at cost of ~30% slower training.

Data Parallel Training

DataParallel (Single Machine, Multi-GPU)

model = nn.DataParallel(model)  # Simple but inefficient
output = model(input)

DistributedDataParallel (Preferred)

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def train(rank, world_size):
    setup(rank, world_size)
    
    model = Model().to(rank)
    model = DDP(model, device_ids=[rank])
    
    # Training loop (each GPU gets different data)
    sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
    
    for epoch in range(epochs):
        sampler.set_epoch(epoch)  # Shuffle differently each epoch
        for batch in dataloader:
            # Normal training code
            ...

# Launch with:
# torchrun --nproc_per_node=4 train.py

Model Parallelism

Split model across GPUs when it doesn’t fit on one:

Pipeline Parallelism

# Split model into stages
class PipelinedModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.stage1 = nn.Sequential(...).to('cuda:0')
        self.stage2 = nn.Sequential(...).to('cuda:1')
        self.stage3 = nn.Sequential(...).to('cuda:2')
    
    def forward(self, x):
        x = self.stage1(x.to('cuda:0'))
        x = self.stage2(x.to('cuda:1'))
        x = self.stage3(x.to('cuda:2'))
        return x

Tensor Parallelism

Split individual layers across GPUs (used in LLM training):
# Column-parallel linear layer
class ColumnParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size):
        super().__init__()
        assert out_features % world_size == 0
        self.local_out = out_features // world_size
        self.weight = nn.Parameter(torch.randn(self.local_out, in_features))
    
    def forward(self, x):
        local_out = F.linear(x, self.weight)
        # All-gather to combine outputs
        return all_gather(local_out)

FSDP (Fully Sharded Data Parallel)

Shard model parameters, gradients, and optimizer states across GPUs:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# Wrap model with FSDP
model = FSDP(
    model,
    auto_wrap_policy=transformer_auto_wrap_policy,
    mixed_precision=MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
    ),
)
ParallelismMemory SavingsCommunication
DDPNoneAllReduce gradients
FSDP (ZeRO-3)~Nx (N=GPUs)AllGather + ReduceScatter

DeepSpeed Integration

import deepspeed

model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    config={
        "train_batch_size": 256,
        "gradient_accumulation_steps": 4,
        "fp16": {"enabled": True},
        "zero_optimization": {
            "stage": 3,  # ZeRO Stage 3
            "offload_param": {"device": "cpu"},
            "offload_optimizer": {"device": "cpu"},
        },
    }
)

for batch in dataloader:
    loss = model_engine(batch)
    model_engine.backward(loss)
    model_engine.step()

Efficient Data Loading

from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=8,  # Parallel data loading
    pin_memory=True,  # Faster GPU transfer
    prefetch_factor=2,  # Prefetch batches
    persistent_workers=True,  # Keep workers alive
)

Training Recipe Summary

ScaleTechniqueMemorySpeed
1 GPUMixed precision2x2x
1 GPU+ Gradient checkpoint4x0.7x
1 GPU+ Gradient accumSimulates bigger batch-
Multi-GPUDDP1xNx
Multi-GPUFSDP/ZeRONx~Nx
Multi-node+ CPU offloadEven largerSlower

Exercises

Implement mixed precision training for a CNN. Measure memory usage and speed improvement.
Simulate a 256 batch size with 64 actual batch size using 4 accumulation steps.
Set up DistributedDataParallel training on 2 GPUs. Compare throughput with DataParallel.

What’s Next