Skip to main content
Gradient Flow

Gradient Flow: The Lifeblood of Deep Learning

Understanding Gradient Dynamics

Gradients are how neural networks learn. They flow backward through the network, telling each parameter how to change. When this flow is disrupted, learning stops. LW(1)=Ly^y^a(L)a(L)a(L1)a(2)W(1)\frac{\partial \mathcal{L}}{\partial W^{(1)}} = \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial a^{(L)}} \cdot \frac{\partial a^{(L)}}{\partial a^{(L-1)}} \cdots \frac{\partial a^{(2)}}{\partial W^{(1)}} This chain of multiplications is the crux of all gradient problems.
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from collections import defaultdict

torch.manual_seed(42)

The Vanishing Gradient Problem

Mathematical Analysis

For a sigmoid activation σ(x)=11+ex\sigma(x) = \frac{1}{1 + e^{-x}}: σ(x)=σ(x)(1σ(x))0.25\sigma'(x) = \sigma(x)(1 - \sigma(x)) \leq 0.25 Through LL layers: LW(1)0.25LLW(L)\left\|\frac{\partial \mathcal{L}}{\partial W^{(1)}}\right\| \leq 0.25^L \cdot \left\|\frac{\partial \mathcal{L}}{\partial W^{(L)}}\right\|
def analyze_vanishing_gradients():
    """Demonstrate vanishing gradients mathematically."""
    
    # Sigmoid gradient analysis
    x = np.linspace(-10, 10, 1000)
    sigmoid = 1 / (1 + np.exp(-x))
    sigmoid_grad = sigmoid * (1 - sigmoid)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Sigmoid function
    axes[0].plot(x, sigmoid, 'b-', linewidth=2)
    axes[0].set_title('Sigmoid Activation')
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('σ(x)')
    axes[0].grid(True, alpha=0.3)
    axes[0].axhline(y=0.5, color='r', linestyle='--', alpha=0.5)
    
    # Sigmoid gradient
    axes[1].plot(x, sigmoid_grad, 'g-', linewidth=2)
    axes[1].set_title("Sigmoid Gradient (max=0.25)")
    axes[1].set_xlabel('x')
    axes[1].set_ylabel("σ'(x)")
    axes[1].grid(True, alpha=0.3)
    axes[1].axhline(y=0.25, color='r', linestyle='--', alpha=0.5, label='max=0.25')
    axes[1].legend()
    
    # Gradient decay through layers
    layers = np.arange(1, 51)
    max_grad_factor = 0.25 ** layers
    
    axes[2].semilogy(layers, max_grad_factor, 'r-', linewidth=2)
    axes[2].set_title('Gradient Decay Through Layers')
    axes[2].set_xlabel('Number of Layers')
    axes[2].set_ylabel('Max Gradient Factor')
    axes[2].grid(True, alpha=0.3)
    
    # Annotate specific points
    for l in [10, 20, 30, 40, 50]:
        axes[2].annotate(f'{0.25**l:.2e}', (l, 0.25**l), 
                        textcoords="offset points", xytext=(0,10), ha='center')
    
    plt.tight_layout()
    plt.show()
    
    print("Gradient Decay Analysis")
    print("="*50)
    for depth in [10, 20, 50, 100]:
        factor = 0.25 ** depth
        print(f"After {depth} layers: gradient ≤ {factor:.2e}")

analyze_vanishing_gradients()

Live Demonstration

def vanishing_gradient_demo():
    """Watch gradients vanish in a deep sigmoid network."""
    
    class DeepSigmoid(nn.Module):
        def __init__(self, depth, width=100):
            super().__init__()
            layers = []
            for _ in range(depth):
                layers.append(nn.Linear(width, width))
                layers.append(nn.Sigmoid())
            layers.append(nn.Linear(width, 1))
            self.net = nn.Sequential(*layers)
        
        def forward(self, x):
            return self.net(x)
    
    depths = [5, 10, 20, 50]
    
    print("Vanishing Gradient Demonstration")
    print("="*60)
    
    results = {}
    
    for depth in depths:
        model = DeepSigmoid(depth)
        x = torch.randn(32, 100)
        y = torch.randn(32, 1)
        
        # Forward and backward
        output = model(x)
        loss = nn.MSELoss()(output, y)
        loss.backward()
        
        # Collect gradient norms per layer
        grad_norms = []
        for name, param in model.named_parameters():
            if 'weight' in name and param.grad is not None:
                grad_norms.append(param.grad.norm().item())
        
        results[depth] = grad_norms
        
        print(f"\nDepth {depth}:")
        print(f"  First layer grad norm: {grad_norms[0]:.2e}")
        print(f"  Last layer grad norm:  {grad_norms[-1]:.2e}")
        print(f"  Ratio (first/last):    {grad_norms[0]/grad_norms[-1]:.2e}")
    
    # Plot
    plt.figure(figsize=(12, 5))
    for depth, grads in results.items():
        plt.semilogy(range(len(grads)), grads, 'o-', label=f'Depth={depth}')
    
    plt.xlabel('Layer Index')
    plt.ylabel('Gradient Norm (log scale)')
    plt.title('Gradient Norms Through Network Depth')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

vanishing_gradient_demo()

The Exploding Gradient Problem

When Gradients Explode

If weight matrices have eigenvalues λ>1|\lambda| > 1: i=1LW(i)vλL\prod_{i=1}^{L} W^{(i)} \approx v \lambda^L Where vv is an eigenvector and gradients grow exponentially.
def exploding_gradient_demo():
    """Demonstrate exploding gradients."""
    
    class DeepLinear(nn.Module):
        def __init__(self, depth, width=100):
            super().__init__()
            layers = []
            for _ in range(depth):
                layer = nn.Linear(width, width, bias=False)
                # Initialize with slightly too large weights
                nn.init.normal_(layer.weight, std=1.5 / np.sqrt(width))
                layers.append(layer)
            self.net = nn.Sequential(*layers)
        
        def forward(self, x):
            return self.net(x)
    
    print("Exploding Gradient Demonstration")
    print("="*60)
    
    for depth in [10, 20, 30]:
        model = DeepLinear(depth)
        x = torch.randn(32, 100)
        
        try:
            # Forward pass
            with torch.no_grad():
                activations = [x]
                current = x
                for layer in model.net:
                    current = layer(current)
                    activations.append(current)
            
            # Check for explosion
            output = model(x)
            
            print(f"\nDepth {depth}:")
            print(f"  Input norm:  {x.norm().item():.2e}")
            print(f"  Output norm: {output.norm().item():.2e}")
            
            # Track activation growth
            norms = [a.norm().item() for a in activations]
            growth_rate = norms[-1] / norms[0]
            print(f"  Growth rate: {growth_rate:.2e}")
            
            if np.isnan(output.norm().item()) or np.isinf(output.norm().item()):
                print("  ⚠ EXPLODED to NaN/Inf!")
            elif growth_rate > 1e6:
                print("  ⚠ Severe explosion detected!")
                
        except Exception as e:
            print(f"\nDepth {depth}: Failed - {str(e)[:50]}")

exploding_gradient_demo()

Gradient Clipping Solutions

class GradientClipper:
    """Various gradient clipping strategies."""
    
    @staticmethod
    def clip_by_norm(parameters, max_norm):
        """Clip gradients by global norm (most common)."""
        total_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm)
        return total_norm
    
    @staticmethod
    def clip_by_value(parameters, clip_value):
        """Clip each gradient element to [-clip_value, clip_value]."""
        torch.nn.utils.clip_grad_value_(parameters, clip_value)
    
    @staticmethod
    def clip_by_global_norm_manual(parameters, max_norm):
        """Manual implementation of global norm clipping."""
        parameters = list(filter(lambda p: p.grad is not None, parameters))
        
        # Compute global norm
        total_norm = 0.0
        for p in parameters:
            total_norm += p.grad.data.norm(2).item() ** 2
        total_norm = np.sqrt(total_norm)
        
        # Compute clipping coefficient
        clip_coef = max_norm / (total_norm + 1e-6)
        
        if clip_coef < 1:
            for p in parameters:
                p.grad.data.mul_(clip_coef)
        
        return total_norm, clip_coef
    
    @staticmethod
    def adaptive_clipping(parameters, percentile=10):
        """Clip based on gradient distribution (AdaClip style)."""
        all_grads = []
        for p in parameters:
            if p.grad is not None:
                all_grads.append(p.grad.data.abs().flatten())
        
        all_grads = torch.cat(all_grads)
        threshold = torch.quantile(all_grads, 1 - percentile/100)
        
        for p in parameters:
            if p.grad is not None:
                p.grad.data.clamp_(-threshold, threshold)
        
        return threshold.item()


# Example usage
def gradient_clipping_example():
    """Demonstrate gradient clipping strategies."""
    
    model = nn.Linear(100, 100)
    x = torch.randn(32, 100)
    y = torch.randn(32, 100)
    
    # Simulate large gradients
    loss = nn.MSELoss()(model(x), y) * 1000
    loss.backward()
    
    original_norm = model.weight.grad.norm().item()
    print(f"Original gradient norm: {original_norm:.2f}")
    
    # Clip by norm
    loss.backward()
    clipped_norm = GradientClipper.clip_by_norm(model.parameters(), max_norm=1.0)
    print(f"After clip_by_norm(1.0): {model.weight.grad.norm().item():.2f}")
    
    # Clip by value
    loss.backward()
    GradientClipper.clip_by_value(model.parameters(), clip_value=0.1)
    print(f"After clip_by_value(0.1): {model.weight.grad.norm().item():.2f}")

gradient_clipping_example()

Gradient Flow Visualization

Building a Gradient Monitor

class GradientMonitor:
    """Comprehensive gradient monitoring toolkit."""
    
    def __init__(self, model):
        self.model = model
        self.gradient_history = defaultdict(list)
        self.activation_history = defaultdict(list)
        self.hooks = []
    
    def register_hooks(self):
        """Register forward and backward hooks."""
        
        def forward_hook(name):
            def hook(module, input, output):
                if isinstance(output, torch.Tensor):
                    self.activation_history[name].append({
                        'mean': output.mean().item(),
                        'std': output.std().item(),
                        'min': output.min().item(),
                        'max': output.max().item(),
                        'dead_fraction': (output == 0).float().mean().item()
                    })
            return hook
        
        def backward_hook(name):
            def hook(module, grad_input, grad_output):
                if grad_output[0] is not None:
                    grad = grad_output[0]
                    self.gradient_history[name].append({
                        'mean': grad.mean().item(),
                        'std': grad.std().item(),
                        'norm': grad.norm().item(),
                        'max_abs': grad.abs().max().item()
                    })
            return hook
        
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                self.hooks.append(
                    module.register_forward_hook(forward_hook(name))
                )
                self.hooks.append(
                    module.register_full_backward_hook(backward_hook(name))
                )
    
    def remove_hooks(self):
        """Clean up hooks."""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
    
    def plot_gradients(self, title="Gradient Flow"):
        """Visualize gradient statistics."""
        
        if not self.gradient_history:
            print("No gradients recorded. Did you run a backward pass?")
            return
        
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        
        layer_names = list(self.gradient_history.keys())
        
        # Get latest statistics
        means = [self.gradient_history[n][-1]['mean'] for n in layer_names]
        stds = [self.gradient_history[n][-1]['std'] for n in layer_names]
        norms = [self.gradient_history[n][-1]['norm'] for n in layer_names]
        max_abs = [self.gradient_history[n][-1]['max_abs'] for n in layer_names]
        
        x = range(len(layer_names))
        
        # Gradient means
        axes[0,0].bar(x, means, color='blue', alpha=0.7)
        axes[0,0].set_xlabel('Layer')
        axes[0,0].set_ylabel('Gradient Mean')
        axes[0,0].set_title('Gradient Means')
        axes[0,0].axhline(y=0, color='r', linestyle='--', alpha=0.5)
        
        # Gradient stds
        axes[0,1].bar(x, stds, color='green', alpha=0.7)
        axes[0,1].set_xlabel('Layer')
        axes[0,1].set_ylabel('Gradient Std')
        axes[0,1].set_title('Gradient Standard Deviations')
        axes[0,1].set_yscale('log')
        
        # Gradient norms
        axes[1,0].bar(x, norms, color='orange', alpha=0.7)
        axes[1,0].set_xlabel('Layer')
        axes[1,0].set_ylabel('Gradient Norm')
        axes[1,0].set_title('Gradient Norms per Layer')
        axes[1,0].set_yscale('log')
        
        # Max absolute gradient
        axes[1,1].bar(x, max_abs, color='red', alpha=0.7)
        axes[1,1].set_xlabel('Layer')
        axes[1,1].set_ylabel('Max |Gradient|')
        axes[1,1].set_title('Maximum Absolute Gradient')
        axes[1,1].set_yscale('log')
        
        plt.suptitle(title, fontsize=14)
        plt.tight_layout()
        plt.show()
    
    def plot_gradient_evolution(self, layer_name=None):
        """Plot how gradients evolve over training."""
        
        if layer_name is None:
            layer_name = list(self.gradient_history.keys())[0]
        
        history = self.gradient_history[layer_name]
        
        steps = range(len(history))
        norms = [h['norm'] for h in history]
        stds = [h['std'] for h in history]
        
        fig, axes = plt.subplots(1, 2, figsize=(12, 4))
        
        axes[0].plot(steps, norms, 'b-', linewidth=2)
        axes[0].set_xlabel('Training Step')
        axes[0].set_ylabel('Gradient Norm')
        axes[0].set_title(f'Gradient Norm Evolution - {layer_name}')
        axes[0].grid(True, alpha=0.3)
        
        axes[1].plot(steps, stds, 'g-', linewidth=2)
        axes[1].set_xlabel('Training Step')
        axes[1].set_ylabel('Gradient Std')
        axes[1].set_title(f'Gradient Std Evolution - {layer_name}')
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()


# Example usage
def gradient_monitoring_example():
    """Demonstrate gradient monitoring."""
    
    # Create a model with potential gradient issues
    model = nn.Sequential(
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, 10)
    )
    
    monitor = GradientMonitor(model)
    monitor.register_hooks()
    
    # Simulate training
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    
    for step in range(50):
        x = torch.randn(32, 784)
        y = torch.randint(0, 10, (32,))
        
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
    
    # Visualize
    monitor.plot_gradients("Gradient Flow After 50 Steps")
    monitor.remove_hooks()

gradient_monitoring_example()

Gradient Flow in Different Architectures

Residual Connections

def residual_gradient_flow():
    """Compare gradient flow with and without residual connections."""
    
    class PlainBlock(nn.Module):
        def __init__(self, dim):
            super().__init__()
            self.fc1 = nn.Linear(dim, dim)
            self.fc2 = nn.Linear(dim, dim)
            self.relu = nn.ReLU()
        
        def forward(self, x):
            return self.relu(self.fc2(self.relu(self.fc1(x))))
    
    class ResidualBlock(nn.Module):
        def __init__(self, dim):
            super().__init__()
            self.fc1 = nn.Linear(dim, dim)
            self.fc2 = nn.Linear(dim, dim)
            self.relu = nn.ReLU()
        
        def forward(self, x):
            residual = x
            out = self.relu(self.fc1(x))
            out = self.fc2(out)
            return self.relu(out + residual)  # Skip connection!
    
    def build_network(block_class, num_blocks, dim):
        layers = [block_class(dim) for _ in range(num_blocks)]
        return nn.Sequential(*layers)
    
    depth = 50
    dim = 128
    
    plain_net = build_network(PlainBlock, depth, dim)
    res_net = build_network(ResidualBlock, depth, dim)
    
    print("Gradient Flow: Plain vs Residual Networks")
    print("="*60)
    
    for name, net in [("Plain", plain_net), ("Residual", res_net)]:
        x = torch.randn(32, dim)
        y = torch.randn(32, dim)
        
        output = net(x)
        loss = nn.MSELoss()(output, y)
        loss.backward()
        
        # Collect gradients from each block
        grad_norms = []
        for module in net:
            if hasattr(module, 'fc1'):
                grad_norms.append(module.fc1.weight.grad.norm().item())
        
        print(f"\n{name} Network ({depth} blocks):")
        print(f"  First block gradient: {grad_norms[0]:.6f}")
        print(f"  Last block gradient:  {grad_norms[-1]:.6f}")
        print(f"  Ratio (first/last):   {grad_norms[0]/grad_norms[-1]:.4f}")
        
        # Plot
        plt.semilogy(grad_norms, label=name, marker='o')
    
    plt.xlabel('Block Index')
    plt.ylabel('Gradient Norm (log scale)')
    plt.title('Gradient Flow: Plain vs Residual')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

residual_gradient_flow()

Dense Connections (DenseNet style)

class DenseBlock(nn.Module):
    """DenseNet-style block with dense connections."""
    
    def __init__(self, dim, growth_rate=32):
        super().__init__()
        self.fc = nn.Linear(dim, growth_rate)
        self.relu = nn.ReLU()
    
    def forward(self, features):
        # features is a list of all previous feature maps
        concat = torch.cat(features, dim=1)
        out = self.relu(self.fc(concat))
        return out


def dense_gradient_flow():
    """Demonstrate gradient flow in DenseNet architecture."""
    
    class DenseNetwork(nn.Module):
        def __init__(self, input_dim, num_blocks, growth_rate=32):
            super().__init__()
            self.initial = nn.Linear(input_dim, growth_rate)
            
            # Dense blocks
            self.blocks = nn.ModuleList()
            in_features = growth_rate
            for _ in range(num_blocks):
                self.blocks.append(nn.Linear(in_features, growth_rate))
                in_features += growth_rate
            
            self.final = nn.Linear(in_features, 10)
        
        def forward(self, x):
            features = [self.initial(x)]
            
            for block in self.blocks:
                concat = torch.cat(features, dim=1)
                new_features = torch.relu(block(concat))
                features.append(new_features)
            
            concat = torch.cat(features, dim=1)
            return self.final(concat)
    
    model = DenseNetwork(input_dim=784, num_blocks=20, growth_rate=32)
    
    x = torch.randn(32, 784)
    y = torch.randint(0, 10, (32,))
    
    output = model(x)
    loss = nn.CrossEntropyLoss()(output, y)
    loss.backward()
    
    # Analyze gradients
    print("Dense Network Gradient Analysis")
    print("="*50)
    
    grad_norms = []
    for i, block in enumerate(model.blocks):
        grad_norms.append(block.weight.grad.norm().item())
        print(f"Block {i}: gradient norm = {grad_norms[-1]:.6f}")
    
    print(f"\nGradient variation: std = {np.std(grad_norms):.6f}")
    print(f"Ratio (first/last): {grad_norms[0]/grad_norms[-1]:.4f}")

dense_gradient_flow()

Advanced Analysis Techniques

Gradient Covariance Analysis

def gradient_covariance_analysis(model, data_loader, num_batches=10):
    """Analyze gradient covariance structure."""
    
    print("Gradient Covariance Analysis")
    print("="*50)
    
    # Collect gradients over multiple batches
    all_gradients = defaultdict(list)
    
    for batch_idx, (x, y) in enumerate(data_loader):
        if batch_idx >= num_batches:
            break
        
        model.zero_grad()
        output = model(x)
        loss = nn.CrossEntropyLoss()(output, y)
        loss.backward()
        
        for name, param in model.named_parameters():
            if param.grad is not None:
                all_gradients[name].append(param.grad.flatten().detach().cpu().numpy())
    
    # Analyze each layer
    for name, grads in all_gradients.items():
        grads = np.array(grads)  # [num_batches, num_params]
        
        # Compute covariance
        mean_grad = grads.mean(axis=0)
        centered = grads - mean_grad
        
        # Gradient variance
        variance = np.var(grads, axis=0).mean()
        
        # Gradient correlation (sample a subset for large layers)
        if grads.shape[1] > 1000:
            idx = np.random.choice(grads.shape[1], 1000, replace=False)
            grads_sample = grads[:, idx]
        else:
            grads_sample = grads
        
        corr = np.corrcoef(grads_sample.T)
        avg_correlation = (corr.sum() - np.trace(corr)) / (corr.size - corr.shape[0])
        
        print(f"\n{name}:")
        print(f"  Mean gradient magnitude: {np.abs(mean_grad).mean():.6f}")
        print(f"  Gradient variance: {variance:.6f}")
        print(f"  Avg correlation between params: {avg_correlation:.4f}")


# Create a simple example
def run_covariance_analysis():
    model = nn.Sequential(
        nn.Linear(100, 50),
        nn.ReLU(),
        nn.Linear(50, 10)
    )
    
    # Simple dataset
    dataset = torch.utils.data.TensorDataset(
        torch.randn(1000, 100),
        torch.randint(0, 10, (1000,))
    )
    loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    
    gradient_covariance_analysis(model, loader)

run_covariance_analysis()

Hessian Analysis

def hessian_analysis():
    """Analyze Hessian eigenspectrum for understanding loss landscape."""
    
    # Simple model for tractable Hessian computation
    model = nn.Linear(10, 5)
    
    # Sample data
    x = torch.randn(50, 10)
    y = torch.randint(0, 5, (50,))
    
    # Compute Hessian using autograd
    def compute_hessian_eigenvalues(model, x, y, top_k=10):
        """Compute top eigenvalues of the Hessian."""
        from torch.autograd.functional import hessian
        
        # Flatten parameters
        params = torch.cat([p.flatten() for p in model.parameters()])
        n_params = len(params)
        
        print(f"Computing Hessian for {n_params} parameters...")
        
        def loss_fn(flat_params):
            # Unflatten and apply
            idx = 0
            for p in model.parameters():
                numel = p.numel()
                p.data = flat_params[idx:idx+numel].view(p.shape)
                idx += numel
            
            output = model(x)
            return nn.CrossEntropyLoss()(output, y)
        
        # Compute Hessian
        H = hessian(loss_fn, params)
        
        # Get eigenvalues
        eigenvalues = torch.linalg.eigvalsh(H)
        
        return eigenvalues
    
    eigenvalues = compute_hessian_eigenvalues(model, x, y)
    
    print("\nHessian Eigenvalue Analysis")
    print("="*50)
    print(f"Max eigenvalue: {eigenvalues.max().item():.4f}")
    print(f"Min eigenvalue: {eigenvalues.min().item():.4f}")
    print(f"Condition number: {eigenvalues.max().item() / (eigenvalues.min().item() + 1e-8):.2f}")
    
    # Plot eigenvalue distribution
    plt.figure(figsize=(10, 4))
    plt.hist(eigenvalues.numpy(), bins=50, edgecolor='black', alpha=0.7)
    plt.xlabel('Eigenvalue')
    plt.ylabel('Count')
    plt.title('Hessian Eigenvalue Distribution')
    plt.axvline(x=0, color='r', linestyle='--', alpha=0.5)
    plt.show()
    
    # Negative eigenvalues indicate saddle points
    n_negative = (eigenvalues < 0).sum().item()
    print(f"\nNegative eigenvalues: {n_negative} ({100*n_negative/len(eigenvalues):.1f}%)")
    if n_negative > 0:
        print("→ You might be at a saddle point!")

hessian_analysis()

Fixing Gradient Flow Issues

Comprehensive Diagnostic and Fix Toolkit

class GradientDoctor:
    """Diagnose and fix gradient flow issues."""
    
    def __init__(self, model):
        self.model = model
        self.issues = []
    
    def diagnose(self, sample_input, sample_target):
        """Run comprehensive gradient diagnostics."""
        
        print("╔══════════════════════════════════════════════════════════╗")
        print("║              GRADIENT FLOW DIAGNOSIS                     ║")
        print("╚══════════════════════════════════════════════════════════╝")
        
        self.issues = []
        
        # Forward pass
        output = self.model(sample_input)
        
        # Check for NaN in output
        if torch.isnan(output).any():
            self.issues.append(("CRITICAL", "NaN in forward pass output"))
            print("⚠ CRITICAL: NaN detected in output!")
            return self.issues
        
        # Backward pass
        if output.dim() == 2 and output.size(1) > 1:
            loss = nn.CrossEntropyLoss()(output, sample_target)
        else:
            loss = nn.MSELoss()(output.flatten(), sample_target.float().flatten())
        
        loss.backward()
        
        # Check each layer
        print("\nLayer-by-layer analysis:")
        print("-" * 60)
        
        layer_idx = 0
        for name, param in self.model.named_parameters():
            if param.grad is None:
                self.issues.append(("WARNING", f"{name}: No gradient computed"))
                print(f"⚠ {name}: No gradient")
                continue
            
            grad = param.grad
            grad_norm = grad.norm().item()
            grad_mean = grad.mean().item()
            grad_std = grad.std().item()
            
            # Check for issues
            status = "✓"
            
            if torch.isnan(grad).any():
                self.issues.append(("CRITICAL", f"{name}: NaN gradient"))
                status = "⚠ NaN"
            elif grad_norm < 1e-7:
                self.issues.append(("WARNING", f"{name}: Vanishing gradient (norm={grad_norm:.2e})"))
                status = "⚠ Vanishing"
            elif grad_norm > 1e4:
                self.issues.append(("WARNING", f"{name}: Exploding gradient (norm={grad_norm:.2e})"))
                status = "⚠ Exploding"
            
            print(f"{name:<40} norm={grad_norm:<10.2e} std={grad_std:<10.2e} {status}")
            layer_idx += 1
        
        print("\n" + "="*60)
        if self.issues:
            print(f"Found {len(self.issues)} issues:")
            for severity, msg in self.issues:
                print(f"  [{severity}] {msg}")
        else:
            print("✓ Gradient flow looks healthy!")
        
        return self.issues
    
    def suggest_fixes(self):
        """Suggest fixes based on diagnosed issues."""
        
        print("\n╔══════════════════════════════════════════════════════════╗")
        print("║              SUGGESTED FIXES                             ║")
        print("╚══════════════════════════════════════════════════════════╝")
        
        if not self.issues:
            print("No issues to fix!")
            return
        
        fixes = []
        
        for severity, msg in self.issues:
            if "Vanishing" in msg:
                fixes.extend([
                    "• Use He/Kaiming initialization for ReLU layers",
                    "• Add residual connections (skip connections)",
                    "• Replace sigmoid/tanh with ReLU/GELU",
                    "• Add BatchNorm or LayerNorm",
                    "• Use LSTM/GRU instead of vanilla RNN"
                ])
            
            elif "Exploding" in msg:
                fixes.extend([
                    "• Apply gradient clipping: torch.nn.utils.clip_grad_norm_(..., max_norm=1.0)",
                    "• Reduce learning rate",
                    "• Initialize weights with smaller variance",
                    "• Add weight decay regularization"
                ])
            
            elif "NaN" in msg:
                fixes.extend([
                    "• Check for numerical instability (log of 0, division by 0)",
                    "• Reduce learning rate significantly",
                    "• Use gradient clipping",
                    "• Check for correct loss function usage",
                    "• Verify data preprocessing (no NaN in inputs)"
                ])
            
            elif "No gradient" in msg:
                fixes.extend([
                    "• Ensure requires_grad=True for trainable parameters",
                    "• Check if the layer is actually used in forward pass",
                    "• Verify no torch.no_grad() context is active"
                ])
        
        # Remove duplicates
        fixes = list(dict.fromkeys(fixes))
        
        for fix in fixes:
            print(fix)
    
    def apply_quick_fixes(self):
        """Apply automatic quick fixes."""
        
        print("\nApplying quick fixes...")
        
        for name, module in self.model.named_modules():
            # Re-initialize layers that might have bad weights
            if isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            
            elif isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
        
        print("✓ Re-initialized all Linear and Conv2d layers with He initialization")


# Example usage
def gradient_doctor_demo():
    # Create a problematic model
    model = nn.Sequential(
        nn.Linear(100, 256),
        nn.Sigmoid(),  # Problematic for deep networks
        nn.Linear(256, 256),
        nn.Sigmoid(),
        nn.Linear(256, 256),
        nn.Sigmoid(),
        nn.Linear(256, 10)
    )
    
    # Small weight initialization (will cause vanishing)
    for m in model.modules():
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, std=0.01)
    
    doctor = GradientDoctor(model)
    
    x = torch.randn(32, 100)
    y = torch.randint(0, 10, (32,))
    
    doctor.diagnose(x, y)
    doctor.suggest_fixes()

gradient_doctor_demo()

Exercises

Compare gradient flow through different activation functions:
activations = [nn.Sigmoid(), nn.Tanh(), nn.ReLU(), nn.GELU(), nn.SiLU()]

for act in activations:
    model = build_deep_network(depth=30, activation=act)
    grad_norms = measure_gradient_flow(model)
    plot_gradient_profile(grad_norms, label=act.__class__.__name__)
The gradient noise scale (grad_norm / batch_size) indicates if you’re in:
  • Small batch regime: high noise, needs smaller LR
  • Large batch regime: low noise, can use larger LR
def compute_gradient_noise_scale(model, data_loader):
    # Compute gradient with full batch
    # Compute gradients with mini-batches
    # Compare noise levels
    pass
Create a real-time gradient monitoring dashboard using matplotlib:
class GradientDashboard:
    def __init__(self, model):
        self.fig, self.axes = plt.subplots(2, 2)
        plt.ion()  # Interactive mode
    
    def update(self, step):
        # Update gradient histograms
        # Update gradient norm curves
        # Update layer-wise statistics
        plt.draw()
        plt.pause(0.01)

What’s Next?