Skip to main content
Attention Mechanism

Attention Mechanism

The Bottleneck Problem

In the previous chapters, we built sequence-to-sequence models using encoder-decoder LSTMs. But there’s a fundamental problem: The entire source sequence is compressed into a single fixed-size vector.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

def visualize_bottleneck():
    """
    The Encoder-Decoder Bottleneck Problem
    
    Source: "The quick brown fox jumps over the lazy dog"

    LSTM Encoder → Single vector (e.g., 512 dimensions)

    LSTM Decoder → "Le rapide renard brun saute par-dessus le chien paresseux"
    
    Problem: ALL information about a 9-word sentence must fit in 512 numbers!
    For longer sentences, this becomes impossible.
    """
    
    print("The Bottleneck Problem:")
    print("=" * 60)
    print()
    print("Source sentence (English):")
    print("  'The quick brown fox jumps over the lazy dog'")
    print()
    print("Compressed to: vector of shape (512,)")
    print()
    print("Then decoded to (French):")
    print("  'Le rapide renard brun saute par-dessus le chien paresseux'")
    print()
    print("❌ Problems:")
    print("  - Information loss for long sequences")
    print("  - Decoder must 'guess' what encoder meant")
    print("  - Early encoder states forgotten by final state")

visualize_bottleneck()
Encoder Bottleneck Problem

Evidence of the Problem

def show_translation_quality_vs_length():
    """
    Real observation: Translation quality degrades with sentence length
    when using fixed-size encoding.
    """
    
    # Approximate BLEU scores from research papers
    lengths = [10, 20, 30, 40, 50, 60]
    bleu_without_attention = [35, 32, 28, 22, 16, 12]
    bleu_with_attention = [38, 37, 36, 35, 34, 33]
    
    plt.figure(figsize=(10, 5))
    plt.plot(lengths, bleu_without_attention, 'o-', label='Without Attention', linewidth=2)
    plt.plot(lengths, bleu_with_attention, 's-', label='With Attention', linewidth=2)
    
    plt.xlabel('Source Sentence Length')
    plt.ylabel('BLEU Score')
    plt.title('Translation Quality vs Sentence Length')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print("Key insight: Attention maintains quality even for long sentences!")

# show_translation_quality_vs_length()

The Attention Solution

Intuition: Looking Back at the Source

Instead of forcing the decoder to use only the final encoder state, let it look at all encoder states and focus on relevant ones: When translating “dog” to “chien”:
  • Look at all encoder states
  • Focus attention on the state corresponding to “dog”
  • Use that information to generate “chien”
Key Insight: Attention computes a weighted combination of all encoder states, where weights indicate relevance to the current decoding step.
def attention_intuition():
    """
    Attention Intuition: The Query-Key-Value Framework
    
    Imagine you're at a library:
    - Query: "I need information about neural networks"
    - Keys: Book titles on the shelves
    - Values: The actual content of each book
    
    Process:
    1. Compare your query with all keys (book titles)
    2. Compute relevance scores
    3. Retrieve values (content) weighted by relevance
    """
    
    print("Attention as Information Retrieval:")
    print("=" * 60)
    print()
    print("Query (what we're looking for):")
    print("  → Current decoder state: 'trying to translate word X'")
    print()
    print("Keys (what we're comparing against):")
    print("  → All encoder hidden states")
    print()
    print("Values (what we retrieve):")
    print("  → Same encoder hidden states (or transformed versions)")
    print()
    print("Output:")
    print("  → Weighted sum of values, weights from query-key similarity")

attention_intuition()

Attention Mechanisms in Detail

Dot-Product Attention

The simplest form of attention: Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Scaled Dot-Product Attention.
    
    Args:
        query: (batch, num_queries, d_k)
        key: (batch, num_keys, d_k)
        value: (batch, num_keys, d_v)
        mask: Optional mask to prevent attention to certain positions
    
    Returns:
        output: (batch, num_queries, d_v)
        attention_weights: (batch, num_queries, num_keys)
    """
    d_k = query.size(-1)
    
    # Compute attention scores
    # (batch, num_queries, d_k) @ (batch, d_k, num_keys) → (batch, num_queries, num_keys)
    scores = torch.matmul(query, key.transpose(-2, -1))
    
    # Scale by sqrt(d_k) to prevent softmax saturation
    scores = scores / np.sqrt(d_k)
    
    # Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    
    # Apply attention weights to values
    output = torch.matmul(attention_weights, value)
    
    return output, attention_weights


# Example
batch_size = 2
num_queries = 4  # e.g., decoder positions
num_keys = 6     # e.g., encoder positions
d_k = 64         # dimension of queries and keys
d_v = 64         # dimension of values

query = torch.randn(batch_size, num_queries, d_k)
key = torch.randn(batch_size, num_keys, d_k)
value = torch.randn(batch_size, num_keys, d_v)

output, weights = scaled_dot_product_attention(query, key, value)

print(f"Query shape:  {query.shape}")
print(f"Key shape:    {key.shape}")
print(f"Value shape:  {value.shape}")
print(f"Output shape: {output.shape}")
print(f"Weights shape: {weights.shape}")
print(f"\nAttention weights sum to 1: {weights.sum(dim=-1)}")

Visualizing Attention Weights

def visualize_attention():
    """Visualize what attention 'looks at'."""
    
    # Simulated attention weights for translation
    source = ["The", "cat", "sat", "on", "the", "mat", "."]
    target = ["Le", "chat", "était", "assis", "sur", "le", "tapis", "."]
    
    # Attention matrix (which source words each target word attends to)
    attention = np.array([
        [0.8, 0.1, 0.0, 0.0, 0.1, 0.0, 0.0],  # Le → The
        [0.1, 0.8, 0.0, 0.0, 0.0, 0.1, 0.0],  # chat → cat
        [0.0, 0.1, 0.7, 0.1, 0.0, 0.0, 0.1],  # était → sat
        [0.0, 0.1, 0.7, 0.1, 0.0, 0.0, 0.1],  # assis → sat
        [0.0, 0.0, 0.0, 0.8, 0.1, 0.0, 0.1],  # sur → on
        [0.1, 0.0, 0.0, 0.0, 0.8, 0.0, 0.1],  # le → the
        [0.0, 0.1, 0.0, 0.0, 0.0, 0.8, 0.1],  # tapis → mat
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],  # . → .
    ])
    
    fig, ax = plt.subplots(figsize=(10, 8))
    im = ax.imshow(attention, cmap='Blues')
    
    ax.set_xticks(np.arange(len(source)))
    ax.set_yticks(np.arange(len(target)))
    ax.set_xticklabels(source, fontsize=12)
    ax.set_yticklabels(target, fontsize=12)
    
    ax.set_xlabel("Source (English)", fontsize=14)
    ax.set_ylabel("Target (French)", fontsize=14)
    ax.set_title("Attention Weights in Translation", fontsize=16)
    
    # Add attention values in cells
    for i in range(len(target)):
        for j in range(len(source)):
            text = ax.text(j, i, f'{attention[i, j]:.1f}',
                          ha='center', va='center', color='black' if attention[i,j] < 0.5 else 'white')
    
    plt.colorbar(im)
    plt.tight_layout()
    plt.show()

visualize_attention()
Attention Weight Visualization

Types of Attention

Additive (Bahdanau) Attention

The original attention mechanism from the 2014 paper: eij=vTtanh(W1hi+W2sj)e_{ij} = v^T \tanh(W_1 h_i + W_2 s_j) αij=exp(eij)kexp(ekj)\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_k \exp(e_{kj})}
class AdditiveAttention(nn.Module):
    """
    Bahdanau (Additive) Attention.
    
    Uses a learned alignment model to compute attention scores.
    More expressive but slower than dot-product attention.
    """
    
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        
        # Project encoder states
        self.W1 = nn.Linear(encoder_dim, attention_dim, bias=False)
        
        # Project decoder state
        self.W2 = nn.Linear(decoder_dim, attention_dim, bias=False)
        
        # Compute scalar attention score
        self.v = nn.Linear(attention_dim, 1, bias=False)
    
    def forward(self, encoder_outputs, decoder_hidden):
        """
        Args:
            encoder_outputs: (batch, src_len, encoder_dim)
            decoder_hidden: (batch, decoder_dim)
        
        Returns:
            context: (batch, encoder_dim)
            attention_weights: (batch, src_len)
        """
        # Project encoder outputs
        encoder_proj = self.W1(encoder_outputs)  # (batch, src_len, attention_dim)
        
        # Project decoder hidden (add dimension for broadcasting)
        decoder_proj = self.W2(decoder_hidden).unsqueeze(1)  # (batch, 1, attention_dim)
        
        # Compute alignment scores
        scores = self.v(torch.tanh(encoder_proj + decoder_proj))  # (batch, src_len, 1)
        scores = scores.squeeze(-1)  # (batch, src_len)
        
        # Softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)
        
        # Compute context vector
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        context = context.squeeze(1)  # (batch, encoder_dim)
        
        return context, attention_weights


# Test
attention = AdditiveAttention(encoder_dim=512, decoder_dim=512, attention_dim=256)

encoder_outputs = torch.randn(8, 20, 512)  # 8 sequences, 20 timesteps
decoder_hidden = torch.randn(8, 512)

context, weights = attention(encoder_outputs, decoder_hidden)

print(f"Encoder outputs: {encoder_outputs.shape}")
print(f"Decoder hidden:  {decoder_hidden.shape}")
print(f"Context vector:  {context.shape}")
print(f"Attention weights: {weights.shape}")

Multiplicative (Luong) Attention

Simpler and faster variants:
class MultiplicativeAttention(nn.Module):
    """
    Luong (Multiplicative) Attention variants.
    
    Variants:
    - dot: score = h_s · h_t
    - general: score = h_s · W · h_t
    - concat: score = v · tanh(W · [h_s; h_t])
    """
    
    def __init__(self, encoder_dim, decoder_dim, method='general'):
        super().__init__()
        
        self.method = method
        
        if method == 'general':
            self.W = nn.Linear(decoder_dim, encoder_dim, bias=False)
        elif method == 'concat':
            self.W = nn.Linear(encoder_dim + decoder_dim, decoder_dim, bias=False)
            self.v = nn.Linear(decoder_dim, 1, bias=False)
    
    def forward(self, encoder_outputs, decoder_hidden):
        """
        Args:
            encoder_outputs: (batch, src_len, encoder_dim)
            decoder_hidden: (batch, decoder_dim)
        """
        if self.method == 'dot':
            # Simple dot product
            scores = torch.bmm(encoder_outputs, decoder_hidden.unsqueeze(-1))
            scores = scores.squeeze(-1)
            
        elif self.method == 'general':
            # Transform then dot product
            decoder_proj = self.W(decoder_hidden)  # (batch, encoder_dim)
            scores = torch.bmm(encoder_outputs, decoder_proj.unsqueeze(-1))
            scores = scores.squeeze(-1)
            
        elif self.method == 'concat':
            # Concatenate and project
            src_len = encoder_outputs.size(1)
            decoder_expanded = decoder_hidden.unsqueeze(1).expand(-1, src_len, -1)
            concat = torch.cat([encoder_outputs, decoder_expanded], dim=-1)
            scores = self.v(torch.tanh(self.W(concat))).squeeze(-1)
        
        attention_weights = F.softmax(scores, dim=-1)
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        
        return context, attention_weights


# Compare methods
for method in ['dot', 'general', 'concat']:
    attn = MultiplicativeAttention(512, 512, method=method)
    params = sum(p.numel() for p in attn.parameters())
    print(f"{method:8s}: {params:,} parameters")

Self-Attention

Attending to the Same Sequence

Self-attention allows each position in a sequence to attend to all other positions:
class SelfAttention(nn.Module):
    """
    Self-Attention: Sequence attends to itself.
    
    Each position can gather information from all other positions,
    allowing for direct modeling of dependencies regardless of distance.
    """
    
    def __init__(self, d_model):
        super().__init__()
        
        self.d_model = d_model
        
        # Project to query, key, value
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: Input sequence (batch, seq_len, d_model)
            mask: Optional attention mask
        
        Returns:
            output: (batch, seq_len, d_model)
            attention_weights: (batch, seq_len, seq_len)
        """
        # Project to Q, K, V
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Scaled dot-product attention
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights


# Example: Self-attention on a sentence
self_attn = SelfAttention(d_model=256)

# 4 sentences, 10 words each, 256-dim embeddings
x = torch.randn(4, 10, 256)

output, weights = self_attn(x)

print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print(f"Attention weights: {weights.shape}")
print(f"\nEach word attends to all words (including itself)")

Why Self-Attention Matters

def compare_rnn_vs_self_attention():
    """
    Self-attention vs RNN for capturing dependencies.
    """
    
    print("Self-Attention Advantages over RNNs:")
    print("=" * 60)
    print()
    print("1. PATH LENGTH for long-range dependencies:")
    print("   RNN:  O(n) - information must flow through n steps")
    print("   Attn: O(1) - direct connection between any two positions")
    print()
    print("2. PARALLELIZATION:")
    print("   RNN:  Sequential - must process one step at a time")
    print("   Attn: Fully parallel - all positions computed simultaneously")
    print()
    print("3. INTERPRETABILITY:")
    print("   RNN:  Hidden state is opaque")
    print("   Attn: Attention weights show what the model focuses on")
    print()
    
    # Computational complexity comparison
    n = 1000  # sequence length
    d = 512   # hidden dimension
    
    rnn_complexity = n * d * d  # O(n·d²) sequential steps
    attn_complexity = n * n * d  # O(n²·d) but parallelizable
    
    print(f"For sequence length {n}, hidden dim {d}:")
    print(f"   RNN:  ~{rnn_complexity:,} ops (sequential)")
    print(f"   Attn: ~{attn_complexity:,} ops (parallel)")

compare_rnn_vs_self_attention()

Multi-Head Attention

Multiple Attention “Perspectives”

Instead of one attention function, use multiple “heads” that each learn different relationships: MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O where headi=Attention(QWiQ,KWiK,VWiV)\text{where head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention.
    
    Multiple attention heads allow the model to jointly attend to
    information from different representation subspaces.
    
    For example, in "The animal didn't cross the street because it was too tired":
    - One head might focus on "it" → "animal" (coreference)
    - Another head might focus on "tired" → "animal" (attribute)
    """
    
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: (batch, seq_len_q, d_model)
            key: (batch, seq_len_k, d_model)
            value: (batch, seq_len_v, d_model)
            mask: Optional mask (batch, 1, seq_len_q, seq_len_k) or (batch, 1, 1, seq_len_k)
        
        Returns:
            output: (batch, seq_len_q, d_model)
            attention_weights: (batch, num_heads, seq_len_q, seq_len_k)
        """
        batch_size = query.size(0)
        
        # 1. Linear projections
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)
        
        # 2. Split into multiple heads
        # (batch, seq_len, d_model) → (batch, num_heads, seq_len, d_k)
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 3. Scaled dot-product attention for each head
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        context = torch.matmul(attention_weights, V)
        
        # 4. Concatenate heads
        # (batch, num_heads, seq_len, d_k) → (batch, seq_len, d_model)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # 5. Final linear projection
        output = self.W_o(context)
        
        return output, attention_weights


# Test multi-head attention
mha = MultiHeadAttention(d_model=512, num_heads=8)

# Self-attention: Q=K=V
x = torch.randn(2, 10, 512)
output, weights = mha(x, x, x)

print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print(f"Attention weights: {weights.shape}")
print(f"  → {weights.shape[1]} heads, each with {weights.shape[2]}×{weights.shape[3]} attention matrix")

Visualizing Multi-Head Attention

def visualize_multi_head_attention():
    """
    Visualize what different attention heads learn.
    """
    
    sentence = ["The", "cat", "sat", "on", "the", "mat", ".", "<pad>"]
    n = len(sentence)
    
    # Simulated attention patterns for different heads
    # In practice, these are learned
    
    # Head 1: Focuses on adjacent words
    head1 = np.eye(n) * 0.4 + np.eye(n, k=1) * 0.3 + np.eye(n, k=-1) * 0.3
    head1 = head1 / head1.sum(axis=1, keepdims=True)
    
    # Head 2: Focuses on nouns (positions 1, 5 = cat, mat)
    head2 = np.zeros((n, n))
    for i in range(n):
        head2[i, [1, 5]] = [0.5, 0.5]
    
    # Head 3: Focuses on structure (The...the, cat...mat)
    head3 = np.eye(n) * 0.5
    head3[0, 4] = 0.5
    head3[4, 0] = 0.5
    head3[1, 5] = 0.25
    head3[5, 1] = 0.25
    head3 = head3 / head3.sum(axis=1, keepdims=True)
    
    # Head 4: Focuses on end of sentence
    head4 = np.zeros((n, n))
    head4[:, -2] = 1.0  # Everyone attends to "."
    
    heads = [head1, head2, head3, head4]
    titles = ["Adjacent Words", "Nouns", "Structural", "Punctuation"]
    
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
    for ax, head, title in zip(axes, heads, titles):
        im = ax.imshow(head, cmap='Blues', vmin=0, vmax=1)
        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels(sentence, rotation=45, ha='right')
        ax.set_yticklabels(sentence)
        ax.set_title(f"Head: {title}")
    
    plt.tight_layout()
    plt.show()
    
    print("Different heads learn to attend to different types of relationships!")

visualize_multi_head_attention()
Multi-Head Attention Patterns

Positional Encoding

The Problem: Attention is Permutation-Invariant

Unlike RNNs, self-attention has no inherent notion of position:
def demonstrate_position_invariance():
    """
    Self-attention treats "The cat sat" and "sat cat The" the same way!
    We need to explicitly add position information.
    """
    
    print("Problem: Attention is Position-Blind")
    print("=" * 60)
    print()
    print("Self-attention on 'The cat sat':")
    print("  Q_cat @ K_The → score (same regardless of position)")
    print()
    print("Self-attention on 'sat cat The':")
    print("  Q_cat @ K_The → SAME score!")
    print()
    print("Solution: Add positional encoding to embeddings")
    print("  embedding + position_encoding → position-aware representation")

demonstrate_position_invariance()

Sinusoidal Positional Encoding

The original Transformer uses sinusoidal functions: PE(pos,2i)=sin(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i+1)=cos(pos100002i/dmodel)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)
class SinusoidalPositionalEncoding(nn.Module):
    """
    Sinusoidal Positional Encoding from "Attention Is All You Need".
    
    Properties:
    - Deterministic (no learned parameters)
    - Can extrapolate to longer sequences than seen during training
    - Relative positions can be represented as linear functions
    """
    
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
        )
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        
        # Register as buffer (not a parameter, but saved with model)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, d_model)
        """
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


# Visualize positional encoding
def visualize_positional_encoding():
    d_model = 128
    max_len = 100
    
    pe_layer = SinusoidalPositionalEncoding(d_model, max_len, dropout=0)
    
    # Get positional encodings
    x = torch.zeros(1, max_len, d_model)
    pe = pe_layer.pe[0, :max_len, :].numpy()
    
    plt.figure(figsize=(12, 6))
    plt.imshow(pe.T, aspect='auto', cmap='RdBu')
    plt.xlabel('Position')
    plt.ylabel('Dimension')
    plt.title('Sinusoidal Positional Encoding')
    plt.colorbar(label='Encoding Value')
    plt.show()
    
    # Show specific dimensions
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    
    for i, ax in enumerate(axes.flat):
        dim = i * 20
        ax.plot(pe[:, dim], label=f'dim {dim} (sin)')
        ax.plot(pe[:, dim+1], label=f'dim {dim+1} (cos)')
        ax.set_xlabel('Position')
        ax.set_ylabel('Value')
        ax.legend()
        ax.set_title(f'Dimensions {dim}, {dim+1}')
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

visualize_positional_encoding()
Positional Encoding Visualization

Learned Positional Embeddings

An alternative: learn position embeddings like word embeddings:
class LearnedPositionalEncoding(nn.Module):
    """
    Learned positional embeddings.
    
    Properties:
    - Can learn task-specific position patterns
    - Cannot extrapolate beyond max_len
    - More parameters to learn
    """
    
    def __init__(self, d_model, max_len=512, dropout=0.1):
        super().__init__()
        
        self.position_embeddings = nn.Embedding(max_len, d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Initialize with small values
        nn.init.normal_(self.position_embeddings.weight, std=0.02)
    
    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, d_model)
        """
        seq_len = x.size(1)
        positions = torch.arange(seq_len, device=x.device)
        position_embeddings = self.position_embeddings(positions)
        
        x = x + position_embeddings
        return self.dropout(x)


# Compare parameter counts
sinusoidal = SinusoidalPositionalEncoding(512, 1000)
learned = LearnedPositionalEncoding(512, 1000)

print(f"Sinusoidal PE params: {sum(p.numel() for p in sinusoidal.parameters()):,}")
print(f"Learned PE params:    {sum(p.numel() for p in learned.parameters()):,}")

Attention with Masking

Causal (Autoregressive) Mask

For language models, we need to prevent attending to future positions:
def create_causal_mask(seq_len):
    """
    Create a causal mask to prevent attending to future positions.
    
    For position i, can only attend to positions 0, 1, ..., i
    """
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask


def visualize_causal_mask():
    seq_len = 6
    mask = create_causal_mask(seq_len)
    
    # Convert to attention-friendly format (0 = attend, -inf = don't attend)
    viz_mask = torch.zeros_like(mask)
    viz_mask[mask == float('-inf')] = 1
    
    plt.figure(figsize=(6, 6))
    plt.imshow(viz_mask.numpy(), cmap='Greys')
    plt.xlabel('Key Position (can attend to)')
    plt.ylabel('Query Position')
    plt.title('Causal Mask\n(white = masked, black = can attend)')
    
    for i in range(seq_len):
        for j in range(seq_len):
            color = 'white' if viz_mask[i, j] == 1 else 'black'
            plt.text(j, i, '✓' if viz_mask[i,j] == 0 else '✗', 
                    ha='center', va='center', color=color, fontsize=16)
    
    plt.show()

visualize_causal_mask()


class CausalSelfAttention(nn.Module):
    """Self-attention with causal masking for autoregressive models."""
    
    def __init__(self, d_model, num_heads, max_len=512, dropout=0.1):
        super().__init__()
        
        self.mha = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Register causal mask as buffer
        mask = torch.triu(torch.ones(max_len, max_len), diagonal=1).bool()
        self.register_buffer('mask', mask)
    
    def forward(self, x):
        seq_len = x.size(1)
        
        # Get appropriate size mask and expand for batch and heads
        mask = self.mask[:seq_len, :seq_len]
        mask = mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)
        mask = ~mask  # Invert: True = can attend, False = masked
        
        return self.mha(x, x, x, mask=mask)

Padding Mask

For variable-length sequences with padding:
def create_padding_mask(seq, pad_idx=0):
    """
    Create mask to prevent attending to padding tokens.
    
    Args:
        seq: Token indices (batch, seq_len)
        pad_idx: Index of padding token
    
    Returns:
        mask: (batch, 1, 1, seq_len) - True where valid, False where padding
    """
    mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
    return mask


# Example
batch = torch.tensor([
    [5, 3, 7, 2, 0, 0],  # Padded with 2 zeros
    [8, 1, 4, 6, 9, 0],  # Padded with 1 zero
])

padding_mask = create_padding_mask(batch, pad_idx=0)
print("Sequence:")
print(batch)
print("\nPadding mask:")
print(padding_mask.squeeze())

Complete Attention-Based Seq2Seq

class AttentionSeq2Seq(nn.Module):
    """
    Sequence-to-sequence model with attention.
    
    Unlike vanilla seq2seq, the decoder can look at all encoder states
    at each decoding step, selecting relevant information.
    """
    
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, 
                 num_heads, num_layers, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        
        # Embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        
        # Positional encoding
        self.pos_encoding = SinusoidalPositionalEncoding(d_model, dropout=dropout)
        
        # Encoder (bidirectional LSTM + self-attention)
        self.encoder_lstm = nn.LSTM(
            d_model, d_model // 2, num_layers,
            batch_first=True, bidirectional=True, dropout=dropout
        )
        self.encoder_attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.encoder_norm = nn.LayerNorm(d_model)
        
        # Decoder
        self.decoder_lstm = nn.LSTM(
            d_model, d_model, num_layers,
            batch_first=True, dropout=dropout
        )
        
        # Cross-attention (decoder attends to encoder)
        self.cross_attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_norm = nn.LayerNorm(d_model)
        
        # Output projection
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)
        
        self.dropout = nn.Dropout(dropout)
    
    def encode(self, src, src_mask=None):
        """Encode source sequence."""
        # Embed and add positional encoding
        x = self.src_embedding(src) * np.sqrt(self.d_model)
        x = self.pos_encoding(x)
        
        # LSTM encoding
        x, _ = self.encoder_lstm(x)
        
        # Self-attention
        attn_out, _ = self.encoder_attention(x, x, x, mask=src_mask)
        x = self.encoder_norm(x + self.dropout(attn_out))
        
        return x
    
    def decode(self, tgt, encoder_output, src_mask=None):
        """Decode target sequence."""
        # Embed and add positional encoding
        x = self.tgt_embedding(tgt) * np.sqrt(self.d_model)
        x = self.pos_encoding(x)
        
        # LSTM decoding
        x, _ = self.decoder_lstm(x)
        
        # Cross-attention to encoder output
        attn_out, attention_weights = self.cross_attention(
            x, encoder_output, encoder_output, mask=src_mask
        )
        x = self.cross_norm(x + self.dropout(attn_out))
        
        # Project to vocabulary
        logits = self.output_projection(x)
        
        return logits, attention_weights
    
    def forward(self, src, tgt, src_mask=None):
        encoder_output = self.encode(src, src_mask)
        logits, attention = self.decode(tgt, encoder_output, src_mask)
        return logits, attention


# Test
model = AttentionSeq2Seq(
    src_vocab_size=10000,
    tgt_vocab_size=8000,
    d_model=256,
    num_heads=8,
    num_layers=2
)

src = torch.randint(0, 10000, (4, 30))
tgt = torch.randint(0, 8000, (4, 25))

logits, attention = model(src, tgt)

print(f"Source: {src.shape}")
print(f"Target: {tgt.shape}")
print(f"Output logits: {logits.shape}")
print(f"Attention weights: {attention.shape}")

Exercises

Implement scaled dot-product attention without using any PyTorch attention functions:
  1. Implement the forward pass with proper scaling
  2. Add masking support (padding and causal)
  3. Verify gradients flow correctly
  4. Test on a simple sequence copying task
Using a pre-trained model (or train your own):
  1. Extract attention weights for various inputs
  2. Create visualizations showing what each head attends to
  3. Identify heads with interpretable patterns
  4. Compare attention patterns for different input types
Implement and compare:
  1. Dot-product attention
  2. Additive (Bahdanau) attention
  3. Multiplicative (Luong) attention
Train each on a translation task and compare:
  • Training speed
  • Final BLEU score
  • Attention patterns
Implement relative positional encoding:
  1. Instead of absolute positions, encode relative distances
  2. Modify attention scores to include position bias
  3. Compare with sinusoidal on long sequences
  4. Test extrapolation to longer sequences
Implement a more efficient attention mechanism:
  1. Implement local attention (attend only to nearby positions)
  2. Implement sparse attention patterns
  3. Compare memory usage and speed with full attention
  4. Evaluate impact on model quality

Key Takeaways

ConceptKey Insight
AttentionWeighted combination of values based on query-key similarity
Dot-Productsoftmax(QKT/d)V\text{softmax}(QK^T/\sqrt{d})V - simple and efficient
Multi-HeadMultiple attention functions for different relationship types
Self-AttentionSequence attends to itself - O(1) dependency paths
Positional EncodingAdd position information since attention is permutation-invariant
MaskingCausal for autoregressive, padding for variable lengths
The attention mechanism is the foundation of modern NLP. Understanding it deeply will help you grasp Transformers, BERT, GPT, and virtually all state-of-the-art language models.

What’s Next