Normalization Techniques
Why Normalize?
Deep networks suffer from internal covariate shift — the distribution of inputs to each layer changes during training as weights update.
This causes:
- Slower training (need smaller learning rates)
- Difficulty with saturating activations
- Careful initialization requirements
Normalization stabilizes these distributions, enabling:
- Higher learning rates
- Faster convergence
- Reduced sensitivity to initialization
Batch Normalization
Normalize across the batch dimension:
x^i=σB2+ϵxi−μB⋅γ+β
Where μB and σB are batch statistics, and γ, β are learnable.
import torch
import torch.nn as nn
class BatchNorm(nn.Module):
"""Batch Normalization from scratch."""
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super().__init__()
self.eps = eps
self.momentum = momentum
# Learnable parameters
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
# Running statistics (not learnable)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
def forward(self, x):
if self.training:
# Compute batch statistics
mean = x.mean(dim=(0, 2, 3)) # Mean over batch, H, W
var = x.var(dim=(0, 2, 3), unbiased=False)
# Update running statistics
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
else:
mean = self.running_mean
var = self.running_var
# Normalize
x = (x - mean.view(1, -1, 1, 1)) / torch.sqrt(var.view(1, -1, 1, 1) + self.eps)
# Scale and shift
return self.gamma.view(1, -1, 1, 1) * x + self.beta.view(1, -1, 1, 1)
BatchNorm behaves differently at train vs eval! Always call model.eval() before inference.
Layer Normalization
Normalize across the feature dimension (independent of batch):
x^=σ2+ϵx−μ⋅γ+β
Where μ and σ are computed per sample across all features.
class LayerNorm(nn.Module):
"""Layer Normalization."""
def __init__(self, normalized_shape, eps=1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(normalized_shape))
self.beta = nn.Parameter(torch.zeros(normalized_shape))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
return self.gamma * (x - mean) / torch.sqrt(var + self.eps) + self.beta
Used in: Transformers, RNNs (works with variable batch sizes)
Comparison of Normalization Types
| Type | Normalize Over | Best For |
|---|
| Batch Norm | Batch, H, W | CNNs with large batches |
| Layer Norm | C, H, W (per sample) | Transformers, RNNs |
| Instance Norm | H, W (per channel) | Style transfer |
| Group Norm | Groups of channels | Small batches |
| RMSNorm | Features (no mean) | LLMs (faster) |
RMSNorm (Modern LLMs)
RMSNorm(x)=RMS(x)x⋅γ,RMS(x)=n1∑xi2
class RMSNorm(nn.Module):
"""Root Mean Square Normalization (used in LLaMA, etc.)."""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x / rms * self.weight
When to Use What
| Scenario | Recommendation |
|---|
| CNN with batch ≥ 32 | Batch Norm |
| CNN with small batch | Group Norm |
| Transformer/Attention | Layer Norm or RMSNorm |
| RNN/LSTM | Layer Norm |
| Style Transfer | Instance Norm |
| Modern LLM | RMSNorm |
Exercises
Exercise 1: BatchNorm Analysis
Train the same CNN with and without BatchNorm. Compare learning curves, final accuracy, and sensitivity to learning rate.
Exercise 2: Small Batch Study
Compare BatchNorm vs GroupNorm with batch sizes of 2, 4, 8, 16, 32.
Exercise 3: Pre-Norm vs Post-Norm
In transformers, compare placing LayerNorm before vs after attention/FFN blocks.
What’s Next