Skip to main content
Memory-Efficient Training

Memory-Efficient Training

Understanding GPU Memory

Where does memory go during training?
ComponentMemory Usage
Model parameters4×params4 \times \text{params} (fp32)
Gradients4×params4 \times \text{params} (fp32)
Optimizer states8×params8 \times \text{params} (Adam)
ActivationsDepends on batch size, model depth
Temporary buffersVariable
For a 7B parameter model:
  • Parameters: 28 GB
  • Gradients: 28 GB
  • Adam states: 56 GB
  • Total: ~112 GB (without activations!)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from typing import Optional, Tuple, List, Callable
import gc
import functools

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Memory Profiling

class MemoryProfiler:
    """
    Profile GPU memory usage during training.
    """
    
    def __init__(self, device: torch.device = device):
        self.device = device
        self.snapshots = []
    
    def get_memory_stats(self) -> dict:
        """Get current memory statistics."""
        if not torch.cuda.is_available():
            return {'allocated': 0, 'cached': 0, 'max_allocated': 0}
        
        return {
            'allocated_mb': torch.cuda.memory_allocated(self.device) / 1e6,
            'cached_mb': torch.cuda.memory_reserved(self.device) / 1e6,
            'max_allocated_mb': torch.cuda.max_memory_allocated(self.device) / 1e6,
        }
    
    def snapshot(self, label: str = ""):
        """Take a memory snapshot."""
        stats = self.get_memory_stats()
        stats['label'] = label
        self.snapshots.append(stats)
        return stats
    
    def reset_peak_stats(self):
        """Reset peak memory tracking."""
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
    
    def print_summary(self):
        """Print memory usage summary."""
        print("\n" + "="*60)
        print("MEMORY USAGE SUMMARY")
        print("="*60)
        
        for snap in self.snapshots:
            print(f"{snap['label']:30} | "
                  f"Allocated: {snap['allocated_mb']:8.1f} MB | "
                  f"Peak: {snap['max_allocated_mb']:8.1f} MB")
    
    @staticmethod
    def estimate_model_memory(model: nn.Module) -> dict:
        """Estimate memory requirements for a model."""
        param_mem = sum(p.numel() * p.element_size() for p in model.parameters())
        buffer_mem = sum(b.numel() * b.element_size() for b in model.buffers())
        
        # Gradient memory (same as params for most cases)
        grad_mem = param_mem
        
        # Optimizer states (Adam: 2x for m and v)
        adam_mem = 2 * param_mem
        
        return {
            'parameters_mb': param_mem / 1e6,
            'buffers_mb': buffer_mem / 1e6,
            'gradients_mb': grad_mem / 1e6,
            'adam_states_mb': adam_mem / 1e6,
            'total_mb': (param_mem + buffer_mem + grad_mem + adam_mem) / 1e6
        }


# Memory-efficient model analysis
def analyze_memory_per_layer(model: nn.Module, input_shape: tuple):
    """Analyze memory usage per layer."""
    profiler = MemoryProfiler()
    
    if not torch.cuda.is_available():
        print("CUDA not available for memory analysis")
        return
    
    model = model.to(device)
    x = torch.randn(*input_shape).to(device)
    
    profiler.reset_peak_stats()
    profiler.snapshot("After model loading")
    
    # Forward with hooks
    activations = {}
    
    def save_activation(name):
        def hook(module, input, output):
            if isinstance(output, torch.Tensor):
                activations[name] = output.numel() * output.element_size() / 1e6
        return hook
    
    for name, module in model.named_modules():
        module.register_forward_hook(save_activation(name))
    
    # Forward pass
    with torch.no_grad():
        _ = model(x)
    
    profiler.snapshot("After forward pass")
    
    # Print activation memory
    print("\nActivation memory per layer:")
    for name, mem in sorted(activations.items(), key=lambda x: -x[1])[:10]:
        print(f"  {name}: {mem:.2f} MB")

Gradient Checkpointing

class GradientCheckpointing:
    """
    Trade compute for memory by recomputing activations.
    
    Instead of storing all activations for backward pass,
    we store only at checkpoint boundaries and recompute
    intermediate activations during backward.
    
    Memory savings: O(sqrt(n)) vs O(n) for n layers
    Compute overhead: ~30-40% more forward passes
    """
    
    @staticmethod
    def checkpoint_sequential(
        functions: List[nn.Module],
        segments: int,
        input: torch.Tensor
    ) -> torch.Tensor:
        """
        Apply checkpointing to sequential modules.
        
        Args:
            functions: List of modules to apply sequentially
            segments: Number of checkpoint segments
            input: Input tensor
        """
        return checkpoint_sequential(functions, segments, input)
    
    @staticmethod
    def checkpoint_function(
        function: Callable,
        *args,
        use_reentrant: bool = False
    ) -> torch.Tensor:
        """
        Checkpoint a single function.
        
        Args:
            function: Function to checkpoint
            *args: Function arguments
            use_reentrant: Use reentrant version (False recommended)
        """
        return checkpoint(function, *args, use_reentrant=use_reentrant)


# Transformer with gradient checkpointing
class CheckpointedTransformerBlock(nn.Module):
    """
    Transformer block with optional gradient checkpointing.
    """
    
    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        use_checkpoint: bool = False
    ):
        super().__init__()
        
        self.use_checkpoint = use_checkpoint
        
        # Self-attention
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        
        # Feedforward
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model),
            nn.Dropout(dropout)
        )
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def _attention_block(self, x: torch.Tensor) -> torch.Tensor:
        """Self-attention sub-block."""
        attn_out, _ = self.self_attn(x, x, x)
        return self.dropout(attn_out)
    
    def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
        """Feedforward sub-block."""
        return self.ff(x)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_checkpoint and self.training:
            # Checkpoint attention
            attn_out = checkpoint(
                self._attention_block, 
                x,
                use_reentrant=False
            )
            x = self.norm1(x + attn_out)
            
            # Checkpoint feedforward
            ff_out = checkpoint(
                self._ff_block,
                x,
                use_reentrant=False
            )
            x = self.norm2(x + ff_out)
        else:
            # Standard forward
            x = self.norm1(x + self._attention_block(x))
            x = self.norm2(x + self._ff_block(x))
        
        return x


class CheckpointedTransformer(nn.Module):
    """
    Full transformer with gradient checkpointing.
    """
    
    def __init__(
        self,
        d_model: int = 512,
        nhead: int = 8,
        num_layers: int = 6,
        vocab_size: int = 10000,
        use_checkpoint: bool = True
    ):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        self.layers = nn.ModuleList([
            CheckpointedTransformerBlock(
                d_model, nhead, 
                use_checkpoint=use_checkpoint
            )
            for _ in range(num_layers)
        ])
        
        self.output = nn.Linear(d_model, vocab_size)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embedding(x)
        
        for layer in self.layers:
            x = layer(x)
        
        return self.output(x)


# Compare memory with/without checkpointing
def compare_checkpointing_memory():
    """Compare memory usage with and without checkpointing."""
    
    if not torch.cuda.is_available():
        print("CUDA not available")
        return
    
    batch_size, seq_len = 32, 512
    d_model, num_layers = 512, 12
    
    # Without checkpointing
    model_no_ckpt = CheckpointedTransformer(
        d_model=d_model, num_layers=num_layers,
        use_checkpoint=False
    ).to(device)
    
    torch.cuda.reset_peak_memory_stats()
    x = torch.randint(0, 10000, (batch_size, seq_len)).to(device)
    
    out = model_no_ckpt(x)
    loss = out.sum()
    loss.backward()
    
    mem_no_ckpt = torch.cuda.max_memory_allocated() / 1e9
    
    del model_no_ckpt, out, loss
    torch.cuda.empty_cache()
    
    # With checkpointing
    model_ckpt = CheckpointedTransformer(
        d_model=d_model, num_layers=num_layers,
        use_checkpoint=True
    ).to(device)
    
    torch.cuda.reset_peak_memory_stats()
    
    out = model_ckpt(x)
    loss = out.sum()
    loss.backward()
    
    mem_ckpt = torch.cuda.max_memory_allocated() / 1e9
    
    print(f"Without checkpointing: {mem_no_ckpt:.2f} GB")
    print(f"With checkpointing: {mem_ckpt:.2f} GB")
    print(f"Memory saved: {(1 - mem_ckpt/mem_no_ckpt)*100:.1f}%")

Mixed Precision Training

class MixedPrecisionTraining:
    """
    Train with mixed FP16/FP32 precision.
    
    Benefits:
    - 2x memory reduction for activations
    - Faster matrix multiplications on modern GPUs
    - Minimal accuracy loss with proper scaling
    
    Key components:
    - Automatic casting (autocast)
    - Gradient scaling (prevent underflow)
    """
    
    def __init__(self, enabled: bool = True):
        self.enabled = enabled
        self.scaler = torch.amp.GradScaler('cuda') if enabled else None
    
    def training_step(
        self,
        model: nn.Module,
        batch: tuple,
        optimizer: optim.Optimizer,
        criterion: nn.Module
    ) -> float:
        """
        Perform one training step with mixed precision.
        """
        inputs, targets = batch
        
        optimizer.zero_grad()
        
        if self.enabled:
            # Mixed precision forward
            with torch.amp.autocast('cuda'):
                outputs = model(inputs)
                loss = criterion(outputs, targets)
            
            # Scaled backward
            self.scaler.scale(loss).backward()
            
            # Unscale gradients for clipping
            self.scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            # Optimizer step with scaling
            self.scaler.step(optimizer)
            self.scaler.update()
        else:
            # Standard precision
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        
        return loss.item()


# BFloat16 for better stability
class BFloat16Training:
    """
    BFloat16 training (available on Ampere+ GPUs).
    
    Advantages over FP16:
    - Same exponent range as FP32
    - No gradient scaling needed
    - Better numerical stability
    """
    
    @staticmethod
    def is_available() -> bool:
        """Check if BF16 is available."""
        if not torch.cuda.is_available():
            return False
        return torch.cuda.get_device_capability()[0] >= 8
    
    @staticmethod
    def training_step(
        model: nn.Module,
        batch: tuple,
        optimizer: optim.Optimizer,
        criterion: nn.Module
    ) -> float:
        """Training step with BF16."""
        inputs, targets = batch
        
        optimizer.zero_grad()
        
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        
        # No scaling needed for BF16!
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        return loss.item()


# Precision comparison
def demonstrate_precision_formats():
    """Show memory usage for different precisions."""
    
    num_params = 1_000_000  # 1M parameters
    
    formats = {
        'float32': (4, "Standard precision"),
        'float16': (2, "Half precision"),
        'bfloat16': (2, "Brain float"),
        'int8': (1, "Quantized"),
    }
    
    print("Memory per million parameters:")
    print("-" * 45)
    for name, (bytes_per, desc) in formats.items():
        mem = num_params * bytes_per / 1e6
        print(f"{name:10} ({desc:20}): {mem:.1f} MB")

demonstrate_precision_formats()

CPU Offloading

class CPUOffloadOptimizer:
    """
    Offload optimizer states to CPU to save GPU memory.
    
    Trade-off: Slower training but can train larger models.
    """
    
    def __init__(
        self,
        model: nn.Module,
        lr: float = 1e-4,
        offload: bool = True
    ):
        self.model = model
        self.offload = offload
        
        if offload:
            # Keep parameters on GPU but states on CPU
            self.param_groups = [
                {
                    'params': list(model.parameters()),
                    'lr': lr
                }
            ]
            
            # Initialize optimizer states on CPU
            self.state = {}
            for p in model.parameters():
                if p.requires_grad:
                    # Adam states
                    self.state[p] = {
                        'm': torch.zeros_like(p.data, device='cpu'),
                        'v': torch.zeros_like(p.data, device='cpu'),
                        't': 0
                    }
        else:
            self.optimizer = optim.AdamW(model.parameters(), lr=lr)
    
    def zero_grad(self):
        """Zero gradients."""
        if self.offload:
            for p in self.model.parameters():
                if p.grad is not None:
                    p.grad.zero_()
        else:
            self.optimizer.zero_grad()
    
    def step(self):
        """Perform optimizer step."""
        if not self.offload:
            self.optimizer.step()
            return
        
        lr = self.param_groups[0]['lr']
        beta1, beta2 = 0.9, 0.999
        eps = 1e-8
        
        for p in self.model.parameters():
            if not p.requires_grad or p.grad is None:
                continue
            
            state = self.state[p]
            state['t'] += 1
            t = state['t']
            
            # Move gradient to CPU
            grad_cpu = p.grad.data.cpu()
            
            # Update moments on CPU
            state['m'].mul_(beta1).add_(grad_cpu, alpha=1 - beta1)
            state['v'].mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2)
            
            # Bias correction
            m_hat = state['m'] / (1 - beta1 ** t)
            v_hat = state['v'] / (1 - beta2 ** t)
            
            # Compute update on CPU
            update = m_hat / (v_hat.sqrt() + eps)
            
            # Apply update to GPU parameter
            p.data.add_(update.to(p.device), alpha=-lr)


class GradientOffloading:
    """
    Offload gradients to CPU during backward pass.
    
    Useful for very large models where even gradients don't fit.
    """
    
    def __init__(self, model: nn.Module):
        self.model = model
        self.cpu_gradients = {}
        self._register_hooks()
    
    def _register_hooks(self):
        """Register hooks to offload gradients."""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.register_post_accumulate_grad_hook(
                    self._make_hook(name)
                )
    
    def _make_hook(self, name: str):
        def hook(param):
            if param.grad is not None:
                # Move gradient to CPU
                self.cpu_gradients[name] = param.grad.cpu()
                # Free GPU memory
                param.grad = None
        return hook
    
    def restore_gradients(self):
        """Move gradients back to GPU for optimizer."""
        for name, param in self.model.named_parameters():
            if name in self.cpu_gradients:
                param.grad = self.cpu_gradients[name].to(param.device)

Activation Recomputation Strategies

class SelectiveCheckpointing:
    """
    Smart checkpointing strategies.
    
    Not all layers benefit equally from checkpointing.
    Focus on high-memory layers.
    """
    
    @staticmethod
    def should_checkpoint(module: nn.Module) -> bool:
        """
        Decide if a module should be checkpointed.
        
        Rules:
        - Checkpoint attention (quadratic memory)
        - Checkpoint large FFN layers
        - Don't checkpoint small layers (overhead not worth it)
        """
        if isinstance(module, nn.MultiheadAttention):
            return True
        
        # Check for large linear layers
        if isinstance(module, nn.Linear):
            return module.in_features * module.out_features > 1e6
        
        return False
    
    @staticmethod
    def estimate_activation_memory(
        module: nn.Module,
        input_shape: tuple
    ) -> float:
        """
        Estimate activation memory for a module.
        
        Returns memory in MB.
        """
        # This is a simplified estimate
        batch_size = input_shape[0]
        
        if isinstance(module, nn.Linear):
            # Output activations
            return batch_size * module.out_features * 4 / 1e6
        
        if isinstance(module, nn.MultiheadAttention):
            # Attention weights (quadratic in sequence length)
            seq_len = input_shape[1]
            num_heads = module.num_heads
            return batch_size * num_heads * seq_len * seq_len * 4 / 1e6
        
        return 0.0


class LayerWiseCheckpointing(nn.Module):
    """
    Apply checkpointing at layer granularity.
    """
    
    def __init__(
        self,
        layers: nn.ModuleList,
        checkpoint_ratio: float = 0.5
    ):
        super().__init__()
        self.layers = layers
        
        # Checkpoint every other layer (or based on ratio)
        self.checkpoint_mask = [
            i % int(1 / checkpoint_ratio) == 0
            for i in range(len(layers))
        ]
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for i, layer in enumerate(self.layers):
            if self.checkpoint_mask[i] and self.training:
                x = checkpoint(layer, x, use_reentrant=False)
            else:
                x = layer(x)
        
        return x

Memory-Efficient Attention

class MemoryEfficientAttention(nn.Module):
    """
    Memory-efficient attention implementations.
    
    Standard attention: O(n²) memory for attention matrix
    Efficient versions reduce this significantly.
    """
    
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        chunk_size: int = 1024
    ):
        super().__init__()
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.chunk_size = chunk_size
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        
        # Compute Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, H, N, D)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Chunked attention to reduce peak memory
        output = self.chunked_attention(q, k, v)
        
        output = output.transpose(1, 2).reshape(B, N, C)
        return self.out_proj(output)
    
    def chunked_attention(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute attention in chunks.
        
        Instead of computing full NxN attention matrix,
        process in chunks to reduce memory.
        """
        B, H, N, D = q.shape
        scale = D ** -0.5
        
        output = torch.zeros_like(v)
        
        for i in range(0, N, self.chunk_size):
            end_i = min(i + self.chunk_size, N)
            q_chunk = q[:, :, i:end_i]
            
            # Compute attention for this query chunk
            attn = torch.matmul(q_chunk, k.transpose(-2, -1)) * scale
            attn = attn.softmax(dim=-1)
            
            output[:, :, i:end_i] = torch.matmul(attn, v)
        
        return output


class FlashAttention(nn.Module):
    """
    Flash Attention wrapper (requires flash-attn package).
    
    Key innovations:
    - Tiling to fit in SRAM
    - No materialization of attention matrix
    - IO-aware algorithm design
    
    Memory: O(n) instead of O(n²)
    Speed: 2-4x faster than standard attention
    """
    
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
        super().__init__()
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.dropout = dropout
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        self._check_flash_available()
    
    def _check_flash_available(self):
        """Check if Flash Attention is available."""
        try:
            from flash_attn import flash_attn_func
            self.flash_attn = flash_attn_func
            self.use_flash = True
        except ImportError:
            print("Flash Attention not available. Using standard attention.")
            self.use_flash = False
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        
        if self.use_flash:
            # Flash attention format: (B, N, H, D)
            q = qkv[:, :, 0].transpose(1, 2)
            k = qkv[:, :, 1].transpose(1, 2)
            v = qkv[:, :, 2].transpose(1, 2)
            
            out = self.flash_attn(q, k, v, dropout_p=self.dropout if self.training else 0.0)
            out = out.transpose(1, 2).reshape(B, N, C)
        else:
            # Standard attention
            qkv = qkv.permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]
            
            scale = self.head_dim ** -0.5
            attn = (q @ k.transpose(-2, -1)) * scale
            attn = attn.softmax(dim=-1)
            
            out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        
        return self.out_proj(out)

Efficient Batch Processing

class GradientAccumulation:
    """
    Accumulate gradients over multiple mini-batches.
    
    Simulates larger batch sizes without memory increase.
    """
    
    def __init__(
        self,
        model: nn.Module,
        optimizer: optim.Optimizer,
        accumulation_steps: int = 4
    ):
        self.model = model
        self.optimizer = optimizer
        self.accumulation_steps = accumulation_steps
        self.current_step = 0
    
    def backward(self, loss: torch.Tensor):
        """
        Backward with gradient accumulation.
        """
        # Scale loss by accumulation steps
        scaled_loss = loss / self.accumulation_steps
        scaled_loss.backward()
        
        self.current_step += 1
        
        if self.current_step >= self.accumulation_steps:
            # Perform optimizer step
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.current_step = 0
            return True  # Step was performed
        
        return False  # Step not performed yet


class MicroBatching:
    """
    Split batches into micro-batches for forward pass.
    
    Useful when single batch doesn't fit but you need
    batch normalization statistics from full batch.
    """
    
    @staticmethod
    def micro_batch_forward(
        model: nn.Module,
        batch: torch.Tensor,
        micro_batch_size: int
    ) -> torch.Tensor:
        """
        Forward in micro-batches.
        """
        outputs = []
        
        for i in range(0, batch.size(0), micro_batch_size):
            micro_batch = batch[i:i + micro_batch_size]
            
            with torch.no_grad() if not model.training else torch.enable_grad():
                output = model(micro_batch)
            
            outputs.append(output)
        
        return torch.cat(outputs, dim=0)


class InPlaceOperations:
    """
    In-place operations to reduce memory allocations.
    
    Use with caution - can break gradient computation!
    """
    
    @staticmethod
    def inplace_relu(x: torch.Tensor) -> torch.Tensor:
        """In-place ReLU."""
        return torch.relu_(x)
    
    @staticmethod
    def inplace_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """In-place addition."""
        return x.add_(y)
    
    @staticmethod
    def check_safe_inplace(tensor: torch.Tensor) -> bool:
        """
        Check if in-place operation is safe.
        
        Not safe if:
        - Tensor requires grad and is a leaf
        - Tensor is used elsewhere in computation graph
        """
        if tensor.requires_grad and tensor.is_leaf:
            return False
        return True

Memory-Efficient Data Loading

class MemoryEfficientDataLoader:
    """
    Reduce data loading memory overhead.
    """
    
    @staticmethod
    def create_efficient_loader(
        dataset,
        batch_size: int,
        num_workers: int = 4,
        pin_memory: bool = True
    ):
        """
        Create memory-efficient data loader.
        
        Tips:
        - pin_memory for faster GPU transfer
        - Appropriate num_workers (usually 4 per GPU)
        - prefetch_factor controls memory usage
        """
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=pin_memory,
            # Limit prefetching to reduce memory
            prefetch_factor=2 if num_workers > 0 else None,
            # Don't copy tensors
            persistent_workers=True if num_workers > 0 else False,
        )
    
    @staticmethod
    def use_memory_mapping(file_path: str):
        """
        Use memory-mapped files for large datasets.
        
        Data stays on disk, loaded on demand.
        """
        return np.memmap(file_path, dtype='float32', mode='r')


class LowMemoryDataset(torch.utils.data.Dataset):
    """
    Dataset that loads data on-demand.
    """
    
    def __init__(self, file_paths: List[str]):
        self.file_paths = file_paths
    
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        # Load only when needed
        data = torch.load(self.file_paths[idx])
        return data

Memory Debugging Tools

class MemoryDebugger:
    """
    Tools for debugging memory issues.
    """
    
    @staticmethod
    def find_memory_leaks():
        """Find potential memory leaks."""
        if not torch.cuda.is_available():
            return []
        
        gc.collect()
        torch.cuda.empty_cache()
        
        leaks = []
        
        for obj in gc.get_objects():
            if isinstance(obj, torch.Tensor):
                if obj.is_cuda:
                    leaks.append({
                        'shape': obj.shape,
                        'dtype': obj.dtype,
                        'size_mb': obj.numel() * obj.element_size() / 1e6,
                        'requires_grad': obj.requires_grad
                    })
        
        return sorted(leaks, key=lambda x: -x['size_mb'])
    
    @staticmethod
    def memory_summary():
        """Print detailed memory summary."""
        if not torch.cuda.is_available():
            print("CUDA not available")
            return
        
        print(torch.cuda.memory_summary())
    
    @staticmethod
    def clear_memory():
        """Aggressively clear GPU memory."""
        gc.collect()
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()


# Usage demonstration
def memory_optimization_demo():
    """Demonstrate memory optimization techniques."""
    
    summary = """
    ╔════════════════════════════════════════════════════════════════╗
    ║           MEMORY OPTIMIZATION TECHNIQUES                        ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                 ║
    ║  TECHNIQUE              MEMORY SAVING    COMPUTE COST           ║
    ║  ─────────────────────────────────────────────────────────────  ║
    ║  Mixed Precision        ~50%             ~0% (often faster!)    ║
    ║  Gradient Checkpointing ~60-70%          ~30-40%                ║
    ║  Gradient Accumulation  Linear in steps  ~0%                    ║
    ║  Flash Attention        ~50% for attn    ~0% (often faster!)    ║
    ║  CPU Offloading         ~50-70%          ~20-50%                ║
    ║                                                                 ║
    ╠════════════════════════════════════════════════════════════════╣
    ║  RECOMMENDATION ORDER:                                          ║
    ║  1. Enable mixed precision (always)                             ║
    ║  2. Use Flash Attention (if available)                          ║
    ║  3. Add gradient checkpointing                                  ║
    ║  4. Gradient accumulation for larger effective batch            ║
    ║  5. CPU offloading as last resort                               ║
    ╚════════════════════════════════════════════════════════════════╝
    """
    print(summary)

memory_optimization_demo()

Exercises

Use the MemoryProfiler to analyze a model’s memory usage at each layer. Identify the biggest memory consumers and optimize them.
Create a system that automatically decides which layers to checkpoint based on their memory usage vs. compute cost.
Implement training that stays within a fixed memory budget by dynamically adjusting batch size and checkpointing.

What’s Next?