Skip to main content

Documentation Index

Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt

Use this file to discover all available pages before exploring further.

Memory-Efficient Training

Memory-Efficient Training

Understanding GPU Memory

GPU memory during training is like a hotel with a fixed number of rooms: your model parameters check in, then their gradients need rooms, then the optimizer’s bookkeeping (momentum, variance) needs even more rooms, and finally every intermediate computation (activation) holds a room open until the backward pass checks it out. When the hotel is full, training crashes with the dreaded CUDA out of memory error. Understanding who occupies which rooms — and for how long — is the key to training larger models on the hardware you actually have. Where does memory go during training?
ComponentMemory UsageAnalogy
Model parameters4×params4 \times \text{params} (fp32)The permanent residents
Gradients4×params4 \times \text{params} (fp32)Their shadows (same size)
Optimizer states8×params8 \times \text{params} (Adam)Adam keeps two running averages per parameter
ActivationsDepends on batch size, model depthThe guests — they check in during forward, check out during backward
Temporary buffersVariableRoom service supplies
For a 7B parameter model:
  • Parameters: 28 GB
  • Gradients: 28 GB
  • Adam states: 56 GB
  • Total: ~112 GB (without activations!)
The single most impactful optimization for most teams is mixed precision training (FP16/BF16). It nearly halves memory for parameters and activations with minimal accuracy impact, and it is a one-line change in most frameworks. Start there before reaching for more complex techniques like gradient checkpointing or model parallelism.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from typing import Optional, Tuple, List, Callable
import gc
import functools

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Memory Profiling

class MemoryProfiler:
    """
    Profile GPU memory usage during training.
    """
    
    def __init__(self, device: torch.device = device):
        self.device = device
        self.snapshots = []
    
    def get_memory_stats(self) -> dict:
        """Get current memory statistics."""
        if not torch.cuda.is_available():
            return {'allocated': 0, 'cached': 0, 'max_allocated': 0}
        
        return {
            'allocated_mb': torch.cuda.memory_allocated(self.device) / 1e6,
            'cached_mb': torch.cuda.memory_reserved(self.device) / 1e6,
            'max_allocated_mb': torch.cuda.max_memory_allocated(self.device) / 1e6,
        }
    
    def snapshot(self, label: str = ""):
        """Take a memory snapshot."""
        stats = self.get_memory_stats()
        stats['label'] = label
        self.snapshots.append(stats)
        return stats
    
    def reset_peak_stats(self):
        """Reset peak memory tracking."""
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
    
    def print_summary(self):
        """Print memory usage summary."""
        print("\n" + "="*60)
        print("MEMORY USAGE SUMMARY")
        print("="*60)
        
        for snap in self.snapshots:
            print(f"{snap['label']:30} | "
                  f"Allocated: {snap['allocated_mb']:8.1f} MB | "
                  f"Peak: {snap['max_allocated_mb']:8.1f} MB")
    
    @staticmethod
    def estimate_model_memory(model: nn.Module) -> dict:
        """Estimate memory requirements for a model."""
        param_mem = sum(p.numel() * p.element_size() for p in model.parameters())
        buffer_mem = sum(b.numel() * b.element_size() for b in model.buffers())
        
        # Gradient memory (same as params for most cases)
        grad_mem = param_mem
        
        # Optimizer states (Adam: 2x for m and v)
        adam_mem = 2 * param_mem
        
        return {
            'parameters_mb': param_mem / 1e6,
            'buffers_mb': buffer_mem / 1e6,
            'gradients_mb': grad_mem / 1e6,
            'adam_states_mb': adam_mem / 1e6,
            'total_mb': (param_mem + buffer_mem + grad_mem + adam_mem) / 1e6
        }


# Memory-efficient model analysis
def analyze_memory_per_layer(model: nn.Module, input_shape: tuple):
    """Analyze memory usage per layer."""
    profiler = MemoryProfiler()
    
    if not torch.cuda.is_available():
        print("CUDA not available for memory analysis")
        return
    
    model = model.to(device)
    x = torch.randn(*input_shape).to(device)
    
    profiler.reset_peak_stats()
    profiler.snapshot("After model loading")
    
    # Forward with hooks
    activations = {}
    
    def save_activation(name):
        def hook(module, input, output):
            if isinstance(output, torch.Tensor):
                activations[name] = output.numel() * output.element_size() / 1e6
        return hook
    
    for name, module in model.named_modules():
        module.register_forward_hook(save_activation(name))
    
    # Forward pass
    with torch.no_grad():
        _ = model(x)
    
    profiler.snapshot("After forward pass")
    
    # Print activation memory
    print("\nActivation memory per layer:")
    for name, mem in sorted(activations.items(), key=lambda x: -x[1])[:10]:
        print(f"  {name}: {mem:.2f} MB")

Gradient Checkpointing

During the forward pass, PyTorch saves every intermediate activation because it needs them during the backward pass to compute gradients. For a 24-layer transformer, this means storing 24 layers worth of intermediate tensors simultaneously. Gradient checkpointing (also called activation checkpointing or rematerialization) offers a simple trade: do not save intermediate activations — instead, recompute them on-the-fly during the backward pass. You save memory at the cost of running parts of the forward pass twice. The math is elegant: for nn layers, standard training stores O(n)O(n) activations. With checkpointing every n\sqrt{n} layers, you store only O(n)O(\sqrt{n}) activations and recompute the rest, at the cost of about 33% extra compute. For a 24-layer model, that is storing 5 checkpoints instead of 24 activation sets — a nearly 5x memory reduction for activations.
class GradientCheckpointing:
    """
    Trade compute for memory by recomputing activations.
    
    Instead of storing all activations for backward pass,
    we store only at checkpoint boundaries and recompute
    intermediate activations during backward.
    
    Memory savings: O(sqrt(n)) vs O(n) for n layers
    Compute overhead: ~30-40% more forward passes
    
    When to use: When activation memory (not parameter/optimizer memory)
    is your bottleneck. This is common for models with many layers
    (deep transformers), large spatial dimensions (segmentation), or
    long sequences (NLP with 4K+ tokens).
    """
    
    @staticmethod
    def checkpoint_sequential(
        functions: List[nn.Module],
        segments: int,
        input: torch.Tensor
    ) -> torch.Tensor:
        """
        Apply checkpointing to sequential modules.
        
        Args:
            functions: List of modules to apply sequentially
            segments: Number of checkpoint segments
            input: Input tensor
        """
        return checkpoint_sequential(functions, segments, input)
    
    @staticmethod
    def checkpoint_function(
        function: Callable,
        *args,
        use_reentrant: bool = False
    ) -> torch.Tensor:
        """
        Checkpoint a single function.
        
        Args:
            function: Function to checkpoint
            *args: Function arguments
            use_reentrant: Use reentrant version (False recommended)
        """
        return checkpoint(function, *args, use_reentrant=use_reentrant)


# Transformer with gradient checkpointing
class CheckpointedTransformerBlock(nn.Module):
    """
    Transformer block with optional gradient checkpointing.
    """
    
    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        use_checkpoint: bool = False
    ):
        super().__init__()
        
        self.use_checkpoint = use_checkpoint
        
        # Self-attention
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        
        # Feedforward
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model),
            nn.Dropout(dropout)
        )
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def _attention_block(self, x: torch.Tensor) -> torch.Tensor:
        """Self-attention sub-block."""
        attn_out, _ = self.self_attn(x, x, x)
        return self.dropout(attn_out)
    
    def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
        """Feedforward sub-block."""
        return self.ff(x)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_checkpoint and self.training:
            # Checkpoint attention
            attn_out = checkpoint(
                self._attention_block, 
                x,
                use_reentrant=False
            )
            x = self.norm1(x + attn_out)
            
            # Checkpoint feedforward
            ff_out = checkpoint(
                self._ff_block,
                x,
                use_reentrant=False
            )
            x = self.norm2(x + ff_out)
        else:
            # Standard forward
            x = self.norm1(x + self._attention_block(x))
            x = self.norm2(x + self._ff_block(x))
        
        return x


class CheckpointedTransformer(nn.Module):
    """
    Full transformer with gradient checkpointing.
    """
    
    def __init__(
        self,
        d_model: int = 512,
        nhead: int = 8,
        num_layers: int = 6,
        vocab_size: int = 10000,
        use_checkpoint: bool = True
    ):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        self.layers = nn.ModuleList([
            CheckpointedTransformerBlock(
                d_model, nhead, 
                use_checkpoint=use_checkpoint
            )
            for _ in range(num_layers)
        ])
        
        self.output = nn.Linear(d_model, vocab_size)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embedding(x)
        
        for layer in self.layers:
            x = layer(x)
        
        return self.output(x)


# Compare memory with/without checkpointing
def compare_checkpointing_memory():
    """Compare memory usage with and without checkpointing."""
    
    if not torch.cuda.is_available():
        print("CUDA not available")
        return
    
    batch_size, seq_len = 32, 512
    d_model, num_layers = 512, 12
    
    # Without checkpointing
    model_no_ckpt = CheckpointedTransformer(
        d_model=d_model, num_layers=num_layers,
        use_checkpoint=False
    ).to(device)
    
    torch.cuda.reset_peak_memory_stats()
    x = torch.randint(0, 10000, (batch_size, seq_len)).to(device)
    
    out = model_no_ckpt(x)
    loss = out.sum()
    loss.backward()
    
    mem_no_ckpt = torch.cuda.max_memory_allocated() / 1e9
    
    del model_no_ckpt, out, loss
    torch.cuda.empty_cache()
    
    # With checkpointing
    model_ckpt = CheckpointedTransformer(
        d_model=d_model, num_layers=num_layers,
        use_checkpoint=True
    ).to(device)
    
    torch.cuda.reset_peak_memory_stats()
    
    out = model_ckpt(x)
    loss = out.sum()
    loss.backward()
    
    mem_ckpt = torch.cuda.max_memory_allocated() / 1e9
    
    print(f"Without checkpointing: {mem_no_ckpt:.2f} GB")
    print(f"With checkpointing: {mem_ckpt:.2f} GB")
    print(f"Memory saved: {(1 - mem_ckpt/mem_no_ckpt)*100:.1f}%")

Mixed Precision Training

Mixed precision training is the single highest-impact memory optimization for most practitioners. The idea: use 16-bit floating point (FP16 or BF16) for the computationally intensive operations (matrix multiplies, convolutions) while keeping 32-bit precision where it matters (loss computation, gradient accumulation, optimizer states). Modern GPUs have dedicated hardware (Tensor Cores) that run FP16 matrix multiplies at 2-8x the speed of FP32, so you get both memory and speed benefits. The tricky part is avoiding numerical issues. FP16 has a very narrow dynamic range — gradients can underflow to zero (too small to represent) or overflow to infinity. The solution is gradient scaling: multiply the loss by a large number before backward, then divide the gradients by the same number before the optimizer step. This shifts the gradient values into FP16’s representable range. BFloat16 (BF16) avoids this entirely because it has the same exponent range as FP32, at the cost of less precision in the mantissa.
Do not use FP16 mixed precision for fine-tuning large language models. BF16 is strongly preferred because FP16’s limited range causes gradient underflow on very small learning rates (1e-5 range) that are typical for LLM fine-tuning. If your GPU supports BF16 (Ampere or newer), always prefer it.
class MixedPrecisionTraining:
    """
    Train with mixed FP16/FP32 precision.
    
    Benefits:
    - 2x memory reduction for activations and model weights
    - 2-8x faster matrix multiplications on Tensor Core GPUs
    - Minimal accuracy loss with proper gradient scaling
    
    Key components:
    - Automatic casting (autocast): PyTorch decides which ops run in FP16
    - Gradient scaling (GradScaler): prevents FP16 gradient underflow
    
    Typical setup: 3 lines of code change (autocast context, scale loss,
    scaler.step instead of optimizer.step).
    """
    
    def __init__(self, enabled: bool = True):
        self.enabled = enabled
        self.scaler = torch.amp.GradScaler('cuda') if enabled else None
    
    def training_step(
        self,
        model: nn.Module,
        batch: tuple,
        optimizer: optim.Optimizer,
        criterion: nn.Module
    ) -> float:
        """
        Perform one training step with mixed precision.
        """
        inputs, targets = batch
        
        optimizer.zero_grad()
        
        if self.enabled:
            # Mixed precision forward
            with torch.amp.autocast('cuda'):
                outputs = model(inputs)
                loss = criterion(outputs, targets)
            
            # Scaled backward
            self.scaler.scale(loss).backward()
            
            # Unscale gradients for clipping
            self.scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            # Optimizer step with scaling
            self.scaler.step(optimizer)
            self.scaler.update()
        else:
            # Standard precision
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        
        return loss.item()


# BFloat16 for better stability
class BFloat16Training:
    """
    BFloat16 training (available on Ampere+ GPUs).
    
    Advantages over FP16:
    - Same exponent range as FP32
    - No gradient scaling needed
    - Better numerical stability
    """
    
    @staticmethod
    def is_available() -> bool:
        """Check if BF16 is available."""
        if not torch.cuda.is_available():
            return False
        return torch.cuda.get_device_capability()[0] >= 8
    
    @staticmethod
    def training_step(
        model: nn.Module,
        batch: tuple,
        optimizer: optim.Optimizer,
        criterion: nn.Module
    ) -> float:
        """Training step with BF16."""
        inputs, targets = batch
        
        optimizer.zero_grad()
        
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        
        # No scaling needed for BF16!
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        return loss.item()


# Precision comparison
def demonstrate_precision_formats():
    """Show memory usage for different precisions."""
    
    num_params = 1_000_000  # 1M parameters
    
    formats = {
        'float32': (4, "Standard precision"),
        'float16': (2, "Half precision"),
        'bfloat16': (2, "Brain float"),
        'int8': (1, "Quantized"),
    }
    
    print("Memory per million parameters:")
    print("-" * 45)
    for name, (bytes_per, desc) in formats.items():
        mem = num_params * bytes_per / 1e6
        print(f"{name:10} ({desc:20}): {mem:.1f} MB")

demonstrate_precision_formats()

CPU Offloading

CPU offloading exploits the fact that your machine typically has 10-50x more CPU RAM than GPU memory. The idea is simple: keep only what the GPU needs right now on the GPU, and store everything else on the CPU (or even NVMe). This is the same principle behind ZeRO-Offload in DeepSpeed and FSDP’s CPU offloading in PyTorch. The trade-off is straightforward: you trade training speed (because of PCIe data transfers) for the ability to train models that otherwise would not fit at all.
CPU offloading is most effective for optimizer states (which are only needed during the parameter update step, not during forward/backward). Offloading activations is less effective because they are accessed frequently during the backward pass, making the PCIe bandwidth a bottleneck. If you are using DeepSpeed, start with ZeRO Stage 2 + CPU offloading of optimizer states before trying full activation offloading.
class CPUOffloadOptimizer:
    """
    Offload optimizer states to CPU to save GPU memory.
    
    Trade-off: Slower training (CPU<->GPU transfers over PCIe) but can
    train models 2-4x larger than what fits in GPU memory alone.
    
    When to use: When your model fits on the GPU for forward/backward
    but OOMs during the optimizer step (because Adam doubles the memory
    footprint with its momentum and variance states).
    """
    
    def __init__(
        self,
        model: nn.Module,
        lr: float = 1e-4,
        offload: bool = True
    ):
        self.model = model
        self.offload = offload
        
        if offload:
            # Keep parameters on GPU but states on CPU
            self.param_groups = [
                {
                    'params': list(model.parameters()),
                    'lr': lr
                }
            ]
            
            # Initialize optimizer states on CPU
            self.state = {}
            for p in model.parameters():
                if p.requires_grad:
                    # Adam states
                    self.state[p] = {
                        'm': torch.zeros_like(p.data, device='cpu'),
                        'v': torch.zeros_like(p.data, device='cpu'),
                        't': 0
                    }
        else:
            self.optimizer = optim.AdamW(model.parameters(), lr=lr)
    
    def zero_grad(self):
        """Zero gradients."""
        if self.offload:
            for p in self.model.parameters():
                if p.grad is not None:
                    p.grad.zero_()
        else:
            self.optimizer.zero_grad()
    
    def step(self):
        """Perform optimizer step."""
        if not self.offload:
            self.optimizer.step()
            return
        
        lr = self.param_groups[0]['lr']
        beta1, beta2 = 0.9, 0.999
        eps = 1e-8
        
        for p in self.model.parameters():
            if not p.requires_grad or p.grad is None:
                continue
            
            state = self.state[p]
            state['t'] += 1
            t = state['t']
            
            # Move gradient to CPU
            grad_cpu = p.grad.data.cpu()
            
            # Update moments on CPU
            state['m'].mul_(beta1).add_(grad_cpu, alpha=1 - beta1)
            state['v'].mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2)
            
            # Bias correction
            m_hat = state['m'] / (1 - beta1 ** t)
            v_hat = state['v'] / (1 - beta2 ** t)
            
            # Compute update on CPU
            update = m_hat / (v_hat.sqrt() + eps)
            
            # Apply update to GPU parameter
            p.data.add_(update.to(p.device), alpha=-lr)


class GradientOffloading:
    """
    Offload gradients to CPU during backward pass.
    
    Useful for very large models where even gradients don't fit.
    """
    
    def __init__(self, model: nn.Module):
        self.model = model
        self.cpu_gradients = {}
        self._register_hooks()
    
    def _register_hooks(self):
        """Register hooks to offload gradients."""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.register_post_accumulate_grad_hook(
                    self._make_hook(name)
                )
    
    def _make_hook(self, name: str):
        def hook(param):
            if param.grad is not None:
                # Move gradient to CPU
                self.cpu_gradients[name] = param.grad.cpu()
                # Free GPU memory
                param.grad = None
        return hook
    
    def restore_gradients(self):
        """Move gradients back to GPU for optimizer."""
        for name, param in self.model.named_parameters():
            if name in self.cpu_gradients:
                param.grad = self.cpu_gradients[name].to(param.device)

Activation Recomputation Strategies

class SelectiveCheckpointing:
    """
    Smart checkpointing strategies.
    
    Not all layers benefit equally from checkpointing.
    Focus on high-memory layers.
    """
    
    @staticmethod
    def should_checkpoint(module: nn.Module) -> bool:
        """
        Decide if a module should be checkpointed.
        
        Rules:
        - Checkpoint attention (quadratic memory)
        - Checkpoint large FFN layers
        - Don't checkpoint small layers (overhead not worth it)
        """
        if isinstance(module, nn.MultiheadAttention):
            return True
        
        # Check for large linear layers
        if isinstance(module, nn.Linear):
            return module.in_features * module.out_features > 1e6
        
        return False
    
    @staticmethod
    def estimate_activation_memory(
        module: nn.Module,
        input_shape: tuple
    ) -> float:
        """
        Estimate activation memory for a module.
        
        Returns memory in MB.
        """
        # This is a simplified estimate
        batch_size = input_shape[0]
        
        if isinstance(module, nn.Linear):
            # Output activations
            return batch_size * module.out_features * 4 / 1e6
        
        if isinstance(module, nn.MultiheadAttention):
            # Attention weights (quadratic in sequence length)
            seq_len = input_shape[1]
            num_heads = module.num_heads
            return batch_size * num_heads * seq_len * seq_len * 4 / 1e6
        
        return 0.0


class LayerWiseCheckpointing(nn.Module):
    """
    Apply checkpointing at layer granularity.
    """
    
    def __init__(
        self,
        layers: nn.ModuleList,
        checkpoint_ratio: float = 0.5
    ):
        super().__init__()
        self.layers = layers
        
        # Checkpoint every other layer (or based on ratio)
        self.checkpoint_mask = [
            i % int(1 / checkpoint_ratio) == 0
            for i in range(len(layers))
        ]
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for i, layer in enumerate(self.layers):
            if self.checkpoint_mask[i] and self.training:
                x = checkpoint(layer, x, use_reentrant=False)
            else:
                x = layer(x)
        
        return x

Memory-Efficient Attention

Attention is often the single largest memory consumer in transformer models. Standard self-attention materializes an N×NN \times N attention matrix, where NN is the sequence length. For a 4096-token sequence with 32 attention heads in float16, that single matrix consumes 32×4096×4096×232 \times 4096 \times 4096 \times 2 bytes = 1 GB. Double the sequence length and the memory quadruples. The methods below tackle this bottleneck from different angles.
If you are using PyTorch 2.0+ or later, use torch.nn.functional.scaled_dot_product_attention with is_causal=True or attn_mask. PyTorch will automatically select the most efficient backend (FlashAttention, memory-efficient attention, or math fallback) based on your hardware, input shapes, and whether you need dropout. This is almost always the right choice — only implement custom attention when you need something the fused kernel does not support.
class MemoryEfficientAttention(nn.Module):
    """
    Memory-efficient attention implementations.
    
    Standard attention: O(n^2) memory for the full attention matrix.
    Chunked attention: O(n * chunk_size) memory -- process queries in chunks,
    computing attention scores against all keys/values for each chunk.
    Flash Attention: O(n) memory -- never materializes the full matrix at all,
    using a tiling algorithm that stays in GPU SRAM.
    """
    
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        chunk_size: int = 1024
    ):
        super().__init__()
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.chunk_size = chunk_size
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        
        # Compute Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, H, N, D)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Chunked attention to reduce peak memory
        output = self.chunked_attention(q, k, v)
        
        output = output.transpose(1, 2).reshape(B, N, C)
        return self.out_proj(output)
    
    def chunked_attention(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute attention in chunks.
        
        Instead of computing full NxN attention matrix,
        process in chunks to reduce memory.
        """
        B, H, N, D = q.shape
        scale = D ** -0.5
        
        output = torch.zeros_like(v)
        
        for i in range(0, N, self.chunk_size):
            end_i = min(i + self.chunk_size, N)
            q_chunk = q[:, :, i:end_i]
            
            # Compute attention for this query chunk
            attn = torch.matmul(q_chunk, k.transpose(-2, -1)) * scale
            attn = attn.softmax(dim=-1)
            
            output[:, :, i:end_i] = torch.matmul(attn, v)
        
        return output


class FlashAttention(nn.Module):
    """
    Flash Attention wrapper (requires flash-attn package).
    
    Key innovations:
    - Tiling to fit in SRAM
    - No materialization of attention matrix
    - IO-aware algorithm design
    
    Memory: O(n) instead of O(n²)
    Speed: 2-4x faster than standard attention
    """
    
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
        super().__init__()
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.dropout = dropout
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        self._check_flash_available()
    
    def _check_flash_available(self):
        """Check if Flash Attention is available."""
        try:
            from flash_attn import flash_attn_func
            self.flash_attn = flash_attn_func
            self.use_flash = True
        except ImportError:
            print("Flash Attention not available. Using standard attention.")
            self.use_flash = False
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        
        if self.use_flash:
            # Flash attention format: (B, N, H, D)
            q = qkv[:, :, 0].transpose(1, 2)
            k = qkv[:, :, 1].transpose(1, 2)
            v = qkv[:, :, 2].transpose(1, 2)
            
            out = self.flash_attn(q, k, v, dropout_p=self.dropout if self.training else 0.0)
            out = out.transpose(1, 2).reshape(B, N, C)
        else:
            # Standard attention
            qkv = qkv.permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]
            
            scale = self.head_dim ** -0.5
            attn = (q @ k.transpose(-2, -1)) * scale
            attn = attn.softmax(dim=-1)
            
            out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        
        return self.out_proj(out)

Efficient Batch Processing

class GradientAccumulation:
    """
    Accumulate gradients over multiple mini-batches.
    
    Simulates larger batch sizes without memory increase.
    """
    
    def __init__(
        self,
        model: nn.Module,
        optimizer: optim.Optimizer,
        accumulation_steps: int = 4
    ):
        self.model = model
        self.optimizer = optimizer
        self.accumulation_steps = accumulation_steps
        self.current_step = 0
    
    def backward(self, loss: torch.Tensor):
        """
        Backward with gradient accumulation.
        """
        # Scale loss by accumulation steps
        scaled_loss = loss / self.accumulation_steps
        scaled_loss.backward()
        
        self.current_step += 1
        
        if self.current_step >= self.accumulation_steps:
            # Perform optimizer step
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.current_step = 0
            return True  # Step was performed
        
        return False  # Step not performed yet


class MicroBatching:
    """
    Split batches into micro-batches for forward pass.
    
    Useful when single batch doesn't fit but you need
    batch normalization statistics from full batch.
    """
    
    @staticmethod
    def micro_batch_forward(
        model: nn.Module,
        batch: torch.Tensor,
        micro_batch_size: int
    ) -> torch.Tensor:
        """
        Forward in micro-batches.
        """
        outputs = []
        
        for i in range(0, batch.size(0), micro_batch_size):
            micro_batch = batch[i:i + micro_batch_size]
            
            with torch.no_grad() if not model.training else torch.enable_grad():
                output = model(micro_batch)
            
            outputs.append(output)
        
        return torch.cat(outputs, dim=0)


class InPlaceOperations:
    """
    In-place operations to reduce memory allocations.
    
    Use with caution - can break gradient computation!
    """
    
    @staticmethod
    def inplace_relu(x: torch.Tensor) -> torch.Tensor:
        """In-place ReLU."""
        return torch.relu_(x)
    
    @staticmethod
    def inplace_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """In-place addition."""
        return x.add_(y)
    
    @staticmethod
    def check_safe_inplace(tensor: torch.Tensor) -> bool:
        """
        Check if in-place operation is safe.
        
        Not safe if:
        - Tensor requires grad and is a leaf
        - Tensor is used elsewhere in computation graph
        """
        if tensor.requires_grad and tensor.is_leaf:
            return False
        return True

Memory-Efficient Data Loading

class MemoryEfficientDataLoader:
    """
    Reduce data loading memory overhead.
    """
    
    @staticmethod
    def create_efficient_loader(
        dataset,
        batch_size: int,
        num_workers: int = 4,
        pin_memory: bool = True
    ):
        """
        Create memory-efficient data loader.
        
        Tips:
        - pin_memory for faster GPU transfer
        - Appropriate num_workers (usually 4 per GPU)
        - prefetch_factor controls memory usage
        """
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=pin_memory,
            # Limit prefetching to reduce memory
            prefetch_factor=2 if num_workers > 0 else None,
            # Don't copy tensors
            persistent_workers=True if num_workers > 0 else False,
        )
    
    @staticmethod
    def use_memory_mapping(file_path: str):
        """
        Use memory-mapped files for large datasets.
        
        Data stays on disk, loaded on demand.
        """
        return np.memmap(file_path, dtype='float32', mode='r')


class LowMemoryDataset(torch.utils.data.Dataset):
    """
    Dataset that loads data on-demand.
    """
    
    def __init__(self, file_paths: List[str]):
        self.file_paths = file_paths
    
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        # Load only when needed
        data = torch.load(self.file_paths[idx])
        return data

Memory Debugging Tools

class MemoryDebugger:
    """
    Tools for debugging memory issues.
    """
    
    @staticmethod
    def find_memory_leaks():
        """Find potential memory leaks."""
        if not torch.cuda.is_available():
            return []
        
        gc.collect()
        torch.cuda.empty_cache()
        
        leaks = []
        
        for obj in gc.get_objects():
            if isinstance(obj, torch.Tensor):
                if obj.is_cuda:
                    leaks.append({
                        'shape': obj.shape,
                        'dtype': obj.dtype,
                        'size_mb': obj.numel() * obj.element_size() / 1e6,
                        'requires_grad': obj.requires_grad
                    })
        
        return sorted(leaks, key=lambda x: -x['size_mb'])
    
    @staticmethod
    def memory_summary():
        """Print detailed memory summary."""
        if not torch.cuda.is_available():
            print("CUDA not available")
            return
        
        print(torch.cuda.memory_summary())
    
    @staticmethod
    def clear_memory():
        """Aggressively clear GPU memory."""
        gc.collect()
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()


# Usage demonstration
def memory_optimization_demo():
    """Demonstrate memory optimization techniques."""
    
    summary = """
    ╔════════════════════════════════════════════════════════════════╗
    ║           MEMORY OPTIMIZATION TECHNIQUES                        ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                 ║
    ║  TECHNIQUE              MEMORY SAVING    COMPUTE COST           ║
    ║  ─────────────────────────────────────────────────────────────  ║
    ║  Mixed Precision        ~50%             ~0% (often faster!)    ║
    ║  Gradient Checkpointing ~60-70%          ~30-40%                ║
    ║  Gradient Accumulation  Linear in steps  ~0%                    ║
    ║  Flash Attention        ~50% for attn    ~0% (often faster!)    ║
    ║  CPU Offloading         ~50-70%          ~20-50%                ║
    ║                                                                 ║
    ╠════════════════════════════════════════════════════════════════╣
    ║  RECOMMENDATION ORDER:                                          ║
    ║  1. Enable mixed precision (always)                             ║
    ║  2. Use Flash Attention (if available)                          ║
    ║  3. Add gradient checkpointing                                  ║
    ║  4. Gradient accumulation for larger effective batch            ║
    ║  5. CPU offloading as last resort                               ║
    ╚════════════════════════════════════════════════════════════════╝
    """
    print(summary)

memory_optimization_demo()

Exercises

Use the MemoryProfiler to analyze a model’s memory usage at each layer. Identify the biggest memory consumers and optimize them.
Create a system that automatically decides which layers to checkpoint based on their memory usage vs. compute cost.
Implement training that stays within a fixed memory budget by dynamically adjusting batch size and checkpointing.

What’s Next?

Quantization Deep Dive

Post-training and quantization-aware training

Knowledge Distillation

Transfer knowledge to smaller models