Skip to main content

Documentation Index

Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt

Use this file to discover all available pages before exploring further.

Attention Mechanism

Attention Mechanism

The Bottleneck Problem

In the previous chapters, we built sequence-to-sequence models using encoder-decoder LSTMs. But there’s a fundamental problem: The entire source sequence is compressed into a single fixed-size vector. This is like asking someone to summarize a 300-page novel in a single tweet, and then expecting another person to reconstruct the entire plot from that tweet.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

def visualize_bottleneck():
    """
    The Encoder-Decoder Bottleneck Problem
    
    Source: "The quick brown fox jumps over the lazy dog"

    LSTM Encoder → Single vector (e.g., 512 dimensions)

    LSTM Decoder → "Le rapide renard brun saute par-dessus le chien paresseux"
    
    Problem: ALL information about a 9-word sentence must fit in 512 numbers!
    For longer sentences, this becomes impossible.
    """
    
    print("The Bottleneck Problem:")
    print("=" * 60)
    print()
    print("Source sentence (English):")
    print("  'The quick brown fox jumps over the lazy dog'")
    print()
    print("Compressed to: vector of shape (512,)")
    print()
    print("Then decoded to (French):")
    print("  'Le rapide renard brun saute par-dessus le chien paresseux'")
    print()
    print("❌ Problems:")
    print("  - Information loss for long sequences")
    print("  - Decoder must 'guess' what encoder meant")
    print("  - Early encoder states forgotten by final state")

visualize_bottleneck()
Encoder Bottleneck Problem

Evidence of the Problem

def show_translation_quality_vs_length():
    """
    Real observation: Translation quality degrades with sentence length
    when using fixed-size encoding.
    """
    
    # Approximate BLEU scores from research papers
    lengths = [10, 20, 30, 40, 50, 60]
    bleu_without_attention = [35, 32, 28, 22, 16, 12]
    bleu_with_attention = [38, 37, 36, 35, 34, 33]
    
    plt.figure(figsize=(10, 5))
    plt.plot(lengths, bleu_without_attention, 'o-', label='Without Attention', linewidth=2)
    plt.plot(lengths, bleu_with_attention, 's-', label='With Attention', linewidth=2)
    
    plt.xlabel('Source Sentence Length')
    plt.ylabel('BLEU Score')
    plt.title('Translation Quality vs Sentence Length')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print("Key insight: Attention maintains quality even for long sentences!")

# show_translation_quality_vs_length()

The Attention Solution

Intuition: Looking Back at the Source

Instead of forcing the decoder to use only the final encoder state, let it look at all encoder states and focus on relevant ones: When translating “dog” to “chien”:
  • Look at all encoder states
  • Focus attention on the state corresponding to “dog”
  • Use that information to generate “chien”
Key Insight: Attention computes a weighted combination of all encoder states, where weights indicate relevance to the current decoding step.
def attention_intuition():
    """
    Attention Intuition: The Query-Key-Value Framework
    
    Imagine you're at a library:
    - Query: "I need information about neural networks"
    - Keys: Book titles on the shelves
    - Values: The actual content of each book
    
    Process:
    1. Compare your query with all keys (book titles)
    2. Compute relevance scores
    3. Retrieve values (content) weighted by relevance
    """
    
    print("Attention as Information Retrieval:")
    print("=" * 60)
    print()
    print("Query (what we're looking for):")
    print("  → Current decoder state: 'trying to translate word X'")
    print()
    print("Keys (what we're comparing against):")
    print("  → All encoder hidden states")
    print()
    print("Values (what we retrieve):")
    print("  → Same encoder hidden states (or transformed versions)")
    print()
    print("Output:")
    print("  → Weighted sum of values, weights from query-key similarity")

attention_intuition()

Attention Mechanisms in Detail

Dot-Product Attention

The simplest form of attention: Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Why divide by dk\sqrt{d_k}? This is not an arbitrary choice — it is essential for numerical stability. When dkd_k is large (e.g., 512), each dot product qkq \cdot k is the sum of 512 terms. If entries of qq and kk have unit variance, the dot product has variance dkd_k (by the central limit theorem, the sum of 512 independent unit-variance terms has variance 512). So the raw scores have standard deviation dk22.6\sqrt{d_k} \approx 22.6. Softmax applied to values this large pushes almost all the probability mass onto a single key — the attention becomes a hard argmax, and its gradient effectively vanishes. Dividing by dk\sqrt{d_k} rescales the scores to unit variance, keeping the softmax in its “useful” regime where gradients flow to multiple keys. Without this scaling, training larger models would be dramatically harder. This is one of those small details that separates a paper implementation from a working one.
def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Scaled Dot-Product Attention -- the core computation inside every transformer.
    
    Args:
        query: (batch, num_queries, d_k) -- "what am I looking for?"
        key: (batch, num_keys, d_k) -- "what do I contain?"
        value: (batch, num_keys, d_v) -- "what information do I provide?"
        mask: Optional mask to prevent attention to certain positions
    
    Returns:
        output: (batch, num_queries, d_v) -- weighted combination of values
        attention_weights: (batch, num_queries, num_keys) -- where we looked
    """
    d_k = query.size(-1)
    
    # Step 1: Compute attention scores via dot product
    # High dot product = query and key point in similar directions = "relevant"
    # (batch, num_queries, d_k) @ (batch, d_k, num_keys) -> (batch, num_queries, num_keys)
    scores = torch.matmul(query, key.transpose(-2, -1))
    
    # Step 2: Scale by sqrt(d_k) -- this is critical and often misunderstood
    # Without scaling, dot products grow proportionally to d_k, pushing softmax
    # into saturation where gradients are near-zero. Dividing by sqrt(d_k)
    # keeps the variance of scores at ~1 regardless of dimension.
    scores = scores / np.sqrt(d_k)
    
    # Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    
    # Apply attention weights to values
    output = torch.matmul(attention_weights, value)
    
    return output, attention_weights


# Example
batch_size = 2
num_queries = 4  # e.g., decoder positions
num_keys = 6     # e.g., encoder positions
d_k = 64         # dimension of queries and keys
d_v = 64         # dimension of values

query = torch.randn(batch_size, num_queries, d_k)
key = torch.randn(batch_size, num_keys, d_k)
value = torch.randn(batch_size, num_keys, d_v)

output, weights = scaled_dot_product_attention(query, key, value)

print(f"Query shape:  {query.shape}")
print(f"Key shape:    {key.shape}")
print(f"Value shape:  {value.shape}")
print(f"Output shape: {output.shape}")
print(f"Weights shape: {weights.shape}")
print(f"\nAttention weights sum to 1: {weights.sum(dim=-1)}")

Visualizing Attention Weights

def visualize_attention():
    """Visualize what attention 'looks at'."""
    
    # Simulated attention weights for translation
    source = ["The", "cat", "sat", "on", "the", "mat", "."]
    target = ["Le", "chat", "était", "assis", "sur", "le", "tapis", "."]
    
    # Attention matrix (which source words each target word attends to)
    attention = np.array([
        [0.8, 0.1, 0.0, 0.0, 0.1, 0.0, 0.0],  # Le → The
        [0.1, 0.8, 0.0, 0.0, 0.0, 0.1, 0.0],  # chat → cat
        [0.0, 0.1, 0.7, 0.1, 0.0, 0.0, 0.1],  # était → sat
        [0.0, 0.1, 0.7, 0.1, 0.0, 0.0, 0.1],  # assis → sat
        [0.0, 0.0, 0.0, 0.8, 0.1, 0.0, 0.1],  # sur → on
        [0.1, 0.0, 0.0, 0.0, 0.8, 0.0, 0.1],  # le → the
        [0.0, 0.1, 0.0, 0.0, 0.0, 0.8, 0.1],  # tapis → mat
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],  # . → .
    ])
    
    fig, ax = plt.subplots(figsize=(10, 8))
    im = ax.imshow(attention, cmap='Blues')
    
    ax.set_xticks(np.arange(len(source)))
    ax.set_yticks(np.arange(len(target)))
    ax.set_xticklabels(source, fontsize=12)
    ax.set_yticklabels(target, fontsize=12)
    
    ax.set_xlabel("Source (English)", fontsize=14)
    ax.set_ylabel("Target (French)", fontsize=14)
    ax.set_title("Attention Weights in Translation", fontsize=16)
    
    # Add attention values in cells
    for i in range(len(target)):
        for j in range(len(source)):
            text = ax.text(j, i, f'{attention[i, j]:.1f}',
                          ha='center', va='center', color='black' if attention[i,j] < 0.5 else 'white')
    
    plt.colorbar(im)
    plt.tight_layout()
    plt.show()

visualize_attention()
Attention Weight Visualization

Types of Attention

Additive (Bahdanau) Attention

The original attention mechanism from the 2014 paper: eij=vTtanh(W1hi+W2sj)e_{ij} = v^T \tanh(W_1 h_i + W_2 s_j) αij=exp(eij)kexp(ekj)\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_k \exp(e_{kj})}
class AdditiveAttention(nn.Module):
    """
    Bahdanau (Additive) Attention (2014).
    
    Uses a learned alignment model (a small neural network) to compute
    attention scores. More expressive but slower than dot-product attention
    because it requires an extra matrix multiplication and tanh activation.
    
    Why "additive"? The score adds two projections (W1*encoder + W2*decoder)
    then passes through tanh. This allows non-linear interactions between
    query and key, unlike the dot product which is strictly linear.
    """
    
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        
        # Project encoder states
        self.W1 = nn.Linear(encoder_dim, attention_dim, bias=False)
        
        # Project decoder state
        self.W2 = nn.Linear(decoder_dim, attention_dim, bias=False)
        
        # Compute scalar attention score
        self.v = nn.Linear(attention_dim, 1, bias=False)
    
    def forward(self, encoder_outputs, decoder_hidden):
        """
        Args:
            encoder_outputs: (batch, src_len, encoder_dim)
            decoder_hidden: (batch, decoder_dim)
        
        Returns:
            context: (batch, encoder_dim)
            attention_weights: (batch, src_len)
        """
        # Project encoder outputs
        encoder_proj = self.W1(encoder_outputs)  # (batch, src_len, attention_dim)
        
        # Project decoder hidden (add dimension for broadcasting)
        decoder_proj = self.W2(decoder_hidden).unsqueeze(1)  # (batch, 1, attention_dim)
        
        # Compute alignment scores
        scores = self.v(torch.tanh(encoder_proj + decoder_proj))  # (batch, src_len, 1)
        scores = scores.squeeze(-1)  # (batch, src_len)
        
        # Softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)
        
        # Compute context vector
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        context = context.squeeze(1)  # (batch, encoder_dim)
        
        return context, attention_weights


# Test
attention = AdditiveAttention(encoder_dim=512, decoder_dim=512, attention_dim=256)

encoder_outputs = torch.randn(8, 20, 512)  # 8 sequences, 20 timesteps
decoder_hidden = torch.randn(8, 512)

context, weights = attention(encoder_outputs, decoder_hidden)

print(f"Encoder outputs: {encoder_outputs.shape}")
print(f"Decoder hidden:  {decoder_hidden.shape}")
print(f"Context vector:  {context.shape}")
print(f"Attention weights: {weights.shape}")

Multiplicative (Luong) Attention

Simpler and faster variants:
class MultiplicativeAttention(nn.Module):
    """
    Luong (Multiplicative) Attention variants.
    
    Variants:
    - dot: score = h_s · h_t
    - general: score = h_s · W · h_t
    - concat: score = v · tanh(W · [h_s; h_t])
    """
    
    def __init__(self, encoder_dim, decoder_dim, method='general'):
        super().__init__()
        
        self.method = method
        
        if method == 'general':
            self.W = nn.Linear(decoder_dim, encoder_dim, bias=False)
        elif method == 'concat':
            self.W = nn.Linear(encoder_dim + decoder_dim, decoder_dim, bias=False)
            self.v = nn.Linear(decoder_dim, 1, bias=False)
    
    def forward(self, encoder_outputs, decoder_hidden):
        """
        Args:
            encoder_outputs: (batch, src_len, encoder_dim)
            decoder_hidden: (batch, decoder_dim)
        """
        if self.method == 'dot':
            # Simple dot product
            scores = torch.bmm(encoder_outputs, decoder_hidden.unsqueeze(-1))
            scores = scores.squeeze(-1)
            
        elif self.method == 'general':
            # Transform then dot product
            decoder_proj = self.W(decoder_hidden)  # (batch, encoder_dim)
            scores = torch.bmm(encoder_outputs, decoder_proj.unsqueeze(-1))
            scores = scores.squeeze(-1)
            
        elif self.method == 'concat':
            # Concatenate and project
            src_len = encoder_outputs.size(1)
            decoder_expanded = decoder_hidden.unsqueeze(1).expand(-1, src_len, -1)
            concat = torch.cat([encoder_outputs, decoder_expanded], dim=-1)
            scores = self.v(torch.tanh(self.W(concat))).squeeze(-1)
        
        attention_weights = F.softmax(scores, dim=-1)
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        
        return context, attention_weights


# Compare methods
for method in ['dot', 'general', 'concat']:
    attn = MultiplicativeAttention(512, 512, method=method)
    params = sum(p.numel() for p in attn.parameters())
    print(f"{method:8s}: {params:,} parameters")

Self-Attention

Attending to the Same Sequence

Self-attention allows each position in a sequence to attend to all other positions:
class SelfAttention(nn.Module):
    """
    Self-Attention: Sequence attends to itself.
    
    Each position can gather information from all other positions,
    allowing for direct modeling of dependencies regardless of distance.
    """
    
    def __init__(self, d_model):
        super().__init__()
        
        self.d_model = d_model
        
        # Project to query, key, value
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: Input sequence (batch, seq_len, d_model)
            mask: Optional attention mask
        
        Returns:
            output: (batch, seq_len, d_model)
            attention_weights: (batch, seq_len, seq_len)
        """
        # Project to Q, K, V
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Scaled dot-product attention
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights


# Example: Self-attention on a sentence
self_attn = SelfAttention(d_model=256)

# 4 sentences, 10 words each, 256-dim embeddings
x = torch.randn(4, 10, 256)

output, weights = self_attn(x)

print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print(f"Attention weights: {weights.shape}")
print(f"\nEach word attends to all words (including itself)")

Why Self-Attention Matters

def compare_rnn_vs_self_attention():
    """
    Self-attention vs RNN for capturing dependencies.
    """
    
    print("Self-Attention Advantages over RNNs:")
    print("=" * 60)
    print()
    print("1. PATH LENGTH for long-range dependencies:")
    print("   RNN:  O(n) - information must flow through n steps")
    print("   Attn: O(1) - direct connection between any two positions")
    print()
    print("2. PARALLELIZATION:")
    print("   RNN:  Sequential - must process one step at a time")
    print("   Attn: Fully parallel - all positions computed simultaneously")
    print()
    print("3. INTERPRETABILITY:")
    print("   RNN:  Hidden state is opaque")
    print("   Attn: Attention weights show what the model focuses on")
    print()
    
    # Computational complexity comparison
    n = 1000  # sequence length
    d = 512   # hidden dimension
    
    rnn_complexity = n * d * d  # O(n·d²) sequential steps
    attn_complexity = n * n * d  # O(n²·d) but parallelizable
    
    print(f"For sequence length {n}, hidden dim {d}:")
    print(f"   RNN:  ~{rnn_complexity:,} ops (sequential)")
    print(f"   Attn: ~{attn_complexity:,} ops (parallel)")

compare_rnn_vs_self_attention()

Multi-Head Attention

Multiple Attention “Perspectives”

Instead of one attention function, use multiple “heads” that each learn different relationships: MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O where headi=Attention(QWiQ,KWiK,VWiV)\text{where head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) Why multiple heads instead of one big attention? A single attention head can only compute one weighted average per query position — it assigns a single scalar relevance to each key. But language has multiple simultaneous relationships: “it” relates to “animal” syntactically (coreference), to “tired” semantically (attribute), and to “cross” structurally (subject-verb). Multiple heads let the model maintain all these relationships simultaneously, each head specializing in a different type of dependency. The mathematical cost is negligible: if dmodel=512d_{model} = 512 and you use 8 heads, each head operates in a dk=64d_k = 64 dimensional subspace. The total computation is the same as a single head with dk=512d_k = 512, but you get 8 independent attention patterns instead of 1. The output projection WOW^O then learns how to combine these perspectives. This is one of the best “free lunches” in deep learning architecture design.
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention.
    
    Multiple attention heads allow the model to jointly attend to
    information from different representation subspaces.
    
    Analogy: Imagine a team of analysts reading the same document.
    One analyst focuses on who did what (subject-verb relationships).
    Another tracks pronoun references (coreference).
    A third watches for temporal ordering (before/after).
    Each analyst sees the same text but extracts different relationships.
    
    For example, in "The animal didn't cross the street because it was too tired":
    - One head might focus on "it" -> "animal" (coreference resolution)
    - Another head might focus on "tired" -> "animal" (attribute binding)
    - A third might focus on "cross" -> "street" (action-object pairing)
    
    The multi-head output combines all these perspectives.
    """
    
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: (batch, seq_len_q, d_model)
            key: (batch, seq_len_k, d_model)
            value: (batch, seq_len_v, d_model)
            mask: Optional mask (batch, 1, seq_len_q, seq_len_k) or (batch, 1, 1, seq_len_k)
        
        Returns:
            output: (batch, seq_len_q, d_model)
            attention_weights: (batch, num_heads, seq_len_q, seq_len_k)
        """
        batch_size = query.size(0)
        
        # 1. Linear projections
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)
        
        # 2. Split into multiple heads
        # (batch, seq_len, d_model) → (batch, num_heads, seq_len, d_k)
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 3. Scaled dot-product attention for each head
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        context = torch.matmul(attention_weights, V)
        
        # 4. Concatenate heads
        # (batch, num_heads, seq_len, d_k) → (batch, seq_len, d_model)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # 5. Final linear projection
        output = self.W_o(context)
        
        return output, attention_weights


# Test multi-head attention
mha = MultiHeadAttention(d_model=512, num_heads=8)

# Self-attention: Q=K=V
x = torch.randn(2, 10, 512)
output, weights = mha(x, x, x)

print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print(f"Attention weights: {weights.shape}")
print(f"  → {weights.shape[1]} heads, each with {weights.shape[2]}×{weights.shape[3]} attention matrix")

Visualizing Multi-Head Attention

def visualize_multi_head_attention():
    """
    Visualize what different attention heads learn.
    """
    
    sentence = ["The", "cat", "sat", "on", "the", "mat", ".", "<pad>"]
    n = len(sentence)
    
    # Simulated attention patterns for different heads
    # In practice, these are learned
    
    # Head 1: Focuses on adjacent words
    head1 = np.eye(n) * 0.4 + np.eye(n, k=1) * 0.3 + np.eye(n, k=-1) * 0.3
    head1 = head1 / head1.sum(axis=1, keepdims=True)
    
    # Head 2: Focuses on nouns (positions 1, 5 = cat, mat)
    head2 = np.zeros((n, n))
    for i in range(n):
        head2[i, [1, 5]] = [0.5, 0.5]
    
    # Head 3: Focuses on structure (The...the, cat...mat)
    head3 = np.eye(n) * 0.5
    head3[0, 4] = 0.5
    head3[4, 0] = 0.5
    head3[1, 5] = 0.25
    head3[5, 1] = 0.25
    head3 = head3 / head3.sum(axis=1, keepdims=True)
    
    # Head 4: Focuses on end of sentence
    head4 = np.zeros((n, n))
    head4[:, -2] = 1.0  # Everyone attends to "."
    
    heads = [head1, head2, head3, head4]
    titles = ["Adjacent Words", "Nouns", "Structural", "Punctuation"]
    
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
    for ax, head, title in zip(axes, heads, titles):
        im = ax.imshow(head, cmap='Blues', vmin=0, vmax=1)
        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels(sentence, rotation=45, ha='right')
        ax.set_yticklabels(sentence)
        ax.set_title(f"Head: {title}")
    
    plt.tight_layout()
    plt.show()
    
    print("Different heads learn to attend to different types of relationships!")

visualize_multi_head_attention()
Multi-Head Attention Patterns

Positional Encoding

The Problem: Attention is Permutation-Invariant

Unlike RNNs, self-attention has no inherent notion of position:
def demonstrate_position_invariance():
    """
    Self-attention treats "The cat sat" and "sat cat The" the same way!
    We need to explicitly add position information.
    """
    
    print("Problem: Attention is Position-Blind")
    print("=" * 60)
    print()
    print("Self-attention on 'The cat sat':")
    print("  Q_cat @ K_The → score (same regardless of position)")
    print()
    print("Self-attention on 'sat cat The':")
    print("  Q_cat @ K_The → SAME score!")
    print()
    print("Solution: Add positional encoding to embeddings")
    print("  embedding + position_encoding → position-aware representation")

demonstrate_position_invariance()

Sinusoidal Positional Encoding

The original Transformer uses sinusoidal functions: PE(pos,2i)=sin(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i+1)=cos(pos100002i/dmodel)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)
class SinusoidalPositionalEncoding(nn.Module):
    """
    Sinusoidal Positional Encoding from "Attention Is All You Need".
    
    Properties:
    - Deterministic (no learned parameters)
    - Can extrapolate to longer sequences than seen during training
    - Relative positions can be represented as linear functions
    """
    
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
        )
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        
        # Register as buffer (not a parameter, but saved with model)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, d_model)
        """
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


# Visualize positional encoding
def visualize_positional_encoding():
    d_model = 128
    max_len = 100
    
    pe_layer = SinusoidalPositionalEncoding(d_model, max_len, dropout=0)
    
    # Get positional encodings
    x = torch.zeros(1, max_len, d_model)
    pe = pe_layer.pe[0, :max_len, :].numpy()
    
    plt.figure(figsize=(12, 6))
    plt.imshow(pe.T, aspect='auto', cmap='RdBu')
    plt.xlabel('Position')
    plt.ylabel('Dimension')
    plt.title('Sinusoidal Positional Encoding')
    plt.colorbar(label='Encoding Value')
    plt.show()
    
    # Show specific dimensions
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    
    for i, ax in enumerate(axes.flat):
        dim = i * 20
        ax.plot(pe[:, dim], label=f'dim {dim} (sin)')
        ax.plot(pe[:, dim+1], label=f'dim {dim+1} (cos)')
        ax.set_xlabel('Position')
        ax.set_ylabel('Value')
        ax.legend()
        ax.set_title(f'Dimensions {dim}, {dim+1}')
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

visualize_positional_encoding()
Positional Encoding Visualization

Learned Positional Embeddings

An alternative: learn position embeddings like word embeddings:
class LearnedPositionalEncoding(nn.Module):
    """
    Learned positional embeddings.
    
    Properties:
    - Can learn task-specific position patterns
    - Cannot extrapolate beyond max_len
    - More parameters to learn
    """
    
    def __init__(self, d_model, max_len=512, dropout=0.1):
        super().__init__()
        
        self.position_embeddings = nn.Embedding(max_len, d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Initialize with small values
        nn.init.normal_(self.position_embeddings.weight, std=0.02)
    
    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, d_model)
        """
        seq_len = x.size(1)
        positions = torch.arange(seq_len, device=x.device)
        position_embeddings = self.position_embeddings(positions)
        
        x = x + position_embeddings
        return self.dropout(x)


# Compare parameter counts
sinusoidal = SinusoidalPositionalEncoding(512, 1000)
learned = LearnedPositionalEncoding(512, 1000)

print(f"Sinusoidal PE params: {sum(p.numel() for p in sinusoidal.parameters()):,}")
print(f"Learned PE params:    {sum(p.numel() for p in learned.parameters()):,}")

Attention with Masking

Causal (Autoregressive) Mask

For language models, we need to prevent attending to future positions:
def create_causal_mask(seq_len):
    """
    Create a causal mask to prevent attending to future positions.
    
    For position i, can only attend to positions 0, 1, ..., i
    """
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask


def visualize_causal_mask():
    seq_len = 6
    mask = create_causal_mask(seq_len)
    
    # Convert to attention-friendly format (0 = attend, -inf = don't attend)
    viz_mask = torch.zeros_like(mask)
    viz_mask[mask == float('-inf')] = 1
    
    plt.figure(figsize=(6, 6))
    plt.imshow(viz_mask.numpy(), cmap='Greys')
    plt.xlabel('Key Position (can attend to)')
    plt.ylabel('Query Position')
    plt.title('Causal Mask\n(white = masked, black = can attend)')
    
    for i in range(seq_len):
        for j in range(seq_len):
            color = 'white' if viz_mask[i, j] == 1 else 'black'
            plt.text(j, i, '✓' if viz_mask[i,j] == 0 else '✗', 
                    ha='center', va='center', color=color, fontsize=16)
    
    plt.show()

visualize_causal_mask()


class CausalSelfAttention(nn.Module):
    """Self-attention with causal masking for autoregressive models."""
    
    def __init__(self, d_model, num_heads, max_len=512, dropout=0.1):
        super().__init__()
        
        self.mha = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Register causal mask as buffer
        mask = torch.triu(torch.ones(max_len, max_len), diagonal=1).bool()
        self.register_buffer('mask', mask)
    
    def forward(self, x):
        seq_len = x.size(1)
        
        # Get appropriate size mask and expand for batch and heads
        mask = self.mask[:seq_len, :seq_len]
        mask = mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)
        mask = ~mask  # Invert: True = can attend, False = masked
        
        return self.mha(x, x, x, mask=mask)

Padding Mask

For variable-length sequences with padding:
def create_padding_mask(seq, pad_idx=0):
    """
    Create mask to prevent attending to padding tokens.
    
    Args:
        seq: Token indices (batch, seq_len)
        pad_idx: Index of padding token
    
    Returns:
        mask: (batch, 1, 1, seq_len) - True where valid, False where padding
    """
    mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
    return mask


# Example
batch = torch.tensor([
    [5, 3, 7, 2, 0, 0],  # Padded with 2 zeros
    [8, 1, 4, 6, 9, 0],  # Padded with 1 zero
])

padding_mask = create_padding_mask(batch, pad_idx=0)
print("Sequence:")
print(batch)
print("\nPadding mask:")
print(padding_mask.squeeze())

Complete Attention-Based Seq2Seq

class AttentionSeq2Seq(nn.Module):
    """
    Sequence-to-sequence model with attention.
    
    Unlike vanilla seq2seq, the decoder can look at all encoder states
    at each decoding step, selecting relevant information.
    """
    
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, 
                 num_heads, num_layers, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        
        # Embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        
        # Positional encoding
        self.pos_encoding = SinusoidalPositionalEncoding(d_model, dropout=dropout)
        
        # Encoder (bidirectional LSTM + self-attention)
        self.encoder_lstm = nn.LSTM(
            d_model, d_model // 2, num_layers,
            batch_first=True, bidirectional=True, dropout=dropout
        )
        self.encoder_attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.encoder_norm = nn.LayerNorm(d_model)
        
        # Decoder
        self.decoder_lstm = nn.LSTM(
            d_model, d_model, num_layers,
            batch_first=True, dropout=dropout
        )
        
        # Cross-attention (decoder attends to encoder)
        self.cross_attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_norm = nn.LayerNorm(d_model)
        
        # Output projection
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)
        
        self.dropout = nn.Dropout(dropout)
    
    def encode(self, src, src_mask=None):
        """Encode source sequence."""
        # Embed and add positional encoding
        x = self.src_embedding(src) * np.sqrt(self.d_model)
        x = self.pos_encoding(x)
        
        # LSTM encoding
        x, _ = self.encoder_lstm(x)
        
        # Self-attention
        attn_out, _ = self.encoder_attention(x, x, x, mask=src_mask)
        x = self.encoder_norm(x + self.dropout(attn_out))
        
        return x
    
    def decode(self, tgt, encoder_output, src_mask=None):
        """Decode target sequence."""
        # Embed and add positional encoding
        x = self.tgt_embedding(tgt) * np.sqrt(self.d_model)
        x = self.pos_encoding(x)
        
        # LSTM decoding
        x, _ = self.decoder_lstm(x)
        
        # Cross-attention to encoder output
        attn_out, attention_weights = self.cross_attention(
            x, encoder_output, encoder_output, mask=src_mask
        )
        x = self.cross_norm(x + self.dropout(attn_out))
        
        # Project to vocabulary
        logits = self.output_projection(x)
        
        return logits, attention_weights
    
    def forward(self, src, tgt, src_mask=None):
        encoder_output = self.encode(src, src_mask)
        logits, attention = self.decode(tgt, encoder_output, src_mask)
        return logits, attention


# Test
model = AttentionSeq2Seq(
    src_vocab_size=10000,
    tgt_vocab_size=8000,
    d_model=256,
    num_heads=8,
    num_layers=2
)

src = torch.randint(0, 10000, (4, 30))
tgt = torch.randint(0, 8000, (4, 25))

logits, attention = model(src, tgt)

print(f"Source: {src.shape}")
print(f"Target: {tgt.shape}")
print(f"Output logits: {logits.shape}")
print(f"Attention weights: {attention.shape}")

Exercises

Implement scaled dot-product attention without using any PyTorch attention functions:
  1. Implement the forward pass with proper scaling
  2. Add masking support (padding and causal)
  3. Verify gradients flow correctly
  4. Test on a simple sequence copying task
Using a pre-trained model (or train your own):
  1. Extract attention weights for various inputs
  2. Create visualizations showing what each head attends to
  3. Identify heads with interpretable patterns
  4. Compare attention patterns for different input types
Implement and compare:
  1. Dot-product attention
  2. Additive (Bahdanau) attention
  3. Multiplicative (Luong) attention
Train each on a translation task and compare:
  • Training speed
  • Final BLEU score
  • Attention patterns
Implement relative positional encoding:
  1. Instead of absolute positions, encode relative distances
  2. Modify attention scores to include position bias
  3. Compare with sinusoidal on long sequences
  4. Test extrapolation to longer sequences
Implement a more efficient attention mechanism:
  1. Implement local attention (attend only to nearby positions)
  2. Implement sparse attention patterns
  3. Compare memory usage and speed with full attention
  4. Evaluate impact on model quality

Training Pitfalls and Debugging Hints

Attention weights sum to 1 but that does not mean they are correct. A common mistake is interpreting attention weights as “explanation” for model behavior. Attention weights tell you where the model looked, not why it made its decision. Two models with identical outputs can have completely different attention patterns. Use attention visualizations for debugging intuition, not for mechanistic explanations.Uniform attention is a red flag. If your attention weights look roughly uniform (every position gets about 1/n attention), the model has not learned meaningful attention patterns. Common causes: (1) learning rate too low, (2) embedding quality is poor (the model cannot distinguish queries from keys), (3) the task does not actually require attention (try a simpler baseline).Attention memory scales quadratically with sequence length. For sequence length nn and hh heads, the attention weight matrix consumes O(n2h)O(n^2 \cdot h) memory. At n=4096n = 4096 with 12 heads in float32, that is about 750 MB per layer just for attention scores. If you run out of memory during training, sequence length is almost always the bottleneck — reduce it before reducing batch size.Mask shape mismatches: The most frustrating attention bug. PyTorch broadcasting rules mean a wrong mask shape often produces no error but silently computes the wrong thing. Always verify: padding masks should broadcast across heads and queries, causal masks should be square and lower-triangular, and the combination should use logical AND. Print mask.shape and scores.shape during debugging and verify they broadcast correctly.Multi-head attention head collapse: Sometimes all heads learn nearly identical attention patterns, wasting capacity. This happens with high dropout on attention weights or very small dkd_k (dimension per head). Monitor head diversity by computing the cosine similarity between attention patterns of different heads — if it is consistently above 0.9, you have collapse. Fix: reduce attention dropout, increase dkd_k, or add auxiliary losses that encourage head diversity.

Key Takeaways

ConceptKey Insight
AttentionWeighted combination of values based on query-key similarity
Dot-Productsoftmax(QKT/d)V\text{softmax}(QK^T/\sqrt{d})V — simple and efficient
Multi-HeadMultiple attention functions for different relationship types
Self-AttentionSequence attends to itself — O(1) dependency paths
Positional EncodingAdd position information since attention is permutation-invariant
MaskingCausal for autoregressive, padding for variable lengths
The attention mechanism is the foundation of modern NLP. Understanding it deeply will help you grasp Transformers, BERT, GPT, and virtually all state-of-the-art language models.

What’s Next

Module 11: Transformers

Build the complete Transformer architecture — combine attention with feed-forward networks, layer normalization, and residual connections to create the model that revolutionized NLP.