LSTMs & GRUs: Gated Recurrent Networks
The Memory Problem Revisited
Vanilla RNNs suffer from a fundamental flaw: they can’t remember things for long. The vanishing gradient problem means information from early time steps gets “washed out” as it passes through many layers of tanh activations. Real-world consequence: An RNN reading a book can’t remember what happened in Chapter 1 when it reaches Chapter 10.Copy
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
def demonstrate_memory_problem():
"""
The Long-Range Dependency Problem:
"The cat, which had been sleeping on the warm sunny
windowsill for the entire lazy afternoon while the
rain pattered gently against the glass, finally ___."
To predict "woke up" or "stretched", we need to remember
"cat" and "sleeping" from 25+ words ago!
"""
print("Long-range dependency example:")
print("-" * 60)
print("Sentence: 'The CAT, which had been SLEEPING on the warm...")
print(" sunny windowsill for the entire lazy afternoon...")
print(" while the rain pattered against the glass, finally ___'")
print()
print("To predict the blank, we need to remember:")
print(" - Subject: 'cat' (25 words back)")
print(" - State: 'sleeping' (23 words back)")
print()
print("Vanilla RNNs fail because gradients vanish over 25 steps!")
demonstrate_memory_problem()
The Solution: Instead of trying to force information through a single path, create multiple pathways for information flow, some of which can pass information unchanged. This is the key insight behind LSTMs and GRUs.
Long Short-Term Memory (LSTM)
The Big Idea: A Memory Cell with Gates
An LSTM maintains two types of state:- Cell State (Ct): The “long-term memory” - a highway for information
- Hidden State (ht): The “working memory” - current output
- Forget Gate: What to erase from cell state
- Input Gate: What new information to add
- Output Gate: What to output based on cell state
LSTM Equations
ftitC~tCtotht=σ(Wf⋅[ht−1,xt]+bf)=σ(Wi⋅[ht−1,xt]+bi)=tanh(WC⋅[ht−1,xt]+bC)=ft⊙Ct−1+it⊙C~t=σ(Wo⋅[ht−1,xt]+bo)=ot⊙tanh(Ct)(Forget gate)(Input gate)(Candidate values)(Cell state update)(Output gate)(Hidden state) Where:- σ is the sigmoid function (outputs 0-1 for gating)
- ⊙ is element-wise multiplication
- [ht−1,xt] is concatenation of previous hidden and current input
Copy
class LSTMCell(nn.Module):
"""
LSTM Cell implemented from scratch.
The key insight: Cell state C_t flows through with only
element-wise operations (multiply by forget, add input).
This creates a "gradient highway" - gradients can flow
backward through time without vanishing!
"""
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
# All gates in one matrix for efficiency
# Order: input, forget, cell, output (i, f, g, o)
self.gates = nn.Linear(input_size + hidden_size, 4 * hidden_size)
def forward(self, x, hidden=None):
"""
Args:
x: Input (batch, input_size)
hidden: Tuple of (h, c), each (batch, hidden_size)
Returns:
h_new: New hidden state
c_new: New cell state
"""
batch_size = x.size(0)
# Initialize hidden states if not provided
if hidden is None:
h = torch.zeros(batch_size, self.hidden_size, device=x.device)
c = torch.zeros(batch_size, self.hidden_size, device=x.device)
else:
h, c = hidden
# Concatenate input and hidden state
combined = torch.cat([x, h], dim=1)
# Compute all gates at once
gates = self.gates(combined)
# Split into individual gates
i, f, g, o = gates.chunk(4, dim=1)
# Apply activations
i = torch.sigmoid(i) # Input gate
f = torch.sigmoid(f) # Forget gate
g = torch.tanh(g) # Candidate values
o = torch.sigmoid(o) # Output gate
# Update cell state
c_new = f * c + i * g
# Compute hidden state
h_new = o * torch.tanh(c_new)
return h_new, c_new
# Test our LSTM cell
cell = LSTMCell(input_size=32, hidden_size=64)
x = torch.randn(8, 32) # Batch of 8, input size 32
h, c = cell(x)
print(f"Input: {x.shape}")
print(f"Hidden state: {h.shape}")
print(f"Cell state: {c.shape}")
Complete LSTM Layer
Copy
class LSTM(nn.Module):
"""
Complete LSTM layer that processes sequences.
"""
def __init__(self, input_size, hidden_size, num_layers=1,
batch_first=True, dropout=0.0, bidirectional=False):
super().__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.batch_first = batch_first
self.bidirectional = bidirectional
self.num_directions = 2 if bidirectional else 1
# Stack of LSTM cells
self.cells = nn.ModuleList()
for layer in range(num_layers):
for direction in range(self.num_directions):
layer_input_size = input_size if layer == 0 else hidden_size * self.num_directions
self.cells.append(LSTMCell(layer_input_size, hidden_size))
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
def forward(self, x, hidden=None):
"""
Process sequence through LSTM.
Args:
x: (batch, seq_len, input) if batch_first else (seq_len, batch, input)
hidden: Tuple of (h_0, c_0), each (num_layers * directions, batch, hidden)
Returns:
output: Hidden states at each time step
(h_n, c_n): Final hidden and cell states
"""
if self.batch_first:
batch_size, seq_len, _ = x.size()
else:
seq_len, batch_size, _ = x.size()
x = x.transpose(0, 1)
# Initialize hidden states
if hidden is None:
h = [torch.zeros(batch_size, self.hidden_size, device=x.device)
for _ in range(self.num_layers * self.num_directions)]
c = [torch.zeros(batch_size, self.hidden_size, device=x.device)
for _ in range(self.num_layers * self.num_directions)]
else:
h = [hidden[0][i] for i in range(self.num_layers * self.num_directions)]
c = [hidden[1][i] for i in range(self.num_layers * self.num_directions)]
# Process each layer
layer_input = x
for layer in range(self.num_layers):
outputs_forward = []
outputs_backward = []
# Forward direction
cell_idx = layer * self.num_directions
h_forward, c_forward = h[cell_idx], c[cell_idx]
for t in range(seq_len):
h_forward, c_forward = self.cells[cell_idx](
layer_input[:, t, :], (h_forward, c_forward)
)
outputs_forward.append(h_forward)
h[cell_idx], c[cell_idx] = h_forward, c_forward
# Backward direction (if bidirectional)
if self.bidirectional:
cell_idx = layer * self.num_directions + 1
h_backward, c_backward = h[cell_idx], c[cell_idx]
for t in range(seq_len - 1, -1, -1):
h_backward, c_backward = self.cells[cell_idx](
layer_input[:, t, :], (h_backward, c_backward)
)
outputs_backward.insert(0, h_backward)
h[cell_idx], c[cell_idx] = h_backward, c_backward
# Combine outputs
outputs_forward = torch.stack(outputs_forward, dim=1)
if self.bidirectional:
outputs_backward = torch.stack(outputs_backward, dim=1)
layer_output = torch.cat([outputs_forward, outputs_backward], dim=-1)
else:
layer_output = outputs_forward
# Apply dropout between layers (not on last layer)
if self.dropout and layer < self.num_layers - 1:
layer_output = self.dropout(layer_output)
layer_input = layer_output
# Stack hidden states
h_n = torch.stack(h, dim=0)
c_n = torch.stack(c, dim=0)
if not self.batch_first:
layer_output = layer_output.transpose(0, 1)
return layer_output, (h_n, c_n)
# Compare with PyTorch implementation
our_lstm = LSTM(input_size=32, hidden_size=64, num_layers=2, batch_first=True)
pytorch_lstm = nn.LSTM(input_size=32, hidden_size=64, num_layers=2, batch_first=True)
x = torch.randn(8, 20, 32) # batch=8, seq_len=20, input=32
our_output, (our_h, our_c) = our_lstm(x)
pt_output, (pt_h, pt_c) = pytorch_lstm(x)
print(f"Our output shape: {our_output.shape}")
print(f"PyTorch output shape: {pt_output.shape}")
print(f"Our hidden shape: {our_h.shape}")
print(f"PyTorch hidden shape: {pt_h.shape}")
Understanding the Gates
The Forget Gate: Learning What to Ignore
Copy
def visualize_forget_gate():
"""
The forget gate decides what information from the cell state
should be thrown away.
f_t = σ(W_f · [h_{t-1}, x_t] + b_f)
- f_t ≈ 1: Keep this information
- f_t ≈ 0: Forget this information
"""
# Example: Language model processing "John lives in Paris. He speaks ___"
# When we see "He", the forget gate might:
# - Keep: Information that subject is male ("John")
# - Forget: Specific details about Paris that aren't relevant
print("Forget Gate Example:")
print("-" * 50)
print("Context: 'John lives in Paris. He speaks ___'")
print()
print("Cell state before 'He':")
print(" [subject=John, location=Paris, job=unknown, ...]")
print()
print("Forget gate values (learned):")
print(" [subject: 0.95, location: 0.3, job: 0.1, ...]")
print()
print("Cell state after forget gate:")
print(" [subject=John(kept), location=faded, job=forgotten]")
visualize_forget_gate()
def forget_gate_experiment():
"""
Demonstrate how the forget gate responds to different inputs.
"""
lstm = nn.LSTM(10, 20, batch_first=True)
# Create two sequences
# Sequence 1: Normal input
x1 = torch.randn(1, 50, 10)
# Sequence 2: Input with a "reset signal" at position 25
x2 = x1.clone()
x2[0, 25, :] = 10.0 # Strong signal
# Get intermediate states using hooks
forget_gates = []
def hook(module, input, output):
# LSTM internal states
pass
# Process and compare cell states
_, (_, c1) = lstm(x1)
_, (_, c2) = lstm(x2)
print("Cell state comparison:")
print(f"Normal sequence - cell state norm: {c1.norm().item():.4f}")
print(f"With reset signal - cell state norm: {c2.norm().item():.4f}")
forget_gate_experiment()
The Input Gate: Learning What to Remember
Copy
def visualize_input_gate():
"""
The input gate decides what new information to store.
i_t = σ(W_i · [h_{t-1}, x_t] + b_i) # What to update
g_t = tanh(W_g · [h_{t-1}, x_t] + b_g) # Candidate values
New information = i_t * g_t
"""
print("Input Gate Example:")
print("-" * 50)
print("Processing: 'Marie Curie won the Nobel Prize in ___ and ___'")
print()
print("When we see 'Nobel Prize':")
print(" Candidate values (tanh): [award_type: +0.9, person: -0.1, ...]")
print(" Input gate (sigmoid): [award_type: 0.95, person: 0.2, ...]")
print()
print("New cell state += [strong award signal, weak person signal]")
visualize_input_gate()
The Output Gate: Learning What to Reveal
Copy
def visualize_output_gate():
"""
The output gate decides what parts of the cell state to output.
o_t = σ(W_o · [h_{t-1}, x_t] + b_o)
h_t = o_t * tanh(C_t)
The cell state might store information that isn't immediately
relevant but will be needed later.
"""
print("Output Gate Example:")
print("-" * 50)
print("Cell state: [subject=John, visited_places=[Paris,London], mood=happy]")
print()
print("Current word: 'traveled' → Need location info")
print(" Output gate: [subject: 0.3, places: 0.9, mood: 0.1]")
print(" Hidden output emphasizes: places")
print()
print("Current word: 'smiled' → Need mood info")
print(" Output gate: [subject: 0.2, places: 0.1, mood: 0.95]")
print(" Hidden output emphasizes: mood")
visualize_output_gate()
Why LSTM Solves Vanishing Gradients
The Gradient Highway
Copy
def explain_gradient_highway():
"""
The key to LSTM's success: The cell state update
C_t = f_t * C_{t-1} + i_t * g_t
The gradient of C_t w.r.t. C_{t-1} is simply f_t (forget gate).
If f_t ≈ 1, gradients flow through unchanged!
This is like a "gradient highway" - information and gradients
can travel long distances without being squashed.
"""
print("Gradient Flow Comparison:")
print("-" * 60)
print()
print("Vanilla RNN gradient through T steps:")
print(" ∂h_T/∂h_1 = ∏(W_hh * diag(tanh')) → vanishes!")
print()
print("LSTM gradient through T steps:")
print(" ∂C_T/∂C_1 = ∏(f_t) → can stay close to 1!")
print()
print("Key insight: Forget gate learns to keep gradients alive")
print("for important long-range dependencies.")
explain_gradient_highway()
def compare_gradient_flow():
"""Compare gradient flow in RNN vs LSTM."""
seq_len = 100
hidden_size = 64
# Vanilla RNN
rnn = nn.RNN(32, hidden_size, batch_first=True)
# LSTM
lstm = nn.LSTM(32, hidden_size, batch_first=True)
x = torch.randn(1, seq_len, 32, requires_grad=True)
# RNN gradient flow
rnn_out, _ = rnn(x)
rnn_loss = rnn_out[:, -1, :].sum()
rnn_loss.backward()
rnn_grad = x.grad[:, 0, :].norm().item()
x.grad.zero_()
# LSTM gradient flow
lstm_out, _ = lstm(x)
lstm_loss = lstm_out[:, -1, :].sum()
lstm_loss.backward()
lstm_grad = x.grad[:, 0, :].norm().item()
print(f"\nGradient at first time step (sequence length {seq_len}):")
print(f" RNN: {rnn_grad:.6f}")
print(f" LSTM: {lstm_grad:.6f}")
print(f" Ratio (LSTM/RNN): {lstm_grad/rnn_grad:.2f}x")
compare_gradient_flow()
Gated Recurrent Unit (GRU)
A Simpler Alternative
GRU simplifies LSTM by:- Combining forget and input gates into an “update gate”
- Merging cell state and hidden state
GRU Equations
ztrth~tht=σ(Wz⋅[ht−1,xt])=σ(Wr⋅[ht−1,xt])=tanh(Wh⋅[rt⊙ht−1,xt])=(1−zt)⊙ht−1+zt⊙h~t(Update gate)(Reset gate)(Candidate)(Final state)Copy
class GRUCell(nn.Module):
"""
GRU Cell implemented from scratch.
Compared to LSTM:
- 2 gates instead of 3
- No separate cell state
- Fewer parameters
- Often performs similarly
"""
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
# Update and reset gates
self.W_z = nn.Linear(input_size + hidden_size, hidden_size)
self.W_r = nn.Linear(input_size + hidden_size, hidden_size)
# Candidate hidden state
self.W_h = nn.Linear(input_size + hidden_size, hidden_size)
def forward(self, x, h=None):
"""
Args:
x: Input (batch, input_size)
h: Hidden state (batch, hidden_size)
Returns:
h_new: Updated hidden state
"""
batch_size = x.size(0)
if h is None:
h = torch.zeros(batch_size, self.hidden_size, device=x.device)
# Concatenate input and hidden
combined = torch.cat([x, h], dim=1)
# Update gate: how much of new state to use
z = torch.sigmoid(self.W_z(combined))
# Reset gate: how much of old state to forget when computing candidate
r = torch.sigmoid(self.W_r(combined))
# Candidate hidden state
combined_reset = torch.cat([x, r * h], dim=1)
h_candidate = torch.tanh(self.W_h(combined_reset))
# Final hidden state: interpolate between old and candidate
h_new = (1 - z) * h + z * h_candidate
return h_new
class GRU(nn.Module):
"""Complete GRU layer."""
def __init__(self, input_size, hidden_size, num_layers=1, batch_first=True):
super().__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.batch_first = batch_first
self.cells = nn.ModuleList([
GRUCell(input_size if i == 0 else hidden_size, hidden_size)
for i in range(num_layers)
])
def forward(self, x, h_0=None):
if self.batch_first:
batch_size, seq_len, _ = x.size()
else:
seq_len, batch_size, _ = x.size()
x = x.transpose(0, 1)
# Initialize hidden states
if h_0 is None:
h = [torch.zeros(batch_size, self.hidden_size, device=x.device)
for _ in range(self.num_layers)]
else:
h = [h_0[i] for i in range(self.num_layers)]
outputs = []
for t in range(seq_len):
x_t = x[:, t, :]
for layer, cell in enumerate(self.cells):
h[layer] = cell(x_t, h[layer])
x_t = h[layer]
outputs.append(h[-1])
outputs = torch.stack(outputs, dim=1)
h_n = torch.stack(h, dim=0)
if not self.batch_first:
outputs = outputs.transpose(0, 1)
return outputs, h_n
# Test GRU
gru = GRU(input_size=32, hidden_size=64, num_layers=2, batch_first=True)
x = torch.randn(8, 20, 32)
output, h_n = gru(x)
print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print(f"Hidden: {h_n.shape}")
LSTM vs GRU Comparison
| Aspect | LSTM | GRU |
|---|---|---|
| Gates | 3 (forget, input, output) | 2 (update, reset) |
| States | 2 (hidden + cell) | 1 (hidden only) |
| Parameters | 4 × hidden² | 3 × hidden² |
| Training | Slightly slower | Faster |
| Performance | Often slightly better on complex tasks | Often comparable |
| When to use | Long sequences, complex dependencies | Faster training needed, simpler tasks |
Copy
def compare_lstm_gru():
"""Compare LSTM and GRU architectures."""
input_size = 128
hidden_size = 256
lstm = nn.LSTM(input_size, hidden_size, num_layers=2, batch_first=True)
gru = nn.GRU(input_size, hidden_size, num_layers=2, batch_first=True)
lstm_params = sum(p.numel() for p in lstm.parameters())
gru_params = sum(p.numel() for p in gru.parameters())
print("Parameter Comparison:")
print(f" LSTM: {lstm_params:,} parameters")
print(f" GRU: {gru_params:,} parameters")
print(f" GRU is {100*(1 - gru_params/lstm_params):.1f}% smaller")
# Speed comparison
import time
x = torch.randn(32, 100, input_size)
# Warm up
_ = lstm(x)
_ = gru(x)
# Time LSTM
start = time.time()
for _ in range(100):
_ = lstm(x)
lstm_time = time.time() - start
# Time GRU
start = time.time()
for _ in range(100):
_ = gru(x)
gru_time = time.time() - start
print(f"\nSpeed Comparison (100 forward passes):")
print(f" LSTM: {lstm_time:.3f}s")
print(f" GRU: {gru_time:.3f}s")
print(f" GRU is {100*(1 - gru_time/lstm_time):.1f}% faster")
compare_lstm_gru()
Practical Applications
Sentiment Analysis with LSTM
Copy
class SentimentLSTM(nn.Module):
"""
Sentiment analysis using bidirectional LSTM.
"""
def __init__(self, vocab_size, embed_dim, hidden_dim,
num_layers, num_classes, dropout=0.5):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(
embed_dim,
hidden_dim,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
dropout=dropout if num_layers > 1 else 0
)
self.dropout = nn.Dropout(dropout)
# Bidirectional → hidden_dim * 2
self.fc = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, x, lengths=None):
"""
Args:
x: Token indices (batch, seq_len)
lengths: Actual sequence lengths (for packing)
"""
# Embed tokens
embedded = self.dropout(self.embedding(x))
# Pack sequences if lengths provided
if lengths is not None:
packed = nn.utils.rnn.pack_padded_sequence(
embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
)
output, (hidden, cell) = self.lstm(packed)
else:
output, (hidden, cell) = self.lstm(embedded)
# Concatenate final forward and backward hidden states
# hidden shape: (num_layers * 2, batch, hidden_dim)
hidden_forward = hidden[-2, :, :] # Last layer, forward
hidden_backward = hidden[-1, :, :] # Last layer, backward
hidden_cat = torch.cat([hidden_forward, hidden_backward], dim=1)
# Classify
output = self.fc(self.dropout(hidden_cat))
return output
# Example usage
model = SentimentLSTM(
vocab_size=10000,
embed_dim=128,
hidden_dim=256,
num_layers=2,
num_classes=2
)
# Simulated batch
batch = torch.randint(1, 10000, (16, 100)) # 16 sequences, length 100
predictions = model(batch)
print(f"Input: {batch.shape}")
print(f"Predictions: {predictions.shape}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
Language Modeling with LSTM
Copy
class LanguageModelLSTM(nn.Module):
"""
Language model: Predict next word given previous words.
"""
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers,
dropout=0.5, tie_weights=True):
super().__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
embed_dim,
hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0
)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden_dim, vocab_size)
# Weight tying: share embedding and output weights
if tie_weights and embed_dim == hidden_dim:
self.fc.weight = self.embedding.weight
def forward(self, x, hidden=None):
"""
Args:
x: Token indices (batch, seq_len)
hidden: (h, c) tuple for stateful processing
Returns:
logits: (batch, seq_len, vocab_size)
hidden: Updated hidden state
"""
embedded = self.dropout(self.embedding(x))
output, hidden = self.lstm(embedded, hidden)
output = self.dropout(output)
logits = self.fc(output)
return logits, hidden
def init_hidden(self, batch_size, device='cpu'):
"""Initialize hidden state."""
h = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)
c = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)
return (h, c)
def generate(self, start_tokens, max_length=100, temperature=1.0):
"""Generate text continuation."""
self.eval()
batch_size = start_tokens.size(0)
device = start_tokens.device
hidden = self.init_hidden(batch_size, device)
generated = start_tokens.tolist()
current = start_tokens
with torch.no_grad():
for _ in range(max_length):
logits, hidden = self.forward(current, hidden)
# Get last token prediction
logits = logits[:, -1, :] / temperature
probs = torch.softmax(logits, dim=-1)
# Sample
next_token = torch.multinomial(probs, 1)
for i in range(batch_size):
generated[i].append(next_token[i].item())
current = next_token
return generated
# Create model
lm = LanguageModelLSTM(
vocab_size=50000,
embed_dim=512,
hidden_dim=512,
num_layers=3,
dropout=0.3
)
# Training example
x = torch.randint(0, 50000, (8, 35)) # batch=8, seq_len=35
y = torch.randint(0, 50000, (8, 35)) # targets (shifted by 1)
logits, _ = lm(x)
loss = nn.CrossEntropyLoss()(logits.view(-1, 50000), y.view(-1))
print(f"Input: {x.shape}")
print(f"Output: {logits.shape}")
print(f"Loss: {loss.item():.4f}")
print(f"Perplexity: {torch.exp(loss).item():.2f}")
Sequence-to-Sequence Translation
Copy
class Encoder(nn.Module):
"""LSTM Encoder for seq2seq."""
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
embed_dim, hidden_dim, num_layers,
batch_first=True, dropout=dropout, bidirectional=True
)
self.dropout = nn.Dropout(dropout)
# Project bidirectional to unidirectional for decoder
self.fc_h = nn.Linear(hidden_dim * 2, hidden_dim)
self.fc_c = nn.Linear(hidden_dim * 2, hidden_dim)
def forward(self, src):
embedded = self.dropout(self.embedding(src))
outputs, (hidden, cell) = self.lstm(embedded)
# Combine bidirectional states
# hidden: (num_layers * 2, batch, hidden_dim)
hidden = hidden.view(self.lstm.num_layers, 2, -1, self.lstm.hidden_size)
hidden = torch.cat([hidden[:, 0], hidden[:, 1]], dim=-1)
hidden = torch.tanh(self.fc_h(hidden))
cell = cell.view(self.lstm.num_layers, 2, -1, self.lstm.hidden_size)
cell = torch.cat([cell[:, 0], cell[:, 1]], dim=-1)
cell = torch.tanh(self.fc_c(cell))
return outputs, (hidden, cell)
class Decoder(nn.Module):
"""LSTM Decoder for seq2seq."""
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout):
super().__init__()
self.vocab_size = vocab_size
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
embed_dim, hidden_dim, num_layers,
batch_first=True, dropout=dropout
)
self.fc = nn.Linear(hidden_dim, vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, tgt, hidden):
embedded = self.dropout(self.embedding(tgt))
output, hidden = self.lstm(embedded, hidden)
prediction = self.fc(output)
return prediction, hidden
class Seq2Seq(nn.Module):
"""Complete sequence-to-sequence model."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, src, tgt, teacher_forcing_ratio=0.5):
"""
Args:
src: Source sequence (batch, src_len)
tgt: Target sequence (batch, tgt_len)
teacher_forcing_ratio: Probability of using ground truth
"""
batch_size = src.size(0)
tgt_len = tgt.size(1)
tgt_vocab_size = self.decoder.vocab_size
# Store outputs
outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size, device=src.device)
# Encode source
encoder_outputs, hidden = self.encoder(src)
# First decoder input is <sos> token
decoder_input = tgt[:, 0:1]
for t in range(1, tgt_len):
output, hidden = self.decoder(decoder_input, hidden)
outputs[:, t] = output.squeeze(1)
# Teacher forcing
teacher_force = torch.rand(1).item() < teacher_forcing_ratio
if teacher_force:
decoder_input = tgt[:, t:t+1]
else:
decoder_input = output.argmax(-1)
return outputs
# Create seq2seq model
encoder = Encoder(vocab_size=10000, embed_dim=256, hidden_dim=512,
num_layers=2, dropout=0.3)
decoder = Decoder(vocab_size=8000, embed_dim=256, hidden_dim=512,
num_layers=2, dropout=0.3)
model = Seq2Seq(encoder, decoder)
# Example
src = torch.randint(0, 10000, (16, 30)) # Source: 16 sentences, max 30 tokens
tgt = torch.randint(0, 8000, (16, 25)) # Target: 16 sentences, max 25 tokens
output = model(src, tgt)
print(f"Source: {src.shape}")
print(f"Target: {tgt.shape}")
print(f"Output: {output.shape}")
Advanced LSTM Variants
Peephole Connections
Copy
class PeepholeLSTMCell(nn.Module):
"""
LSTM with peephole connections.
Gates can directly look at the cell state, not just hidden state.
This can help with precise timing and counting.
"""
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
# Standard weights
self.W_i = nn.Linear(input_size + hidden_size, hidden_size)
self.W_f = nn.Linear(input_size + hidden_size, hidden_size)
self.W_o = nn.Linear(input_size + hidden_size, hidden_size)
self.W_c = nn.Linear(input_size + hidden_size, hidden_size)
# Peephole weights (diagonal, so just vectors)
self.p_i = nn.Parameter(torch.randn(hidden_size))
self.p_f = nn.Parameter(torch.randn(hidden_size))
self.p_o = nn.Parameter(torch.randn(hidden_size))
def forward(self, x, hidden=None):
batch_size = x.size(0)
if hidden is None:
h = torch.zeros(batch_size, self.hidden_size, device=x.device)
c = torch.zeros(batch_size, self.hidden_size, device=x.device)
else:
h, c = hidden
combined = torch.cat([x, h], dim=1)
# Gates with peephole connections to cell state
f = torch.sigmoid(self.W_f(combined) + self.p_f * c)
i = torch.sigmoid(self.W_i(combined) + self.p_i * c)
c_candidate = torch.tanh(self.W_c(combined))
c_new = f * c + i * c_candidate
o = torch.sigmoid(self.W_o(combined) + self.p_o * c_new)
h_new = o * torch.tanh(c_new)
return h_new, c_new
Layer Normalization in LSTM
Copy
class LayerNormLSTMCell(nn.Module):
"""
LSTM with layer normalization for better training stability.
"""
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
self.W_ih = nn.Linear(input_size, 4 * hidden_size, bias=False)
self.W_hh = nn.Linear(hidden_size, 4 * hidden_size, bias=False)
# Layer normalization
self.ln_ih = nn.LayerNorm(4 * hidden_size)
self.ln_hh = nn.LayerNorm(4 * hidden_size)
self.ln_cell = nn.LayerNorm(hidden_size)
def forward(self, x, hidden=None):
batch_size = x.size(0)
if hidden is None:
h = torch.zeros(batch_size, self.hidden_size, device=x.device)
c = torch.zeros(batch_size, self.hidden_size, device=x.device)
else:
h, c = hidden
# Apply layer norm to gate inputs
gates = self.ln_ih(self.W_ih(x)) + self.ln_hh(self.W_hh(h))
i, f, g, o = gates.chunk(4, dim=1)
i = torch.sigmoid(i)
f = torch.sigmoid(f)
g = torch.tanh(g)
o = torch.sigmoid(o)
c_new = f * c + i * g
# Layer norm on cell state before output
h_new = o * torch.tanh(self.ln_cell(c_new))
return h_new, c_new
Best Practices and Tips
Training LSTM/GRU Models:
- Gradient clipping: Always use
torch.nn.utils.clip_grad_norm_with max_norm=1.0-5.0 - Learning rate: Start with 0.001 for Adam, 0.1-1.0 for SGD
- Dropout: Use 0.2-0.5 between layers, not within cells
- Initialization: PyTorch defaults are usually fine
- Bidirectional: Use for tasks where you have full sequence (classification, tagging)
Common Mistakes:
-
Not detaching hidden state: For stateful training, detach hidden state between batches:
Copy
hidden = (hidden[0].detach(), hidden[1].detach()) -
Ignoring sequence lengths: Use
pack_padded_sequencefor variable-length sequences - Wrong hidden state indexing: For bidirectional, hidden[-2:] contains last layer
- Too many layers: 2-3 layers usually sufficient; more can hurt
Exercises
Exercise 1: Implement LSTM from Scratch
Exercise 1: Implement LSTM from Scratch
Implement a complete LSTM without using
nn.LSTM:- Implement
LSTMCellwith all gates - Stack cells into
LSTMlayer - Add bidirectional support
- Verify outputs match
nn.LSTM
Exercise 2: Gradient Flow Analysis
Exercise 2: Gradient Flow Analysis
Create a visualization of gradient flow through LSTM:
- Process sequences of length 10, 50, 100, 200
- Track gradient magnitude at each time step
- Compare with vanilla RNN
- Plot the results
Exercise 3: Named Entity Recognition
Exercise 3: Named Entity Recognition
Build an NER tagger using BiLSTM:
- Load CoNLL-2003 dataset
- Implement BiLSTM-CRF model
- Train with proper evaluation (F1 score)
- Analyze errors by entity type
Exercise 4: Music Generation
Exercise 4: Music Generation
Train an LSTM to generate music:
- Download MIDI files and convert to sequences
- Train character-level LSTM on ABC notation
- Generate new melodies
- Convert back to MIDI and listen
Exercise 5: Time Series Forecasting
Exercise 5: Time Series Forecasting
Build a multivariate time series forecaster:
- Use a dataset like air quality or stock prices
- Implement encoder-decoder with LSTM
- Add attention (preview of next chapter!)
- Compare with simple baselines
Key Takeaways
| Concept | Key Insight |
|---|---|
| Cell State | Long-term memory highway - gradients flow unimpeded |
| Forget Gate | Learns what to erase - enables “forgetting” of irrelevant info |
| Input Gate | Learns what to remember - filters new information |
| Output Gate | Learns what to reveal - controls hidden state exposure |
| GRU | Simpler, fewer parameters, often similar performance |
| Gradient Highway | Cell state update is additive → gradients don’t vanish |