Skip to main content
Normalization Techniques

Normalization Techniques

Why Normalize?

Deep networks suffer from internal covariate shift — the distribution of inputs to each layer changes during training as weights update. This causes:
  • Slower training (need smaller learning rates)
  • Difficulty with saturating activations
  • Careful initialization requirements
Normalization stabilizes these distributions, enabling:
  • Higher learning rates
  • Faster convergence
  • Reduced sensitivity to initialization

Batch Normalization

Normalize across the batch dimension: x^i=xiμBσB2+ϵγ+β\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} \cdot \gamma + \beta Where μB\mu_B and σB\sigma_B are batch statistics, and γ\gamma, β\beta are learnable.
import torch
import torch.nn as nn

class BatchNorm(nn.Module):
    """Batch Normalization from scratch."""
    
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.eps = eps
        self.momentum = momentum
        
        # Learnable parameters
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        
        # Running statistics (not learnable)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
    
    def forward(self, x):
        if self.training:
            # Compute batch statistics
            mean = x.mean(dim=(0, 2, 3))  # Mean over batch, H, W
            var = x.var(dim=(0, 2, 3), unbiased=False)
            
            # Update running statistics
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var
        
        # Normalize
        x = (x - mean.view(1, -1, 1, 1)) / torch.sqrt(var.view(1, -1, 1, 1) + self.eps)
        
        # Scale and shift
        return self.gamma.view(1, -1, 1, 1) * x + self.beta.view(1, -1, 1, 1)
BatchNorm behaves differently at train vs eval! Always call model.eval() before inference.

Layer Normalization

Normalize across the feature dimension (independent of batch): x^=xμσ2+ϵγ+β\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta Where μ\mu and σ\sigma are computed per sample across all features.
class LayerNorm(nn.Module):
    """Layer Normalization."""
    
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(normalized_shape))
        self.beta = nn.Parameter(torch.zeros(normalized_shape))
    
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        return self.gamma * (x - mean) / torch.sqrt(var + self.eps) + self.beta
Used in: Transformers, RNNs (works with variable batch sizes)

Comparison of Normalization Types

Normalization Types Comparison
TypeNormalize OverBest For
Batch NormBatch, H, WCNNs with large batches
Layer NormC, H, W (per sample)Transformers, RNNs
Instance NormH, W (per channel)Style transfer
Group NormGroups of channelsSmall batches
RMSNormFeatures (no mean)LLMs (faster)

RMSNorm (Modern LLMs)

RMSNorm(x)=xRMS(x)γ,RMS(x)=1nxi2\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma, \quad \text{RMS}(x) = \sqrt{\frac{1}{n}\sum x_i^2}
class RMSNorm(nn.Module):
    """Root Mean Square Normalization (used in LLaMA, etc.)."""
    
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight

When to Use What

ScenarioRecommendation
CNN with batch ≥ 32Batch Norm
CNN with small batchGroup Norm
Transformer/AttentionLayer Norm or RMSNorm
RNN/LSTMLayer Norm
Style TransferInstance Norm
Modern LLMRMSNorm

Exercises

Train the same CNN with and without BatchNorm. Compare learning curves, final accuracy, and sensitivity to learning rate.
Compare BatchNorm vs GroupNorm with batch sizes of 2, 4, 8, 16, 32.
In transformers, compare placing LayerNorm before vs after attention/FFN blocks.

What’s Next