Training Strategies at Scale
The Scale Challenge
Modern models are huge:
- GPT-3: 175B parameters
- LLaMA 70B: 70B parameters
- Vision models: Billions of images
How do we train these efficiently?
Mixed Precision Training (FP16/BF16)
Use lower precision for faster computation while maintaining accuracy.
import torch
from torch.cuda.amp import autocast, GradScaler
model = Model().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler()
for batch in dataloader:
x, y = batch
x, y = x.cuda(), y.cuda()
optimizer.zero_grad()
# Forward pass in mixed precision
with autocast():
output = model(x)
loss = criterion(output, y)
# Backward pass with scaling
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Precision Comparison
| Type | Bits | Range | Use Case |
|---|
| FP32 | 32 | ±3.4e38 | Accumulation, critical ops |
| FP16 | 16 | ±65504 | Most compute, may need scaling |
| BF16 | 16 | ±3.4e38 | Modern GPUs, no scaling needed |
| FP8 | 8 | Limited | Inference, cutting edge |
BF16 is preferred on modern GPUs (A100, H100) — same range as FP32, no loss scaling needed.
Gradient Accumulation
Simulate larger batch sizes without more GPU memory:
accumulation_steps = 4
effective_batch_size = batch_size * accumulation_steps
optimizer.zero_grad()
for i, batch in enumerate(dataloader):
x, y = batch
with autocast():
output = model(x)
loss = criterion(output, y) / accumulation_steps # Scale loss
scaler.scale(loss).backward()
if (i + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Gradient Checkpointing
Trade compute for memory — recompute activations during backward:
from torch.utils.checkpoint import checkpoint
class CheckpointedBlock(nn.Module):
def __init__(self, block):
super().__init__()
self.block = block
def forward(self, x):
return checkpoint(self.block, x, use_reentrant=False)
# Apply to model
model.layers = nn.ModuleList([
CheckpointedBlock(layer) for layer in model.layers
])
Memory savings: ~2-3x at cost of ~30% slower training.
Data Parallel Training
DataParallel (Single Machine, Multi-GPU)
model = nn.DataParallel(model) # Simple but inefficient
output = model(input)
DistributedDataParallel (Preferred)
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def train(rank, world_size):
setup(rank, world_size)
model = Model().to(rank)
model = DDP(model, device_ids=[rank])
# Training loop (each GPU gets different data)
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
for epoch in range(epochs):
sampler.set_epoch(epoch) # Shuffle differently each epoch
for batch in dataloader:
# Normal training code
...
# Launch with:
# torchrun --nproc_per_node=4 train.py
Model Parallelism
Split model across GPUs when it doesn’t fit on one:
Pipeline Parallelism
# Split model into stages
class PipelinedModel(nn.Module):
def __init__(self):
super().__init__()
self.stage1 = nn.Sequential(...).to('cuda:0')
self.stage2 = nn.Sequential(...).to('cuda:1')
self.stage3 = nn.Sequential(...).to('cuda:2')
def forward(self, x):
x = self.stage1(x.to('cuda:0'))
x = self.stage2(x.to('cuda:1'))
x = self.stage3(x.to('cuda:2'))
return x
Tensor Parallelism
Split individual layers across GPUs (used in LLM training):
# Column-parallel linear layer
class ColumnParallelLinear(nn.Module):
def __init__(self, in_features, out_features, world_size):
super().__init__()
assert out_features % world_size == 0
self.local_out = out_features // world_size
self.weight = nn.Parameter(torch.randn(self.local_out, in_features))
def forward(self, x):
local_out = F.linear(x, self.weight)
# All-gather to combine outputs
return all_gather(local_out)
FSDP (Fully Sharded Data Parallel)
Shard model parameters, gradients, and optimizer states across GPUs:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
# Wrap model with FSDP
model = FSDP(
model,
auto_wrap_policy=transformer_auto_wrap_policy,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
),
)
| Parallelism | Memory Savings | Communication |
|---|
| DDP | None | AllReduce gradients |
| FSDP (ZeRO-3) | ~Nx (N=GPUs) | AllGather + ReduceScatter |
DeepSpeed Integration
import deepspeed
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
config={
"train_batch_size": 256,
"gradient_accumulation_steps": 4,
"fp16": {"enabled": True},
"zero_optimization": {
"stage": 3, # ZeRO Stage 3
"offload_param": {"device": "cpu"},
"offload_optimizer": {"device": "cpu"},
},
}
)
for batch in dataloader:
loss = model_engine(batch)
model_engine.backward(loss)
model_engine.step()
Efficient Data Loading
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset,
batch_size=64,
shuffle=True,
num_workers=8, # Parallel data loading
pin_memory=True, # Faster GPU transfer
prefetch_factor=2, # Prefetch batches
persistent_workers=True, # Keep workers alive
)
Training Recipe Summary
| Scale | Technique | Memory | Speed |
|---|
| 1 GPU | Mixed precision | 2x | 2x |
| 1 GPU | + Gradient checkpoint | 4x | 0.7x |
| 1 GPU | + Gradient accum | Simulates bigger batch | - |
| Multi-GPU | DDP | 1x | Nx |
| Multi-GPU | FSDP/ZeRO | Nx | ~Nx |
| Multi-node | + CPU offload | Even larger | Slower |
Exercises
Exercise 1: Mixed Precision
Implement mixed precision training for a CNN. Measure memory usage and speed improvement.
Exercise 2: Gradient Accumulation
Simulate a 256 batch size with 64 actual batch size using 4 accumulation steps.
Exercise 3: Multi-GPU Setup
Set up DistributedDataParallel training on 2 GPUs. Compare throughput with DataParallel.
What’s Next