> ## Documentation Index
> Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt
> Use this file to discover all available pages before exploring further.

# Training Strategies at Scale

> Mixed precision, gradient accumulation, distributed training, and efficient training

<Frame>
  <img src="https://mintlify.s3.us-west-1.amazonaws.com/devweeekends/images/courses/deep-learning-mastery/training-scale-concept.svg" alt="Training at Scale" />
</Frame>

# Training Strategies at Scale

## The Scale Challenge

Here is the central tension of modern deep learning: the best models are enormous, but GPU memory and compute are finite. A 70B-parameter model in FP32 requires 280 GB just for the weights -- that is more than 3x the memory of an A100 80GB GPU. And that is just the parameters; you also need memory for activations, gradients, and optimizer states (Adam stores two additional copies of every parameter).

Modern models are **huge**:

* GPT-3: 175B parameters (\~700 GB in FP32, impossible on any single GPU)
* LLaMA 70B: 70B parameters (\~280 GB in FP32)
* Vision models: trained on billions of images, requiring weeks of compute

How do we train these efficiently? The answer is a toolkit of techniques -- mixed precision, gradient accumulation, checkpointing, and various parallelism strategies -- that trade off between memory, speed, and implementation complexity. A senior ML engineer needs to know when to reach for each tool.

<Note>
  **Think of it like packing for a trip.** A single suitcase (one GPU) can only hold so much. Mixed precision is like using vacuum bags to compress your clothes (same content, half the space). Gradient accumulation is like making multiple trips to the car. Model parallelism is like using multiple suitcases. FSDP/ZeRO is like having each family member carry a different part of a modular tent -- nobody carries the full thing, but you can assemble it when needed.
</Note>

***

## Mixed Precision Training (FP16/BF16)

The single highest-impact optimization for most practitioners. The idea: most operations in a neural network do not need 32-bit floating point precision. Matrix multiplications, convolutions, and attention computations work just fine in 16-bit, which is both 2x smaller (less memory) and 2x faster (modern GPUs have dedicated 16-bit tensor cores that run at double the throughput).

The "mixed" part means you keep a master copy of the weights in FP32 for the optimizer update (where precision matters for accumulating tiny gradient changes), but run the forward and backward passes in FP16/BF16.

```python theme={null}
import torch
from torch.cuda.amp import autocast, GradScaler

model = Model().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# GradScaler prevents FP16 underflow by scaling the loss up
# before backward, then scaling gradients back down before the step.
# Only needed for FP16; BF16 has enough dynamic range to skip this.
scaler = GradScaler()

for batch in dataloader:
    x, y = batch
    x, y = x.cuda(), y.cuda()
    
    optimizer.zero_grad()
    
    # Forward pass in mixed precision -- PyTorch automatically
    # casts operations to FP16 where safe and keeps FP32 where needed
    # (e.g., softmax, layer norm, loss computation)
    with autocast():
        output = model(x)
        loss = criterion(output, y)
    
    # Backward pass with loss scaling to prevent gradient underflow
    scaler.scale(loss).backward()
    scaler.step(optimizer)  # Unscales gradients, then steps
    scaler.update()         # Adjusts scale factor for next iteration
```

### Precision Comparison

| Type | Bits | Range   | Use Case                       |
| ---- | ---- | ------- | ------------------------------ |
| FP32 | 32   | ±3.4e38 | Accumulation, critical ops     |
| FP16 | 16   | ±65504  | Most compute, may need scaling |
| BF16 | 16   | ±3.4e38 | Modern GPUs, no scaling needed |
| FP8  | 8    | Limited | Inference, cutting edge        |

<Tip>
  **BF16 is preferred** on modern GPUs (A100, H100) — same range as FP32, no loss scaling needed.
</Tip>

***

## Gradient Accumulation

When you cannot fit the batch size you want into GPU memory, gradient accumulation lets you simulate it by splitting the batch across multiple forward-backward passes and accumulating the gradients before taking a single optimizer step. Mathematically, the result is identical to running the full batch -- the gradients are the same, just computed in pieces.

Think of it like filling a swimming pool with a garden hose instead of a fire hose. It takes longer per "batch" but uses far less water pressure (memory) at any given moment.

```python theme={null}
accumulation_steps = 4  # 4 mini-batches of 64 = effective batch of 256
effective_batch_size = batch_size * accumulation_steps

optimizer.zero_grad()
for i, batch in enumerate(dataloader):
    x, y = batch
    
    with autocast():
        output = model(x)
        # CRITICAL: divide loss by accumulation_steps so the accumulated
        # gradient is the same as if we had run the full batch at once.
        # Forgetting this division is a common bug that silently scales
        # your effective learning rate by accumulation_steps.
        loss = criterion(output, y) / accumulation_steps
    
    scaler.scale(loss).backward()  # Gradients accumulate in .grad
    
    # Only step the optimizer every accumulation_steps mini-batches
    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()  # Reset gradients for next accumulation cycle
```

<Warning>
  **Pitfall -- BatchNorm with gradient accumulation:** BatchNorm computes statistics per mini-batch, not per accumulated batch. With accumulation\_steps=4, BatchNorm sees a batch of 64 even though your effective batch is 256. This means BN statistics are noisier than expected. If you rely on BatchNorm, either use SyncBatchNorm across accumulation steps (complex) or switch to GroupNorm/LayerNorm which are unaffected.
</Warning>

***

## Gradient Checkpointing

During the forward pass, PyTorch saves every intermediate activation for use during backpropagation. For a 100-layer Transformer, this means storing 100 full-size activation tensors -- which can dwarf the model parameters in memory usage. Gradient checkpointing (also called activation checkpointing) trades compute for memory: instead of saving activations, it discards them during forward and *recomputes* them on-the-fly during backward.

The analogy: instead of photocopying every page of a book as you read it (so you can refer back later), you just remember the page numbers and re-read the pages when you need them. Slower, but requires almost no shelf space.

```python theme={null}
from torch.utils.checkpoint import checkpoint

class CheckpointedBlock(nn.Module):
    """Wraps a block with gradient checkpointing.
    
    During forward: runs the block normally but does NOT save activations.
    During backward: re-runs the forward pass to recompute activations on the fly.
    """
    def __init__(self, block):
        super().__init__()
        self.block = block
    
    def forward(self, x):
        # use_reentrant=False is the recommended mode in modern PyTorch --
        # it handles edge cases with multiple autograd threads correctly.
        return checkpoint(self.block, x, use_reentrant=False)

# Apply to model -- typically checkpoint every Transformer block
model.layers = nn.ModuleList([
    CheckpointedBlock(layer) for layer in model.layers
])
```

**Memory savings**: \~2-3x at cost of \~30% slower training (one extra forward pass per block).

<Tip>
  **Selective checkpointing:** You do not have to checkpoint every layer. In practice, checkpointing every 2nd or 3rd block gives most of the memory benefit (\~60-70%) with less speed penalty (\~10-15%). Profile your specific model to find the sweet spot. The layers with the largest activations (usually the attention layers in Transformers) benefit most from checkpointing.
</Tip>

***

## Data Parallel Training

### DataParallel (Single Machine, Multi-GPU)

```python theme={null}
model = nn.DataParallel(model)  # Simple but inefficient
output = model(input)
```

### DistributedDataParallel (Preferred)

```python theme={null}
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def train(rank, world_size):
    setup(rank, world_size)
    
    model = Model().to(rank)
    model = DDP(model, device_ids=[rank])
    
    # Training loop (each GPU gets different data)
    sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
    
    for epoch in range(epochs):
        sampler.set_epoch(epoch)  # Shuffle differently each epoch
        for batch in dataloader:
            # Normal training code
            ...

# Launch with:
# torchrun --nproc_per_node=4 train.py
```

***

## Model Parallelism

Split model across GPUs when it doesn't fit on one:

### Pipeline Parallelism

```python theme={null}
# Split model into stages
class PipelinedModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.stage1 = nn.Sequential(...).to('cuda:0')
        self.stage2 = nn.Sequential(...).to('cuda:1')
        self.stage3 = nn.Sequential(...).to('cuda:2')
    
    def forward(self, x):
        x = self.stage1(x.to('cuda:0'))
        x = self.stage2(x.to('cuda:1'))
        x = self.stage3(x.to('cuda:2'))
        return x
```

### Tensor Parallelism

Split individual layers across GPUs (used in LLM training):

```python theme={null}
# Column-parallel linear layer
class ColumnParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size):
        super().__init__()
        assert out_features % world_size == 0
        self.local_out = out_features // world_size
        self.weight = nn.Parameter(torch.randn(self.local_out, in_features))
    
    def forward(self, x):
        local_out = F.linear(x, self.weight)
        # All-gather to combine outputs
        return all_gather(local_out)
```

***

## FSDP (Fully Sharded Data Parallel)

Shard model parameters, gradients, and optimizer states across GPUs:

```python theme={null}
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# Wrap model with FSDP
model = FSDP(
    model,
    auto_wrap_policy=transformer_auto_wrap_policy,
    mixed_precision=MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
    ),
)
```

| Parallelism   | Memory Savings | Communication             |
| ------------- | -------------- | ------------------------- |
| DDP           | None           | AllReduce gradients       |
| FSDP (ZeRO-3) | \~Nx (N=GPUs)  | AllGather + ReduceScatter |

***

## DeepSpeed Integration

```python theme={null}
import deepspeed

model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    config={
        "train_batch_size": 256,
        "gradient_accumulation_steps": 4,
        "fp16": {"enabled": True},
        "zero_optimization": {
            "stage": 3,  # ZeRO Stage 3
            "offload_param": {"device": "cpu"},
            "offload_optimizer": {"device": "cpu"},
        },
    }
)

for batch in dataloader:
    loss = model_engine(batch)
    model_engine.backward(loss)
    model_engine.step()
```

***

## Efficient Data Loading

A surprisingly common bottleneck: your GPU sits idle waiting for data. If your data loading is slower than your model's forward-backward pass, you are **data-bound** -- upgrading your GPU will not help at all. The fix is to overlap data loading with computation using multiple worker processes.

```python theme={null}
from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=8,           # Parallel data loading -- use 4-8 workers per GPU.
                             # Too many workers wastes CPU/RAM; too few starves the GPU.
    pin_memory=True,         # Pre-allocates batches in page-locked (pinned) memory,
                             # which allows async CPU-to-GPU transfer (2-3x faster).
    prefetch_factor=2,       # Each worker prefetches 2 batches ahead.
    persistent_workers=True, # Keep worker processes alive between epochs --
                             # avoids the overhead of spawning new processes each epoch.
)
```

<Warning>
  **Pitfall -- num\_workers on Windows:** On Windows, multiprocessing uses `spawn` instead of `fork`, which is significantly slower to start workers. If you see long pauses at epoch boundaries, set `persistent_workers=True` (as above) and consider reducing `num_workers` to 4. Also, ensure your dataset code is picklable -- lambdas in transforms will crash the data loader with a cryptic pickling error.
</Warning>

***

## Training Recipe Summary

| Scale      | Technique             | Memory                 | Speed  |
| ---------- | --------------------- | ---------------------- | ------ |
| 1 GPU      | Mixed precision       | 2x                     | 2x     |
| 1 GPU      | + Gradient checkpoint | 4x                     | 0.7x   |
| 1 GPU      | + Gradient accum      | Simulates bigger batch | -      |
| Multi-GPU  | DDP                   | 1x                     | Nx     |
| Multi-GPU  | FSDP/ZeRO             | Nx                     | \~Nx   |
| Multi-node | + CPU offload         | Even larger            | Slower |

***

## Exercises

<AccordionGroup>
  <Accordion title="Exercise 1: Mixed Precision">
    Implement mixed precision training for a CNN. Measure memory usage and speed improvement.
  </Accordion>

  <Accordion title="Exercise 2: Gradient Accumulation">
    Simulate a 256 batch size with 64 actual batch size using 4 accumulation steps.
  </Accordion>

  <Accordion title="Exercise 3: Multi-GPU Setup">
    Set up DistributedDataParallel training on 2 GPUs. Compare throughput with DataParallel.
  </Accordion>
</AccordionGroup>

***

## What's Next

<CardGroup cols={1}>
  <Card title="Module 20: Transfer Learning & Fine-tuning" icon="arrows-repeat" href="/courses/deep-learning-mastery/20-transfer-learning">
    Leverage pretrained models effectively — feature extraction, fine-tuning, and adaptation.
  </Card>
</CardGroup>
