Attention Mechanism
The Bottleneck Problem
In the previous chapters, we built sequence-to-sequence models using encoder-decoder LSTMs. But there’s a fundamental problem: The entire source sequence is compressed into a single fixed-size vector.Copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
def visualize_bottleneck():
"""
The Encoder-Decoder Bottleneck Problem
Source: "The quick brown fox jumps over the lazy dog"
↓
LSTM Encoder → Single vector (e.g., 512 dimensions)
↓
LSTM Decoder → "Le rapide renard brun saute par-dessus le chien paresseux"
Problem: ALL information about a 9-word sentence must fit in 512 numbers!
For longer sentences, this becomes impossible.
"""
print("The Bottleneck Problem:")
print("=" * 60)
print()
print("Source sentence (English):")
print(" 'The quick brown fox jumps over the lazy dog'")
print()
print("Compressed to: vector of shape (512,)")
print()
print("Then decoded to (French):")
print(" 'Le rapide renard brun saute par-dessus le chien paresseux'")
print()
print("❌ Problems:")
print(" - Information loss for long sequences")
print(" - Decoder must 'guess' what encoder meant")
print(" - Early encoder states forgotten by final state")
visualize_bottleneck()
Evidence of the Problem
Copy
def show_translation_quality_vs_length():
"""
Real observation: Translation quality degrades with sentence length
when using fixed-size encoding.
"""
# Approximate BLEU scores from research papers
lengths = [10, 20, 30, 40, 50, 60]
bleu_without_attention = [35, 32, 28, 22, 16, 12]
bleu_with_attention = [38, 37, 36, 35, 34, 33]
plt.figure(figsize=(10, 5))
plt.plot(lengths, bleu_without_attention, 'o-', label='Without Attention', linewidth=2)
plt.plot(lengths, bleu_with_attention, 's-', label='With Attention', linewidth=2)
plt.xlabel('Source Sentence Length')
plt.ylabel('BLEU Score')
plt.title('Translation Quality vs Sentence Length')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
print("Key insight: Attention maintains quality even for long sentences!")
# show_translation_quality_vs_length()
The Attention Solution
Intuition: Looking Back at the Source
Instead of forcing the decoder to use only the final encoder state, let it look at all encoder states and focus on relevant ones: When translating “dog” to “chien”:- Look at all encoder states
- Focus attention on the state corresponding to “dog”
- Use that information to generate “chien”
Key Insight: Attention computes a weighted combination of all encoder states, where weights indicate relevance to the current decoding step.
Copy
def attention_intuition():
"""
Attention Intuition: The Query-Key-Value Framework
Imagine you're at a library:
- Query: "I need information about neural networks"
- Keys: Book titles on the shelves
- Values: The actual content of each book
Process:
1. Compare your query with all keys (book titles)
2. Compute relevance scores
3. Retrieve values (content) weighted by relevance
"""
print("Attention as Information Retrieval:")
print("=" * 60)
print()
print("Query (what we're looking for):")
print(" → Current decoder state: 'trying to translate word X'")
print()
print("Keys (what we're comparing against):")
print(" → All encoder hidden states")
print()
print("Values (what we retrieve):")
print(" → Same encoder hidden states (or transformed versions)")
print()
print("Output:")
print(" → Weighted sum of values, weights from query-key similarity")
attention_intuition()
Attention Mechanisms in Detail
Dot-Product Attention
The simplest form of attention: Attention(Q,K,V)=softmax(dkQKT)VCopy
def scaled_dot_product_attention(query, key, value, mask=None):
"""
Scaled Dot-Product Attention.
Args:
query: (batch, num_queries, d_k)
key: (batch, num_keys, d_k)
value: (batch, num_keys, d_v)
mask: Optional mask to prevent attention to certain positions
Returns:
output: (batch, num_queries, d_v)
attention_weights: (batch, num_queries, num_keys)
"""
d_k = query.size(-1)
# Compute attention scores
# (batch, num_queries, d_k) @ (batch, d_k, num_keys) → (batch, num_queries, num_keys)
scores = torch.matmul(query, key.transpose(-2, -1))
# Scale by sqrt(d_k) to prevent softmax saturation
scores = scores / np.sqrt(d_k)
# Apply mask if provided
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1)
# Apply attention weights to values
output = torch.matmul(attention_weights, value)
return output, attention_weights
# Example
batch_size = 2
num_queries = 4 # e.g., decoder positions
num_keys = 6 # e.g., encoder positions
d_k = 64 # dimension of queries and keys
d_v = 64 # dimension of values
query = torch.randn(batch_size, num_queries, d_k)
key = torch.randn(batch_size, num_keys, d_k)
value = torch.randn(batch_size, num_keys, d_v)
output, weights = scaled_dot_product_attention(query, key, value)
print(f"Query shape: {query.shape}")
print(f"Key shape: {key.shape}")
print(f"Value shape: {value.shape}")
print(f"Output shape: {output.shape}")
print(f"Weights shape: {weights.shape}")
print(f"\nAttention weights sum to 1: {weights.sum(dim=-1)}")
Visualizing Attention Weights
Copy
def visualize_attention():
"""Visualize what attention 'looks at'."""
# Simulated attention weights for translation
source = ["The", "cat", "sat", "on", "the", "mat", "."]
target = ["Le", "chat", "était", "assis", "sur", "le", "tapis", "."]
# Attention matrix (which source words each target word attends to)
attention = np.array([
[0.8, 0.1, 0.0, 0.0, 0.1, 0.0, 0.0], # Le → The
[0.1, 0.8, 0.0, 0.0, 0.0, 0.1, 0.0], # chat → cat
[0.0, 0.1, 0.7, 0.1, 0.0, 0.0, 0.1], # était → sat
[0.0, 0.1, 0.7, 0.1, 0.0, 0.0, 0.1], # assis → sat
[0.0, 0.0, 0.0, 0.8, 0.1, 0.0, 0.1], # sur → on
[0.1, 0.0, 0.0, 0.0, 0.8, 0.0, 0.1], # le → the
[0.0, 0.1, 0.0, 0.0, 0.0, 0.8, 0.1], # tapis → mat
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # . → .
])
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(attention, cmap='Blues')
ax.set_xticks(np.arange(len(source)))
ax.set_yticks(np.arange(len(target)))
ax.set_xticklabels(source, fontsize=12)
ax.set_yticklabels(target, fontsize=12)
ax.set_xlabel("Source (English)", fontsize=14)
ax.set_ylabel("Target (French)", fontsize=14)
ax.set_title("Attention Weights in Translation", fontsize=16)
# Add attention values in cells
for i in range(len(target)):
for j in range(len(source)):
text = ax.text(j, i, f'{attention[i, j]:.1f}',
ha='center', va='center', color='black' if attention[i,j] < 0.5 else 'white')
plt.colorbar(im)
plt.tight_layout()
plt.show()
visualize_attention()
Types of Attention
Additive (Bahdanau) Attention
The original attention mechanism from the 2014 paper: eij=vTtanh(W1hi+W2sj) αij=∑kexp(ekj)exp(eij)Copy
class AdditiveAttention(nn.Module):
"""
Bahdanau (Additive) Attention.
Uses a learned alignment model to compute attention scores.
More expressive but slower than dot-product attention.
"""
def __init__(self, encoder_dim, decoder_dim, attention_dim):
super().__init__()
# Project encoder states
self.W1 = nn.Linear(encoder_dim, attention_dim, bias=False)
# Project decoder state
self.W2 = nn.Linear(decoder_dim, attention_dim, bias=False)
# Compute scalar attention score
self.v = nn.Linear(attention_dim, 1, bias=False)
def forward(self, encoder_outputs, decoder_hidden):
"""
Args:
encoder_outputs: (batch, src_len, encoder_dim)
decoder_hidden: (batch, decoder_dim)
Returns:
context: (batch, encoder_dim)
attention_weights: (batch, src_len)
"""
# Project encoder outputs
encoder_proj = self.W1(encoder_outputs) # (batch, src_len, attention_dim)
# Project decoder hidden (add dimension for broadcasting)
decoder_proj = self.W2(decoder_hidden).unsqueeze(1) # (batch, 1, attention_dim)
# Compute alignment scores
scores = self.v(torch.tanh(encoder_proj + decoder_proj)) # (batch, src_len, 1)
scores = scores.squeeze(-1) # (batch, src_len)
# Softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1)
# Compute context vector
context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
context = context.squeeze(1) # (batch, encoder_dim)
return context, attention_weights
# Test
attention = AdditiveAttention(encoder_dim=512, decoder_dim=512, attention_dim=256)
encoder_outputs = torch.randn(8, 20, 512) # 8 sequences, 20 timesteps
decoder_hidden = torch.randn(8, 512)
context, weights = attention(encoder_outputs, decoder_hidden)
print(f"Encoder outputs: {encoder_outputs.shape}")
print(f"Decoder hidden: {decoder_hidden.shape}")
print(f"Context vector: {context.shape}")
print(f"Attention weights: {weights.shape}")
Multiplicative (Luong) Attention
Simpler and faster variants:Copy
class MultiplicativeAttention(nn.Module):
"""
Luong (Multiplicative) Attention variants.
Variants:
- dot: score = h_s · h_t
- general: score = h_s · W · h_t
- concat: score = v · tanh(W · [h_s; h_t])
"""
def __init__(self, encoder_dim, decoder_dim, method='general'):
super().__init__()
self.method = method
if method == 'general':
self.W = nn.Linear(decoder_dim, encoder_dim, bias=False)
elif method == 'concat':
self.W = nn.Linear(encoder_dim + decoder_dim, decoder_dim, bias=False)
self.v = nn.Linear(decoder_dim, 1, bias=False)
def forward(self, encoder_outputs, decoder_hidden):
"""
Args:
encoder_outputs: (batch, src_len, encoder_dim)
decoder_hidden: (batch, decoder_dim)
"""
if self.method == 'dot':
# Simple dot product
scores = torch.bmm(encoder_outputs, decoder_hidden.unsqueeze(-1))
scores = scores.squeeze(-1)
elif self.method == 'general':
# Transform then dot product
decoder_proj = self.W(decoder_hidden) # (batch, encoder_dim)
scores = torch.bmm(encoder_outputs, decoder_proj.unsqueeze(-1))
scores = scores.squeeze(-1)
elif self.method == 'concat':
# Concatenate and project
src_len = encoder_outputs.size(1)
decoder_expanded = decoder_hidden.unsqueeze(1).expand(-1, src_len, -1)
concat = torch.cat([encoder_outputs, decoder_expanded], dim=-1)
scores = self.v(torch.tanh(self.W(concat))).squeeze(-1)
attention_weights = F.softmax(scores, dim=-1)
context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)
return context, attention_weights
# Compare methods
for method in ['dot', 'general', 'concat']:
attn = MultiplicativeAttention(512, 512, method=method)
params = sum(p.numel() for p in attn.parameters())
print(f"{method:8s}: {params:,} parameters")
Self-Attention
Attending to the Same Sequence
Self-attention allows each position in a sequence to attend to all other positions:Copy
class SelfAttention(nn.Module):
"""
Self-Attention: Sequence attends to itself.
Each position can gather information from all other positions,
allowing for direct modeling of dependencies regardless of distance.
"""
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
# Project to query, key, value
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
"""
Args:
x: Input sequence (batch, seq_len, d_model)
mask: Optional attention mask
Returns:
output: (batch, seq_len, d_model)
attention_weights: (batch, seq_len, seq_len)
"""
# Project to Q, K, V
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
# Scaled dot-product attention
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output, attention_weights
# Example: Self-attention on a sentence
self_attn = SelfAttention(d_model=256)
# 4 sentences, 10 words each, 256-dim embeddings
x = torch.randn(4, 10, 256)
output, weights = self_attn(x)
print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print(f"Attention weights: {weights.shape}")
print(f"\nEach word attends to all words (including itself)")
Why Self-Attention Matters
Copy
def compare_rnn_vs_self_attention():
"""
Self-attention vs RNN for capturing dependencies.
"""
print("Self-Attention Advantages over RNNs:")
print("=" * 60)
print()
print("1. PATH LENGTH for long-range dependencies:")
print(" RNN: O(n) - information must flow through n steps")
print(" Attn: O(1) - direct connection between any two positions")
print()
print("2. PARALLELIZATION:")
print(" RNN: Sequential - must process one step at a time")
print(" Attn: Fully parallel - all positions computed simultaneously")
print()
print("3. INTERPRETABILITY:")
print(" RNN: Hidden state is opaque")
print(" Attn: Attention weights show what the model focuses on")
print()
# Computational complexity comparison
n = 1000 # sequence length
d = 512 # hidden dimension
rnn_complexity = n * d * d # O(n·d²) sequential steps
attn_complexity = n * n * d # O(n²·d) but parallelizable
print(f"For sequence length {n}, hidden dim {d}:")
print(f" RNN: ~{rnn_complexity:,} ops (sequential)")
print(f" Attn: ~{attn_complexity:,} ops (parallel)")
compare_rnn_vs_self_attention()
Multi-Head Attention
Multiple Attention “Perspectives”
Instead of one attention function, use multiple “heads” that each learn different relationships: MultiHead(Q,K,V)=Concat(head1,...,headh)WO where headi=Attention(QWiQ,KWiK,VWiV)Copy
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention.
Multiple attention heads allow the model to jointly attend to
information from different representation subspaces.
For example, in "The animal didn't cross the street because it was too tired":
- One head might focus on "it" → "animal" (coreference)
- Another head might focus on "tired" → "animal" (attribute)
"""
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Linear projections for Q, K, V
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# Output projection
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
"""
Args:
query: (batch, seq_len_q, d_model)
key: (batch, seq_len_k, d_model)
value: (batch, seq_len_v, d_model)
mask: Optional mask (batch, 1, seq_len_q, seq_len_k) or (batch, 1, 1, seq_len_k)
Returns:
output: (batch, seq_len_q, d_model)
attention_weights: (batch, num_heads, seq_len_q, seq_len_k)
"""
batch_size = query.size(0)
# 1. Linear projections
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
# 2. Split into multiple heads
# (batch, seq_len, d_model) → (batch, num_heads, seq_len, d_k)
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 3. Scaled dot-product attention for each head
scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
context = torch.matmul(attention_weights, V)
# 4. Concatenate heads
# (batch, num_heads, seq_len, d_k) → (batch, seq_len, d_model)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 5. Final linear projection
output = self.W_o(context)
return output, attention_weights
# Test multi-head attention
mha = MultiHeadAttention(d_model=512, num_heads=8)
# Self-attention: Q=K=V
x = torch.randn(2, 10, 512)
output, weights = mha(x, x, x)
print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print(f"Attention weights: {weights.shape}")
print(f" → {weights.shape[1]} heads, each with {weights.shape[2]}×{weights.shape[3]} attention matrix")
Visualizing Multi-Head Attention
Copy
def visualize_multi_head_attention():
"""
Visualize what different attention heads learn.
"""
sentence = ["The", "cat", "sat", "on", "the", "mat", ".", "<pad>"]
n = len(sentence)
# Simulated attention patterns for different heads
# In practice, these are learned
# Head 1: Focuses on adjacent words
head1 = np.eye(n) * 0.4 + np.eye(n, k=1) * 0.3 + np.eye(n, k=-1) * 0.3
head1 = head1 / head1.sum(axis=1, keepdims=True)
# Head 2: Focuses on nouns (positions 1, 5 = cat, mat)
head2 = np.zeros((n, n))
for i in range(n):
head2[i, [1, 5]] = [0.5, 0.5]
# Head 3: Focuses on structure (The...the, cat...mat)
head3 = np.eye(n) * 0.5
head3[0, 4] = 0.5
head3[4, 0] = 0.5
head3[1, 5] = 0.25
head3[5, 1] = 0.25
head3 = head3 / head3.sum(axis=1, keepdims=True)
# Head 4: Focuses on end of sentence
head4 = np.zeros((n, n))
head4[:, -2] = 1.0 # Everyone attends to "."
heads = [head1, head2, head3, head4]
titles = ["Adjacent Words", "Nouns", "Structural", "Punctuation"]
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for ax, head, title in zip(axes, heads, titles):
im = ax.imshow(head, cmap='Blues', vmin=0, vmax=1)
ax.set_xticks(range(n))
ax.set_yticks(range(n))
ax.set_xticklabels(sentence, rotation=45, ha='right')
ax.set_yticklabels(sentence)
ax.set_title(f"Head: {title}")
plt.tight_layout()
plt.show()
print("Different heads learn to attend to different types of relationships!")
visualize_multi_head_attention()
Positional Encoding
The Problem: Attention is Permutation-Invariant
Unlike RNNs, self-attention has no inherent notion of position:Copy
def demonstrate_position_invariance():
"""
Self-attention treats "The cat sat" and "sat cat The" the same way!
We need to explicitly add position information.
"""
print("Problem: Attention is Position-Blind")
print("=" * 60)
print()
print("Self-attention on 'The cat sat':")
print(" Q_cat @ K_The → score (same regardless of position)")
print()
print("Self-attention on 'sat cat The':")
print(" Q_cat @ K_The → SAME score!")
print()
print("Solution: Add positional encoding to embeddings")
print(" embedding + position_encoding → position-aware representation")
demonstrate_position_invariance()
Sinusoidal Positional Encoding
The original Transformer uses sinusoidal functions: PE(pos,2i)=sin(100002i/dmodelpos) PE(pos,2i+1)=cos(100002i/dmodelpos)Copy
class SinusoidalPositionalEncoding(nn.Module):
"""
Sinusoidal Positional Encoding from "Attention Is All You Need".
Properties:
- Deterministic (no learned parameters)
- Can extrapolate to longer sequences than seen during training
- Relative positions can be represented as linear functions
"""
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Create positional encoding matrix
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
# Register as buffer (not a parameter, but saved with model)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: (batch, seq_len, d_model)
"""
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
# Visualize positional encoding
def visualize_positional_encoding():
d_model = 128
max_len = 100
pe_layer = SinusoidalPositionalEncoding(d_model, max_len, dropout=0)
# Get positional encodings
x = torch.zeros(1, max_len, d_model)
pe = pe_layer.pe[0, :max_len, :].numpy()
plt.figure(figsize=(12, 6))
plt.imshow(pe.T, aspect='auto', cmap='RdBu')
plt.xlabel('Position')
plt.ylabel('Dimension')
plt.title('Sinusoidal Positional Encoding')
plt.colorbar(label='Encoding Value')
plt.show()
# Show specific dimensions
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
for i, ax in enumerate(axes.flat):
dim = i * 20
ax.plot(pe[:, dim], label=f'dim {dim} (sin)')
ax.plot(pe[:, dim+1], label=f'dim {dim+1} (cos)')
ax.set_xlabel('Position')
ax.set_ylabel('Value')
ax.legend()
ax.set_title(f'Dimensions {dim}, {dim+1}')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
visualize_positional_encoding()
Learned Positional Embeddings
An alternative: learn position embeddings like word embeddings:Copy
class LearnedPositionalEncoding(nn.Module):
"""
Learned positional embeddings.
Properties:
- Can learn task-specific position patterns
- Cannot extrapolate beyond max_len
- More parameters to learn
"""
def __init__(self, d_model, max_len=512, dropout=0.1):
super().__init__()
self.position_embeddings = nn.Embedding(max_len, d_model)
self.dropout = nn.Dropout(dropout)
# Initialize with small values
nn.init.normal_(self.position_embeddings.weight, std=0.02)
def forward(self, x):
"""
Args:
x: (batch, seq_len, d_model)
"""
seq_len = x.size(1)
positions = torch.arange(seq_len, device=x.device)
position_embeddings = self.position_embeddings(positions)
x = x + position_embeddings
return self.dropout(x)
# Compare parameter counts
sinusoidal = SinusoidalPositionalEncoding(512, 1000)
learned = LearnedPositionalEncoding(512, 1000)
print(f"Sinusoidal PE params: {sum(p.numel() for p in sinusoidal.parameters()):,}")
print(f"Learned PE params: {sum(p.numel() for p in learned.parameters()):,}")
Attention with Masking
Causal (Autoregressive) Mask
For language models, we need to prevent attending to future positions:Copy
def create_causal_mask(seq_len):
"""
Create a causal mask to prevent attending to future positions.
For position i, can only attend to positions 0, 1, ..., i
"""
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
def visualize_causal_mask():
seq_len = 6
mask = create_causal_mask(seq_len)
# Convert to attention-friendly format (0 = attend, -inf = don't attend)
viz_mask = torch.zeros_like(mask)
viz_mask[mask == float('-inf')] = 1
plt.figure(figsize=(6, 6))
plt.imshow(viz_mask.numpy(), cmap='Greys')
plt.xlabel('Key Position (can attend to)')
plt.ylabel('Query Position')
plt.title('Causal Mask\n(white = masked, black = can attend)')
for i in range(seq_len):
for j in range(seq_len):
color = 'white' if viz_mask[i, j] == 1 else 'black'
plt.text(j, i, '✓' if viz_mask[i,j] == 0 else '✗',
ha='center', va='center', color=color, fontsize=16)
plt.show()
visualize_causal_mask()
class CausalSelfAttention(nn.Module):
"""Self-attention with causal masking for autoregressive models."""
def __init__(self, d_model, num_heads, max_len=512, dropout=0.1):
super().__init__()
self.mha = MultiHeadAttention(d_model, num_heads, dropout)
# Register causal mask as buffer
mask = torch.triu(torch.ones(max_len, max_len), diagonal=1).bool()
self.register_buffer('mask', mask)
def forward(self, x):
seq_len = x.size(1)
# Get appropriate size mask and expand for batch and heads
mask = self.mask[:seq_len, :seq_len]
mask = mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
mask = ~mask # Invert: True = can attend, False = masked
return self.mha(x, x, x, mask=mask)
Padding Mask
For variable-length sequences with padding:Copy
def create_padding_mask(seq, pad_idx=0):
"""
Create mask to prevent attending to padding tokens.
Args:
seq: Token indices (batch, seq_len)
pad_idx: Index of padding token
Returns:
mask: (batch, 1, 1, seq_len) - True where valid, False where padding
"""
mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
return mask
# Example
batch = torch.tensor([
[5, 3, 7, 2, 0, 0], # Padded with 2 zeros
[8, 1, 4, 6, 9, 0], # Padded with 1 zero
])
padding_mask = create_padding_mask(batch, pad_idx=0)
print("Sequence:")
print(batch)
print("\nPadding mask:")
print(padding_mask.squeeze())
Complete Attention-Based Seq2Seq
Copy
class AttentionSeq2Seq(nn.Module):
"""
Sequence-to-sequence model with attention.
Unlike vanilla seq2seq, the decoder can look at all encoder states
at each decoding step, selecting relevant information.
"""
def __init__(self, src_vocab_size, tgt_vocab_size, d_model,
num_heads, num_layers, dropout=0.1):
super().__init__()
self.d_model = d_model
# Embeddings
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
# Positional encoding
self.pos_encoding = SinusoidalPositionalEncoding(d_model, dropout=dropout)
# Encoder (bidirectional LSTM + self-attention)
self.encoder_lstm = nn.LSTM(
d_model, d_model // 2, num_layers,
batch_first=True, bidirectional=True, dropout=dropout
)
self.encoder_attention = MultiHeadAttention(d_model, num_heads, dropout)
self.encoder_norm = nn.LayerNorm(d_model)
# Decoder
self.decoder_lstm = nn.LSTM(
d_model, d_model, num_layers,
batch_first=True, dropout=dropout
)
# Cross-attention (decoder attends to encoder)
self.cross_attention = MultiHeadAttention(d_model, num_heads, dropout)
self.cross_norm = nn.LayerNorm(d_model)
# Output projection
self.output_projection = nn.Linear(d_model, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)
def encode(self, src, src_mask=None):
"""Encode source sequence."""
# Embed and add positional encoding
x = self.src_embedding(src) * np.sqrt(self.d_model)
x = self.pos_encoding(x)
# LSTM encoding
x, _ = self.encoder_lstm(x)
# Self-attention
attn_out, _ = self.encoder_attention(x, x, x, mask=src_mask)
x = self.encoder_norm(x + self.dropout(attn_out))
return x
def decode(self, tgt, encoder_output, src_mask=None):
"""Decode target sequence."""
# Embed and add positional encoding
x = self.tgt_embedding(tgt) * np.sqrt(self.d_model)
x = self.pos_encoding(x)
# LSTM decoding
x, _ = self.decoder_lstm(x)
# Cross-attention to encoder output
attn_out, attention_weights = self.cross_attention(
x, encoder_output, encoder_output, mask=src_mask
)
x = self.cross_norm(x + self.dropout(attn_out))
# Project to vocabulary
logits = self.output_projection(x)
return logits, attention_weights
def forward(self, src, tgt, src_mask=None):
encoder_output = self.encode(src, src_mask)
logits, attention = self.decode(tgt, encoder_output, src_mask)
return logits, attention
# Test
model = AttentionSeq2Seq(
src_vocab_size=10000,
tgt_vocab_size=8000,
d_model=256,
num_heads=8,
num_layers=2
)
src = torch.randint(0, 10000, (4, 30))
tgt = torch.randint(0, 8000, (4, 25))
logits, attention = model(src, tgt)
print(f"Source: {src.shape}")
print(f"Target: {tgt.shape}")
print(f"Output logits: {logits.shape}")
print(f"Attention weights: {attention.shape}")
Exercises
Exercise 1: Implement Attention from Scratch
Exercise 1: Implement Attention from Scratch
Implement scaled dot-product attention without using any PyTorch attention functions:
- Implement the forward pass with proper scaling
- Add masking support (padding and causal)
- Verify gradients flow correctly
- Test on a simple sequence copying task
Exercise 2: Visualize Real Attention Patterns
Exercise 2: Visualize Real Attention Patterns
Using a pre-trained model (or train your own):
- Extract attention weights for various inputs
- Create visualizations showing what each head attends to
- Identify heads with interpretable patterns
- Compare attention patterns for different input types
Exercise 3: Compare Attention Variants
Exercise 3: Compare Attention Variants
Implement and compare:
- Dot-product attention
- Additive (Bahdanau) attention
- Multiplicative (Luong) attention
- Training speed
- Final BLEU score
- Attention patterns
Exercise 4: Relative Position Encoding
Exercise 4: Relative Position Encoding
Implement relative positional encoding:
- Instead of absolute positions, encode relative distances
- Modify attention scores to include position bias
- Compare with sinusoidal on long sequences
- Test extrapolation to longer sequences
Exercise 5: Efficient Attention
Exercise 5: Efficient Attention
Implement a more efficient attention mechanism:
- Implement local attention (attend only to nearby positions)
- Implement sparse attention patterns
- Compare memory usage and speed with full attention
- Evaluate impact on model quality
Key Takeaways
| Concept | Key Insight |
|---|---|
| Attention | Weighted combination of values based on query-key similarity |
| Dot-Product | softmax(QKT/d)V - simple and efficient |
| Multi-Head | Multiple attention functions for different relationship types |
| Self-Attention | Sequence attends to itself - O(1) dependency paths |
| Positional Encoding | Add position information since attention is permutation-invariant |
| Masking | Causal for autoregressive, padding for variable lengths |
The attention mechanism is the foundation of modern NLP. Understanding it deeply will help you grasp Transformers, BERT, GPT, and virtually all state-of-the-art language models.