Weight Initialization: The Foundation of Training
Why Initialization Matters
Poor weight initialization can cause:- Vanishing gradients: Signals shrink to zero, learning stops
- Exploding gradients: Signals blow up, NaN everywhere
- Symmetry: All neurons learn the same thing
- Slow convergence: Training takes forever
- Stable activations: Signals don’t vanish or explode
- Broken symmetry: Each neuron learns something different
- Fast convergence: Networks train efficiently
Reality Check: A neural network with poor initialization might never converge, while the same architecture with proper initialization trains smoothly. Initialization is that important!
Copy
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from scipy import stats
# Reproducibility
np.random.seed(42)
torch.manual_seed(42)
The Problem: Signal Propagation
Observing Vanishing/Exploding Activations
Copy
def demonstrate_initialization_problem():
"""Show how bad initialization kills gradients."""
# Deep network with different initializations
def create_network(init_std, depth=50, width=256):
layers = []
for i in range(depth):
layer = nn.Linear(width, width, bias=False)
# Initialize with given std
nn.init.normal_(layer.weight, std=init_std)
layers.append(layer)
layers.append(nn.Tanh())
return nn.Sequential(*layers)
# Test different initialization scales
init_stds = [0.01, 0.1, 1.0, 2.0]
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()
for ax, std in zip(axes, init_stds):
model = create_network(std)
# Forward pass with random input
x = torch.randn(100, 256)
# Track activations at each layer
activation_means = []
activation_stds = []
current = x
with torch.no_grad():
for layer in model:
current = layer(current)
if isinstance(layer, nn.Linear):
activation_means.append(current.mean().item())
activation_stds.append(current.std().item())
# Plot
layers_idx = range(len(activation_stds))
ax.plot(layers_idx, activation_stds, 'b-', label='Std')
ax.plot(layers_idx, activation_means, 'r--', label='Mean')
ax.set_xlabel('Layer')
ax.set_ylabel('Activation Statistics')
ax.set_title(f'Init std = {std}')
ax.legend()
ax.set_yscale('symlog') # Symmetric log scale
ax.grid(True, alpha=0.3)
# Check if vanishing or exploding
final_std = activation_stds[-1] if activation_stds else 0
if final_std < 1e-5:
ax.text(0.5, 0.5, 'VANISHING!', transform=ax.transAxes,
fontsize=20, color='red', alpha=0.5, ha='center')
elif final_std > 1e5:
ax.text(0.5, 0.5, 'EXPLODING!', transform=ax.transAxes,
fontsize=20, color='red', alpha=0.5, ha='center')
plt.tight_layout()
plt.suptitle('Effect of Initialization Scale on Deep Networks', y=1.02)
plt.show()
demonstrate_initialization_problem()
Mathematical Analysis
For a layer y=Wx where W has shape (nout,nin): Var(yj)=nin⋅Var(W)⋅Var(x) To maintain stable variance across layers, we need: Var(W)=nin1Copy
def variance_propagation_analysis():
"""Analyze how variance propagates through layers."""
n_in = 256
n_out = 256
n_samples = 10000
# Different initialization strategies
strategies = {
'Small (0.01)': 0.01,
'Medium (0.1)': 0.1,
'Correct (1/√n)': 1.0 / np.sqrt(n_in),
'Large (1.0)': 1.0
}
print("Variance Propagation Analysis")
print("="*60)
print(f"Input dimension: {n_in}")
print(f"Expected variance to maintain: Var(W) = 1/{n_in} = {1/n_in:.6f}")
print()
x = np.random.randn(n_samples, n_in) # Input with unit variance
for name, std in strategies.items():
W = np.random.randn(n_in, n_out) * std
y = x @ W
var_W = np.var(W)
var_y = np.var(y)
expected_var_y = n_in * var_W * np.var(x)
print(f"{name}:")
print(f" Var(W) = {var_W:.6f}")
print(f" Var(y) = {var_y:.4f} (expected: {expected_var_y:.4f})")
# After 50 layers
var_after_50 = var_y ** 50
print(f" After 50 layers: {var_after_50:.2e}")
if var_after_50 < 1e-10:
print(f" → VANISHING")
elif var_after_50 > 1e10:
print(f" → EXPLODING")
else:
print(f" → STABLE ✓")
print()
variance_propagation_analysis()
Classic Initialization Methods
1. Xavier/Glorot Initialization
Designed for tanh and sigmoid activations: W∼N(0,nin+nout2)orW∼U(−nin+nout6,nin+nout6)Copy
def xavier_initialization():
"""Xavier/Glorot initialization for tanh/sigmoid."""
n_in, n_out = 512, 256
# Normal variant
std = np.sqrt(2.0 / (n_in + n_out))
W_normal = np.random.randn(n_in, n_out) * std
# Uniform variant
limit = np.sqrt(6.0 / (n_in + n_out))
W_uniform = np.random.uniform(-limit, limit, (n_in, n_out))
print("Xavier/Glorot Initialization")
print("="*50)
print(f"n_in={n_in}, n_out={n_out}")
print(f"\nNormal variant: std = √(2/(n_in+n_out)) = {std:.6f}")
print(f" Actual Var(W) = {np.var(W_normal):.6f}")
print(f"\nUniform variant: limit = √(6/(n_in+n_out)) = {limit:.6f}")
print(f" Actual Var(W) = {np.var(W_uniform):.6f}")
# PyTorch implementation
linear = nn.Linear(n_in, n_out)
nn.init.xavier_normal_(linear.weight)
print(f"\nPyTorch xavier_normal_: Var = {linear.weight.var().item():.6f}")
nn.init.xavier_uniform_(linear.weight)
print(f"PyTorch xavier_uniform_: Var = {linear.weight.var().item():.6f}")
xavier_initialization()
2. He/Kaiming Initialization
Designed for ReLU and its variants: W∼N(0,nin2) The factor of 2 compensates for ReLU zeroing out half the activations.Copy
def he_initialization():
"""He/Kaiming initialization for ReLU."""
n_in, n_out = 512, 256
# For ReLU
std = np.sqrt(2.0 / n_in)
W_relu = np.random.randn(n_in, n_out) * std
# For Leaky ReLU (negative_slope = 0.01)
negative_slope = 0.01
std_leaky = np.sqrt(2.0 / (1 + negative_slope**2) / n_in)
W_leaky = np.random.randn(n_in, n_out) * std_leaky
print("He/Kaiming Initialization")
print("="*50)
print(f"n_in={n_in}")
print(f"\nFor ReLU: std = √(2/n_in) = {std:.6f}")
print(f" Var(W) = {np.var(W_relu):.6f}")
print(f"\nFor Leaky ReLU (slope={negative_slope}):")
print(f" std = √(2/(1+slope²)/n_in) = {std_leaky:.6f}")
print(f" Var(W) = {np.var(W_leaky):.6f}")
# PyTorch implementation
linear = nn.Linear(n_in, n_out)
nn.init.kaiming_normal_(linear.weight, mode='fan_in', nonlinearity='relu')
print(f"\nPyTorch kaiming_normal_ (fan_in, relu): Var = {linear.weight.var().item():.6f}")
nn.init.kaiming_normal_(linear.weight, mode='fan_out', nonlinearity='relu')
print(f"PyTorch kaiming_normal_ (fan_out, relu): Var = {linear.weight.var().item():.6f}")
he_initialization()
3. Orthogonal Initialization
Initializes weights as orthogonal matrices — preserves norms exactly: WTW=ICopy
def orthogonal_initialization():
"""Orthogonal initialization for stable signal propagation."""
n = 256
# Create orthogonal matrix
W = nn.Linear(n, n)
nn.init.orthogonal_(W.weight)
# Verify orthogonality
WtW = W.weight @ W.weight.T
print("Orthogonal Initialization")
print("="*50)
print(f"\nW^T @ W should be identity:")
print(f" Diagonal mean: {torch.diag(WtW).mean().item():.6f} (should be 1)")
print(f" Off-diagonal std: {(WtW - torch.eye(n)).std().item():.6f} (should be 0)")
# Signal preservation
x = torch.randn(100, n)
y = W(x)
print(f"\nSignal preservation:")
print(f" Input norm mean: {torch.norm(x, dim=1).mean().item():.4f}")
print(f" Output norm mean: {torch.norm(y, dim=1).mean().item():.4f}")
# Through many layers
print("\nThrough 50 orthogonal layers:")
current = x
for _ in range(50):
layer = nn.Linear(n, n, bias=False)
nn.init.orthogonal_(layer.weight)
current = layer(current)
print(f" Final norm mean: {torch.norm(current, dim=1).mean().item():.4f}")
print(" (Should be similar to input norm)")
orthogonal_initialization()
Advanced Initialization Techniques
4. LSUV (Layer-Sequential Unit-Variance)
A data-driven approach that iteratively normalizes each layer:Copy
def lsuv_initialization(model, data_batch, tol=0.1, max_iter=10):
"""
Layer-Sequential Unit-Variance initialization.
Iteratively adjusts weights so each layer has unit variance output.
"""
print("LSUV Initialization")
print("="*50)
model.eval()
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
print(f"\nProcessing layer: {name}")
# Orthogonal init first
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight)
else:
nn.init.orthogonal_(module.weight.view(module.weight.size(0), -1))
for iteration in range(max_iter):
with torch.no_grad():
# Forward pass up to this layer
output = data_batch
for n, m in model.named_modules():
if isinstance(m, (nn.Linear, nn.Conv2d, nn.ReLU, nn.BatchNorm2d)):
output = m(output)
if n == name:
break
variance = output.var().item()
if abs(variance - 1.0) < tol:
print(f" Iteration {iteration}: Var = {variance:.4f} ✓")
break
# Rescale weights
module.weight.data /= np.sqrt(variance)
print(f" Iteration {iteration}: Var = {variance:.4f} → rescaling")
return model
# Example usage
class SimpleMLP(nn.Module):
def __init__(self, layer_sizes):
super().__init__()
layers = []
for i in range(len(layer_sizes) - 1):
layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
if i < len(layer_sizes) - 2:
layers.append(nn.ReLU())
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
# Apply LSUV
model = SimpleMLP([784, 256, 256, 256, 10])
dummy_batch = torch.randn(32, 784)
model = lsuv_initialization(model, dummy_batch)
5. Data-Dependent Initialization
Copy
def data_dependent_init(layer, data_batch, target_std=1.0):
"""
Initialize weights based on actual data statistics.
"""
with torch.no_grad():
# Compute output with current weights
output = layer(data_batch)
current_std = output.std().item()
# Scale weights to achieve target std
scale = target_std / (current_std + 1e-8)
layer.weight.data *= scale
print(f"Data-dependent init:")
print(f" Initial output std: {current_std:.4f}")
output = layer(data_batch)
print(f" Final output std: {output.std().item():.4f}")
return layer
# Example
layer = nn.Linear(256, 128)
data = torch.randn(100, 256) * 5 # Data with different scale
layer = data_dependent_init(layer, data)
6. Fixup Initialization
Enables training very deep networks without normalization:Copy
def fixup_initialization(model, num_layers):
"""
Fixup initialization for residual networks without BatchNorm.
Key ideas:
- Scale down the last layer of each residual block
- Zero-initialize certain layers
"""
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
# Standard initialization
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
# Scale down if it's the last conv in a residual block
if 'last_conv' in name or 'conv2' in name:
module.weight.data.mul_(num_layers ** (-0.5))
elif isinstance(module, nn.Linear):
nn.init.constant_(module.weight, 0)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
print(f"Fixup initialization applied for {num_layers} layers")
print(" - Residual branch weights scaled by L^(-0.5)")
print(" - Final layer initialized to zero")
Initialization for Specific Architectures
Transformer Initialization
Copy
def transformer_initialization():
"""Special initialization for Transformer models."""
d_model = 512
n_layers = 12
n_heads = 8
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, n_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def init_transformer_weights(module, n_layers):
"""
GPT-2 style initialization.
"""
if isinstance(module, nn.Linear):
# Standard normal initialization
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
# Scale down residual projections
# This prevents the output from growing with depth
for name, p in module.named_parameters():
if name.endswith('out_proj.weight') or name.endswith('ffn.2.weight'):
# Scale by 1/√(2*n_layers)
p.data.div_(np.sqrt(2 * n_layers))
# Apply initialization
block = TransformerBlock(d_model, n_heads)
init_transformer_weights(block, n_layers)
print("Transformer Initialization (GPT-2 style)")
print("="*50)
print(f" - Linear weights: N(0, 0.02)")
print(f" - Residual outputs scaled by 1/√(2×{n_layers})")
print(f" - LayerNorm: weight=1, bias=0")
transformer_initialization()
CNN Initialization
Copy
def cnn_initialization():
"""Initialization strategies for CNNs."""
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
# Different initialization strategies
conv = nn.Conv2d(64, 128, 3, padding=1)
print("CNN Initialization Strategies")
print("="*50)
# 1. Kaiming for ReLU
nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='relu')
print(f"\n1. Kaiming (fan_out, ReLU):")
print(f" Var = {conv.weight.var().item():.6f}")
# 2. Xavier for no activation / before BN
nn.init.xavier_uniform_(conv.weight)
print(f"\n2. Xavier Uniform:")
print(f" Var = {conv.weight.var().item():.6f}")
# 3. Delta initialization for identity
def delta_init(conv):
"""Initialize conv to approximate identity."""
nn.init.zeros_(conv.weight)
# Set center of kernel to 1 for each in/out channel pair
c_out, c_in, h, w = conv.weight.shape
center_h, center_w = h // 2, w // 2
for i in range(min(c_in, c_out)):
conv.weight.data[i, i, center_h, center_w] = 1.0
conv_delta = nn.Conv2d(64, 64, 3, padding=1)
delta_init(conv_delta)
print(f"\n3. Delta (identity) initialization:")
print(f" Center weights = 1, others = 0")
# Test identity property
x = torch.randn(1, 64, 8, 8)
y = conv_delta(x)
print(f" Input-Output difference: {(x - y).abs().mean().item():.6f}")
cnn_initialization()
Practical Guidelines
Choosing the Right Initialization
Copy
def initialization_decision_tree():
"""Guide for choosing initialization."""
print("""
╔════════════════════════════════════════════════════════════════╗
║ WEIGHT INITIALIZATION DECISION TREE ║
╠════════════════════════════════════════════════════════════════╣
║ ║
║ What activation function are you using? ║
║ ║
║ ├── ReLU / Leaky ReLU / ELU ║
║ │ └── Use He (Kaiming) initialization ║
║ │ • PyTorch: kaiming_normal_(weight, nonlinearity='relu')║
║ │ ║
║ ├── Sigmoid / Tanh ║
║ │ └── Use Xavier (Glorot) initialization ║
║ │ • PyTorch: xavier_uniform_(weight) ║
║ │ ║
║ ├── GELU / SiLU / Swish ║
║ │ └── Use He initialization (similar to ReLU) ║
║ │ ║
║ └── Linear (no activation) ║
║ └── Use Xavier or small normal ║
║ ║
╠════════════════════════════════════════════════════════════════╣
║ ║
║ Special cases: ║
║ ║
║ ├── Transformers ║
║ │ └── N(0, 0.02) + scale residual projections ║
║ │ ║
║ ├── Very deep networks (100+ layers) ║
║ │ └── Orthogonal or LSUV ║
║ │ ║
║ ├── RNNs / LSTMs ║
║ │ └── Orthogonal for hidden-to-hidden weights ║
║ │ ║
║ ├── GANs ║
║ │ └── N(0, 0.02) often works well ║
║ │ ║
║ └── Residual Networks without BatchNorm ║
║ └── Fixup initialization ║
║ ║
╠════════════════════════════════════════════════════════════════╣
║ ║
║ Bias initialization: ║
║ • Usually initialize to 0 ║
║ • For ReLU, small positive (0.01) can help ║
║ • For LSTM forget gate, initialize to 1 ║
║ ║
╚════════════════════════════════════════════════════════════════╝
""")
initialization_decision_tree()
Complete Initialization Function
Copy
def initialize_model(model, init_type='auto'):
"""
Comprehensive weight initialization.
Args:
model: PyTorch model
init_type: 'auto', 'he', 'xavier', 'orthogonal'
"""
def get_activation(module):
"""Detect activation function following this layer."""
# This is a simplified heuristic
return 'relu' # Default assumption
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
if init_type == 'auto':
# He init for typical ReLU networks
nn.init.kaiming_uniform_(module.weight, a=np.sqrt(5))
elif init_type == 'he':
nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
elif init_type == 'xavier':
nn.init.xavier_uniform_(module.weight)
elif init_type == 'orthogonal':
nn.init.orthogonal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
if init_type in ['auto', 'he']:
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
elif init_type == 'xavier':
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=0.02)
elif isinstance(module, nn.LSTM):
for param_name, param in module.named_parameters():
if 'weight_ih' in param_name:
nn.init.xavier_uniform_(param)
elif 'weight_hh' in param_name:
nn.init.orthogonal_(param)
elif 'bias' in param_name:
nn.init.zeros_(param)
# Set forget gate bias to 1
n = param.size(0)
param.data[n//4:n//2].fill_(1.0)
print(f"Model initialized with '{init_type}' strategy")
return model
# Example usage
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
model = initialize_model(model, init_type='he')
Diagnosing Initialization Problems
Copy
def diagnose_initialization(model, sample_input):
"""
Diagnose if initialization is causing issues.
"""
print("Initialization Diagnostics")
print("="*60)
model.eval()
# Track statistics through layers
activations = {}
gradients = {}
def save_activation(name):
def hook(model, input, output):
activations[name] = output.detach()
return hook
def save_gradient(name):
def hook(model, grad_input, grad_output):
if grad_output[0] is not None:
gradients[name] = grad_output[0].detach()
return hook
# Register hooks
handles = []
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
handles.append(module.register_forward_hook(save_activation(name)))
handles.append(module.register_full_backward_hook(save_gradient(name)))
# Forward pass
x = sample_input.requires_grad_(True)
output = model(x)
# Backward pass
loss = output.sum()
loss.backward()
# Analyze
print("\nActivation Statistics:")
print("-" * 60)
print(f"{'Layer':<30} {'Mean':>10} {'Std':>10} {'%Dead':>10}")
print("-" * 60)
issues = []
for name, act in activations.items():
mean = act.mean().item()
std = act.std().item()
dead_fraction = (act == 0).float().mean().item() * 100
print(f"{name:<30} {mean:>10.4f} {std:>10.4f} {dead_fraction:>9.1f}%")
if std < 0.01:
issues.append(f" ⚠ {name}: Very small std (vanishing activations)")
if std > 10:
issues.append(f" ⚠ {name}: Very large std (exploding activations)")
if dead_fraction > 50:
issues.append(f" ⚠ {name}: >50% dead neurons")
print("\nGradient Statistics:")
print("-" * 60)
print(f"{'Layer':<30} {'Mean':>10} {'Std':>10}")
print("-" * 60)
for name, grad in gradients.items():
mean = grad.mean().item()
std = grad.std().item()
print(f"{name:<30} {mean:>10.6f} {std:>10.6f}")
if std < 1e-6:
issues.append(f" ⚠ {name}: Very small gradient std (vanishing)")
if std > 100:
issues.append(f" ⚠ {name}: Very large gradient std (exploding)")
# Clean up hooks
for handle in handles:
handle.remove()
print("\n" + "="*60)
if issues:
print("ISSUES DETECTED:")
for issue in issues:
print(issue)
else:
print("✓ Initialization looks healthy!")
return activations, gradients
# Example usage
model = nn.Sequential(
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
sample = torch.randn(32, 784)
diagnose_initialization(model, sample)
Exercises
Exercise 1: Compare Initialization Methods
Exercise 1: Compare Initialization Methods
Train the same network with different initialization methods and compare:
- Training curves
- Final accuracy
- Time to converge
Copy
def exercise_1():
init_methods = ['he', 'xavier', 'orthogonal', 'small_normal']
for method in init_methods:
model = create_model()
initialize_with(model, method)
history = train(model, epochs=20)
plot(history, label=method)
Exercise 2: Implement LSUV from Scratch
Exercise 2: Implement LSUV from Scratch
Implement Layer-Sequential Unit-Variance initialization:
Copy
def exercise_2():
def lsuv(model, data):
for layer in model.layers:
# Initialize orthogonally
orthogonal_init(layer)
# Iteratively scale to unit variance
for _ in range(max_iter):
output = forward_to_layer(model, data, layer)
variance = output.var()
if abs(variance - 1.0) < tol:
break
layer.weight /= sqrt(variance)
Exercise 3: Analyze Gradient Flow
Exercise 3: Analyze Gradient Flow
For a 100-layer network, visualize how gradients flow backward:
Copy
def exercise_3():
model = DeepNetwork(100_layers)
for init_method in ['he', 'xavier', 'orthogonal']:
initialize(model, init_method)
# Forward and backward pass
output = model(x)
loss = criterion(output, y)
loss.backward()
# Plot gradient magnitudes per layer
gradient_norms = [layer.weight.grad.norm() for layer in model.layers]
plot(gradient_norms, label=init_method)