Documentation Index
Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt
Use this file to discover all available pages before exploring further.
Memory-Efficient Training
Understanding GPU Memory
GPU memory during training is like a hotel with a fixed number of rooms: your model parameters check in, then their gradients need rooms, then the optimizer’s bookkeeping (momentum, variance) needs even more rooms, and finally every intermediate computation (activation) holds a room open until the backward pass checks it out. When the hotel is full, training crashes with the dreadedCUDA out of memory error. Understanding who occupies which rooms — and for how long — is the key to training larger models on the hardware you actually have.
Where does memory go during training?
| Component | Memory Usage | Analogy |
|---|---|---|
| Model parameters | 4×params (fp32) | The permanent residents |
| Gradients | 4×params (fp32) | Their shadows (same size) |
| Optimizer states | 8×params (Adam) | Adam keeps two running averages per parameter |
| Activations | Depends on batch size, model depth | The guests — they check in during forward, check out during backward |
| Temporary buffers | Variable | Room service supplies |
- Parameters: 28 GB
- Gradients: 28 GB
- Adam states: 56 GB
- Total: ~112 GB (without activations!)
The single most impactful optimization for most teams is mixed precision training (FP16/BF16). It nearly halves memory for parameters and activations with minimal accuracy impact, and it is a one-line change in most frameworks. Start there before reaching for more complex techniques like gradient checkpointing or model parallelism.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from typing import Optional, Tuple, List, Callable
import gc
import functools
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Memory Profiling
class MemoryProfiler:
"""
Profile GPU memory usage during training.
"""
def __init__(self, device: torch.device = device):
self.device = device
self.snapshots = []
def get_memory_stats(self) -> dict:
"""Get current memory statistics."""
if not torch.cuda.is_available():
return {'allocated': 0, 'cached': 0, 'max_allocated': 0}
return {
'allocated_mb': torch.cuda.memory_allocated(self.device) / 1e6,
'cached_mb': torch.cuda.memory_reserved(self.device) / 1e6,
'max_allocated_mb': torch.cuda.max_memory_allocated(self.device) / 1e6,
}
def snapshot(self, label: str = ""):
"""Take a memory snapshot."""
stats = self.get_memory_stats()
stats['label'] = label
self.snapshots.append(stats)
return stats
def reset_peak_stats(self):
"""Reset peak memory tracking."""
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
def print_summary(self):
"""Print memory usage summary."""
print("\n" + "="*60)
print("MEMORY USAGE SUMMARY")
print("="*60)
for snap in self.snapshots:
print(f"{snap['label']:30} | "
f"Allocated: {snap['allocated_mb']:8.1f} MB | "
f"Peak: {snap['max_allocated_mb']:8.1f} MB")
@staticmethod
def estimate_model_memory(model: nn.Module) -> dict:
"""Estimate memory requirements for a model."""
param_mem = sum(p.numel() * p.element_size() for p in model.parameters())
buffer_mem = sum(b.numel() * b.element_size() for b in model.buffers())
# Gradient memory (same as params for most cases)
grad_mem = param_mem
# Optimizer states (Adam: 2x for m and v)
adam_mem = 2 * param_mem
return {
'parameters_mb': param_mem / 1e6,
'buffers_mb': buffer_mem / 1e6,
'gradients_mb': grad_mem / 1e6,
'adam_states_mb': adam_mem / 1e6,
'total_mb': (param_mem + buffer_mem + grad_mem + adam_mem) / 1e6
}
# Memory-efficient model analysis
def analyze_memory_per_layer(model: nn.Module, input_shape: tuple):
"""Analyze memory usage per layer."""
profiler = MemoryProfiler()
if not torch.cuda.is_available():
print("CUDA not available for memory analysis")
return
model = model.to(device)
x = torch.randn(*input_shape).to(device)
profiler.reset_peak_stats()
profiler.snapshot("After model loading")
# Forward with hooks
activations = {}
def save_activation(name):
def hook(module, input, output):
if isinstance(output, torch.Tensor):
activations[name] = output.numel() * output.element_size() / 1e6
return hook
for name, module in model.named_modules():
module.register_forward_hook(save_activation(name))
# Forward pass
with torch.no_grad():
_ = model(x)
profiler.snapshot("After forward pass")
# Print activation memory
print("\nActivation memory per layer:")
for name, mem in sorted(activations.items(), key=lambda x: -x[1])[:10]:
print(f" {name}: {mem:.2f} MB")
Gradient Checkpointing
During the forward pass, PyTorch saves every intermediate activation because it needs them during the backward pass to compute gradients. For a 24-layer transformer, this means storing 24 layers worth of intermediate tensors simultaneously. Gradient checkpointing (also called activation checkpointing or rematerialization) offers a simple trade: do not save intermediate activations — instead, recompute them on-the-fly during the backward pass. You save memory at the cost of running parts of the forward pass twice. The math is elegant: for n layers, standard training stores O(n) activations. With checkpointing every n layers, you store only O(n) activations and recompute the rest, at the cost of about 33% extra compute. For a 24-layer model, that is storing 5 checkpoints instead of 24 activation sets — a nearly 5x memory reduction for activations.class GradientCheckpointing:
"""
Trade compute for memory by recomputing activations.
Instead of storing all activations for backward pass,
we store only at checkpoint boundaries and recompute
intermediate activations during backward.
Memory savings: O(sqrt(n)) vs O(n) for n layers
Compute overhead: ~30-40% more forward passes
When to use: When activation memory (not parameter/optimizer memory)
is your bottleneck. This is common for models with many layers
(deep transformers), large spatial dimensions (segmentation), or
long sequences (NLP with 4K+ tokens).
"""
@staticmethod
def checkpoint_sequential(
functions: List[nn.Module],
segments: int,
input: torch.Tensor
) -> torch.Tensor:
"""
Apply checkpointing to sequential modules.
Args:
functions: List of modules to apply sequentially
segments: Number of checkpoint segments
input: Input tensor
"""
return checkpoint_sequential(functions, segments, input)
@staticmethod
def checkpoint_function(
function: Callable,
*args,
use_reentrant: bool = False
) -> torch.Tensor:
"""
Checkpoint a single function.
Args:
function: Function to checkpoint
*args: Function arguments
use_reentrant: Use reentrant version (False recommended)
"""
return checkpoint(function, *args, use_reentrant=use_reentrant)
# Transformer with gradient checkpointing
class CheckpointedTransformerBlock(nn.Module):
"""
Transformer block with optional gradient checkpointing.
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
use_checkpoint: bool = False
):
super().__init__()
self.use_checkpoint = use_checkpoint
# Self-attention
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
# Feedforward
self.ff = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
nn.Dropout(dropout)
)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def _attention_block(self, x: torch.Tensor) -> torch.Tensor:
"""Self-attention sub-block."""
attn_out, _ = self.self_attn(x, x, x)
return self.dropout(attn_out)
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
"""Feedforward sub-block."""
return self.ff(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_checkpoint and self.training:
# Checkpoint attention
attn_out = checkpoint(
self._attention_block,
x,
use_reentrant=False
)
x = self.norm1(x + attn_out)
# Checkpoint feedforward
ff_out = checkpoint(
self._ff_block,
x,
use_reentrant=False
)
x = self.norm2(x + ff_out)
else:
# Standard forward
x = self.norm1(x + self._attention_block(x))
x = self.norm2(x + self._ff_block(x))
return x
class CheckpointedTransformer(nn.Module):
"""
Full transformer with gradient checkpointing.
"""
def __init__(
self,
d_model: int = 512,
nhead: int = 8,
num_layers: int = 6,
vocab_size: int = 10000,
use_checkpoint: bool = True
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([
CheckpointedTransformerBlock(
d_model, nhead,
use_checkpoint=use_checkpoint
)
for _ in range(num_layers)
])
self.output = nn.Linear(d_model, vocab_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embedding(x)
for layer in self.layers:
x = layer(x)
return self.output(x)
# Compare memory with/without checkpointing
def compare_checkpointing_memory():
"""Compare memory usage with and without checkpointing."""
if not torch.cuda.is_available():
print("CUDA not available")
return
batch_size, seq_len = 32, 512
d_model, num_layers = 512, 12
# Without checkpointing
model_no_ckpt = CheckpointedTransformer(
d_model=d_model, num_layers=num_layers,
use_checkpoint=False
).to(device)
torch.cuda.reset_peak_memory_stats()
x = torch.randint(0, 10000, (batch_size, seq_len)).to(device)
out = model_no_ckpt(x)
loss = out.sum()
loss.backward()
mem_no_ckpt = torch.cuda.max_memory_allocated() / 1e9
del model_no_ckpt, out, loss
torch.cuda.empty_cache()
# With checkpointing
model_ckpt = CheckpointedTransformer(
d_model=d_model, num_layers=num_layers,
use_checkpoint=True
).to(device)
torch.cuda.reset_peak_memory_stats()
out = model_ckpt(x)
loss = out.sum()
loss.backward()
mem_ckpt = torch.cuda.max_memory_allocated() / 1e9
print(f"Without checkpointing: {mem_no_ckpt:.2f} GB")
print(f"With checkpointing: {mem_ckpt:.2f} GB")
print(f"Memory saved: {(1 - mem_ckpt/mem_no_ckpt)*100:.1f}%")
Mixed Precision Training
Mixed precision training is the single highest-impact memory optimization for most practitioners. The idea: use 16-bit floating point (FP16 or BF16) for the computationally intensive operations (matrix multiplies, convolutions) while keeping 32-bit precision where it matters (loss computation, gradient accumulation, optimizer states). Modern GPUs have dedicated hardware (Tensor Cores) that run FP16 matrix multiplies at 2-8x the speed of FP32, so you get both memory and speed benefits. The tricky part is avoiding numerical issues. FP16 has a very narrow dynamic range — gradients can underflow to zero (too small to represent) or overflow to infinity. The solution is gradient scaling: multiply the loss by a large number before backward, then divide the gradients by the same number before the optimizer step. This shifts the gradient values into FP16’s representable range. BFloat16 (BF16) avoids this entirely because it has the same exponent range as FP32, at the cost of less precision in the mantissa.Do not use FP16 mixed precision for fine-tuning large language models. BF16 is strongly preferred because FP16’s limited range causes gradient underflow on very small learning rates (1e-5 range) that are typical for LLM fine-tuning. If your GPU supports BF16 (Ampere or newer), always prefer it.
class MixedPrecisionTraining:
"""
Train with mixed FP16/FP32 precision.
Benefits:
- 2x memory reduction for activations and model weights
- 2-8x faster matrix multiplications on Tensor Core GPUs
- Minimal accuracy loss with proper gradient scaling
Key components:
- Automatic casting (autocast): PyTorch decides which ops run in FP16
- Gradient scaling (GradScaler): prevents FP16 gradient underflow
Typical setup: 3 lines of code change (autocast context, scale loss,
scaler.step instead of optimizer.step).
"""
def __init__(self, enabled: bool = True):
self.enabled = enabled
self.scaler = torch.amp.GradScaler('cuda') if enabled else None
def training_step(
self,
model: nn.Module,
batch: tuple,
optimizer: optim.Optimizer,
criterion: nn.Module
) -> float:
"""
Perform one training step with mixed precision.
"""
inputs, targets = batch
optimizer.zero_grad()
if self.enabled:
# Mixed precision forward
with torch.amp.autocast('cuda'):
outputs = model(inputs)
loss = criterion(outputs, targets)
# Scaled backward
self.scaler.scale(loss).backward()
# Unscale gradients for clipping
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Optimizer step with scaling
self.scaler.step(optimizer)
self.scaler.update()
else:
# Standard precision
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
return loss.item()
# BFloat16 for better stability
class BFloat16Training:
"""
BFloat16 training (available on Ampere+ GPUs).
Advantages over FP16:
- Same exponent range as FP32
- No gradient scaling needed
- Better numerical stability
"""
@staticmethod
def is_available() -> bool:
"""Check if BF16 is available."""
if not torch.cuda.is_available():
return False
return torch.cuda.get_device_capability()[0] >= 8
@staticmethod
def training_step(
model: nn.Module,
batch: tuple,
optimizer: optim.Optimizer,
criterion: nn.Module
) -> float:
"""Training step with BF16."""
inputs, targets = batch
optimizer.zero_grad()
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
outputs = model(inputs)
loss = criterion(outputs, targets)
# No scaling needed for BF16!
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
return loss.item()
# Precision comparison
def demonstrate_precision_formats():
"""Show memory usage for different precisions."""
num_params = 1_000_000 # 1M parameters
formats = {
'float32': (4, "Standard precision"),
'float16': (2, "Half precision"),
'bfloat16': (2, "Brain float"),
'int8': (1, "Quantized"),
}
print("Memory per million parameters:")
print("-" * 45)
for name, (bytes_per, desc) in formats.items():
mem = num_params * bytes_per / 1e6
print(f"{name:10} ({desc:20}): {mem:.1f} MB")
demonstrate_precision_formats()
CPU Offloading
CPU offloading exploits the fact that your machine typically has 10-50x more CPU RAM than GPU memory. The idea is simple: keep only what the GPU needs right now on the GPU, and store everything else on the CPU (or even NVMe). This is the same principle behind ZeRO-Offload in DeepSpeed and FSDP’s CPU offloading in PyTorch. The trade-off is straightforward: you trade training speed (because of PCIe data transfers) for the ability to train models that otherwise would not fit at all.CPU offloading is most effective for optimizer states (which are only needed during the parameter update step, not during forward/backward). Offloading activations is less effective because they are accessed frequently during the backward pass, making the PCIe bandwidth a bottleneck. If you are using DeepSpeed, start with ZeRO Stage 2 + CPU offloading of optimizer states before trying full activation offloading.
class CPUOffloadOptimizer:
"""
Offload optimizer states to CPU to save GPU memory.
Trade-off: Slower training (CPU<->GPU transfers over PCIe) but can
train models 2-4x larger than what fits in GPU memory alone.
When to use: When your model fits on the GPU for forward/backward
but OOMs during the optimizer step (because Adam doubles the memory
footprint with its momentum and variance states).
"""
def __init__(
self,
model: nn.Module,
lr: float = 1e-4,
offload: bool = True
):
self.model = model
self.offload = offload
if offload:
# Keep parameters on GPU but states on CPU
self.param_groups = [
{
'params': list(model.parameters()),
'lr': lr
}
]
# Initialize optimizer states on CPU
self.state = {}
for p in model.parameters():
if p.requires_grad:
# Adam states
self.state[p] = {
'm': torch.zeros_like(p.data, device='cpu'),
'v': torch.zeros_like(p.data, device='cpu'),
't': 0
}
else:
self.optimizer = optim.AdamW(model.parameters(), lr=lr)
def zero_grad(self):
"""Zero gradients."""
if self.offload:
for p in self.model.parameters():
if p.grad is not None:
p.grad.zero_()
else:
self.optimizer.zero_grad()
def step(self):
"""Perform optimizer step."""
if not self.offload:
self.optimizer.step()
return
lr = self.param_groups[0]['lr']
beta1, beta2 = 0.9, 0.999
eps = 1e-8
for p in self.model.parameters():
if not p.requires_grad or p.grad is None:
continue
state = self.state[p]
state['t'] += 1
t = state['t']
# Move gradient to CPU
grad_cpu = p.grad.data.cpu()
# Update moments on CPU
state['m'].mul_(beta1).add_(grad_cpu, alpha=1 - beta1)
state['v'].mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2)
# Bias correction
m_hat = state['m'] / (1 - beta1 ** t)
v_hat = state['v'] / (1 - beta2 ** t)
# Compute update on CPU
update = m_hat / (v_hat.sqrt() + eps)
# Apply update to GPU parameter
p.data.add_(update.to(p.device), alpha=-lr)
class GradientOffloading:
"""
Offload gradients to CPU during backward pass.
Useful for very large models where even gradients don't fit.
"""
def __init__(self, model: nn.Module):
self.model = model
self.cpu_gradients = {}
self._register_hooks()
def _register_hooks(self):
"""Register hooks to offload gradients."""
for name, param in self.model.named_parameters():
if param.requires_grad:
param.register_post_accumulate_grad_hook(
self._make_hook(name)
)
def _make_hook(self, name: str):
def hook(param):
if param.grad is not None:
# Move gradient to CPU
self.cpu_gradients[name] = param.grad.cpu()
# Free GPU memory
param.grad = None
return hook
def restore_gradients(self):
"""Move gradients back to GPU for optimizer."""
for name, param in self.model.named_parameters():
if name in self.cpu_gradients:
param.grad = self.cpu_gradients[name].to(param.device)
Activation Recomputation Strategies
class SelectiveCheckpointing:
"""
Smart checkpointing strategies.
Not all layers benefit equally from checkpointing.
Focus on high-memory layers.
"""
@staticmethod
def should_checkpoint(module: nn.Module) -> bool:
"""
Decide if a module should be checkpointed.
Rules:
- Checkpoint attention (quadratic memory)
- Checkpoint large FFN layers
- Don't checkpoint small layers (overhead not worth it)
"""
if isinstance(module, nn.MultiheadAttention):
return True
# Check for large linear layers
if isinstance(module, nn.Linear):
return module.in_features * module.out_features > 1e6
return False
@staticmethod
def estimate_activation_memory(
module: nn.Module,
input_shape: tuple
) -> float:
"""
Estimate activation memory for a module.
Returns memory in MB.
"""
# This is a simplified estimate
batch_size = input_shape[0]
if isinstance(module, nn.Linear):
# Output activations
return batch_size * module.out_features * 4 / 1e6
if isinstance(module, nn.MultiheadAttention):
# Attention weights (quadratic in sequence length)
seq_len = input_shape[1]
num_heads = module.num_heads
return batch_size * num_heads * seq_len * seq_len * 4 / 1e6
return 0.0
class LayerWiseCheckpointing(nn.Module):
"""
Apply checkpointing at layer granularity.
"""
def __init__(
self,
layers: nn.ModuleList,
checkpoint_ratio: float = 0.5
):
super().__init__()
self.layers = layers
# Checkpoint every other layer (or based on ratio)
self.checkpoint_mask = [
i % int(1 / checkpoint_ratio) == 0
for i in range(len(layers))
]
def forward(self, x: torch.Tensor) -> torch.Tensor:
for i, layer in enumerate(self.layers):
if self.checkpoint_mask[i] and self.training:
x = checkpoint(layer, x, use_reentrant=False)
else:
x = layer(x)
return x
Memory-Efficient Attention
Attention is often the single largest memory consumer in transformer models. Standard self-attention materializes an N×N attention matrix, where N is the sequence length. For a 4096-token sequence with 32 attention heads in float16, that single matrix consumes 32×4096×4096×2 bytes = 1 GB. Double the sequence length and the memory quadruples. The methods below tackle this bottleneck from different angles.If you are using PyTorch 2.0+ or later, use
torch.nn.functional.scaled_dot_product_attention with is_causal=True or attn_mask. PyTorch will automatically select the most efficient backend (FlashAttention, memory-efficient attention, or math fallback) based on your hardware, input shapes, and whether you need dropout. This is almost always the right choice — only implement custom attention when you need something the fused kernel does not support.class MemoryEfficientAttention(nn.Module):
"""
Memory-efficient attention implementations.
Standard attention: O(n^2) memory for the full attention matrix.
Chunked attention: O(n * chunk_size) memory -- process queries in chunks,
computing attention scores against all keys/values for each chunk.
Flash Attention: O(n) memory -- never materializes the full matrix at all,
using a tiling algorithm that stays in GPU SRAM.
"""
def __init__(
self,
d_model: int,
num_heads: int,
chunk_size: int = 1024
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.chunk_size = chunk_size
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
# Compute Q, K, V
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, D)
q, k, v = qkv[0], qkv[1], qkv[2]
# Chunked attention to reduce peak memory
output = self.chunked_attention(q, k, v)
output = output.transpose(1, 2).reshape(B, N, C)
return self.out_proj(output)
def chunked_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor
) -> torch.Tensor:
"""
Compute attention in chunks.
Instead of computing full NxN attention matrix,
process in chunks to reduce memory.
"""
B, H, N, D = q.shape
scale = D ** -0.5
output = torch.zeros_like(v)
for i in range(0, N, self.chunk_size):
end_i = min(i + self.chunk_size, N)
q_chunk = q[:, :, i:end_i]
# Compute attention for this query chunk
attn = torch.matmul(q_chunk, k.transpose(-2, -1)) * scale
attn = attn.softmax(dim=-1)
output[:, :, i:end_i] = torch.matmul(attn, v)
return output
class FlashAttention(nn.Module):
"""
Flash Attention wrapper (requires flash-attn package).
Key innovations:
- Tiling to fit in SRAM
- No materialization of attention matrix
- IO-aware algorithm design
Memory: O(n) instead of O(n²)
Speed: 2-4x faster than standard attention
"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.dropout = dropout
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
self._check_flash_available()
def _check_flash_available(self):
"""Check if Flash Attention is available."""
try:
from flash_attn import flash_attn_func
self.flash_attn = flash_attn_func
self.use_flash = True
except ImportError:
print("Flash Attention not available. Using standard attention.")
self.use_flash = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
if self.use_flash:
# Flash attention format: (B, N, H, D)
q = qkv[:, :, 0].transpose(1, 2)
k = qkv[:, :, 1].transpose(1, 2)
v = qkv[:, :, 2].transpose(1, 2)
out = self.flash_attn(q, k, v, dropout_p=self.dropout if self.training else 0.0)
out = out.transpose(1, 2).reshape(B, N, C)
else:
# Standard attention
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
scale = self.head_dim ** -0.5
attn = (q @ k.transpose(-2, -1)) * scale
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, N, C)
return self.out_proj(out)
Efficient Batch Processing
class GradientAccumulation:
"""
Accumulate gradients over multiple mini-batches.
Simulates larger batch sizes without memory increase.
"""
def __init__(
self,
model: nn.Module,
optimizer: optim.Optimizer,
accumulation_steps: int = 4
):
self.model = model
self.optimizer = optimizer
self.accumulation_steps = accumulation_steps
self.current_step = 0
def backward(self, loss: torch.Tensor):
"""
Backward with gradient accumulation.
"""
# Scale loss by accumulation steps
scaled_loss = loss / self.accumulation_steps
scaled_loss.backward()
self.current_step += 1
if self.current_step >= self.accumulation_steps:
# Perform optimizer step
self.optimizer.step()
self.optimizer.zero_grad()
self.current_step = 0
return True # Step was performed
return False # Step not performed yet
class MicroBatching:
"""
Split batches into micro-batches for forward pass.
Useful when single batch doesn't fit but you need
batch normalization statistics from full batch.
"""
@staticmethod
def micro_batch_forward(
model: nn.Module,
batch: torch.Tensor,
micro_batch_size: int
) -> torch.Tensor:
"""
Forward in micro-batches.
"""
outputs = []
for i in range(0, batch.size(0), micro_batch_size):
micro_batch = batch[i:i + micro_batch_size]
with torch.no_grad() if not model.training else torch.enable_grad():
output = model(micro_batch)
outputs.append(output)
return torch.cat(outputs, dim=0)
class InPlaceOperations:
"""
In-place operations to reduce memory allocations.
Use with caution - can break gradient computation!
"""
@staticmethod
def inplace_relu(x: torch.Tensor) -> torch.Tensor:
"""In-place ReLU."""
return torch.relu_(x)
@staticmethod
def inplace_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""In-place addition."""
return x.add_(y)
@staticmethod
def check_safe_inplace(tensor: torch.Tensor) -> bool:
"""
Check if in-place operation is safe.
Not safe if:
- Tensor requires grad and is a leaf
- Tensor is used elsewhere in computation graph
"""
if tensor.requires_grad and tensor.is_leaf:
return False
return True
Memory-Efficient Data Loading
class MemoryEfficientDataLoader:
"""
Reduce data loading memory overhead.
"""
@staticmethod
def create_efficient_loader(
dataset,
batch_size: int,
num_workers: int = 4,
pin_memory: bool = True
):
"""
Create memory-efficient data loader.
Tips:
- pin_memory for faster GPU transfer
- Appropriate num_workers (usually 4 per GPU)
- prefetch_factor controls memory usage
"""
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
# Limit prefetching to reduce memory
prefetch_factor=2 if num_workers > 0 else None,
# Don't copy tensors
persistent_workers=True if num_workers > 0 else False,
)
@staticmethod
def use_memory_mapping(file_path: str):
"""
Use memory-mapped files for large datasets.
Data stays on disk, loaded on demand.
"""
return np.memmap(file_path, dtype='float32', mode='r')
class LowMemoryDataset(torch.utils.data.Dataset):
"""
Dataset that loads data on-demand.
"""
def __init__(self, file_paths: List[str]):
self.file_paths = file_paths
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
# Load only when needed
data = torch.load(self.file_paths[idx])
return data
Memory Debugging Tools
class MemoryDebugger:
"""
Tools for debugging memory issues.
"""
@staticmethod
def find_memory_leaks():
"""Find potential memory leaks."""
if not torch.cuda.is_available():
return []
gc.collect()
torch.cuda.empty_cache()
leaks = []
for obj in gc.get_objects():
if isinstance(obj, torch.Tensor):
if obj.is_cuda:
leaks.append({
'shape': obj.shape,
'dtype': obj.dtype,
'size_mb': obj.numel() * obj.element_size() / 1e6,
'requires_grad': obj.requires_grad
})
return sorted(leaks, key=lambda x: -x['size_mb'])
@staticmethod
def memory_summary():
"""Print detailed memory summary."""
if not torch.cuda.is_available():
print("CUDA not available")
return
print(torch.cuda.memory_summary())
@staticmethod
def clear_memory():
"""Aggressively clear GPU memory."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Usage demonstration
def memory_optimization_demo():
"""Demonstrate memory optimization techniques."""
summary = """
╔════════════════════════════════════════════════════════════════╗
║ MEMORY OPTIMIZATION TECHNIQUES ║
╠════════════════════════════════════════════════════════════════╣
║ ║
║ TECHNIQUE MEMORY SAVING COMPUTE COST ║
║ ───────────────────────────────────────────────────────────── ║
║ Mixed Precision ~50% ~0% (often faster!) ║
║ Gradient Checkpointing ~60-70% ~30-40% ║
║ Gradient Accumulation Linear in steps ~0% ║
║ Flash Attention ~50% for attn ~0% (often faster!) ║
║ CPU Offloading ~50-70% ~20-50% ║
║ ║
╠════════════════════════════════════════════════════════════════╣
║ RECOMMENDATION ORDER: ║
║ 1. Enable mixed precision (always) ║
║ 2. Use Flash Attention (if available) ║
║ 3. Add gradient checkpointing ║
║ 4. Gradient accumulation for larger effective batch ║
║ 5. CPU offloading as last resort ║
╚════════════════════════════════════════════════════════════════╝
"""
print(summary)
memory_optimization_demo()
Exercises
Exercise 1: Profile Your Model
Exercise 1: Profile Your Model
Use the MemoryProfiler to analyze a model’s memory usage at each layer.
Identify the biggest memory consumers and optimize them.
Exercise 2: Implement Selective Checkpointing
Exercise 2: Implement Selective Checkpointing
Create a system that automatically decides which layers to checkpoint
based on their memory usage vs. compute cost.
Exercise 3: Memory Budget Training
Exercise 3: Memory Budget Training
Implement training that stays within a fixed memory budget by
dynamically adjusting batch size and checkpointing.
What’s Next?
Quantization Deep Dive
Post-training and quantization-aware training
Knowledge Distillation
Transfer knowledge to smaller models