Skip to main content
Debugging Deep Learning

Debugging Deep Learning

Common Training Failures

Deep learning debugging is notoriously difficult — models fail silently or in confusing ways.
SymptomPossible Causes
Loss is NaNExploding gradients, bad learning rate, log(0)
Loss doesn’t decreaseLR too low, bug in loss, wrong labels
Loss decreases then plateausNeeds LR decay, underfitting
Val loss increases (train decreases)Overfitting
Accuracy stuck at randomLabels shuffled wrong, bug in model

Gradient Health Checks

Monitor Gradient Norms

import torch
import matplotlib.pyplot as plt

def get_gradient_norms(model):
    """Get gradient norms per layer."""
    grad_norms = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norms[name] = param.grad.norm().item()
    return grad_norms

# During training
for epoch in range(epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        loss = model(batch)
        loss.backward()
        
        # Check gradients before step
        grad_norms = get_gradient_norms(model)
        if any(norm > 100 for norm in grad_norms.values()):
            print("Warning: Large gradients detected!")
        
        # Clip gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()

Visualize Gradient Flow

def plot_grad_flow(model):
    """Visualize gradient flow through layers."""
    layers = []
    avg_grads = []
    max_grads = []
    
    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None:
            layers.append(name)
            avg_grads.append(param.grad.abs().mean().item())
            max_grads.append(param.grad.abs().max().item())
    
    plt.figure(figsize=(12, 6))
    plt.bar(range(len(layers)), max_grads, alpha=0.5, label='Max')
    plt.bar(range(len(layers)), avg_grads, alpha=0.5, label='Mean')
    plt.xticks(range(len(layers)), layers, rotation=90)
    plt.xlabel('Layers')
    plt.ylabel('Gradient magnitude')
    plt.legend()
    plt.tight_layout()
    plt.savefig('gradient_flow.png')

Detecting NaN/Inf

def check_for_nan(model, loss):
    """Check for NaN in loss and parameters."""
    if torch.isnan(loss) or torch.isinf(loss):
        print(f"Loss is {loss.item()}")
        return True
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            if torch.isnan(param.grad).any():
                print(f"NaN gradient in {name}")
                return True
            if torch.isinf(param.grad).any():
                print(f"Inf gradient in {name}")
                return True
    
    return False

# Use anomaly detection
torch.autograd.set_detect_anomaly(True)  # Slow but catches issues

try:
    loss.backward()
except RuntimeError as e:
    print(f"Backward pass failed: {e}")

Sanity Checks

1. Overfit a Single Batch

def overfit_single_batch(model, dataloader, epochs=100):
    """If model can't overfit one batch, something is wrong."""
    batch = next(iter(dataloader))
    x, y = batch
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            acc = (output.argmax(1) == y).float().mean()
            print(f"Epoch {epoch}: Loss={loss.item():.4f}, Acc={acc.item():.4f}")
    
    # Should reach ~100% accuracy
    final_acc = (model(x).argmax(1) == y).float().mean()
    assert final_acc > 0.99, f"Failed to overfit! Acc={final_acc}"
    print("✓ Model can overfit a single batch")

2. Check Data Pipeline

def verify_data_pipeline(dataloader):
    """Verify data loading is correct."""
    batch = next(iter(dataloader))
    x, y = batch
    
    print(f"Batch shape: {x.shape}")
    print(f"Labels shape: {y.shape}")
    print(f"Label distribution: {torch.bincount(y)}")
    print(f"Input range: [{x.min():.2f}, {x.max():.2f}]")
    print(f"Input mean: {x.mean():.2f}, std: {x.std():.2f}")
    
    # Visualize
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    for i, ax in enumerate(axes.flat):
        img = x[i].permute(1, 2, 0).numpy()
        img = (img - img.min()) / (img.max() - img.min())  # Normalize for display
        ax.imshow(img)
        ax.set_title(f"Label: {y[i].item()}")
        ax.axis('off')
    plt.savefig('data_samples.png')

3. Verify Loss at Initialization

def check_initial_loss(model, dataloader, num_classes):
    """Loss should be ~log(num_classes) for random weights."""
    model.eval()
    batch = next(iter(dataloader))
    x, y = batch
    
    with torch.no_grad():
        output = model(x)
        loss = F.cross_entropy(output, y)
    
    expected = -torch.log(torch.tensor(1.0 / num_classes))
    print(f"Initial loss: {loss.item():.4f}")
    print(f"Expected (random): {expected.item():.4f}")
    
    if abs(loss.item() - expected.item()) > 0.5:
        print("⚠ Initial loss is unexpected - check model initialization")

Loss Landscape Visualization

def plot_loss_landscape(model, dataloader, resolution=20):
    """Visualize 2D loss landscape around current parameters."""
    import copy
    
    # Get two random directions
    direction1 = [torch.randn_like(p) for p in model.parameters()]
    direction2 = [torch.randn_like(p) for p in model.parameters()]
    
    # Normalize directions
    d1_norm = torch.sqrt(sum((d ** 2).sum() for d in direction1))
    d2_norm = torch.sqrt(sum((d ** 2).sum() for d in direction2))
    direction1 = [d / d1_norm for d in direction1]
    direction2 = [d / d2_norm for d in direction2]
    
    # Save original parameters
    original_params = [p.clone() for p in model.parameters()]
    
    losses = torch.zeros(resolution, resolution)
    alphas = torch.linspace(-1, 1, resolution)
    betas = torch.linspace(-1, 1, resolution)
    
    batch = next(iter(dataloader))
    x, y = batch
    
    for i, alpha in enumerate(alphas):
        for j, beta in enumerate(betas):
            # Perturb parameters
            for p, orig, d1, d2 in zip(model.parameters(), original_params, direction1, direction2):
                p.data = orig + alpha * d1 + beta * d2
            
            with torch.no_grad():
                loss = F.cross_entropy(model(x), y)
                losses[i, j] = loss.item()
    
    # Restore original parameters
    for p, orig in zip(model.parameters(), original_params):
        p.data = orig
    
    # Plot
    plt.figure(figsize=(8, 6))
    plt.contourf(alphas.numpy(), betas.numpy(), losses.numpy(), levels=50)
    plt.colorbar(label='Loss')
    plt.xlabel('Direction 1')
    plt.ylabel('Direction 2')
    plt.title('Loss Landscape')
    plt.savefig('loss_landscape.png')

Common Fixes

ProblemFix
Exploding gradientsGradient clipping, lower LR, layer norm
Vanishing gradientsResidual connections, better initialization
Loss is NaNCheck for log(0), division by zero
Not learningVerify data, check loss function, increase LR
OverfittingRegularization, more data, smaller model
UnderfittingLarger model, more training, check data

Debugging Toolkit

class DebugHooks:
    """Attach hooks to monitor activations and gradients."""
    
    def __init__(self, model):
        self.activations = {}
        self.gradients = {}
        
        for name, module in model.named_modules():
            module.register_forward_hook(self._get_activation_hook(name))
            module.register_full_backward_hook(self._get_gradient_hook(name))
    
    def _get_activation_hook(self, name):
        def hook(module, input, output):
            if isinstance(output, torch.Tensor):
                self.activations[name] = output.detach()
        return hook
    
    def _get_gradient_hook(self, name):
        def hook(module, grad_input, grad_output):
            if grad_output[0] is not None:
                self.gradients[name] = grad_output[0].detach()
        return hook
    
    def print_stats(self):
        print("\n=== Activation Stats ===")
        for name, act in self.activations.items():
            print(f"{name}: mean={act.mean():.4f}, std={act.std():.4f}, "
                  f"dead={((act == 0).sum() / act.numel() * 100):.1f}%")

Exercises

Given a model that produces NaN loss, use debugging techniques to find and fix the issue.
Implement gradient flow visualization for a deep network. Identify vanishing gradients.
Generate loss landscape visualizations for networks with and without batch normalization.

What’s Next