Skip to main content
Transformer Architecture

Transformers: Attention Is All You Need

The Architecture That Changed Everything

In 2017, the paper “Attention Is All You Need” introduced the Transformer, a model that:
  • Removed RNNs entirely - using only attention mechanisms
  • Enabled massive parallelization - training became much faster
  • Captured long-range dependencies - directly, without information bottlenecks
Today, Transformers power virtually all state-of-the-art NLP models: GPT, BERT, T5, LLaMA, and many more.
The Core Insight: Why use recurrence at all? Self-attention can capture dependencies between any positions in a sequence, regardless of distance. Combined with position encoding, we get all the benefits of sequence modeling without the sequential bottleneck.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math

def transformer_overview():
    """
    The Transformer architecture at a glance.
    
    Two main components:
    1. ENCODER: Processes the input sequence
       - Self-attention + Feed-forward
       - Stacked N times
       
    2. DECODER: Generates the output sequence
       - Masked self-attention (causal)
       - Cross-attention to encoder
       - Feed-forward
       - Stacked N times
    """
    
    print("Transformer Architecture Overview")
    print("=" * 60)
    print()
    print("ENCODER (processes input):")
    print("  Input → Embedding + Positional Encoding")
    print("  → [Self-Attention → Add & Norm → FFN → Add & Norm] × N")
    print("  → Encoder Output")
    print()
    print("DECODER (generates output):")
    print("  Output (shifted) → Embedding + Positional Encoding")
    print("  → [Masked Self-Attn → Add & Norm")
    print("     → Cross-Attn (to Encoder) → Add & Norm")
    print("     → FFN → Add & Norm] × N")
    print("  → Linear → Softmax → Predictions")
    print()
    print("Key innovations:")
    print("  • Multi-head self-attention (parallelizable)")
    print("  • Layer normalization for stability")
    print("  • Residual connections for gradient flow")
    print("  • Positional encoding for sequence order")

transformer_overview()
Complete Transformer Architecture

Building Blocks

Multi-Head Attention (Revisited)

class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention mechanism.
    
    Allows the model to jointly attend to information from
    different representation subspaces at different positions.
    """
    
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.d_k)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # Linear projections and reshape for multi-head
        Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention = self.dropout(F.softmax(scores, dim=-1))
        
        # Apply attention to values
        context = torch.matmul(attention, V)
        
        # Reshape and project
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(context)
        
        return output, attention

Position-wise Feed-Forward Network

A simple two-layer MLP applied to each position independently: FFN(x)=max(0,xW1+b1)W2+b2\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2
class PositionwiseFeedForward(nn.Module):
    """
    Position-wise Feed-Forward Network.
    
    Applied independently to each position (token).
    Acts as a nonlinear transformation of the attention output.
    
    Typically expands dimension by 4x, then projects back.
    """
    
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, d_model)
        
        Returns:
            output: (batch, seq_len, d_model)
        """
        # Expand, apply non-linearity, project back
        x = self.linear1(x)
        x = F.relu(x)  # Original paper uses ReLU
        x = self.dropout(x)
        x = self.linear2(x)
        return x


# Modern variant: GELU activation (used in BERT, GPT)
class PositionwiseFeedForwardGELU(nn.Module):
    """FFN with GELU activation (used in modern transformers)."""
    
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = F.gelu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x

Layer Normalization

class LayerNorm(nn.Module):
    """
    Layer Normalization.
    
    Normalizes across the feature dimension (not batch).
    More stable than batch norm for variable-length sequences.
    
    LayerNorm(x) = γ * (x - μ) / (σ + ε) + β
    """
    
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
    
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta


# Note: PyTorch has built-in nn.LayerNorm which is equivalent
layer_norm = nn.LayerNorm(512)

Positional Encoding

class PositionalEncoding(nn.Module):
    """
    Sinusoidal Positional Encoding.
    
    Adds position information using sin/cos functions at different frequencies.
    Allows the model to learn relative positions.
    """
    
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """Add positional encoding to input embeddings."""
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

The Encoder

class EncoderLayer(nn.Module):
    """
    Single Transformer Encoder Layer.
    
    Structure:
    x → Self-Attention → Add & Norm → FFN → Add & Norm → output
        └───────────────────┘         └─────────────────┘
              (residual)                   (residual)
    """
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: Input (batch, seq_len, d_model)
            mask: Attention mask (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
        """
        # Self-attention with residual connection
        attn_output, _ = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x


class TransformerEncoder(nn.Module):
    """
    Stack of N Encoder Layers.
    """
    
    def __init__(self, vocab_size, d_model, num_heads, d_ff, 
                 num_layers, max_len=5000, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len, dropout)
        
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, src, src_mask=None):
        """
        Args:
            src: Source token indices (batch, src_len)
            src_mask: Mask for padding (batch, 1, 1, src_len)
        
        Returns:
            encoder_output: (batch, src_len, d_model)
        """
        # Embed and add positional encoding
        x = self.embedding(src) * math.sqrt(self.d_model)
        x = self.positional_encoding(x)
        
        # Pass through encoder layers
        for layer in self.layers:
            x = layer(x, src_mask)
        
        return self.norm(x)


# Test encoder
encoder = TransformerEncoder(
    vocab_size=10000,
    d_model=512,
    num_heads=8,
    d_ff=2048,
    num_layers=6
)

src = torch.randint(0, 10000, (2, 30))
output = encoder(src)

print(f"Input: {src.shape}")
print(f"Encoder output: {output.shape}")
print(f"Encoder parameters: {sum(p.numel() for p in encoder.parameters()):,}")

The Decoder

class DecoderLayer(nn.Module):
    """
    Single Transformer Decoder Layer.
    
    Structure:
    x → Masked Self-Attention → Add & Norm 
      → Cross-Attention (to encoder) → Add & Norm 
      → FFN → Add & Norm → output
    """
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        # Masked self-attention (causal)
        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Cross-attention to encoder output
        self.cross_attention = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Feed-forward network
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        
        # Layer normalizations
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: Decoder input (batch, tgt_len, d_model)
            encoder_output: Encoder output (batch, src_len, d_model)
            src_mask: Source padding mask
            tgt_mask: Target causal mask
        """
        # Masked self-attention
        self_attn_output, _ = self.self_attention(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn_output))
        
        # Cross-attention to encoder
        cross_attn_output, attention_weights = self.cross_attention(
            x, encoder_output, encoder_output, src_mask
        )
        x = self.norm2(x + self.dropout(cross_attn_output))
        
        # Feed-forward
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x, attention_weights


class TransformerDecoder(nn.Module):
    """
    Stack of N Decoder Layers.
    """
    
    def __init__(self, vocab_size, d_model, num_heads, d_ff,
                 num_layers, max_len=5000, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len, dropout)
        
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(d_model)
        self.output_projection = nn.Linear(d_model, vocab_size)
    
    def forward(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
        """
        Args:
            tgt: Target token indices (batch, tgt_len)
            encoder_output: (batch, src_len, d_model)
            src_mask: Source padding mask
            tgt_mask: Target causal mask
        """
        # Embed and add positional encoding
        x = self.embedding(tgt) * math.sqrt(self.d_model)
        x = self.positional_encoding(x)
        
        # Pass through decoder layers
        attention_weights = None
        for layer in self.layers:
            x, attention_weights = layer(x, encoder_output, src_mask, tgt_mask)
        
        x = self.norm(x)
        
        # Project to vocabulary
        logits = self.output_projection(x)
        
        return logits, attention_weights

The Complete Transformer

class Transformer(nn.Module):
    """
    Complete Transformer model for sequence-to-sequence tasks.
    
    This is the architecture from "Attention Is All You Need" (2017).
    """
    
    def __init__(
        self,
        src_vocab_size,
        tgt_vocab_size,
        d_model=512,
        num_heads=8,
        d_ff=2048,
        num_encoder_layers=6,
        num_decoder_layers=6,
        max_len=5000,
        dropout=0.1,
        share_embeddings=False
    ):
        super().__init__()
        
        self.d_model = d_model
        
        # Encoder
        self.encoder = TransformerEncoder(
            src_vocab_size, d_model, num_heads, d_ff,
            num_encoder_layers, max_len, dropout
        )
        
        # Decoder
        self.decoder = TransformerDecoder(
            tgt_vocab_size, d_model, num_heads, d_ff,
            num_decoder_layers, max_len, dropout
        )
        
        # Optionally share embeddings between encoder and decoder
        if share_embeddings and src_vocab_size == tgt_vocab_size:
            self.decoder.embedding = self.encoder.embedding
        
        # Initialize parameters
        self._init_parameters()
    
    def _init_parameters(self):
        """Initialize parameters with Xavier uniform."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def make_src_mask(self, src, pad_idx=0):
        """Create source padding mask."""
        src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask
    
    def make_tgt_mask(self, tgt, pad_idx=0):
        """Create target mask (padding + causal)."""
        batch_size, tgt_len = tgt.shape
        
        # Padding mask
        pad_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(2)
        
        # Causal mask
        causal_mask = torch.tril(torch.ones(tgt_len, tgt_len, device=tgt.device)).bool()
        causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
        
        # Combine masks
        tgt_mask = pad_mask & causal_mask
        
        return tgt_mask
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        Forward pass through the Transformer.
        
        Args:
            src: Source sequence (batch, src_len)
            tgt: Target sequence (batch, tgt_len)
            src_mask: Source mask (optional, will be created if not provided)
            tgt_mask: Target mask (optional, will be created if not provided)
        
        Returns:
            logits: Output logits (batch, tgt_len, tgt_vocab_size)
            attention_weights: Cross-attention weights from last decoder layer
        """
        if src_mask is None:
            src_mask = self.make_src_mask(src)
        
        if tgt_mask is None:
            tgt_mask = self.make_tgt_mask(tgt)
        
        # Encode source
        encoder_output = self.encoder(src, src_mask)
        
        # Decode target
        logits, attention_weights = self.decoder(
            tgt, encoder_output, src_mask, tgt_mask
        )
        
        return logits, attention_weights
    
    def generate(self, src, max_len=50, start_token=1, end_token=2):
        """
        Generate output sequence autoregressively.
        
        Args:
            src: Source sequence (batch, src_len)
            max_len: Maximum generation length
            start_token: Start of sequence token index
            end_token: End of sequence token index
        
        Returns:
            generated: Generated token indices (batch, gen_len)
        """
        self.eval()
        batch_size = src.size(0)
        device = src.device
        
        # Encode source once
        src_mask = self.make_src_mask(src)
        encoder_output = self.encoder(src, src_mask)
        
        # Start with <SOS> token
        generated = torch.full((batch_size, 1), start_token, dtype=torch.long, device=device)
        
        with torch.no_grad():
            for _ in range(max_len):
                tgt_mask = self.make_tgt_mask(generated)
                
                logits, _ = self.decoder(generated, encoder_output, src_mask, tgt_mask)
                
                # Get last token prediction
                next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
                
                generated = torch.cat([generated, next_token], dim=1)
                
                # Check if all sequences have generated <EOS>
                if (next_token == end_token).all():
                    break
        
        return generated


# Create a Transformer
transformer = Transformer(
    src_vocab_size=10000,
    tgt_vocab_size=8000,
    d_model=512,
    num_heads=8,
    d_ff=2048,
    num_encoder_layers=6,
    num_decoder_layers=6,
    dropout=0.1
)

# Test forward pass
src = torch.randint(1, 10000, (4, 30))  # Batch of 4, source length 30
tgt = torch.randint(1, 8000, (4, 25))   # Target length 25

logits, attention = transformer(src, tgt)

print(f"Source: {src.shape}")
print(f"Target: {tgt.shape}")
print(f"Output logits: {logits.shape}")
print(f"Cross-attention: {attention.shape}")
print(f"\nTotal parameters: {sum(p.numel() for p in transformer.parameters()):,}")

Training the Transformer

Label Smoothing

class LabelSmoothingLoss(nn.Module):
    """
    Label Smoothing Cross-Entropy Loss.
    
    Instead of hard targets (0 or 1), use soft targets:
    - True class: 1 - smoothing
    - Other classes: smoothing / (num_classes - 1)
    
    This prevents overconfidence and improves generalization.
    """
    
    def __init__(self, vocab_size, smoothing=0.1, ignore_index=-100):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.smoothing = smoothing
        self.ignore_index = ignore_index
        self.confidence = 1.0 - smoothing
    
    def forward(self, logits, target):
        """
        Args:
            logits: (batch, seq_len, vocab_size)
            target: (batch, seq_len)
        """
        logits = logits.reshape(-1, self.vocab_size)
        target = target.reshape(-1)
        
        # Create smoothed distribution
        smooth_target = torch.zeros_like(logits)
        smooth_target.fill_(self.smoothing / (self.vocab_size - 1))
        
        # Set confidence on true class
        mask = target != self.ignore_index
        indices = target[mask]
        smooth_target[mask] = smooth_target[mask].scatter(
            1, indices.unsqueeze(1), self.confidence
        )
        
        # KL divergence (equivalent to cross-entropy with soft targets)
        log_probs = F.log_softmax(logits, dim=-1)
        loss = -(smooth_target * log_probs).sum(dim=-1)
        
        return loss[mask].mean()

Learning Rate Schedule

The original Transformer uses a special learning rate schedule: lr=dmodel0.5min(step0.5,stepwarmup_steps1.5)lr = d_{model}^{-0.5} \cdot \min(step^{-0.5}, step \cdot warmup\_steps^{-1.5})
class TransformerLRScheduler:
    """
    Learning rate scheduler from "Attention Is All You Need".
    
    Increases linearly during warmup, then decreases proportional to
    the inverse square root of the step number.
    """
    
    def __init__(self, optimizer, d_model, warmup_steps=4000):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.step_num = 0
    
    def step(self):
        self.step_num += 1
        lr = self._get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr
    
    def _get_lr(self):
        step = max(1, self.step_num)
        return (self.d_model ** -0.5) * min(
            step ** -0.5,
            step * (self.warmup_steps ** -1.5)
        )


def visualize_lr_schedule():
    """Visualize the Transformer learning rate schedule."""
    
    # Dummy optimizer
    model = nn.Linear(512, 512)
    optimizer = torch.optim.Adam(model.parameters(), lr=1.0)
    
    scheduler = TransformerLRScheduler(optimizer, d_model=512, warmup_steps=4000)
    
    lrs = []
    for _ in range(100000):
        lrs.append(scheduler.step())
    
    plt.figure(figsize=(10, 5))
    plt.plot(lrs)
    plt.xlabel('Step')
    plt.ylabel('Learning Rate')
    plt.title('Transformer Learning Rate Schedule')
    plt.axvline(x=4000, color='r', linestyle='--', label='Warmup ends')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

# visualize_lr_schedule()

Training Loop

def train_transformer():
    """Complete training example for Transformer."""
    
    # Hyperparameters
    src_vocab_size = 10000
    tgt_vocab_size = 8000
    d_model = 256
    num_heads = 8
    d_ff = 1024
    num_layers = 4
    dropout = 0.1
    warmup_steps = 4000
    epochs = 10
    batch_size = 32
    
    # Model
    model = Transformer(
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size,
        d_model=d_model,
        num_heads=num_heads,
        d_ff=d_ff,
        num_encoder_layers=num_layers,
        num_decoder_layers=num_layers,
        dropout=dropout
    )
    
    # Move to device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # Optimizer
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=1.0,  # Will be overridden by scheduler
        betas=(0.9, 0.98),
        eps=1e-9
    )
    
    # Scheduler
    scheduler = TransformerLRScheduler(optimizer, d_model, warmup_steps)
    
    # Loss function
    criterion = LabelSmoothingLoss(tgt_vocab_size, smoothing=0.1)
    
    # Training loop
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        num_batches = 100  # Simulated
        
        for batch_idx in range(num_batches):
            # Generate dummy data (replace with real data loader)
            src = torch.randint(1, src_vocab_size, (batch_size, 30)).to(device)
            tgt = torch.randint(1, tgt_vocab_size, (batch_size, 25)).to(device)
            
            # Shift target for teacher forcing
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]
            
            # Forward pass
            logits, _ = model(src, tgt_input)
            
            # Compute loss
            loss = criterion(logits, tgt_output)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            lr = scheduler.step()
            
            total_loss += loss.item()
            
            if batch_idx % 20 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}, LR: {lr:.6f}")
        
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1} complete. Average Loss: {avg_loss:.4f}")
    
    return model

# model = train_transformer()

Encoder-Only: BERT

BERT uses only the Transformer encoder for bidirectional language understanding:
class BERTEncoder(nn.Module):
    """
    BERT-style encoder.
    
    Key differences from original Transformer encoder:
    - Learned positional embeddings (instead of sinusoidal)
    - [CLS] token for classification
    - Segment embeddings for sentence pairs
    - GELU activation in FFN
    """
    
    def __init__(
        self,
        vocab_size,
        d_model=768,
        num_heads=12,
        d_ff=3072,
        num_layers=12,
        max_len=512,
        dropout=0.1
    ):
        super().__init__()
        
        self.d_model = d_model
        
        # Embeddings
        self.word_embeddings = nn.Embedding(vocab_size, d_model)
        self.position_embeddings = nn.Embedding(max_len, d_model)
        self.segment_embeddings = nn.Embedding(2, d_model)  # For sentence pairs
        
        self.embedding_norm = nn.LayerNorm(d_model)
        self.embedding_dropout = nn.Dropout(dropout)
        
        # Transformer layers
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # Pooler for [CLS] token
        self.pooler = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Tanh()
        )
    
    def forward(self, input_ids, segment_ids=None, attention_mask=None):
        """
        Args:
            input_ids: Token indices (batch, seq_len)
            segment_ids: Segment indices for sentence pairs (batch, seq_len)
            attention_mask: Mask for padding (batch, seq_len)
        
        Returns:
            sequence_output: All token representations (batch, seq_len, d_model)
            pooled_output: [CLS] representation (batch, d_model)
        """
        batch_size, seq_len = input_ids.shape
        device = input_ids.device
        
        # Create position indices
        position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
        
        # Default segment ids
        if segment_ids is None:
            segment_ids = torch.zeros_like(input_ids)
        
        # Compute embeddings
        word_embeds = self.word_embeddings(input_ids)
        position_embeds = self.position_embeddings(position_ids)
        segment_embeds = self.segment_embeddings(segment_ids)
        
        embeddings = word_embeds + position_embeds + segment_embeds
        embeddings = self.embedding_norm(embeddings)
        embeddings = self.embedding_dropout(embeddings)
        
        # Create attention mask
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        
        # Pass through layers
        hidden_states = embeddings
        for layer in self.layers:
            hidden_states = layer(hidden_states, attention_mask)
        
        # Pool [CLS] token
        cls_token = hidden_states[:, 0]
        pooled_output = self.pooler(cls_token)
        
        return hidden_states, pooled_output


# BERT configurations
bert_configs = {
    'bert-base': {'d_model': 768, 'num_heads': 12, 'd_ff': 3072, 'num_layers': 12},
    'bert-large': {'d_model': 1024, 'num_heads': 16, 'd_ff': 4096, 'num_layers': 24},
}

# Create BERT-base
bert = BERTEncoder(vocab_size=30522, **bert_configs['bert-base'])
print(f"BERT-base parameters: {sum(p.numel() for p in bert.parameters()):,}")
BERT Architecture

Decoder-Only: GPT

GPT uses only the Transformer decoder for autoregressive language modeling:
class GPTDecoder(nn.Module):
    """
    GPT-style decoder.
    
    Key differences:
    - Decoder-only (no cross-attention to encoder)
    - Causal masking for autoregressive generation
    - Learned positional embeddings
    """
    
    def __init__(
        self,
        vocab_size,
        d_model=768,
        num_heads=12,
        d_ff=3072,
        num_layers=12,
        max_len=1024,
        dropout=0.1
    ):
        super().__init__()
        
        self.d_model = d_model
        
        # Embeddings
        self.token_embeddings = nn.Embedding(vocab_size, d_model)
        self.position_embeddings = nn.Embedding(max_len, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
        # Transformer layers (decoder without cross-attention)
        self.layers = nn.ModuleList([
            GPTDecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(d_model)
        
        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size, bias=False)
        
        # Tie embedding weights
        self.output_projection.weight = self.token_embeddings.weight
        
        # Cache for causal mask
        self.register_buffer(
            'causal_mask',
            torch.tril(torch.ones(max_len, max_len)).bool()
        )
    
    def forward(self, input_ids, past_key_values=None):
        """
        Args:
            input_ids: Token indices (batch, seq_len)
            past_key_values: Cached key/values for efficient generation
        """
        batch_size, seq_len = input_ids.shape
        device = input_ids.device
        
        # Get position indices (account for past tokens)
        past_len = 0 if past_key_values is None else past_key_values[0][0].size(2)
        position_ids = torch.arange(past_len, past_len + seq_len, device=device)
        position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
        
        # Embeddings
        hidden_states = self.token_embeddings(input_ids) + self.position_embeddings(position_ids)
        hidden_states = self.dropout(hidden_states)
        
        # Causal mask
        causal_mask = self.causal_mask[past_len:past_len+seq_len, :past_len+seq_len]
        causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
        
        # Process through layers
        for i, layer in enumerate(self.layers):
            past = past_key_values[i] if past_key_values is not None else None
            hidden_states = layer(hidden_states, causal_mask, past)
        
        hidden_states = self.norm(hidden_states)
        
        # Project to vocabulary
        logits = self.output_projection(hidden_states)
        
        return logits


class GPTDecoderLayer(nn.Module):
    """GPT decoder layer (no cross-attention)."""
    
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        
        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForwardGELU(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None, past_key_value=None):
        # Self-attention
        attn_output, _ = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x


# GPT configurations
gpt_configs = {
    'gpt2': {'d_model': 768, 'num_heads': 12, 'd_ff': 3072, 'num_layers': 12},
    'gpt2-medium': {'d_model': 1024, 'num_heads': 16, 'd_ff': 4096, 'num_layers': 24},
    'gpt2-large': {'d_model': 1280, 'num_heads': 20, 'd_ff': 5120, 'num_layers': 36},
    'gpt2-xl': {'d_model': 1600, 'num_heads': 25, 'd_ff': 6400, 'num_layers': 48},
}

gpt = GPTDecoder(vocab_size=50257, **gpt_configs['gpt2'])
print(f"GPT-2 parameters: {sum(p.numel() for p in gpt.parameters()):,}")
GPT Architecture

BERT vs GPT Comparison

def compare_bert_gpt():
    """Compare BERT and GPT architectures."""
    
    comparison = """
    ┌─────────────────────────────────────────────────────────────────┐
    │                    BERT vs GPT Comparison                        │
    ├─────────────────────────────────────────────────────────────────┤
    │ Aspect              │ BERT                │ GPT                  │
    ├─────────────────────────────────────────────────────────────────┤
    │ Architecture        │ Encoder-only        │ Decoder-only         │
    │ Attention           │ Bidirectional       │ Causal (left-to-right)│
    │ Pre-training        │ MLM + NSP           │ Language Modeling    │
    │ Best for            │ Understanding       │ Generation           │
    │ Example tasks       │ Classification,     │ Text generation,     │
    │                     │ NER, QA             │ Summarization        │
    │ Context             │ Sees all tokens     │ Only sees past tokens│
    └─────────────────────────────────────────────────────────────────┘
    
    BERT Pre-training:
    - Masked Language Modeling: Predict [MASK] tokens
    - Next Sentence Prediction: Is sentence B after A?
    
    GPT Pre-training:
    - Autoregressive LM: Predict next token given previous
    - P(token | all previous tokens)
    """
    
    print(comparison)

compare_bert_gpt()
ModelParametersTraining DataRelease
BERT-base110M16GB text2018
GPT-21.5B40GB text2019
GPT-3175B570GB text2020
GPT-4~1T (est.)Unknown2023
LLaMA-27B-70B2T tokens2023

Modern Transformer Improvements

Pre-Norm vs Post-Norm

class PreNormTransformerLayer(nn.Module):
    """
    Pre-normalization: Apply LayerNorm before attention/FFN.
    
    Improves training stability for deep models.
    Used in GPT-2/3, LLaMA, and many modern transformers.
    """
    
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Pre-norm: normalize before attention
        normed = self.norm1(x)
        attn_out, _ = self.self_attention(normed, normed, normed, mask)
        x = x + self.dropout(attn_out)
        
        # Pre-norm: normalize before FFN
        normed = self.norm2(x)
        ff_out = self.feed_forward(normed)
        x = x + self.dropout(ff_out)
        
        return x

RMSNorm

class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization.
    
    Simpler than LayerNorm: only normalizes by RMS, no mean subtraction.
    Used in LLaMA and other efficient transformers.
    """
    
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d_model))
        self.eps = eps
    
    def forward(self, x):
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight

Rotary Position Embedding (RoPE)

class RotaryPositionEmbedding(nn.Module):
    """
    Rotary Position Embedding (RoPE).
    
    Encodes position by rotating query/key vectors.
    Enables relative position understanding and better extrapolation.
    Used in LLaMA, GPT-NeoX, and many modern models.
    """
    
    def __init__(self, d_model, max_len=2048):
        super().__init__()
        
        inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
        self.register_buffer('inv_freq', inv_freq)
        
        # Precompute rotation matrices
        t = torch.arange(max_len)
        freqs = torch.outer(t, inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer('cos', emb.cos())
        self.register_buffer('sin', emb.sin())
    
    def forward(self, x, position_ids):
        """Apply rotary embeddings to queries and keys."""
        # x: (batch, num_heads, seq_len, d_k)
        cos = self.cos[position_ids].unsqueeze(1)
        sin = self.sin[position_ids].unsqueeze(1)
        
        # Rotate
        x1, x2 = x[..., :x.size(-1)//2], x[..., x.size(-1)//2:]
        rotated = torch.cat([-x2, x1], dim=-1)
        
        return x * cos + rotated * sin

SwiGLU Activation

class SwiGLU(nn.Module):
    """
    SwiGLU activation function.
    
    Combines Swish activation with gated linear unit.
    Used in PaLM, LLaMA, and other modern transformers.
    
    SwiGLU(x) = Swish(xW1) ⊙ (xV)
    """
    
    def __init__(self, d_model, d_ff):
        super().__init__()
        
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)  # Gate
    
    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

Exercises

Implement the complete Transformer without looking at the code above:
  1. Multi-head attention
  2. Position-wise feed-forward
  3. Encoder and decoder layers
  4. Full encoder-decoder model
Test on a simple copy task: input “ABC” → output “ABC”
Implement BERT’s pre-training objectives:
  1. Masked Language Modeling (MLM)
    • Randomly mask 15% of tokens
    • 80% [MASK], 10% random, 10% unchanged
  2. Next Sentence Prediction (NSP)
Train on a small corpus and visualize the attention patterns.
Build a small GPT model for text generation:
  1. Implement causal masking
  2. Train on a small text corpus
  3. Implement nucleus (top-p) sampling
  4. Generate text and analyze quality
Implement efficient attention variants:
  1. Linear attention
  2. Sliding window attention
  3. Flash attention (conceptually)
Compare memory usage and speed on long sequences.
Fine-tune a pre-trained transformer for text classification:
  1. Load a pre-trained model (HuggingFace)
  2. Add a classification head
  3. Fine-tune on IMDB or AG News
  4. Analyze attention patterns for interpretability

Key Takeaways

ConceptKey Insight
Self-AttentionEnables O(1) dependency paths, full parallelization
Multi-HeadMultiple attention patterns for different relationships
EncoderBidirectional understanding of input
DecoderCausal generation, one token at a time
Positional EncodingInjects sequence order information
Layer NormalizationStabilizes training, enables deep models
BERTEncoder-only, bidirectional, for understanding
GPTDecoder-only, causal, for generation
The Transformer is the foundation of modern AI. Every major language model (GPT-4, Claude, LLaMA, Gemini) is built on this architecture. Understanding it deeply is essential for anyone working in AI.

What’s Next

Congratulations! You’ve completed the core architecture modules of the Deep Learning Mastery course. You now understand:
  • Neural network fundamentals (perceptrons, backprop, activations, loss functions)
  • Convolutional networks for images
  • Recurrent networks for sequences
  • Attention and Transformers for everything