Skip to main content
Weight Initialization

Weight Initialization: The Foundation of Training

Why Initialization Matters

Poor weight initialization can cause:
  • Vanishing gradients: Signals shrink to zero, learning stops
  • Exploding gradients: Signals blow up, NaN everywhere
  • Symmetry: All neurons learn the same thing
  • Slow convergence: Training takes forever
Good initialization ensures:
  • Stable activations: Signals don’t vanish or explode
  • Broken symmetry: Each neuron learns something different
  • Fast convergence: Networks train efficiently
Reality Check: A neural network with poor initialization might never converge, while the same architecture with proper initialization trains smoothly. Initialization is that important!
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from scipy import stats

# Reproducibility
np.random.seed(42)
torch.manual_seed(42)

The Problem: Signal Propagation

Observing Vanishing/Exploding Activations

def demonstrate_initialization_problem():
    """Show how bad initialization kills gradients."""
    
    # Deep network with different initializations
    def create_network(init_std, depth=50, width=256):
        layers = []
        for i in range(depth):
            layer = nn.Linear(width, width, bias=False)
            # Initialize with given std
            nn.init.normal_(layer.weight, std=init_std)
            layers.append(layer)
            layers.append(nn.Tanh())
        return nn.Sequential(*layers)
    
    # Test different initialization scales
    init_stds = [0.01, 0.1, 1.0, 2.0]
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    for ax, std in zip(axes, init_stds):
        model = create_network(std)
        
        # Forward pass with random input
        x = torch.randn(100, 256)
        
        # Track activations at each layer
        activation_means = []
        activation_stds = []
        
        current = x
        with torch.no_grad():
            for layer in model:
                current = layer(current)
                if isinstance(layer, nn.Linear):
                    activation_means.append(current.mean().item())
                    activation_stds.append(current.std().item())
        
        # Plot
        layers_idx = range(len(activation_stds))
        ax.plot(layers_idx, activation_stds, 'b-', label='Std')
        ax.plot(layers_idx, activation_means, 'r--', label='Mean')
        ax.set_xlabel('Layer')
        ax.set_ylabel('Activation Statistics')
        ax.set_title(f'Init std = {std}')
        ax.legend()
        ax.set_yscale('symlog')  # Symmetric log scale
        ax.grid(True, alpha=0.3)
        
        # Check if vanishing or exploding
        final_std = activation_stds[-1] if activation_stds else 0
        if final_std < 1e-5:
            ax.text(0.5, 0.5, 'VANISHING!', transform=ax.transAxes, 
                   fontsize=20, color='red', alpha=0.5, ha='center')
        elif final_std > 1e5:
            ax.text(0.5, 0.5, 'EXPLODING!', transform=ax.transAxes,
                   fontsize=20, color='red', alpha=0.5, ha='center')
    
    plt.tight_layout()
    plt.suptitle('Effect of Initialization Scale on Deep Networks', y=1.02)
    plt.show()

demonstrate_initialization_problem()

Mathematical Analysis

For a layer y=Wx\mathbf{y} = \mathbf{W}\mathbf{x} where W\mathbf{W} has shape (nout,nin)(n_{out}, n_{in}): Var(yj)=ninVar(W)Var(x)\text{Var}(y_j) = n_{in} \cdot \text{Var}(W) \cdot \text{Var}(x) To maintain stable variance across layers, we need: Var(W)=1nin\text{Var}(W) = \frac{1}{n_{in}}
def variance_propagation_analysis():
    """Analyze how variance propagates through layers."""
    
    n_in = 256
    n_out = 256
    n_samples = 10000
    
    # Different initialization strategies
    strategies = {
        'Small (0.01)': 0.01,
        'Medium (0.1)': 0.1,
        'Correct (1/√n)': 1.0 / np.sqrt(n_in),
        'Large (1.0)': 1.0
    }
    
    print("Variance Propagation Analysis")
    print("="*60)
    print(f"Input dimension: {n_in}")
    print(f"Expected variance to maintain: Var(W) = 1/{n_in} = {1/n_in:.6f}")
    print()
    
    x = np.random.randn(n_samples, n_in)  # Input with unit variance
    
    for name, std in strategies.items():
        W = np.random.randn(n_in, n_out) * std
        y = x @ W
        
        var_W = np.var(W)
        var_y = np.var(y)
        expected_var_y = n_in * var_W * np.var(x)
        
        print(f"{name}:")
        print(f"  Var(W) = {var_W:.6f}")
        print(f"  Var(y) = {var_y:.4f} (expected: {expected_var_y:.4f})")
        
        # After 50 layers
        var_after_50 = var_y ** 50
        print(f"  After 50 layers: {var_after_50:.2e}")
        
        if var_after_50 < 1e-10:
            print(f"  → VANISHING")
        elif var_after_50 > 1e10:
            print(f"  → EXPLODING")
        else:
            print(f"  → STABLE ✓")
        print()

variance_propagation_analysis()

Classic Initialization Methods

1. Xavier/Glorot Initialization

Designed for tanh and sigmoid activations: WN(0,2nin+nout)orWU(6nin+nout,6nin+nout)W \sim \mathcal{N}\left(0, \frac{2}{n_{in} + n_{out}}\right) \quad \text{or} \quad W \sim \mathcal{U}\left(-\sqrt{\frac{6}{n_{in} + n_{out}}}, \sqrt{\frac{6}{n_{in} + n_{out}}}\right)
def xavier_initialization():
    """Xavier/Glorot initialization for tanh/sigmoid."""
    
    n_in, n_out = 512, 256
    
    # Normal variant
    std = np.sqrt(2.0 / (n_in + n_out))
    W_normal = np.random.randn(n_in, n_out) * std
    
    # Uniform variant
    limit = np.sqrt(6.0 / (n_in + n_out))
    W_uniform = np.random.uniform(-limit, limit, (n_in, n_out))
    
    print("Xavier/Glorot Initialization")
    print("="*50)
    print(f"n_in={n_in}, n_out={n_out}")
    print(f"\nNormal variant: std = √(2/(n_in+n_out)) = {std:.6f}")
    print(f"  Actual Var(W) = {np.var(W_normal):.6f}")
    
    print(f"\nUniform variant: limit = √(6/(n_in+n_out)) = {limit:.6f}")
    print(f"  Actual Var(W) = {np.var(W_uniform):.6f}")
    
    # PyTorch implementation
    linear = nn.Linear(n_in, n_out)
    nn.init.xavier_normal_(linear.weight)
    print(f"\nPyTorch xavier_normal_: Var = {linear.weight.var().item():.6f}")
    
    nn.init.xavier_uniform_(linear.weight)
    print(f"PyTorch xavier_uniform_: Var = {linear.weight.var().item():.6f}")

xavier_initialization()

2. He/Kaiming Initialization

Designed for ReLU and its variants: WN(0,2nin)W \sim \mathcal{N}\left(0, \frac{2}{n_{in}}\right) The factor of 2 compensates for ReLU zeroing out half the activations.
def he_initialization():
    """He/Kaiming initialization for ReLU."""
    
    n_in, n_out = 512, 256
    
    # For ReLU
    std = np.sqrt(2.0 / n_in)
    W_relu = np.random.randn(n_in, n_out) * std
    
    # For Leaky ReLU (negative_slope = 0.01)
    negative_slope = 0.01
    std_leaky = np.sqrt(2.0 / (1 + negative_slope**2) / n_in)
    W_leaky = np.random.randn(n_in, n_out) * std_leaky
    
    print("He/Kaiming Initialization")
    print("="*50)
    print(f"n_in={n_in}")
    
    print(f"\nFor ReLU: std = √(2/n_in) = {std:.6f}")
    print(f"  Var(W) = {np.var(W_relu):.6f}")
    
    print(f"\nFor Leaky ReLU (slope={negative_slope}):")
    print(f"  std = √(2/(1+slope²)/n_in) = {std_leaky:.6f}")
    print(f"  Var(W) = {np.var(W_leaky):.6f}")
    
    # PyTorch implementation
    linear = nn.Linear(n_in, n_out)
    
    nn.init.kaiming_normal_(linear.weight, mode='fan_in', nonlinearity='relu')
    print(f"\nPyTorch kaiming_normal_ (fan_in, relu): Var = {linear.weight.var().item():.6f}")
    
    nn.init.kaiming_normal_(linear.weight, mode='fan_out', nonlinearity='relu')
    print(f"PyTorch kaiming_normal_ (fan_out, relu): Var = {linear.weight.var().item():.6f}")

he_initialization()

3. Orthogonal Initialization

Initializes weights as orthogonal matrices — preserves norms exactly: WTW=I\mathbf{W}^T\mathbf{W} = \mathbf{I}
def orthogonal_initialization():
    """Orthogonal initialization for stable signal propagation."""
    
    n = 256
    
    # Create orthogonal matrix
    W = nn.Linear(n, n)
    nn.init.orthogonal_(W.weight)
    
    # Verify orthogonality
    WtW = W.weight @ W.weight.T
    print("Orthogonal Initialization")
    print("="*50)
    print(f"\nW^T @ W should be identity:")
    print(f"  Diagonal mean: {torch.diag(WtW).mean().item():.6f} (should be 1)")
    print(f"  Off-diagonal std: {(WtW - torch.eye(n)).std().item():.6f} (should be 0)")
    
    # Signal preservation
    x = torch.randn(100, n)
    y = W(x)
    
    print(f"\nSignal preservation:")
    print(f"  Input norm mean: {torch.norm(x, dim=1).mean().item():.4f}")
    print(f"  Output norm mean: {torch.norm(y, dim=1).mean().item():.4f}")
    
    # Through many layers
    print("\nThrough 50 orthogonal layers:")
    current = x
    for _ in range(50):
        layer = nn.Linear(n, n, bias=False)
        nn.init.orthogonal_(layer.weight)
        current = layer(current)
    
    print(f"  Final norm mean: {torch.norm(current, dim=1).mean().item():.4f}")
    print("  (Should be similar to input norm)")

orthogonal_initialization()

Advanced Initialization Techniques

4. LSUV (Layer-Sequential Unit-Variance)

A data-driven approach that iteratively normalizes each layer:
def lsuv_initialization(model, data_batch, tol=0.1, max_iter=10):
    """
    Layer-Sequential Unit-Variance initialization.
    
    Iteratively adjusts weights so each layer has unit variance output.
    """
    print("LSUV Initialization")
    print("="*50)
    
    model.eval()
    
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            print(f"\nProcessing layer: {name}")
            
            # Orthogonal init first
            if isinstance(module, nn.Linear):
                nn.init.orthogonal_(module.weight)
            else:
                nn.init.orthogonal_(module.weight.view(module.weight.size(0), -1))
            
            for iteration in range(max_iter):
                with torch.no_grad():
                    # Forward pass up to this layer
                    output = data_batch
                    for n, m in model.named_modules():
                        if isinstance(m, (nn.Linear, nn.Conv2d, nn.ReLU, nn.BatchNorm2d)):
                            output = m(output)
                            if n == name:
                                break
                    
                    variance = output.var().item()
                    
                    if abs(variance - 1.0) < tol:
                        print(f"  Iteration {iteration}: Var = {variance:.4f} ✓")
                        break
                    
                    # Rescale weights
                    module.weight.data /= np.sqrt(variance)
                    print(f"  Iteration {iteration}: Var = {variance:.4f} → rescaling")
    
    return model


# Example usage
class SimpleMLP(nn.Module):
    def __init__(self, layer_sizes):
        super().__init__()
        layers = []
        for i in range(len(layer_sizes) - 1):
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
            if i < len(layer_sizes) - 2:
                layers.append(nn.ReLU())
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layers(x)


# Apply LSUV
model = SimpleMLP([784, 256, 256, 256, 10])
dummy_batch = torch.randn(32, 784)
model = lsuv_initialization(model, dummy_batch)

5. Data-Dependent Initialization

def data_dependent_init(layer, data_batch, target_std=1.0):
    """
    Initialize weights based on actual data statistics.
    """
    with torch.no_grad():
        # Compute output with current weights
        output = layer(data_batch)
        current_std = output.std().item()
        
        # Scale weights to achieve target std
        scale = target_std / (current_std + 1e-8)
        layer.weight.data *= scale
        
        print(f"Data-dependent init:")
        print(f"  Initial output std: {current_std:.4f}")
        
        output = layer(data_batch)
        print(f"  Final output std: {output.std().item():.4f}")
    
    return layer


# Example
layer = nn.Linear(256, 128)
data = torch.randn(100, 256) * 5  # Data with different scale
layer = data_dependent_init(layer, data)

6. Fixup Initialization

Enables training very deep networks without normalization:
def fixup_initialization(model, num_layers):
    """
    Fixup initialization for residual networks without BatchNorm.
    
    Key ideas:
    - Scale down the last layer of each residual block
    - Zero-initialize certain layers
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            # Standard initialization
            nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            
            # Scale down if it's the last conv in a residual block
            if 'last_conv' in name or 'conv2' in name:
                module.weight.data.mul_(num_layers ** (-0.5))
        
        elif isinstance(module, nn.Linear):
            nn.init.constant_(module.weight, 0)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
    
    print(f"Fixup initialization applied for {num_layers} layers")
    print("  - Residual branch weights scaled by L^(-0.5)")
    print("  - Final layer initialized to zero")

Initialization for Specific Architectures

Transformer Initialization

def transformer_initialization():
    """Special initialization for Transformer models."""
    
    d_model = 512
    n_layers = 12
    n_heads = 8
    
    class TransformerBlock(nn.Module):
        def __init__(self, d_model, n_heads):
            super().__init__()
            self.attention = nn.MultiheadAttention(d_model, n_heads)
            self.ffn = nn.Sequential(
                nn.Linear(d_model, 4 * d_model),
                nn.GELU(),
                nn.Linear(4 * d_model, d_model)
            )
            self.norm1 = nn.LayerNorm(d_model)
            self.norm2 = nn.LayerNorm(d_model)
    
    def init_transformer_weights(module, n_layers):
        """
        GPT-2 style initialization.
        """
        if isinstance(module, nn.Linear):
            # Standard normal initialization
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)
        
        # Scale down residual projections
        # This prevents the output from growing with depth
        for name, p in module.named_parameters():
            if name.endswith('out_proj.weight') or name.endswith('ffn.2.weight'):
                # Scale by 1/√(2*n_layers)
                p.data.div_(np.sqrt(2 * n_layers))
    
    # Apply initialization
    block = TransformerBlock(d_model, n_heads)
    init_transformer_weights(block, n_layers)
    
    print("Transformer Initialization (GPT-2 style)")
    print("="*50)
    print(f"  - Linear weights: N(0, 0.02)")
    print(f"  - Residual outputs scaled by 1/√(2×{n_layers})")
    print(f"  - LayerNorm: weight=1, bias=0")

transformer_initialization()

CNN Initialization

def cnn_initialization():
    """Initialization strategies for CNNs."""
    
    class ConvBlock(nn.Module):
        def __init__(self, in_channels, out_channels):
            super().__init__()
            self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
            self.bn = nn.BatchNorm2d(out_channels)
            self.relu = nn.ReLU()
    
    # Different initialization strategies
    conv = nn.Conv2d(64, 128, 3, padding=1)
    
    print("CNN Initialization Strategies")
    print("="*50)
    
    # 1. Kaiming for ReLU
    nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='relu')
    print(f"\n1. Kaiming (fan_out, ReLU):")
    print(f"   Var = {conv.weight.var().item():.6f}")
    
    # 2. Xavier for no activation / before BN
    nn.init.xavier_uniform_(conv.weight)
    print(f"\n2. Xavier Uniform:")
    print(f"   Var = {conv.weight.var().item():.6f}")
    
    # 3. Delta initialization for identity
    def delta_init(conv):
        """Initialize conv to approximate identity."""
        nn.init.zeros_(conv.weight)
        # Set center of kernel to 1 for each in/out channel pair
        c_out, c_in, h, w = conv.weight.shape
        center_h, center_w = h // 2, w // 2
        for i in range(min(c_in, c_out)):
            conv.weight.data[i, i, center_h, center_w] = 1.0
    
    conv_delta = nn.Conv2d(64, 64, 3, padding=1)
    delta_init(conv_delta)
    print(f"\n3. Delta (identity) initialization:")
    print(f"   Center weights = 1, others = 0")
    
    # Test identity property
    x = torch.randn(1, 64, 8, 8)
    y = conv_delta(x)
    print(f"   Input-Output difference: {(x - y).abs().mean().item():.6f}")

cnn_initialization()

Practical Guidelines

Choosing the Right Initialization

def initialization_decision_tree():
    """Guide for choosing initialization."""
    
    print("""
    ╔════════════════════════════════════════════════════════════════╗
    ║              WEIGHT INITIALIZATION DECISION TREE               ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                ║
    ║  What activation function are you using?                       ║
    ║                                                                ║
    ║  ├── ReLU / Leaky ReLU / ELU                                  ║
    ║  │   └── Use He (Kaiming) initialization                      ║
    ║  │       • PyTorch: kaiming_normal_(weight, nonlinearity='relu')║
    ║  │                                                             ║
    ║  ├── Sigmoid / Tanh                                           ║
    ║  │   └── Use Xavier (Glorot) initialization                   ║
    ║  │       • PyTorch: xavier_uniform_(weight)                   ║
    ║  │                                                             ║
    ║  ├── GELU / SiLU / Swish                                      ║
    ║  │   └── Use He initialization (similar to ReLU)              ║
    ║  │                                                             ║
    ║  └── Linear (no activation)                                   ║
    ║      └── Use Xavier or small normal                           ║
    ║                                                                ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                ║
    ║  Special cases:                                                ║
    ║                                                                ║
    ║  ├── Transformers                                             ║
    ║  │   └── N(0, 0.02) + scale residual projections              ║
    ║  │                                                             ║
    ║  ├── Very deep networks (100+ layers)                         ║
    ║  │   └── Orthogonal or LSUV                                   ║
    ║  │                                                             ║
    ║  ├── RNNs / LSTMs                                             ║
    ║  │   └── Orthogonal for hidden-to-hidden weights              ║
    ║  │                                                             ║
    ║  ├── GANs                                                      ║
    ║  │   └── N(0, 0.02) often works well                          ║
    ║  │                                                             ║
    ║  └── Residual Networks without BatchNorm                      ║
    ║      └── Fixup initialization                                  ║
    ║                                                                ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                ║
    ║  Bias initialization:                                         ║
    ║  • Usually initialize to 0                                    ║
    ║  • For ReLU, small positive (0.01) can help                   ║
    ║  • For LSTM forget gate, initialize to 1                      ║
    ║                                                                ║
    ╚════════════════════════════════════════════════════════════════╝
    """)

initialization_decision_tree()

Complete Initialization Function

def initialize_model(model, init_type='auto'):
    """
    Comprehensive weight initialization.
    
    Args:
        model: PyTorch model
        init_type: 'auto', 'he', 'xavier', 'orthogonal'
    """
    
    def get_activation(module):
        """Detect activation function following this layer."""
        # This is a simplified heuristic
        return 'relu'  # Default assumption
    
    for name, module in model.named_modules():
        
        if isinstance(module, nn.Linear):
            if init_type == 'auto':
                # He init for typical ReLU networks
                nn.init.kaiming_uniform_(module.weight, a=np.sqrt(5))
            elif init_type == 'he':
                nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
            elif init_type == 'xavier':
                nn.init.xavier_uniform_(module.weight)
            elif init_type == 'orthogonal':
                nn.init.orthogonal_(module.weight)
            
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.Conv2d):
            if init_type in ['auto', 'he']:
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            elif init_type == 'xavier':
                nn.init.xavier_uniform_(module.weight)
            
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.BatchNorm2d):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, std=0.02)
        
        elif isinstance(module, nn.LSTM):
            for param_name, param in module.named_parameters():
                if 'weight_ih' in param_name:
                    nn.init.xavier_uniform_(param)
                elif 'weight_hh' in param_name:
                    nn.init.orthogonal_(param)
                elif 'bias' in param_name:
                    nn.init.zeros_(param)
                    # Set forget gate bias to 1
                    n = param.size(0)
                    param.data[n//4:n//2].fill_(1.0)
    
    print(f"Model initialized with '{init_type}' strategy")
    return model


# Example usage
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

model = initialize_model(model, init_type='he')

Diagnosing Initialization Problems

def diagnose_initialization(model, sample_input):
    """
    Diagnose if initialization is causing issues.
    """
    print("Initialization Diagnostics")
    print("="*60)
    
    model.eval()
    
    # Track statistics through layers
    activations = {}
    gradients = {}
    
    def save_activation(name):
        def hook(model, input, output):
            activations[name] = output.detach()
        return hook
    
    def save_gradient(name):
        def hook(model, grad_input, grad_output):
            if grad_output[0] is not None:
                gradients[name] = grad_output[0].detach()
        return hook
    
    # Register hooks
    handles = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            handles.append(module.register_forward_hook(save_activation(name)))
            handles.append(module.register_full_backward_hook(save_gradient(name)))
    
    # Forward pass
    x = sample_input.requires_grad_(True)
    output = model(x)
    
    # Backward pass
    loss = output.sum()
    loss.backward()
    
    # Analyze
    print("\nActivation Statistics:")
    print("-" * 60)
    print(f"{'Layer':<30} {'Mean':>10} {'Std':>10} {'%Dead':>10}")
    print("-" * 60)
    
    issues = []
    
    for name, act in activations.items():
        mean = act.mean().item()
        std = act.std().item()
        dead_fraction = (act == 0).float().mean().item() * 100
        
        print(f"{name:<30} {mean:>10.4f} {std:>10.4f} {dead_fraction:>9.1f}%")
        
        if std < 0.01:
            issues.append(f"  ⚠ {name}: Very small std (vanishing activations)")
        if std > 10:
            issues.append(f"  ⚠ {name}: Very large std (exploding activations)")
        if dead_fraction > 50:
            issues.append(f"  ⚠ {name}: >50% dead neurons")
    
    print("\nGradient Statistics:")
    print("-" * 60)
    print(f"{'Layer':<30} {'Mean':>10} {'Std':>10}")
    print("-" * 60)
    
    for name, grad in gradients.items():
        mean = grad.mean().item()
        std = grad.std().item()
        print(f"{name:<30} {mean:>10.6f} {std:>10.6f}")
        
        if std < 1e-6:
            issues.append(f"  ⚠ {name}: Very small gradient std (vanishing)")
        if std > 100:
            issues.append(f"  ⚠ {name}: Very large gradient std (exploding)")
    
    # Clean up hooks
    for handle in handles:
        handle.remove()
    
    print("\n" + "="*60)
    if issues:
        print("ISSUES DETECTED:")
        for issue in issues:
            print(issue)
    else:
        print("✓ Initialization looks healthy!")
    
    return activations, gradients


# Example usage
model = nn.Sequential(
    nn.Linear(784, 512),
    nn.ReLU(),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

sample = torch.randn(32, 784)
diagnose_initialization(model, sample)

Exercises

Train the same network with different initialization methods and compare:
  • Training curves
  • Final accuracy
  • Time to converge
def exercise_1():
    init_methods = ['he', 'xavier', 'orthogonal', 'small_normal']
    
    for method in init_methods:
        model = create_model()
        initialize_with(model, method)
        history = train(model, epochs=20)
        plot(history, label=method)
Implement Layer-Sequential Unit-Variance initialization:
def exercise_2():
    def lsuv(model, data):
        for layer in model.layers:
            # Initialize orthogonally
            orthogonal_init(layer)
            
            # Iteratively scale to unit variance
            for _ in range(max_iter):
                output = forward_to_layer(model, data, layer)
                variance = output.var()
                
                if abs(variance - 1.0) < tol:
                    break
                
                layer.weight /= sqrt(variance)
For a 100-layer network, visualize how gradients flow backward:
def exercise_3():
    model = DeepNetwork(100_layers)
    
    for init_method in ['he', 'xavier', 'orthogonal']:
        initialize(model, init_method)
        
        # Forward and backward pass
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        
        # Plot gradient magnitudes per layer
        gradient_norms = [layer.weight.grad.norm() for layer in model.layers]
        plot(gradient_norms, label=init_method)

What’s Next?