Transformers: Attention Is All You Need
The Architecture That Changed Everything
In 2017, the paper “Attention Is All You Need” introduced the Transformer, a model that:- Removed RNNs entirely - using only attention mechanisms
- Enabled massive parallelization - training became much faster
- Captured long-range dependencies - directly, without information bottlenecks
The Core Insight: Why use recurrence at all? Self-attention can capture dependencies between any positions in a sequence, regardless of distance. Combined with position encoding, we get all the benefits of sequence modeling without the sequential bottleneck.
Copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
def transformer_overview():
"""
The Transformer architecture at a glance.
Two main components:
1. ENCODER: Processes the input sequence
- Self-attention + Feed-forward
- Stacked N times
2. DECODER: Generates the output sequence
- Masked self-attention (causal)
- Cross-attention to encoder
- Feed-forward
- Stacked N times
"""
print("Transformer Architecture Overview")
print("=" * 60)
print()
print("ENCODER (processes input):")
print(" Input → Embedding + Positional Encoding")
print(" → [Self-Attention → Add & Norm → FFN → Add & Norm] × N")
print(" → Encoder Output")
print()
print("DECODER (generates output):")
print(" Output (shifted) → Embedding + Positional Encoding")
print(" → [Masked Self-Attn → Add & Norm")
print(" → Cross-Attn (to Encoder) → Add & Norm")
print(" → FFN → Add & Norm] × N")
print(" → Linear → Softmax → Predictions")
print()
print("Key innovations:")
print(" • Multi-head self-attention (parallelizable)")
print(" • Layer normalization for stability")
print(" • Residual connections for gradient flow")
print(" • Positional encoding for sequence order")
transformer_overview()
Building Blocks
Multi-Head Attention (Revisited)
Copy
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention mechanism.
Allows the model to jointly attend to information from
different representation subspaces at different positions.
"""
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.d_k)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Linear projections and reshape for multi-head
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = self.dropout(F.softmax(scores, dim=-1))
# Apply attention to values
context = torch.matmul(attention, V)
# Reshape and project
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.W_o(context)
return output, attention
Position-wise Feed-Forward Network
A simple two-layer MLP applied to each position independently: FFN(x)=max(0,xW1+b1)W2+b2Copy
class PositionwiseFeedForward(nn.Module):
"""
Position-wise Feed-Forward Network.
Applied independently to each position (token).
Acts as a nonlinear transformation of the attention output.
Typically expands dimension by 4x, then projects back.
"""
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x: (batch, seq_len, d_model)
Returns:
output: (batch, seq_len, d_model)
"""
# Expand, apply non-linearity, project back
x = self.linear1(x)
x = F.relu(x) # Original paper uses ReLU
x = self.dropout(x)
x = self.linear2(x)
return x
# Modern variant: GELU activation (used in BERT, GPT)
class PositionwiseFeedForwardGELU(nn.Module):
"""FFN with GELU activation (used in modern transformers)."""
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = F.gelu(self.linear1(x))
x = self.dropout(x)
x = self.linear2(x)
return x
Layer Normalization
Copy
class LayerNorm(nn.Module):
"""
Layer Normalization.
Normalizes across the feature dimension (not batch).
More stable than batch norm for variable-length sequences.
LayerNorm(x) = γ * (x - μ) / (σ + ε) + β
"""
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
# Note: PyTorch has built-in nn.LayerNorm which is equivalent
layer_norm = nn.LayerNorm(512)
Positional Encoding
Copy
class PositionalEncoding(nn.Module):
"""
Sinusoidal Positional Encoding.
Adds position information using sin/cos functions at different frequencies.
Allows the model to learn relative positions.
"""
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Create positional encoding matrix
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
"""Add positional encoding to input embeddings."""
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
The Encoder
Copy
class EncoderLayer(nn.Module):
"""
Single Transformer Encoder Layer.
Structure:
x → Self-Attention → Add & Norm → FFN → Add & Norm → output
└───────────────────┘ └─────────────────┘
(residual) (residual)
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
Args:
x: Input (batch, seq_len, d_model)
mask: Attention mask (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
"""
# Self-attention with residual connection
attn_output, _ = self.self_attention(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# Feed-forward with residual connection
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
class TransformerEncoder(nn.Module):
"""
Stack of N Encoder Layers.
"""
def __init__(self, vocab_size, d_model, num_heads, d_ff,
num_layers, max_len=5000, dropout=0.1):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_len, dropout)
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(self, src, src_mask=None):
"""
Args:
src: Source token indices (batch, src_len)
src_mask: Mask for padding (batch, 1, 1, src_len)
Returns:
encoder_output: (batch, src_len, d_model)
"""
# Embed and add positional encoding
x = self.embedding(src) * math.sqrt(self.d_model)
x = self.positional_encoding(x)
# Pass through encoder layers
for layer in self.layers:
x = layer(x, src_mask)
return self.norm(x)
# Test encoder
encoder = TransformerEncoder(
vocab_size=10000,
d_model=512,
num_heads=8,
d_ff=2048,
num_layers=6
)
src = torch.randint(0, 10000, (2, 30))
output = encoder(src)
print(f"Input: {src.shape}")
print(f"Encoder output: {output.shape}")
print(f"Encoder parameters: {sum(p.numel() for p in encoder.parameters()):,}")
The Decoder
Copy
class DecoderLayer(nn.Module):
"""
Single Transformer Decoder Layer.
Structure:
x → Masked Self-Attention → Add & Norm
→ Cross-Attention (to encoder) → Add & Norm
→ FFN → Add & Norm → output
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# Masked self-attention (causal)
self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
# Cross-attention to encoder output
self.cross_attention = MultiHeadAttention(d_model, num_heads, dropout)
# Feed-forward network
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
# Layer normalizations
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
"""
Args:
x: Decoder input (batch, tgt_len, d_model)
encoder_output: Encoder output (batch, src_len, d_model)
src_mask: Source padding mask
tgt_mask: Target causal mask
"""
# Masked self-attention
self_attn_output, _ = self.self_attention(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(self_attn_output))
# Cross-attention to encoder
cross_attn_output, attention_weights = self.cross_attention(
x, encoder_output, encoder_output, src_mask
)
x = self.norm2(x + self.dropout(cross_attn_output))
# Feed-forward
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x, attention_weights
class TransformerDecoder(nn.Module):
"""
Stack of N Decoder Layers.
"""
def __init__(self, vocab_size, d_model, num_heads, d_ff,
num_layers, max_len=5000, dropout=0.1):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_len, dropout)
self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
self.output_projection = nn.Linear(d_model, vocab_size)
def forward(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
"""
Args:
tgt: Target token indices (batch, tgt_len)
encoder_output: (batch, src_len, d_model)
src_mask: Source padding mask
tgt_mask: Target causal mask
"""
# Embed and add positional encoding
x = self.embedding(tgt) * math.sqrt(self.d_model)
x = self.positional_encoding(x)
# Pass through decoder layers
attention_weights = None
for layer in self.layers:
x, attention_weights = layer(x, encoder_output, src_mask, tgt_mask)
x = self.norm(x)
# Project to vocabulary
logits = self.output_projection(x)
return logits, attention_weights
The Complete Transformer
Copy
class Transformer(nn.Module):
"""
Complete Transformer model for sequence-to-sequence tasks.
This is the architecture from "Attention Is All You Need" (2017).
"""
def __init__(
self,
src_vocab_size,
tgt_vocab_size,
d_model=512,
num_heads=8,
d_ff=2048,
num_encoder_layers=6,
num_decoder_layers=6,
max_len=5000,
dropout=0.1,
share_embeddings=False
):
super().__init__()
self.d_model = d_model
# Encoder
self.encoder = TransformerEncoder(
src_vocab_size, d_model, num_heads, d_ff,
num_encoder_layers, max_len, dropout
)
# Decoder
self.decoder = TransformerDecoder(
tgt_vocab_size, d_model, num_heads, d_ff,
num_decoder_layers, max_len, dropout
)
# Optionally share embeddings between encoder and decoder
if share_embeddings and src_vocab_size == tgt_vocab_size:
self.decoder.embedding = self.encoder.embedding
# Initialize parameters
self._init_parameters()
def _init_parameters(self):
"""Initialize parameters with Xavier uniform."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def make_src_mask(self, src, pad_idx=0):
"""Create source padding mask."""
src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
return src_mask
def make_tgt_mask(self, tgt, pad_idx=0):
"""Create target mask (padding + causal)."""
batch_size, tgt_len = tgt.shape
# Padding mask
pad_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(2)
# Causal mask
causal_mask = torch.tril(torch.ones(tgt_len, tgt_len, device=tgt.device)).bool()
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
# Combine masks
tgt_mask = pad_mask & causal_mask
return tgt_mask
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
"""
Forward pass through the Transformer.
Args:
src: Source sequence (batch, src_len)
tgt: Target sequence (batch, tgt_len)
src_mask: Source mask (optional, will be created if not provided)
tgt_mask: Target mask (optional, will be created if not provided)
Returns:
logits: Output logits (batch, tgt_len, tgt_vocab_size)
attention_weights: Cross-attention weights from last decoder layer
"""
if src_mask is None:
src_mask = self.make_src_mask(src)
if tgt_mask is None:
tgt_mask = self.make_tgt_mask(tgt)
# Encode source
encoder_output = self.encoder(src, src_mask)
# Decode target
logits, attention_weights = self.decoder(
tgt, encoder_output, src_mask, tgt_mask
)
return logits, attention_weights
def generate(self, src, max_len=50, start_token=1, end_token=2):
"""
Generate output sequence autoregressively.
Args:
src: Source sequence (batch, src_len)
max_len: Maximum generation length
start_token: Start of sequence token index
end_token: End of sequence token index
Returns:
generated: Generated token indices (batch, gen_len)
"""
self.eval()
batch_size = src.size(0)
device = src.device
# Encode source once
src_mask = self.make_src_mask(src)
encoder_output = self.encoder(src, src_mask)
# Start with <SOS> token
generated = torch.full((batch_size, 1), start_token, dtype=torch.long, device=device)
with torch.no_grad():
for _ in range(max_len):
tgt_mask = self.make_tgt_mask(generated)
logits, _ = self.decoder(generated, encoder_output, src_mask, tgt_mask)
# Get last token prediction
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
# Check if all sequences have generated <EOS>
if (next_token == end_token).all():
break
return generated
# Create a Transformer
transformer = Transformer(
src_vocab_size=10000,
tgt_vocab_size=8000,
d_model=512,
num_heads=8,
d_ff=2048,
num_encoder_layers=6,
num_decoder_layers=6,
dropout=0.1
)
# Test forward pass
src = torch.randint(1, 10000, (4, 30)) # Batch of 4, source length 30
tgt = torch.randint(1, 8000, (4, 25)) # Target length 25
logits, attention = transformer(src, tgt)
print(f"Source: {src.shape}")
print(f"Target: {tgt.shape}")
print(f"Output logits: {logits.shape}")
print(f"Cross-attention: {attention.shape}")
print(f"\nTotal parameters: {sum(p.numel() for p in transformer.parameters()):,}")
Training the Transformer
Label Smoothing
Copy
class LabelSmoothingLoss(nn.Module):
"""
Label Smoothing Cross-Entropy Loss.
Instead of hard targets (0 or 1), use soft targets:
- True class: 1 - smoothing
- Other classes: smoothing / (num_classes - 1)
This prevents overconfidence and improves generalization.
"""
def __init__(self, vocab_size, smoothing=0.1, ignore_index=-100):
super().__init__()
self.vocab_size = vocab_size
self.smoothing = smoothing
self.ignore_index = ignore_index
self.confidence = 1.0 - smoothing
def forward(self, logits, target):
"""
Args:
logits: (batch, seq_len, vocab_size)
target: (batch, seq_len)
"""
logits = logits.reshape(-1, self.vocab_size)
target = target.reshape(-1)
# Create smoothed distribution
smooth_target = torch.zeros_like(logits)
smooth_target.fill_(self.smoothing / (self.vocab_size - 1))
# Set confidence on true class
mask = target != self.ignore_index
indices = target[mask]
smooth_target[mask] = smooth_target[mask].scatter(
1, indices.unsqueeze(1), self.confidence
)
# KL divergence (equivalent to cross-entropy with soft targets)
log_probs = F.log_softmax(logits, dim=-1)
loss = -(smooth_target * log_probs).sum(dim=-1)
return loss[mask].mean()
Learning Rate Schedule
The original Transformer uses a special learning rate schedule: lr=dmodel−0.5⋅min(step−0.5,step⋅warmup_steps−1.5)Copy
class TransformerLRScheduler:
"""
Learning rate scheduler from "Attention Is All You Need".
Increases linearly during warmup, then decreases proportional to
the inverse square root of the step number.
"""
def __init__(self, optimizer, d_model, warmup_steps=4000):
self.optimizer = optimizer
self.d_model = d_model
self.warmup_steps = warmup_steps
self.step_num = 0
def step(self):
self.step_num += 1
lr = self._get_lr()
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
return lr
def _get_lr(self):
step = max(1, self.step_num)
return (self.d_model ** -0.5) * min(
step ** -0.5,
step * (self.warmup_steps ** -1.5)
)
def visualize_lr_schedule():
"""Visualize the Transformer learning rate schedule."""
# Dummy optimizer
model = nn.Linear(512, 512)
optimizer = torch.optim.Adam(model.parameters(), lr=1.0)
scheduler = TransformerLRScheduler(optimizer, d_model=512, warmup_steps=4000)
lrs = []
for _ in range(100000):
lrs.append(scheduler.step())
plt.figure(figsize=(10, 5))
plt.plot(lrs)
plt.xlabel('Step')
plt.ylabel('Learning Rate')
plt.title('Transformer Learning Rate Schedule')
plt.axvline(x=4000, color='r', linestyle='--', label='Warmup ends')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
# visualize_lr_schedule()
Training Loop
Copy
def train_transformer():
"""Complete training example for Transformer."""
# Hyperparameters
src_vocab_size = 10000
tgt_vocab_size = 8000
d_model = 256
num_heads = 8
d_ff = 1024
num_layers = 4
dropout = 0.1
warmup_steps = 4000
epochs = 10
batch_size = 32
# Model
model = Transformer(
src_vocab_size=src_vocab_size,
tgt_vocab_size=tgt_vocab_size,
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
num_encoder_layers=num_layers,
num_decoder_layers=num_layers,
dropout=dropout
)
# Move to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Optimizer
optimizer = torch.optim.Adam(
model.parameters(),
lr=1.0, # Will be overridden by scheduler
betas=(0.9, 0.98),
eps=1e-9
)
# Scheduler
scheduler = TransformerLRScheduler(optimizer, d_model, warmup_steps)
# Loss function
criterion = LabelSmoothingLoss(tgt_vocab_size, smoothing=0.1)
# Training loop
model.train()
for epoch in range(epochs):
total_loss = 0
num_batches = 100 # Simulated
for batch_idx in range(num_batches):
# Generate dummy data (replace with real data loader)
src = torch.randint(1, src_vocab_size, (batch_size, 30)).to(device)
tgt = torch.randint(1, tgt_vocab_size, (batch_size, 25)).to(device)
# Shift target for teacher forcing
tgt_input = tgt[:, :-1]
tgt_output = tgt[:, 1:]
# Forward pass
logits, _ = model(src, tgt_input)
# Compute loss
loss = criterion(logits, tgt_output)
# Backward pass
optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
lr = scheduler.step()
total_loss += loss.item()
if batch_idx % 20 == 0:
print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}, LR: {lr:.6f}")
avg_loss = total_loss / num_batches
print(f"Epoch {epoch+1} complete. Average Loss: {avg_loss:.4f}")
return model
# model = train_transformer()
Encoder-Only: BERT
BERT uses only the Transformer encoder for bidirectional language understanding:Copy
class BERTEncoder(nn.Module):
"""
BERT-style encoder.
Key differences from original Transformer encoder:
- Learned positional embeddings (instead of sinusoidal)
- [CLS] token for classification
- Segment embeddings for sentence pairs
- GELU activation in FFN
"""
def __init__(
self,
vocab_size,
d_model=768,
num_heads=12,
d_ff=3072,
num_layers=12,
max_len=512,
dropout=0.1
):
super().__init__()
self.d_model = d_model
# Embeddings
self.word_embeddings = nn.Embedding(vocab_size, d_model)
self.position_embeddings = nn.Embedding(max_len, d_model)
self.segment_embeddings = nn.Embedding(2, d_model) # For sentence pairs
self.embedding_norm = nn.LayerNorm(d_model)
self.embedding_dropout = nn.Dropout(dropout)
# Transformer layers
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
# Pooler for [CLS] token
self.pooler = nn.Sequential(
nn.Linear(d_model, d_model),
nn.Tanh()
)
def forward(self, input_ids, segment_ids=None, attention_mask=None):
"""
Args:
input_ids: Token indices (batch, seq_len)
segment_ids: Segment indices for sentence pairs (batch, seq_len)
attention_mask: Mask for padding (batch, seq_len)
Returns:
sequence_output: All token representations (batch, seq_len, d_model)
pooled_output: [CLS] representation (batch, d_model)
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Create position indices
position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
# Default segment ids
if segment_ids is None:
segment_ids = torch.zeros_like(input_ids)
# Compute embeddings
word_embeds = self.word_embeddings(input_ids)
position_embeds = self.position_embeddings(position_ids)
segment_embeds = self.segment_embeddings(segment_ids)
embeddings = word_embeds + position_embeds + segment_embeds
embeddings = self.embedding_norm(embeddings)
embeddings = self.embedding_dropout(embeddings)
# Create attention mask
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Pass through layers
hidden_states = embeddings
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)
# Pool [CLS] token
cls_token = hidden_states[:, 0]
pooled_output = self.pooler(cls_token)
return hidden_states, pooled_output
# BERT configurations
bert_configs = {
'bert-base': {'d_model': 768, 'num_heads': 12, 'd_ff': 3072, 'num_layers': 12},
'bert-large': {'d_model': 1024, 'num_heads': 16, 'd_ff': 4096, 'num_layers': 24},
}
# Create BERT-base
bert = BERTEncoder(vocab_size=30522, **bert_configs['bert-base'])
print(f"BERT-base parameters: {sum(p.numel() for p in bert.parameters()):,}")
Decoder-Only: GPT
GPT uses only the Transformer decoder for autoregressive language modeling:Copy
class GPTDecoder(nn.Module):
"""
GPT-style decoder.
Key differences:
- Decoder-only (no cross-attention to encoder)
- Causal masking for autoregressive generation
- Learned positional embeddings
"""
def __init__(
self,
vocab_size,
d_model=768,
num_heads=12,
d_ff=3072,
num_layers=12,
max_len=1024,
dropout=0.1
):
super().__init__()
self.d_model = d_model
# Embeddings
self.token_embeddings = nn.Embedding(vocab_size, d_model)
self.position_embeddings = nn.Embedding(max_len, d_model)
self.dropout = nn.Dropout(dropout)
# Transformer layers (decoder without cross-attention)
self.layers = nn.ModuleList([
GPTDecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
# Output projection
self.output_projection = nn.Linear(d_model, vocab_size, bias=False)
# Tie embedding weights
self.output_projection.weight = self.token_embeddings.weight
# Cache for causal mask
self.register_buffer(
'causal_mask',
torch.tril(torch.ones(max_len, max_len)).bool()
)
def forward(self, input_ids, past_key_values=None):
"""
Args:
input_ids: Token indices (batch, seq_len)
past_key_values: Cached key/values for efficient generation
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Get position indices (account for past tokens)
past_len = 0 if past_key_values is None else past_key_values[0][0].size(2)
position_ids = torch.arange(past_len, past_len + seq_len, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Embeddings
hidden_states = self.token_embeddings(input_ids) + self.position_embeddings(position_ids)
hidden_states = self.dropout(hidden_states)
# Causal mask
causal_mask = self.causal_mask[past_len:past_len+seq_len, :past_len+seq_len]
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
# Process through layers
for i, layer in enumerate(self.layers):
past = past_key_values[i] if past_key_values is not None else None
hidden_states = layer(hidden_states, causal_mask, past)
hidden_states = self.norm(hidden_states)
# Project to vocabulary
logits = self.output_projection(hidden_states)
return logits
class GPTDecoderLayer(nn.Module):
"""GPT decoder layer (no cross-attention)."""
def __init__(self, d_model, num_heads, d_ff, dropout):
super().__init__()
self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = PositionwiseFeedForwardGELU(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None, past_key_value=None):
# Self-attention
attn_output, _ = self.self_attention(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# Feed-forward
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
# GPT configurations
gpt_configs = {
'gpt2': {'d_model': 768, 'num_heads': 12, 'd_ff': 3072, 'num_layers': 12},
'gpt2-medium': {'d_model': 1024, 'num_heads': 16, 'd_ff': 4096, 'num_layers': 24},
'gpt2-large': {'d_model': 1280, 'num_heads': 20, 'd_ff': 5120, 'num_layers': 36},
'gpt2-xl': {'d_model': 1600, 'num_heads': 25, 'd_ff': 6400, 'num_layers': 48},
}
gpt = GPTDecoder(vocab_size=50257, **gpt_configs['gpt2'])
print(f"GPT-2 parameters: {sum(p.numel() for p in gpt.parameters()):,}")
BERT vs GPT Comparison
Copy
def compare_bert_gpt():
"""Compare BERT and GPT architectures."""
comparison = """
┌─────────────────────────────────────────────────────────────────┐
│ BERT vs GPT Comparison │
├─────────────────────────────────────────────────────────────────┤
│ Aspect │ BERT │ GPT │
├─────────────────────────────────────────────────────────────────┤
│ Architecture │ Encoder-only │ Decoder-only │
│ Attention │ Bidirectional │ Causal (left-to-right)│
│ Pre-training │ MLM + NSP │ Language Modeling │
│ Best for │ Understanding │ Generation │
│ Example tasks │ Classification, │ Text generation, │
│ │ NER, QA │ Summarization │
│ Context │ Sees all tokens │ Only sees past tokens│
└─────────────────────────────────────────────────────────────────┘
BERT Pre-training:
- Masked Language Modeling: Predict [MASK] tokens
- Next Sentence Prediction: Is sentence B after A?
GPT Pre-training:
- Autoregressive LM: Predict next token given previous
- P(token | all previous tokens)
"""
print(comparison)
compare_bert_gpt()
| Model | Parameters | Training Data | Release |
|---|---|---|---|
| BERT-base | 110M | 16GB text | 2018 |
| GPT-2 | 1.5B | 40GB text | 2019 |
| GPT-3 | 175B | 570GB text | 2020 |
| GPT-4 | ~1T (est.) | Unknown | 2023 |
| LLaMA-2 | 7B-70B | 2T tokens | 2023 |
Modern Transformer Improvements
Pre-Norm vs Post-Norm
Copy
class PreNormTransformerLayer(nn.Module):
"""
Pre-normalization: Apply LayerNorm before attention/FFN.
Improves training stability for deep models.
Used in GPT-2/3, LLaMA, and many modern transformers.
"""
def __init__(self, d_model, num_heads, d_ff, dropout):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Pre-norm: normalize before attention
normed = self.norm1(x)
attn_out, _ = self.self_attention(normed, normed, normed, mask)
x = x + self.dropout(attn_out)
# Pre-norm: normalize before FFN
normed = self.norm2(x)
ff_out = self.feed_forward(normed)
x = x + self.dropout(ff_out)
return x
RMSNorm
Copy
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization.
Simpler than LayerNorm: only normalizes by RMS, no mean subtraction.
Used in LLaMA and other efficient transformers.
"""
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x):
rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return x / rms * self.weight
Rotary Position Embedding (RoPE)
Copy
class RotaryPositionEmbedding(nn.Module):
"""
Rotary Position Embedding (RoPE).
Encodes position by rotating query/key vectors.
Enables relative position understanding and better extrapolation.
Used in LLaMA, GPT-NeoX, and many modern models.
"""
def __init__(self, d_model, max_len=2048):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
self.register_buffer('inv_freq', inv_freq)
# Precompute rotation matrices
t = torch.arange(max_len)
freqs = torch.outer(t, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer('cos', emb.cos())
self.register_buffer('sin', emb.sin())
def forward(self, x, position_ids):
"""Apply rotary embeddings to queries and keys."""
# x: (batch, num_heads, seq_len, d_k)
cos = self.cos[position_ids].unsqueeze(1)
sin = self.sin[position_ids].unsqueeze(1)
# Rotate
x1, x2 = x[..., :x.size(-1)//2], x[..., x.size(-1)//2:]
rotated = torch.cat([-x2, x1], dim=-1)
return x * cos + rotated * sin
SwiGLU Activation
Copy
class SwiGLU(nn.Module):
"""
SwiGLU activation function.
Combines Swish activation with gated linear unit.
Used in PaLM, LLaMA, and other modern transformers.
SwiGLU(x) = Swish(xW1) ⊙ (xV)
"""
def __init__(self, d_model, d_ff):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False) # Gate
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
Exercises
Exercise 1: Build Transformer from Scratch
Exercise 1: Build Transformer from Scratch
Implement the complete Transformer without looking at the code above:
- Multi-head attention
- Position-wise feed-forward
- Encoder and decoder layers
- Full encoder-decoder model
Exercise 2: Implement BERT Pre-training
Exercise 2: Implement BERT Pre-training
Implement BERT’s pre-training objectives:
- Masked Language Modeling (MLM)
- Randomly mask 15% of tokens
- 80% [MASK], 10% random, 10% unchanged
- Next Sentence Prediction (NSP)
Exercise 3: Text Generation with GPT
Exercise 3: Text Generation with GPT
Build a small GPT model for text generation:
- Implement causal masking
- Train on a small text corpus
- Implement nucleus (top-p) sampling
- Generate text and analyze quality
Exercise 4: Efficient Attention
Exercise 4: Efficient Attention
Implement efficient attention variants:
- Linear attention
- Sliding window attention
- Flash attention (conceptually)
Exercise 5: Fine-tune for Classification
Exercise 5: Fine-tune for Classification
Fine-tune a pre-trained transformer for text classification:
- Load a pre-trained model (HuggingFace)
- Add a classification head
- Fine-tune on IMDB or AG News
- Analyze attention patterns for interpretability
Key Takeaways
| Concept | Key Insight |
|---|---|
| Self-Attention | Enables O(1) dependency paths, full parallelization |
| Multi-Head | Multiple attention patterns for different relationships |
| Encoder | Bidirectional understanding of input |
| Decoder | Causal generation, one token at a time |
| Positional Encoding | Injects sequence order information |
| Layer Normalization | Stabilizes training, enables deep models |
| BERT | Encoder-only, bidirectional, for understanding |
| GPT | Decoder-only, causal, for generation |
The Transformer is the foundation of modern AI. Every major language model (GPT-4, Claude, LLaMA, Gemini) is built on this architecture. Understanding it deeply is essential for anyone working in AI.
What’s Next
Congratulations! You’ve completed the core architecture modules of the Deep Learning Mastery course. You now understand:- Neural network fundamentals (perceptrons, backprop, activations, loss functions)
- Convolutional networks for images
- Recurrent networks for sequences
- Attention and Transformers for everything