Skip to main content

Documentation Index

Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt

Use this file to discover all available pages before exploring further.

Debugging Deep Learning

Debugging Deep Learning

Common Training Failures

Deep learning debugging is notoriously difficult — unlike a segfault that points you to the offending line, a neural network fails silently. Your training loop runs without errors, the loss decreases smoothly, and three days later you discover the model has learned something completely useless. There are no compiler warnings for “your labels are accidentally shuffled” or “your learning rate is 1000x too high.” An analogy: Debugging a neural network is like diagnosing a sick patient who cannot tell you their symptoms. You have to run tests (sanity checks), look at vital signs (gradient norms, loss curves), and use process of elimination. The best debuggers are not the ones who can read error messages — they are the ones who have a systematic checklist of things to verify before they ever start training.
SymptomPossible Causes
Loss is NaNExploding gradients, bad learning rate, log(0)
Loss doesn’t decreaseLR too low, bug in loss, wrong labels
Loss decreases then plateausNeeds LR decay, underfitting
Val loss increases (train decreases)Overfitting
Accuracy stuck at randomLabels shuffled wrong, bug in model

Gradient Health Checks

Monitor Gradient Norms

import torch
import matplotlib.pyplot as plt

def get_gradient_norms(model):
    """Get gradient norms per layer."""
    grad_norms = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norms[name] = param.grad.norm().item()
    return grad_norms

# During training
for epoch in range(epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        loss = model(batch)
        loss.backward()
        
        # Check gradients before step
        grad_norms = get_gradient_norms(model)
        if any(norm > 100 for norm in grad_norms.values()):
            print("Warning: Large gradients detected!")
        
        # Clip gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()

Visualize Gradient Flow

def plot_grad_flow(model):
    """Visualize gradient flow through layers."""
    layers = []
    avg_grads = []
    max_grads = []
    
    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None:
            layers.append(name)
            avg_grads.append(param.grad.abs().mean().item())
            max_grads.append(param.grad.abs().max().item())
    
    plt.figure(figsize=(12, 6))
    plt.bar(range(len(layers)), max_grads, alpha=0.5, label='Max')
    plt.bar(range(len(layers)), avg_grads, alpha=0.5, label='Mean')
    plt.xticks(range(len(layers)), layers, rotation=90)
    plt.xlabel('Layers')
    plt.ylabel('Gradient magnitude')
    plt.legend()
    plt.tight_layout()
    plt.savefig('gradient_flow.png')

Detecting NaN/Inf

def check_for_nan(model, loss):
    """Check for NaN in loss and parameters."""
    if torch.isnan(loss) or torch.isinf(loss):
        print(f"Loss is {loss.item()}")
        return True
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            if torch.isnan(param.grad).any():
                print(f"NaN gradient in {name}")
                return True
            if torch.isinf(param.grad).any():
                print(f"Inf gradient in {name}")
                return True
    
    return False

# Use anomaly detection
torch.autograd.set_detect_anomaly(True)  # Slow but catches issues

try:
    loss.backward()
except RuntimeError as e:
    print(f"Backward pass failed: {e}")

Sanity Checks

1. Overfit a Single Batch

This is the single most important debugging technique in deep learning. Before training on the full dataset, verify that your model can memorize a single batch of data to near-perfect accuracy. If it cannot, something is fundamentally broken — a bug in the model architecture, the loss function, the data pipeline, or the optimizer. Do not waste hours training on the full dataset until this test passes. Think of it as a smoke test: if the car will not start in the driveway, do not take it on the highway.
def overfit_single_batch(model, dataloader, epochs=100):
    """THE most important sanity check in deep learning.
    If the model cannot overfit one batch, something is fundamentally broken.
    This test catches: wrong loss function, broken forward pass, 
    mismatched input/output dimensions, label encoding errors, and more.
    """
    batch = next(iter(dataloader))
    x, y = batch
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            acc = (output.argmax(1) == y).float().mean()
            print(f"Epoch {epoch}: Loss={loss.item():.4f}, Acc={acc.item():.4f}")
    
    # Should reach ~100% accuracy on this single batch
    final_acc = (model(x).argmax(1) == y).float().mean()
    assert final_acc > 0.99, f"Failed to overfit! Acc={final_acc}"
    print("Model can overfit a single batch -- forward pass and loss are correct")
If this test fails, check in this order: (1) Are the labels correct? (print a few and verify visually), (2) Are input dimensions correct? (print shapes at each layer), (3) Is the loss function appropriate for your task? (e.g., using BCE for multi-class instead of cross-entropy), (4) Is the learning rate too low? (try 1e-2 or even 1e-1).

2. Check Data Pipeline

def verify_data_pipeline(dataloader):
    """Verify data loading is correct."""
    batch = next(iter(dataloader))
    x, y = batch
    
    print(f"Batch shape: {x.shape}")
    print(f"Labels shape: {y.shape}")
    print(f"Label distribution: {torch.bincount(y)}")
    print(f"Input range: [{x.min():.2f}, {x.max():.2f}]")
    print(f"Input mean: {x.mean():.2f}, std: {x.std():.2f}")
    
    # Visualize
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    for i, ax in enumerate(axes.flat):
        img = x[i].permute(1, 2, 0).numpy()
        img = (img - img.min()) / (img.max() - img.min())  # Normalize for display
        ax.imshow(img)
        ax.set_title(f"Label: {y[i].item()}")
        ax.axis('off')
    plt.savefig('data_samples.png')

3. Verify Loss at Initialization

A randomly initialized classifier should assign roughly equal probability to all classes. For cross-entropy loss with KK classes, this means the initial loss should be approximately log(1/K)=log(K)-\log(1/K) = \log(K). For CIFAR-10 (10 classes), expect ~2.30. For ImageNet (1000 classes), expect ~6.91. If your initial loss is significantly different, something is wrong with the model or the loss function.
def check_initial_loss(model, dataloader, num_classes):
    """Loss should be ~log(num_classes) for random weights.
    This catches: wrong number of output classes, broken softmax,
    incorrect loss function, biased initialization.
    """
    model.eval()
    batch = next(iter(dataloader))
    x, y = batch
    
    with torch.no_grad():
        output = model(x)
        loss = F.cross_entropy(output, y)
    
    expected = -torch.log(torch.tensor(1.0 / num_classes))
    print(f"Initial loss: {loss.item():.4f}")
    print(f"Expected (random): {expected.item():.4f}")
    
    if abs(loss.item() - expected.item()) > 0.5:
        print("WARNING: Initial loss is unexpected - check model initialization")
        print("  If loss is much HIGHER: output layer may have wrong dimensions")
        print("  If loss is much LOWER: model may have a bias toward certain classes")
Why this matters: If initial loss is 0.1 when it should be 2.3, your model is already “confident” before seeing any data — usually meaning the final layer bias is accidentally initialized to favor certain classes. If initial loss is 15.0 when it should be 6.9, the logits are likely unnormalized or the loss function is wrong.

Loss Landscape Visualization

def plot_loss_landscape(model, dataloader, resolution=20):
    """Visualize 2D loss landscape around current parameters."""
    import copy
    
    # Get two random directions
    direction1 = [torch.randn_like(p) for p in model.parameters()]
    direction2 = [torch.randn_like(p) for p in model.parameters()]
    
    # Normalize directions
    d1_norm = torch.sqrt(sum((d ** 2).sum() for d in direction1))
    d2_norm = torch.sqrt(sum((d ** 2).sum() for d in direction2))
    direction1 = [d / d1_norm for d in direction1]
    direction2 = [d / d2_norm for d in direction2]
    
    # Save original parameters
    original_params = [p.clone() for p in model.parameters()]
    
    losses = torch.zeros(resolution, resolution)
    alphas = torch.linspace(-1, 1, resolution)
    betas = torch.linspace(-1, 1, resolution)
    
    batch = next(iter(dataloader))
    x, y = batch
    
    for i, alpha in enumerate(alphas):
        for j, beta in enumerate(betas):
            # Perturb parameters
            for p, orig, d1, d2 in zip(model.parameters(), original_params, direction1, direction2):
                p.data = orig + alpha * d1 + beta * d2
            
            with torch.no_grad():
                loss = F.cross_entropy(model(x), y)
                losses[i, j] = loss.item()
    
    # Restore original parameters
    for p, orig in zip(model.parameters(), original_params):
        p.data = orig
    
    # Plot
    plt.figure(figsize=(8, 6))
    plt.contourf(alphas.numpy(), betas.numpy(), losses.numpy(), levels=50)
    plt.colorbar(label='Loss')
    plt.xlabel('Direction 1')
    plt.ylabel('Direction 2')
    plt.title('Loss Landscape')
    plt.savefig('loss_landscape.png')

Common Fixes

ProblemSymptomsFixWhat to Check First
Exploding gradientsLoss spikes to NaN, grad norms > 100Gradient clipping, lower LR, layer normDid you forget gradient clipping? Is LR too high?
Vanishing gradientsEarly layers stop updating, loss plateausResidual connections, better init (He/Xavier)Are you using sigmoid/tanh? Switch to ReLU/GELU.
Loss is NaNLoss becomes NaN after a few stepsCheck for log(0), division by zero, add epsAdd + 1e-8 inside log() and sqrt() calls.
Not learningLoss stays flat, accuracy at random chanceVerify data, check loss function, increase LRAre labels correct? Try overfitting a single batch.
OverfittingTrain loss decreasing, val loss increasingRegularization, more data, smaller modelAdd dropout, data augmentation, weight decay.
UnderfittingBoth train and val loss highLarger model, more training, check data qualityIs the model too small? Is the data noisy or mislabeled?
The number one debugging rule: Change one thing at a time. If you simultaneously increase the learning rate, add dropout, and change the architecture, you will never know which change had what effect. Disciplined, isolated experiments save more time than they cost.

Debugging Toolkit

class DebugHooks:
    """Attach hooks to monitor activations and gradients."""
    
    def __init__(self, model):
        self.activations = {}
        self.gradients = {}
        
        for name, module in model.named_modules():
            module.register_forward_hook(self._get_activation_hook(name))
            module.register_full_backward_hook(self._get_gradient_hook(name))
    
    def _get_activation_hook(self, name):
        def hook(module, input, output):
            if isinstance(output, torch.Tensor):
                self.activations[name] = output.detach()
        return hook
    
    def _get_gradient_hook(self, name):
        def hook(module, grad_input, grad_output):
            if grad_output[0] is not None:
                self.gradients[name] = grad_output[0].detach()
        return hook
    
    def print_stats(self):
        print("\n=== Activation Stats ===")
        for name, act in self.activations.items():
            print(f"{name}: mean={act.mean():.4f}, std={act.std():.4f}, "
                  f"dead={((act == 0).sum() / act.numel() * 100):.1f}%")

Exercises

Given a model that produces NaN loss, use debugging techniques to find and fix the issue.
Implement gradient flow visualization for a deep network. Identify vanishing gradients.
Generate loss landscape visualizations for networks with and without batch normalization.

What’s Next

Module 23: Vision Transformers

ViT, DeiT, Swin Transformer, and modern vision architectures.