Memory-Efficient Training
Understanding GPU Memory
Where does memory go during training?| Component | Memory Usage |
|---|---|
| Model parameters | 4×params (fp32) |
| Gradients | 4×params (fp32) |
| Optimizer states | 8×params (Adam) |
| Activations | Depends on batch size, model depth |
| Temporary buffers | Variable |
- Parameters: 28 GB
- Gradients: 28 GB
- Adam states: 56 GB
- Total: ~112 GB (without activations!)
Copy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from typing import Optional, Tuple, List, Callable
import gc
import functools
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Memory Profiling
Copy
class MemoryProfiler:
"""
Profile GPU memory usage during training.
"""
def __init__(self, device: torch.device = device):
self.device = device
self.snapshots = []
def get_memory_stats(self) -> dict:
"""Get current memory statistics."""
if not torch.cuda.is_available():
return {'allocated': 0, 'cached': 0, 'max_allocated': 0}
return {
'allocated_mb': torch.cuda.memory_allocated(self.device) / 1e6,
'cached_mb': torch.cuda.memory_reserved(self.device) / 1e6,
'max_allocated_mb': torch.cuda.max_memory_allocated(self.device) / 1e6,
}
def snapshot(self, label: str = ""):
"""Take a memory snapshot."""
stats = self.get_memory_stats()
stats['label'] = label
self.snapshots.append(stats)
return stats
def reset_peak_stats(self):
"""Reset peak memory tracking."""
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
def print_summary(self):
"""Print memory usage summary."""
print("\n" + "="*60)
print("MEMORY USAGE SUMMARY")
print("="*60)
for snap in self.snapshots:
print(f"{snap['label']:30} | "
f"Allocated: {snap['allocated_mb']:8.1f} MB | "
f"Peak: {snap['max_allocated_mb']:8.1f} MB")
@staticmethod
def estimate_model_memory(model: nn.Module) -> dict:
"""Estimate memory requirements for a model."""
param_mem = sum(p.numel() * p.element_size() for p in model.parameters())
buffer_mem = sum(b.numel() * b.element_size() for b in model.buffers())
# Gradient memory (same as params for most cases)
grad_mem = param_mem
# Optimizer states (Adam: 2x for m and v)
adam_mem = 2 * param_mem
return {
'parameters_mb': param_mem / 1e6,
'buffers_mb': buffer_mem / 1e6,
'gradients_mb': grad_mem / 1e6,
'adam_states_mb': adam_mem / 1e6,
'total_mb': (param_mem + buffer_mem + grad_mem + adam_mem) / 1e6
}
# Memory-efficient model analysis
def analyze_memory_per_layer(model: nn.Module, input_shape: tuple):
"""Analyze memory usage per layer."""
profiler = MemoryProfiler()
if not torch.cuda.is_available():
print("CUDA not available for memory analysis")
return
model = model.to(device)
x = torch.randn(*input_shape).to(device)
profiler.reset_peak_stats()
profiler.snapshot("After model loading")
# Forward with hooks
activations = {}
def save_activation(name):
def hook(module, input, output):
if isinstance(output, torch.Tensor):
activations[name] = output.numel() * output.element_size() / 1e6
return hook
for name, module in model.named_modules():
module.register_forward_hook(save_activation(name))
# Forward pass
with torch.no_grad():
_ = model(x)
profiler.snapshot("After forward pass")
# Print activation memory
print("\nActivation memory per layer:")
for name, mem in sorted(activations.items(), key=lambda x: -x[1])[:10]:
print(f" {name}: {mem:.2f} MB")
Gradient Checkpointing
Copy
class GradientCheckpointing:
"""
Trade compute for memory by recomputing activations.
Instead of storing all activations for backward pass,
we store only at checkpoint boundaries and recompute
intermediate activations during backward.
Memory savings: O(sqrt(n)) vs O(n) for n layers
Compute overhead: ~30-40% more forward passes
"""
@staticmethod
def checkpoint_sequential(
functions: List[nn.Module],
segments: int,
input: torch.Tensor
) -> torch.Tensor:
"""
Apply checkpointing to sequential modules.
Args:
functions: List of modules to apply sequentially
segments: Number of checkpoint segments
input: Input tensor
"""
return checkpoint_sequential(functions, segments, input)
@staticmethod
def checkpoint_function(
function: Callable,
*args,
use_reentrant: bool = False
) -> torch.Tensor:
"""
Checkpoint a single function.
Args:
function: Function to checkpoint
*args: Function arguments
use_reentrant: Use reentrant version (False recommended)
"""
return checkpoint(function, *args, use_reentrant=use_reentrant)
# Transformer with gradient checkpointing
class CheckpointedTransformerBlock(nn.Module):
"""
Transformer block with optional gradient checkpointing.
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
use_checkpoint: bool = False
):
super().__init__()
self.use_checkpoint = use_checkpoint
# Self-attention
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
# Feedforward
self.ff = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
nn.Dropout(dropout)
)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def _attention_block(self, x: torch.Tensor) -> torch.Tensor:
"""Self-attention sub-block."""
attn_out, _ = self.self_attn(x, x, x)
return self.dropout(attn_out)
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
"""Feedforward sub-block."""
return self.ff(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_checkpoint and self.training:
# Checkpoint attention
attn_out = checkpoint(
self._attention_block,
x,
use_reentrant=False
)
x = self.norm1(x + attn_out)
# Checkpoint feedforward
ff_out = checkpoint(
self._ff_block,
x,
use_reentrant=False
)
x = self.norm2(x + ff_out)
else:
# Standard forward
x = self.norm1(x + self._attention_block(x))
x = self.norm2(x + self._ff_block(x))
return x
class CheckpointedTransformer(nn.Module):
"""
Full transformer with gradient checkpointing.
"""
def __init__(
self,
d_model: int = 512,
nhead: int = 8,
num_layers: int = 6,
vocab_size: int = 10000,
use_checkpoint: bool = True
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([
CheckpointedTransformerBlock(
d_model, nhead,
use_checkpoint=use_checkpoint
)
for _ in range(num_layers)
])
self.output = nn.Linear(d_model, vocab_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embedding(x)
for layer in self.layers:
x = layer(x)
return self.output(x)
# Compare memory with/without checkpointing
def compare_checkpointing_memory():
"""Compare memory usage with and without checkpointing."""
if not torch.cuda.is_available():
print("CUDA not available")
return
batch_size, seq_len = 32, 512
d_model, num_layers = 512, 12
# Without checkpointing
model_no_ckpt = CheckpointedTransformer(
d_model=d_model, num_layers=num_layers,
use_checkpoint=False
).to(device)
torch.cuda.reset_peak_memory_stats()
x = torch.randint(0, 10000, (batch_size, seq_len)).to(device)
out = model_no_ckpt(x)
loss = out.sum()
loss.backward()
mem_no_ckpt = torch.cuda.max_memory_allocated() / 1e9
del model_no_ckpt, out, loss
torch.cuda.empty_cache()
# With checkpointing
model_ckpt = CheckpointedTransformer(
d_model=d_model, num_layers=num_layers,
use_checkpoint=True
).to(device)
torch.cuda.reset_peak_memory_stats()
out = model_ckpt(x)
loss = out.sum()
loss.backward()
mem_ckpt = torch.cuda.max_memory_allocated() / 1e9
print(f"Without checkpointing: {mem_no_ckpt:.2f} GB")
print(f"With checkpointing: {mem_ckpt:.2f} GB")
print(f"Memory saved: {(1 - mem_ckpt/mem_no_ckpt)*100:.1f}%")
Mixed Precision Training
Copy
class MixedPrecisionTraining:
"""
Train with mixed FP16/FP32 precision.
Benefits:
- 2x memory reduction for activations
- Faster matrix multiplications on modern GPUs
- Minimal accuracy loss with proper scaling
Key components:
- Automatic casting (autocast)
- Gradient scaling (prevent underflow)
"""
def __init__(self, enabled: bool = True):
self.enabled = enabled
self.scaler = torch.amp.GradScaler('cuda') if enabled else None
def training_step(
self,
model: nn.Module,
batch: tuple,
optimizer: optim.Optimizer,
criterion: nn.Module
) -> float:
"""
Perform one training step with mixed precision.
"""
inputs, targets = batch
optimizer.zero_grad()
if self.enabled:
# Mixed precision forward
with torch.amp.autocast('cuda'):
outputs = model(inputs)
loss = criterion(outputs, targets)
# Scaled backward
self.scaler.scale(loss).backward()
# Unscale gradients for clipping
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Optimizer step with scaling
self.scaler.step(optimizer)
self.scaler.update()
else:
# Standard precision
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
return loss.item()
# BFloat16 for better stability
class BFloat16Training:
"""
BFloat16 training (available on Ampere+ GPUs).
Advantages over FP16:
- Same exponent range as FP32
- No gradient scaling needed
- Better numerical stability
"""
@staticmethod
def is_available() -> bool:
"""Check if BF16 is available."""
if not torch.cuda.is_available():
return False
return torch.cuda.get_device_capability()[0] >= 8
@staticmethod
def training_step(
model: nn.Module,
batch: tuple,
optimizer: optim.Optimizer,
criterion: nn.Module
) -> float:
"""Training step with BF16."""
inputs, targets = batch
optimizer.zero_grad()
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
outputs = model(inputs)
loss = criterion(outputs, targets)
# No scaling needed for BF16!
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
return loss.item()
# Precision comparison
def demonstrate_precision_formats():
"""Show memory usage for different precisions."""
num_params = 1_000_000 # 1M parameters
formats = {
'float32': (4, "Standard precision"),
'float16': (2, "Half precision"),
'bfloat16': (2, "Brain float"),
'int8': (1, "Quantized"),
}
print("Memory per million parameters:")
print("-" * 45)
for name, (bytes_per, desc) in formats.items():
mem = num_params * bytes_per / 1e6
print(f"{name:10} ({desc:20}): {mem:.1f} MB")
demonstrate_precision_formats()
CPU Offloading
Copy
class CPUOffloadOptimizer:
"""
Offload optimizer states to CPU to save GPU memory.
Trade-off: Slower training but can train larger models.
"""
def __init__(
self,
model: nn.Module,
lr: float = 1e-4,
offload: bool = True
):
self.model = model
self.offload = offload
if offload:
# Keep parameters on GPU but states on CPU
self.param_groups = [
{
'params': list(model.parameters()),
'lr': lr
}
]
# Initialize optimizer states on CPU
self.state = {}
for p in model.parameters():
if p.requires_grad:
# Adam states
self.state[p] = {
'm': torch.zeros_like(p.data, device='cpu'),
'v': torch.zeros_like(p.data, device='cpu'),
't': 0
}
else:
self.optimizer = optim.AdamW(model.parameters(), lr=lr)
def zero_grad(self):
"""Zero gradients."""
if self.offload:
for p in self.model.parameters():
if p.grad is not None:
p.grad.zero_()
else:
self.optimizer.zero_grad()
def step(self):
"""Perform optimizer step."""
if not self.offload:
self.optimizer.step()
return
lr = self.param_groups[0]['lr']
beta1, beta2 = 0.9, 0.999
eps = 1e-8
for p in self.model.parameters():
if not p.requires_grad or p.grad is None:
continue
state = self.state[p]
state['t'] += 1
t = state['t']
# Move gradient to CPU
grad_cpu = p.grad.data.cpu()
# Update moments on CPU
state['m'].mul_(beta1).add_(grad_cpu, alpha=1 - beta1)
state['v'].mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2)
# Bias correction
m_hat = state['m'] / (1 - beta1 ** t)
v_hat = state['v'] / (1 - beta2 ** t)
# Compute update on CPU
update = m_hat / (v_hat.sqrt() + eps)
# Apply update to GPU parameter
p.data.add_(update.to(p.device), alpha=-lr)
class GradientOffloading:
"""
Offload gradients to CPU during backward pass.
Useful for very large models where even gradients don't fit.
"""
def __init__(self, model: nn.Module):
self.model = model
self.cpu_gradients = {}
self._register_hooks()
def _register_hooks(self):
"""Register hooks to offload gradients."""
for name, param in self.model.named_parameters():
if param.requires_grad:
param.register_post_accumulate_grad_hook(
self._make_hook(name)
)
def _make_hook(self, name: str):
def hook(param):
if param.grad is not None:
# Move gradient to CPU
self.cpu_gradients[name] = param.grad.cpu()
# Free GPU memory
param.grad = None
return hook
def restore_gradients(self):
"""Move gradients back to GPU for optimizer."""
for name, param in self.model.named_parameters():
if name in self.cpu_gradients:
param.grad = self.cpu_gradients[name].to(param.device)
Activation Recomputation Strategies
Copy
class SelectiveCheckpointing:
"""
Smart checkpointing strategies.
Not all layers benefit equally from checkpointing.
Focus on high-memory layers.
"""
@staticmethod
def should_checkpoint(module: nn.Module) -> bool:
"""
Decide if a module should be checkpointed.
Rules:
- Checkpoint attention (quadratic memory)
- Checkpoint large FFN layers
- Don't checkpoint small layers (overhead not worth it)
"""
if isinstance(module, nn.MultiheadAttention):
return True
# Check for large linear layers
if isinstance(module, nn.Linear):
return module.in_features * module.out_features > 1e6
return False
@staticmethod
def estimate_activation_memory(
module: nn.Module,
input_shape: tuple
) -> float:
"""
Estimate activation memory for a module.
Returns memory in MB.
"""
# This is a simplified estimate
batch_size = input_shape[0]
if isinstance(module, nn.Linear):
# Output activations
return batch_size * module.out_features * 4 / 1e6
if isinstance(module, nn.MultiheadAttention):
# Attention weights (quadratic in sequence length)
seq_len = input_shape[1]
num_heads = module.num_heads
return batch_size * num_heads * seq_len * seq_len * 4 / 1e6
return 0.0
class LayerWiseCheckpointing(nn.Module):
"""
Apply checkpointing at layer granularity.
"""
def __init__(
self,
layers: nn.ModuleList,
checkpoint_ratio: float = 0.5
):
super().__init__()
self.layers = layers
# Checkpoint every other layer (or based on ratio)
self.checkpoint_mask = [
i % int(1 / checkpoint_ratio) == 0
for i in range(len(layers))
]
def forward(self, x: torch.Tensor) -> torch.Tensor:
for i, layer in enumerate(self.layers):
if self.checkpoint_mask[i] and self.training:
x = checkpoint(layer, x, use_reentrant=False)
else:
x = layer(x)
return x
Memory-Efficient Attention
Copy
class MemoryEfficientAttention(nn.Module):
"""
Memory-efficient attention implementations.
Standard attention: O(n²) memory for attention matrix
Efficient versions reduce this significantly.
"""
def __init__(
self,
d_model: int,
num_heads: int,
chunk_size: int = 1024
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.chunk_size = chunk_size
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
# Compute Q, K, V
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, D)
q, k, v = qkv[0], qkv[1], qkv[2]
# Chunked attention to reduce peak memory
output = self.chunked_attention(q, k, v)
output = output.transpose(1, 2).reshape(B, N, C)
return self.out_proj(output)
def chunked_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor
) -> torch.Tensor:
"""
Compute attention in chunks.
Instead of computing full NxN attention matrix,
process in chunks to reduce memory.
"""
B, H, N, D = q.shape
scale = D ** -0.5
output = torch.zeros_like(v)
for i in range(0, N, self.chunk_size):
end_i = min(i + self.chunk_size, N)
q_chunk = q[:, :, i:end_i]
# Compute attention for this query chunk
attn = torch.matmul(q_chunk, k.transpose(-2, -1)) * scale
attn = attn.softmax(dim=-1)
output[:, :, i:end_i] = torch.matmul(attn, v)
return output
class FlashAttention(nn.Module):
"""
Flash Attention wrapper (requires flash-attn package).
Key innovations:
- Tiling to fit in SRAM
- No materialization of attention matrix
- IO-aware algorithm design
Memory: O(n) instead of O(n²)
Speed: 2-4x faster than standard attention
"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.dropout = dropout
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
self._check_flash_available()
def _check_flash_available(self):
"""Check if Flash Attention is available."""
try:
from flash_attn import flash_attn_func
self.flash_attn = flash_attn_func
self.use_flash = True
except ImportError:
print("Flash Attention not available. Using standard attention.")
self.use_flash = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
if self.use_flash:
# Flash attention format: (B, N, H, D)
q = qkv[:, :, 0].transpose(1, 2)
k = qkv[:, :, 1].transpose(1, 2)
v = qkv[:, :, 2].transpose(1, 2)
out = self.flash_attn(q, k, v, dropout_p=self.dropout if self.training else 0.0)
out = out.transpose(1, 2).reshape(B, N, C)
else:
# Standard attention
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
scale = self.head_dim ** -0.5
attn = (q @ k.transpose(-2, -1)) * scale
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, N, C)
return self.out_proj(out)
Efficient Batch Processing
Copy
class GradientAccumulation:
"""
Accumulate gradients over multiple mini-batches.
Simulates larger batch sizes without memory increase.
"""
def __init__(
self,
model: nn.Module,
optimizer: optim.Optimizer,
accumulation_steps: int = 4
):
self.model = model
self.optimizer = optimizer
self.accumulation_steps = accumulation_steps
self.current_step = 0
def backward(self, loss: torch.Tensor):
"""
Backward with gradient accumulation.
"""
# Scale loss by accumulation steps
scaled_loss = loss / self.accumulation_steps
scaled_loss.backward()
self.current_step += 1
if self.current_step >= self.accumulation_steps:
# Perform optimizer step
self.optimizer.step()
self.optimizer.zero_grad()
self.current_step = 0
return True # Step was performed
return False # Step not performed yet
class MicroBatching:
"""
Split batches into micro-batches for forward pass.
Useful when single batch doesn't fit but you need
batch normalization statistics from full batch.
"""
@staticmethod
def micro_batch_forward(
model: nn.Module,
batch: torch.Tensor,
micro_batch_size: int
) -> torch.Tensor:
"""
Forward in micro-batches.
"""
outputs = []
for i in range(0, batch.size(0), micro_batch_size):
micro_batch = batch[i:i + micro_batch_size]
with torch.no_grad() if not model.training else torch.enable_grad():
output = model(micro_batch)
outputs.append(output)
return torch.cat(outputs, dim=0)
class InPlaceOperations:
"""
In-place operations to reduce memory allocations.
Use with caution - can break gradient computation!
"""
@staticmethod
def inplace_relu(x: torch.Tensor) -> torch.Tensor:
"""In-place ReLU."""
return torch.relu_(x)
@staticmethod
def inplace_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""In-place addition."""
return x.add_(y)
@staticmethod
def check_safe_inplace(tensor: torch.Tensor) -> bool:
"""
Check if in-place operation is safe.
Not safe if:
- Tensor requires grad and is a leaf
- Tensor is used elsewhere in computation graph
"""
if tensor.requires_grad and tensor.is_leaf:
return False
return True
Memory-Efficient Data Loading
Copy
class MemoryEfficientDataLoader:
"""
Reduce data loading memory overhead.
"""
@staticmethod
def create_efficient_loader(
dataset,
batch_size: int,
num_workers: int = 4,
pin_memory: bool = True
):
"""
Create memory-efficient data loader.
Tips:
- pin_memory for faster GPU transfer
- Appropriate num_workers (usually 4 per GPU)
- prefetch_factor controls memory usage
"""
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
# Limit prefetching to reduce memory
prefetch_factor=2 if num_workers > 0 else None,
# Don't copy tensors
persistent_workers=True if num_workers > 0 else False,
)
@staticmethod
def use_memory_mapping(file_path: str):
"""
Use memory-mapped files for large datasets.
Data stays on disk, loaded on demand.
"""
return np.memmap(file_path, dtype='float32', mode='r')
class LowMemoryDataset(torch.utils.data.Dataset):
"""
Dataset that loads data on-demand.
"""
def __init__(self, file_paths: List[str]):
self.file_paths = file_paths
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
# Load only when needed
data = torch.load(self.file_paths[idx])
return data
Memory Debugging Tools
Copy
class MemoryDebugger:
"""
Tools for debugging memory issues.
"""
@staticmethod
def find_memory_leaks():
"""Find potential memory leaks."""
if not torch.cuda.is_available():
return []
gc.collect()
torch.cuda.empty_cache()
leaks = []
for obj in gc.get_objects():
if isinstance(obj, torch.Tensor):
if obj.is_cuda:
leaks.append({
'shape': obj.shape,
'dtype': obj.dtype,
'size_mb': obj.numel() * obj.element_size() / 1e6,
'requires_grad': obj.requires_grad
})
return sorted(leaks, key=lambda x: -x['size_mb'])
@staticmethod
def memory_summary():
"""Print detailed memory summary."""
if not torch.cuda.is_available():
print("CUDA not available")
return
print(torch.cuda.memory_summary())
@staticmethod
def clear_memory():
"""Aggressively clear GPU memory."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Usage demonstration
def memory_optimization_demo():
"""Demonstrate memory optimization techniques."""
summary = """
╔════════════════════════════════════════════════════════════════╗
║ MEMORY OPTIMIZATION TECHNIQUES ║
╠════════════════════════════════════════════════════════════════╣
║ ║
║ TECHNIQUE MEMORY SAVING COMPUTE COST ║
║ ───────────────────────────────────────────────────────────── ║
║ Mixed Precision ~50% ~0% (often faster!) ║
║ Gradient Checkpointing ~60-70% ~30-40% ║
║ Gradient Accumulation Linear in steps ~0% ║
║ Flash Attention ~50% for attn ~0% (often faster!) ║
║ CPU Offloading ~50-70% ~20-50% ║
║ ║
╠════════════════════════════════════════════════════════════════╣
║ RECOMMENDATION ORDER: ║
║ 1. Enable mixed precision (always) ║
║ 2. Use Flash Attention (if available) ║
║ 3. Add gradient checkpointing ║
║ 4. Gradient accumulation for larger effective batch ║
║ 5. CPU offloading as last resort ║
╚════════════════════════════════════════════════════════════════╝
"""
print(summary)
memory_optimization_demo()
Exercises
Exercise 1: Profile Your Model
Exercise 1: Profile Your Model
Use the MemoryProfiler to analyze a model’s memory usage at each layer.
Identify the biggest memory consumers and optimize them.
Exercise 2: Implement Selective Checkpointing
Exercise 2: Implement Selective Checkpointing
Create a system that automatically decides which layers to checkpoint
based on their memory usage vs. compute cost.
Exercise 3: Memory Budget Training
Exercise 3: Memory Budget Training
Implement training that stays within a fixed memory budget by
dynamically adjusting batch size and checkpointing.