Skip to main content

Documentation Index

Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt

Use this file to discover all available pages before exploring further.

Weight Initialization

Weight Initialization: The Foundation of Training

Why Initialization Matters

Think of weight initialization like tuning a guitar before a concert. If the strings are too loose (weights too small), you get silence — the signal vanishes as it passes through layers. If they are too tight (weights too large), you get screeching feedback — the signal explodes into NaN. Only when each string is tuned to the right tension does the instrument produce music. Poor weight initialization can cause:
  • Vanishing gradients: Signals shrink to zero as they pass through layers, so early layers never learn
  • Exploding gradients: Signals blow up exponentially, producing NaN values that crash training
  • Symmetry: If all weights start identical, all neurons compute the same thing forever — you have a deep network with the effective capacity of a single neuron
  • Slow convergence: Even when training works, bad initialization can make it 10x slower
Good initialization ensures:
  • Stable activations: Signal magnitude stays roughly constant across layers
  • Broken symmetry: Each neuron starts slightly different, so they specialize during training
  • Fast convergence: The network is already “in the right neighborhood” to start learning
Reality Check: A neural network with poor initialization might never converge, while the same architecture with proper initialization trains smoothly. Initialization is that important!
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from scipy import stats

# Reproducibility
np.random.seed(42)
torch.manual_seed(42)

The Problem: Signal Propagation

Observing Vanishing/Exploding Activations

def demonstrate_initialization_problem():
    """Show how bad initialization kills gradients."""
    
    # Deep network with different initializations
    def create_network(init_std, depth=50, width=256):
        layers = []
        for i in range(depth):
            layer = nn.Linear(width, width, bias=False)
            # Initialize with given std
            nn.init.normal_(layer.weight, std=init_std)
            layers.append(layer)
            layers.append(nn.Tanh())
        return nn.Sequential(*layers)
    
    # Test different initialization scales
    init_stds = [0.01, 0.1, 1.0, 2.0]
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    for ax, std in zip(axes, init_stds):
        model = create_network(std)
        
        # Forward pass with random input
        x = torch.randn(100, 256)
        
        # Track activations at each layer
        activation_means = []
        activation_stds = []
        
        current = x
        with torch.no_grad():
            for layer in model:
                current = layer(current)
                if isinstance(layer, nn.Linear):
                    activation_means.append(current.mean().item())
                    activation_stds.append(current.std().item())
        
        # Plot
        layers_idx = range(len(activation_stds))
        ax.plot(layers_idx, activation_stds, 'b-', label='Std')
        ax.plot(layers_idx, activation_means, 'r--', label='Mean')
        ax.set_xlabel('Layer')
        ax.set_ylabel('Activation Statistics')
        ax.set_title(f'Init std = {std}')
        ax.legend()
        ax.set_yscale('symlog')  # Symmetric log scale
        ax.grid(True, alpha=0.3)
        
        # Check if vanishing or exploding
        final_std = activation_stds[-1] if activation_stds else 0
        if final_std < 1e-5:
            ax.text(0.5, 0.5, 'VANISHING!', transform=ax.transAxes, 
                   fontsize=20, color='red', alpha=0.5, ha='center')
        elif final_std > 1e5:
            ax.text(0.5, 0.5, 'EXPLODING!', transform=ax.transAxes,
                   fontsize=20, color='red', alpha=0.5, ha='center')
    
    plt.tight_layout()
    plt.suptitle('Effect of Initialization Scale on Deep Networks', y=1.02)
    plt.show()

demonstrate_initialization_problem()

Mathematical Analysis

For a layer y=Wx\mathbf{y} = \mathbf{W}\mathbf{x} where W\mathbf{W} has shape (nout,nin)(n_{out}, n_{in}): Var(yj)=ninVar(W)Var(x)\text{Var}(y_j) = n_{in} \cdot \text{Var}(W) \cdot \text{Var}(x) To maintain stable variance across layers, we need: Var(W)=1nin\text{Var}(W) = \frac{1}{n_{in}}
def variance_propagation_analysis():
    """Analyze how variance propagates through layers."""
    
    n_in = 256
    n_out = 256
    n_samples = 10000
    
    # Different initialization strategies
    strategies = {
        'Small (0.01)': 0.01,
        'Medium (0.1)': 0.1,
        'Correct (1/√n)': 1.0 / np.sqrt(n_in),
        'Large (1.0)': 1.0
    }
    
    print("Variance Propagation Analysis")
    print("="*60)
    print(f"Input dimension: {n_in}")
    print(f"Expected variance to maintain: Var(W) = 1/{n_in} = {1/n_in:.6f}")
    print()
    
    x = np.random.randn(n_samples, n_in)  # Input with unit variance
    
    for name, std in strategies.items():
        W = np.random.randn(n_in, n_out) * std
        y = x @ W
        
        var_W = np.var(W)
        var_y = np.var(y)
        expected_var_y = n_in * var_W * np.var(x)
        
        print(f"{name}:")
        print(f"  Var(W) = {var_W:.6f}")
        print(f"  Var(y) = {var_y:.4f} (expected: {expected_var_y:.4f})")
        
        # After 50 layers
        var_after_50 = var_y ** 50
        print(f"  After 50 layers: {var_after_50:.2e}")
        
        if var_after_50 < 1e-10:
            print(f"  → VANISHING")
        elif var_after_50 > 1e10:
            print(f"  → EXPLODING")
        else:
            print(f"  → STABLE ✓")
        print()

variance_propagation_analysis()

Classic Initialization Methods

1. Xavier/Glorot Initialization

Designed for tanh and sigmoid activations. The key insight from Glorot and Bengio (2010): to keep activations stable, the variance of each layer’s output should equal the variance of its input. Since a linear layer multiplies by a weight matrix, the weight variance must compensate for the fan-in (number of input connections). Xavier balances both forward and backward passes by averaging fan-in and fan-out. WN(0,2nin+nout)orWU(6nin+nout,6nin+nout)W \sim \mathcal{N}\left(0, \frac{2}{n_{in} + n_{out}}\right) \quad \text{or} \quad W \sim \mathcal{U}\left(-\sqrt{\frac{6}{n_{in} + n_{out}}}, \sqrt{\frac{6}{n_{in} + n_{out}}}\right)
def xavier_initialization():
    """Xavier/Glorot initialization for tanh/sigmoid."""
    
    n_in, n_out = 512, 256
    
    # Normal variant
    std = np.sqrt(2.0 / (n_in + n_out))
    W_normal = np.random.randn(n_in, n_out) * std
    
    # Uniform variant
    limit = np.sqrt(6.0 / (n_in + n_out))
    W_uniform = np.random.uniform(-limit, limit, (n_in, n_out))
    
    print("Xavier/Glorot Initialization")
    print("="*50)
    print(f"n_in={n_in}, n_out={n_out}")
    print(f"\nNormal variant: std = √(2/(n_in+n_out)) = {std:.6f}")
    print(f"  Actual Var(W) = {np.var(W_normal):.6f}")
    
    print(f"\nUniform variant: limit = √(6/(n_in+n_out)) = {limit:.6f}")
    print(f"  Actual Var(W) = {np.var(W_uniform):.6f}")
    
    # PyTorch implementation
    linear = nn.Linear(n_in, n_out)
    nn.init.xavier_normal_(linear.weight)
    print(f"\nPyTorch xavier_normal_: Var = {linear.weight.var().item():.6f}")
    
    nn.init.xavier_uniform_(linear.weight)
    print(f"PyTorch xavier_uniform_: Var = {linear.weight.var().item():.6f}")

xavier_initialization()

2. He/Kaiming Initialization

Designed for ReLU and its variants. Kaiming He showed in 2015 that Xavier initialization is wrong for ReLU networks because it does not account for the fact that ReLU kills half the activations (all negative values become zero). This effectively halves the variance at each layer, causing gradients to slowly vanish even with Xavier init. WN(0,2nin)W \sim \mathcal{N}\left(0, \frac{2}{n_{in}}\right) The factor of 2 compensates for ReLU zeroing out half the activations. Without it, a 50-layer ReLU network with Xavier init would see its activation variance shrink by a factor of 0.55010150.5^{50} \approx 10^{-15} — effectively zero.
def he_initialization():
    """He/Kaiming initialization for ReLU."""
    
    n_in, n_out = 512, 256
    
    # For ReLU
    std = np.sqrt(2.0 / n_in)
    W_relu = np.random.randn(n_in, n_out) * std
    
    # For Leaky ReLU (negative_slope = 0.01)
    negative_slope = 0.01
    std_leaky = np.sqrt(2.0 / (1 + negative_slope**2) / n_in)
    W_leaky = np.random.randn(n_in, n_out) * std_leaky
    
    print("He/Kaiming Initialization")
    print("="*50)
    print(f"n_in={n_in}")
    
    print(f"\nFor ReLU: std = √(2/n_in) = {std:.6f}")
    print(f"  Var(W) = {np.var(W_relu):.6f}")
    
    print(f"\nFor Leaky ReLU (slope={negative_slope}):")
    print(f"  std = √(2/(1+slope²)/n_in) = {std_leaky:.6f}")
    print(f"  Var(W) = {np.var(W_leaky):.6f}")
    
    # PyTorch implementation
    linear = nn.Linear(n_in, n_out)
    
    nn.init.kaiming_normal_(linear.weight, mode='fan_in', nonlinearity='relu')
    print(f"\nPyTorch kaiming_normal_ (fan_in, relu): Var = {linear.weight.var().item():.6f}")
    
    nn.init.kaiming_normal_(linear.weight, mode='fan_out', nonlinearity='relu')
    print(f"PyTorch kaiming_normal_ (fan_out, relu): Var = {linear.weight.var().item():.6f}")

he_initialization()

3. Orthogonal Initialization

Initializes weights as orthogonal matrices — preserves norms exactly. An orthogonal matrix is like a perfect mirror system: it can rotate and reflect vectors, but never stretches or shrinks them. This means signals pass through layers with zero information loss, making it especially valuable for very deep networks and RNNs where signals must traverse many layers. WTW=I\mathbf{W}^T\mathbf{W} = \mathbf{I}
def orthogonal_initialization():
    """Orthogonal initialization for stable signal propagation."""
    
    n = 256
    
    # Create orthogonal matrix
    W = nn.Linear(n, n)
    nn.init.orthogonal_(W.weight)
    
    # Verify orthogonality
    WtW = W.weight @ W.weight.T
    print("Orthogonal Initialization")
    print("="*50)
    print(f"\nW^T @ W should be identity:")
    print(f"  Diagonal mean: {torch.diag(WtW).mean().item():.6f} (should be 1)")
    print(f"  Off-diagonal std: {(WtW - torch.eye(n)).std().item():.6f} (should be 0)")
    
    # Signal preservation
    x = torch.randn(100, n)
    y = W(x)
    
    print(f"\nSignal preservation:")
    print(f"  Input norm mean: {torch.norm(x, dim=1).mean().item():.4f}")
    print(f"  Output norm mean: {torch.norm(y, dim=1).mean().item():.4f}")
    
    # Through many layers
    print("\nThrough 50 orthogonal layers:")
    current = x
    for _ in range(50):
        layer = nn.Linear(n, n, bias=False)
        nn.init.orthogonal_(layer.weight)
        current = layer(current)
    
    print(f"  Final norm mean: {torch.norm(current, dim=1).mean().item():.4f}")
    print("  (Should be similar to input norm)")

orthogonal_initialization()

Advanced Initialization Techniques

4. LSUV (Layer-Sequential Unit-Variance)

A data-driven approach that iteratively normalizes each layer. While Xavier and He use mathematical formulas that assume certain activation distributions, LSUV actually runs your data through the network and adjusts each layer’s weights until the output variance is exactly 1.0. This is more robust because it accounts for the actual data distribution and any architectural quirks that the formulas cannot capture.
def lsuv_initialization(model, data_batch, tol=0.1, max_iter=10):
    """
    Layer-Sequential Unit-Variance initialization.
    
    Iteratively adjusts weights so each layer has unit variance output.
    """
    print("LSUV Initialization")
    print("="*50)
    
    model.eval()
    
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            print(f"\nProcessing layer: {name}")
            
            # Orthogonal init first
            if isinstance(module, nn.Linear):
                nn.init.orthogonal_(module.weight)
            else:
                nn.init.orthogonal_(module.weight.view(module.weight.size(0), -1))
            
            for iteration in range(max_iter):
                with torch.no_grad():
                    # Forward pass up to this layer
                    output = data_batch
                    for n, m in model.named_modules():
                        if isinstance(m, (nn.Linear, nn.Conv2d, nn.ReLU, nn.BatchNorm2d)):
                            output = m(output)
                            if n == name:
                                break
                    
                    variance = output.var().item()
                    
                    if abs(variance - 1.0) < tol:
                        print(f"  Iteration {iteration}: Var = {variance:.4f} ✓")
                        break
                    
                    # Rescale weights
                    module.weight.data /= np.sqrt(variance)
                    print(f"  Iteration {iteration}: Var = {variance:.4f} → rescaling")
    
    return model


# Example usage
class SimpleMLP(nn.Module):
    def __init__(self, layer_sizes):
        super().__init__()
        layers = []
        for i in range(len(layer_sizes) - 1):
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
            if i < len(layer_sizes) - 2:
                layers.append(nn.ReLU())
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layers(x)


# Apply LSUV
model = SimpleMLP([784, 256, 256, 256, 10])
dummy_batch = torch.randn(32, 784)
model = lsuv_initialization(model, dummy_batch)

5. Data-Dependent Initialization

def data_dependent_init(layer, data_batch, target_std=1.0):
    """
    Initialize weights based on actual data statistics.
    """
    with torch.no_grad():
        # Compute output with current weights
        output = layer(data_batch)
        current_std = output.std().item()
        
        # Scale weights to achieve target std
        scale = target_std / (current_std + 1e-8)
        layer.weight.data *= scale
        
        print(f"Data-dependent init:")
        print(f"  Initial output std: {current_std:.4f}")
        
        output = layer(data_batch)
        print(f"  Final output std: {output.std().item():.4f}")
    
    return layer


# Example
layer = nn.Linear(256, 128)
data = torch.randn(100, 256) * 5  # Data with different scale
layer = data_dependent_init(layer, data)

6. Fixup Initialization

Enables training very deep residual networks without BatchNorm. The idea: in a residual network, each block adds its output to the skip connection. If you have LL such blocks, the variance grows as O(L)O(L) unless you compensate. Fixup scales each residual branch by L0.5L^{-0.5}, and zero-initializes the last layer of each block so the network initially behaves as an identity function. This is particularly useful when BatchNorm is undesirable (e.g., in small-batch or online learning settings).
def fixup_initialization(model, num_layers):
    """
    Fixup initialization for residual networks without BatchNorm.
    
    Key ideas:
    - Scale down the last layer of each residual block
    - Zero-initialize certain layers
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            # Standard initialization
            nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            
            # Scale down if it's the last conv in a residual block
            if 'last_conv' in name or 'conv2' in name:
                module.weight.data.mul_(num_layers ** (-0.5))
        
        elif isinstance(module, nn.Linear):
            nn.init.constant_(module.weight, 0)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
    
    print(f"Fixup initialization applied for {num_layers} layers")
    print("  - Residual branch weights scaled by L^(-0.5)")
    print("  - Final layer initialized to zero")

Initialization for Specific Architectures

Transformer Initialization

def transformer_initialization():
    """Special initialization for Transformer models."""
    
    d_model = 512
    n_layers = 12
    n_heads = 8
    
    class TransformerBlock(nn.Module):
        def __init__(self, d_model, n_heads):
            super().__init__()
            self.attention = nn.MultiheadAttention(d_model, n_heads)
            self.ffn = nn.Sequential(
                nn.Linear(d_model, 4 * d_model),
                nn.GELU(),
                nn.Linear(4 * d_model, d_model)
            )
            self.norm1 = nn.LayerNorm(d_model)
            self.norm2 = nn.LayerNorm(d_model)
    
    def init_transformer_weights(module, n_layers):
        """
        GPT-2 style initialization.
        """
        if isinstance(module, nn.Linear):
            # Standard normal initialization
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)
        
        # Scale down residual projections
        # This prevents the output from growing with depth
        for name, p in module.named_parameters():
            if name.endswith('out_proj.weight') or name.endswith('ffn.2.weight'):
                # Scale by 1/√(2*n_layers)
                p.data.div_(np.sqrt(2 * n_layers))
    
    # Apply initialization
    block = TransformerBlock(d_model, n_heads)
    init_transformer_weights(block, n_layers)
    
    print("Transformer Initialization (GPT-2 style)")
    print("="*50)
    print(f"  - Linear weights: N(0, 0.02)")
    print(f"  - Residual outputs scaled by 1/√(2×{n_layers})")
    print(f"  - LayerNorm: weight=1, bias=0")

transformer_initialization()

CNN Initialization

def cnn_initialization():
    """Initialization strategies for CNNs."""
    
    class ConvBlock(nn.Module):
        def __init__(self, in_channels, out_channels):
            super().__init__()
            self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
            self.bn = nn.BatchNorm2d(out_channels)
            self.relu = nn.ReLU()
    
    # Different initialization strategies
    conv = nn.Conv2d(64, 128, 3, padding=1)
    
    print("CNN Initialization Strategies")
    print("="*50)
    
    # 1. Kaiming for ReLU
    nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='relu')
    print(f"\n1. Kaiming (fan_out, ReLU):")
    print(f"   Var = {conv.weight.var().item():.6f}")
    
    # 2. Xavier for no activation / before BN
    nn.init.xavier_uniform_(conv.weight)
    print(f"\n2. Xavier Uniform:")
    print(f"   Var = {conv.weight.var().item():.6f}")
    
    # 3. Delta initialization for identity
    def delta_init(conv):
        """Initialize conv to approximate identity."""
        nn.init.zeros_(conv.weight)
        # Set center of kernel to 1 for each in/out channel pair
        c_out, c_in, h, w = conv.weight.shape
        center_h, center_w = h // 2, w // 2
        for i in range(min(c_in, c_out)):
            conv.weight.data[i, i, center_h, center_w] = 1.0
    
    conv_delta = nn.Conv2d(64, 64, 3, padding=1)
    delta_init(conv_delta)
    print(f"\n3. Delta (identity) initialization:")
    print(f"   Center weights = 1, others = 0")
    
    # Test identity property
    x = torch.randn(1, 64, 8, 8)
    y = conv_delta(x)
    print(f"   Input-Output difference: {(x - y).abs().mean().item():.6f}")

cnn_initialization()

Practical Guidelines

Choosing the Right Initialization

def initialization_decision_tree():
    """Guide for choosing initialization."""
    
    print("""
    ╔════════════════════════════════════════════════════════════════╗
    ║              WEIGHT INITIALIZATION DECISION TREE               ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                ║
    ║  What activation function are you using?                       ║
    ║                                                                ║
    ║  ├── ReLU / Leaky ReLU / ELU                                  ║
    ║  │   └── Use He (Kaiming) initialization                      ║
    ║  │       • PyTorch: kaiming_normal_(weight, nonlinearity='relu')║
    ║  │                                                             ║
    ║  ├── Sigmoid / Tanh                                           ║
    ║  │   └── Use Xavier (Glorot) initialization                   ║
    ║  │       • PyTorch: xavier_uniform_(weight)                   ║
    ║  │                                                             ║
    ║  ├── GELU / SiLU / Swish                                      ║
    ║  │   └── Use He initialization (similar to ReLU)              ║
    ║  │                                                             ║
    ║  └── Linear (no activation)                                   ║
    ║      └── Use Xavier or small normal                           ║
    ║                                                                ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                ║
    ║  Special cases:                                                ║
    ║                                                                ║
    ║  ├── Transformers                                             ║
    ║  │   └── N(0, 0.02) + scale residual projections              ║
    ║  │                                                             ║
    ║  ├── Very deep networks (100+ layers)                         ║
    ║  │   └── Orthogonal or LSUV                                   ║
    ║  │                                                             ║
    ║  ├── RNNs / LSTMs                                             ║
    ║  │   └── Orthogonal for hidden-to-hidden weights              ║
    ║  │                                                             ║
    ║  ├── GANs                                                      ║
    ║  │   └── N(0, 0.02) often works well                          ║
    ║  │                                                             ║
    ║  └── Residual Networks without BatchNorm                      ║
    ║      └── Fixup initialization                                  ║
    ║                                                                ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                ║
    ║  Bias initialization:                                         ║
    ║  • Usually initialize to 0                                    ║
    ║  • For ReLU, small positive (0.01) can help                   ║
    ║  • For LSTM forget gate, initialize to 1                      ║
    ║                                                                ║
    ╚════════════════════════════════════════════════════════════════╝
    """)

initialization_decision_tree()

Complete Initialization Function

def initialize_model(model, init_type='auto'):
    """
    Comprehensive weight initialization.
    
    Args:
        model: PyTorch model
        init_type: 'auto', 'he', 'xavier', 'orthogonal'
    """
    
    def get_activation(module):
        """Detect activation function following this layer."""
        # This is a simplified heuristic
        return 'relu'  # Default assumption
    
    for name, module in model.named_modules():
        
        if isinstance(module, nn.Linear):
            if init_type == 'auto':
                # He init for typical ReLU networks
                nn.init.kaiming_uniform_(module.weight, a=np.sqrt(5))
            elif init_type == 'he':
                nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
            elif init_type == 'xavier':
                nn.init.xavier_uniform_(module.weight)
            elif init_type == 'orthogonal':
                nn.init.orthogonal_(module.weight)
            
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.Conv2d):
            if init_type in ['auto', 'he']:
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            elif init_type == 'xavier':
                nn.init.xavier_uniform_(module.weight)
            
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.BatchNorm2d):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, std=0.02)
        
        elif isinstance(module, nn.LSTM):
            for param_name, param in module.named_parameters():
                if 'weight_ih' in param_name:
                    nn.init.xavier_uniform_(param)
                elif 'weight_hh' in param_name:
                    nn.init.orthogonal_(param)
                elif 'bias' in param_name:
                    nn.init.zeros_(param)
                    # Set forget gate bias to 1
                    n = param.size(0)
                    param.data[n//4:n//2].fill_(1.0)
    
    print(f"Model initialized with '{init_type}' strategy")
    return model


# Example usage
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

model = initialize_model(model, init_type='he')

Diagnosing Initialization Problems

def diagnose_initialization(model, sample_input):
    """
    Diagnose if initialization is causing issues.
    """
    print("Initialization Diagnostics")
    print("="*60)
    
    model.eval()
    
    # Track statistics through layers
    activations = {}
    gradients = {}
    
    def save_activation(name):
        def hook(model, input, output):
            activations[name] = output.detach()
        return hook
    
    def save_gradient(name):
        def hook(model, grad_input, grad_output):
            if grad_output[0] is not None:
                gradients[name] = grad_output[0].detach()
        return hook
    
    # Register hooks
    handles = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            handles.append(module.register_forward_hook(save_activation(name)))
            handles.append(module.register_full_backward_hook(save_gradient(name)))
    
    # Forward pass
    x = sample_input.requires_grad_(True)
    output = model(x)
    
    # Backward pass
    loss = output.sum()
    loss.backward()
    
    # Analyze
    print("\nActivation Statistics:")
    print("-" * 60)
    print(f"{'Layer':<30} {'Mean':>10} {'Std':>10} {'%Dead':>10}")
    print("-" * 60)
    
    issues = []
    
    for name, act in activations.items():
        mean = act.mean().item()
        std = act.std().item()
        dead_fraction = (act == 0).float().mean().item() * 100
        
        print(f"{name:<30} {mean:>10.4f} {std:>10.4f} {dead_fraction:>9.1f}%")
        
        if std < 0.01:
            issues.append(f"  ⚠ {name}: Very small std (vanishing activations)")
        if std > 10:
            issues.append(f"  ⚠ {name}: Very large std (exploding activations)")
        if dead_fraction > 50:
            issues.append(f"  ⚠ {name}: >50% dead neurons")
    
    print("\nGradient Statistics:")
    print("-" * 60)
    print(f"{'Layer':<30} {'Mean':>10} {'Std':>10}")
    print("-" * 60)
    
    for name, grad in gradients.items():
        mean = grad.mean().item()
        std = grad.std().item()
        print(f"{name:<30} {mean:>10.6f} {std:>10.6f}")
        
        if std < 1e-6:
            issues.append(f"  ⚠ {name}: Very small gradient std (vanishing)")
        if std > 100:
            issues.append(f"  ⚠ {name}: Very large gradient std (exploding)")
    
    # Clean up hooks
    for handle in handles:
        handle.remove()
    
    print("\n" + "="*60)
    if issues:
        print("ISSUES DETECTED:")
        for issue in issues:
            print(issue)
    else:
        print("✓ Initialization looks healthy!")
    
    return activations, gradients


# Example usage
model = nn.Sequential(
    nn.Linear(784, 512),
    nn.ReLU(),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

sample = torch.randn(32, 784)
diagnose_initialization(model, sample)

Exercises

Train the same network with different initialization methods and compare:
  • Training curves
  • Final accuracy
  • Time to converge
def exercise_1():
    init_methods = ['he', 'xavier', 'orthogonal', 'small_normal']
    
    for method in init_methods:
        model = create_model()
        initialize_with(model, method)
        history = train(model, epochs=20)
        plot(history, label=method)
Implement Layer-Sequential Unit-Variance initialization:
def exercise_2():
    def lsuv(model, data):
        for layer in model.layers:
            # Initialize orthogonally
            orthogonal_init(layer)
            
            # Iteratively scale to unit variance
            for _ in range(max_iter):
                output = forward_to_layer(model, data, layer)
                variance = output.var()
                
                if abs(variance - 1.0) < tol:
                    break
                
                layer.weight /= sqrt(variance)
For a 100-layer network, visualize how gradients flow backward:
def exercise_3():
    model = DeepNetwork(100_layers)
    
    for init_method in ['he', 'xavier', 'orthogonal']:
        initialize(model, init_method)
        
        # Forward and backward pass
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        
        # Plot gradient magnitudes per layer
        gradient_norms = [layer.weight.grad.norm() for layer in model.layers]
        plot(gradient_norms, label=init_method)

What’s Next?

Gradient Flow Analysis

Understand gradient dynamics in detail

Advanced CNN Architectures

VGG, Inception, ResNets, EfficientNets

Interview Deep-Dive

Strong Answer:Both methods set the initial weight variance to keep activations and gradients stable across layers, but they make different assumptions about the activation function.Xavier (Glorot) initialization sets Var(W) = 2 / (fan_in + fan_out). It assumes the activation function is approximately linear around zero, which is true for tanh and sigmoid in their active region. By averaging fan_in and fan_out, it balances variance preservation in both the forward pass (activations) and backward pass (gradients).He (Kaiming) initialization sets Var(W) = 2 / fan_in. The factor of 2 compensates for ReLU zeroing out half the activations. Without this correction, each layer reduces the activation variance by half, and after 50 layers the signal has been attenuated by 0.5^50, which is approximately 10^ — effectively zero.Using Xavier with ReLU will not cause an immediate crash, but training will be noticeably slower and may stagnate in deeper networks (20+ layers). The activation variance shrinks layer by layer, causing vanishing gradients in the early layers. Using He with sigmoid is less common but can cause a different problem: the weights are too large, pushing sigmoid inputs into the saturation regions where gradients are near zero, effectively creating vanishing gradients through a different mechanism.Where it fails completely: a 100-layer network with ReLU activations and Xavier initialization. The activation variance after 100 layers is roughly (0.5)^100 * initial_variance, which is on the order of 10^. Gradients in the first layer will be computationally zero, and the network will not learn. Switching to He initialization fixes this immediately. I have seen this exact failure in practice when someone used the default PyTorch initialization (which historically differed between layer types) without checking.Follow-up: Modern transformer models typically use N(0, 0.02) initialization instead of He or Xavier. Why does this work despite seemingly ignoring the fan-in/fan-out theory?It works because transformers use LayerNorm (or RMSNorm) before every attention and FFN sublayer. LayerNorm re-normalizes the activations to zero mean and unit variance at every layer, regardless of what the weights do. This acts as a safety net that prevents both vanishing and exploding activations, making the initialization less critical. The 0.02 standard deviation is empirically tuned to give a good starting point, but LayerNorm would rescue training even if you used 0.01 or 0.05.The one place where initialization still matters critically in transformers is the residual output projections. GPT-2 scales these by 1/sqrt(2*N_layers) to prevent the residual stream’s variance from growing with depth. Without this scaling, the output of the transformer would grow proportionally to sqrt(N_layers), eventually causing numerical instability even with LayerNorm.
Strong Answer:If all weights in a layer are initialized to the same value (including zero), every neuron in that layer computes the exact same function of the input. During the backward pass, they receive the exact same gradient. After the weight update, they still have the same weights (just shifted by the same amount). This symmetry persists forever — the network has N neurons in the layer but the effective capacity of just one neuron. Depth does not help either: every layer remains a scaled version of a single computation.What breaks symmetry is randomness in the initialization. By drawing weights from a distribution (Gaussian or uniform), each neuron starts with a slightly different linear function. During training, different neurons receive different gradients because they compute different activations, and they diverge further with each update. The network develops specialized neurons — some detect edges, others detect textures, others detect specific patterns.A subtlety that catches people: zero-initializing biases is fine and common. Biases do not participate in the symmetry problem because the asymmetry comes from the weight matrix (each neuron’s different weight vector gives it a different “view” of the input). Setting all biases to zero just shifts all neurons’ activation thresholds to the same starting point, but the different weight vectors still produce different pre-activation values.One important exception: residual connections with zero initialization. Some architectures (GPT-2, Fixup) deliberately initialize the last layer of each residual block to zero. This does not create a symmetry problem because the skip connection ensures the block’s output is just the identity function initially. During training, the zero-initialized layer breaks symmetry naturally as soon as it receives its first non-zero gradient. This technique actually helps training stability by making the network start as a shallow network and gradually become deeper.Follow-up: Could you use a structured (non-random) initialization that still breaks symmetry? When would this be preferable?Yes. Orthogonal initialization is one example — it is deterministic given a specific seed and produces a structured matrix where columns are mutually orthogonal. This is not random in the same sense as drawing from a Gaussian, but it breaks symmetry because each neuron’s weight vector points in a unique orthogonal direction. Orthogonal initialization has a specific advantage: it preserves norms exactly (the singular values are all 1.0), so signals neither shrink nor grow as they pass through the layer. For very deep networks (100+ layers) and RNNs, this exact norm preservation can be the difference between training and not training.Another structured approach is delta initialization for convolutional layers — initializing the kernel to approximate an identity mapping (center pixel = 1, rest = 0). This is useful in super-resolution networks where the network needs to learn a small residual correction rather than reconstruct the entire image from scratch. The “symmetry” is broken because different input-output channel pairs start with different identity mappings.
Strong Answer:I would approach this systematically, spending the first hour diagnosing before changing anything.Step one: instrument the network. Add forward hooks to every layer (or every 10th layer) that log activation mean, standard deviation, and percentage of dead neurons (activations exactly zero for ReLU). Run a single forward pass on one batch. Plot activation statistics by layer depth. This immediately tells you whether activations are vanishing (std dropping toward zero in later layers), exploding (std growing exponentially), or collapsing (all neurons producing the same output — a sign of symmetry or mode collapse).Step two: check gradients. Add backward hooks and compute the gradient norm at each layer after one backward pass. Plot these norms. In a healthy network, gradient norms should be roughly constant across layers. If they decay by orders of magnitude from the output to the input, you have vanishing gradients. If they grow, exploding gradients.Step three: verify the initialization scheme matches the activation functions. If the network uses ReLU and Xavier initialization, that is the first thing to fix — switch to He (Kaiming). Check whether the team is using PyTorch’s default initialization (which varies by layer type) or a custom scheme.Step four: check for architectural issues that initialization alone cannot fix. A 200-layer CNN without residual connections is essentially untrainable regardless of initialization. The gradient signal must traverse 200 multiplicative layers, and no initialization keeps all 200 Jacobians at exactly 1.0. Recommend adding skip connections. If skip connections exist, verify they are implemented correctly (a common bug is applying normalization on the skip path, which can disrupt gradient flow).Step five: try LSUV as a data-driven initialization that empirically normalizes each layer’s output variance to 1.0. This accounts for non-linearities, batch norm interactions, and any architectural quirks that analytical formulas miss.In my experience, the root cause for a 200-layer network is almost always missing or broken skip connections, not initialization. But proper initialization is still necessary for the skip-connected network to train quickly rather than slowly.Follow-up: How would you diagnose whether the problem is initialization versus learning rate versus architecture?Quick differential diagnosis: (1) If the first forward pass already shows vanishing/exploding activations (before any training), the problem is initialization. (2) If the first forward pass looks healthy but gradients explode after a few steps, the learning rate is too high. (3) If activations and gradients look reasonable but loss does not decrease even after 1000 steps with multiple learning rates, the architecture is the bottleneck — likely missing skip connections or too many sequential non-linearities without normalization. You can test this by training only the last 10 layers (freezing the first 190) — if that works, the architecture prevents gradient flow to early layers.
Strong Answer:LSUV (Layer-Sequential Unit-Variance) is a data-driven initialization method that iteratively adjusts each layer’s weights until its output variance on a real data batch is exactly 1.0. The process: initialize each layer with orthogonal weights, then for each layer sequentially, run a forward pass on real data, measure the output variance, and divide the weights by sqrt(variance). Repeat until the variance is within tolerance of 1.0.The advantage over He and Xavier is that LSUV makes no assumptions about the activation function, the interaction between layers, or the data distribution. He initialization assumes ReLU and independent Gaussian inputs. Xavier assumes approximately linear activations. Neither accounts for the actual data, batch normalization interactions, or architectural peculiarities like unusual skip connection patterns.I would choose LSUV in three situations. First, when using unconventional activation functions (PReLU with learned slope, Mish, SELU) where neither He nor Xavier’s assumptions hold. Second, when the architecture has complex interactions between layers (attention mechanisms, gating, feature concatenation) that make analytical variance computation intractable. Third, when training without batch normalization — BN acts as a per-layer variance normalizer during training, masking bad initialization. Without BN, initialization quality matters much more, and LSUV provides the data-aware normalization that BN would have provided.The downside is that LSUV adds a few seconds to model initialization and requires access to a data batch at initialization time. In practice, this is almost never a problem, but it means you cannot initialize the model before the data pipeline is ready.Follow-up: If you use LSUV on a network with batch normalization, does it have any effect?Very little, and this illustrates an important point. Batch normalization normalizes each layer’s output to zero mean and unit variance during training, effectively re-doing what LSUV did at initialization — but continuously, at every training step. So BN largely negates the benefits of careful initialization after the first few gradient steps. This is why BN made training deep networks much more forgiving of initialization choices, and why researchers sometimes describe BN as “making initialization not matter.”However, LSUV still helps even with BN in one specific way: it gives the network a better starting point, which means the first few training steps are more productive. In practice, this can translate to faster early convergence — maybe reaching a given loss level 10-20% sooner — though the final converged accuracy is usually the same regardless of initialization when BN is present.