CNNs transformed computer vision. But they have a fundamental limitation: they assume fixed-size inputs.What about:
Text: “The cat sat on the mat” (6 words)
Time series: Stock prices over months (varying length)
Audio: Speech of different durations
Video: Frames over time
These are sequences - data where order matters and length varies.
The Core Insight: Sequences have temporal dependencies. The word “sat” depends on knowing “cat” came before it. Today’s stock price depends on yesterday’s. We need networks that can remember.The fundamental difference between a CNN and an RNN: a CNN asks “what is here?”, while an RNN asks “what just happened, and what does that mean for what comes next?” CNNs are spatial; RNNs are temporal.
Think of a feedforward network as someone with amnesia reading a book word by word. Each word is processed in isolation — when they reach “mat,” they have no memory of “cat” or “sat.” An RNN is like a normal reader: they carry a running mental summary of everything they have read so far. Each new word updates that summary. The summary is lossy — you cannot perfectly reconstruct every previous word from it — but it captures the gist that matters for understanding the next word. This running summary is the hidden state.
import torchimport torch.nn as nnimport numpy as npimport matplotlib.pyplot as plt# A feedforward network processes each input independentlyclass FeedforwardClassifier(nn.Module): def __init__(self, input_size, hidden_size, num_classes): super().__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, num_classes) def forward(self, x): # Each input processed independently # No memory of previous inputs! return self.fc2(torch.relu(self.fc1(x)))# Problem: Processing words one at a time loses contextwords = ["The", "cat", "sat", "on", "the", "mat"]# When processing "mat", the network has no memory of "cat"!
An RNN maintains a hidden state that carries information across time steps:ht=f(ht−1,xt)Think of the hidden state as a person’s “running mental summary” while reading a book. After each sentence (xt), the reader updates their understanding (ht) based on what they just read and what they already knew (ht−1). They cannot go back and re-read (that would be an attention mechanism), so their current understanding must compress everything important from the entire story so far into a fixed-size mental state.
class SimpleRNNCell(nn.Module): """ A single RNN cell - the building block of RNNs. At each time step: 1. Combine current input with previous hidden state 2. Apply non-linearity 3. Output new hidden state The same weights are reused at every time step -- this is "weight sharing through time," analogous to how CNNs share weights across spatial positions. """ def __init__(self, input_size, hidden_size): super().__init__() self.hidden_size = hidden_size # Weights for input: transform the current observation self.W_xh = nn.Linear(input_size, hidden_size, bias=False) # Weights for hidden state: transform the memory of the past self.W_hh = nn.Linear(hidden_size, hidden_size, bias=True) def forward(self, x, h_prev): """ Args: x: Current input (batch, input_size) -- what we see now h_prev: Previous hidden state (batch, hidden_size) -- what we remember Returns: h_new: Updated hidden state -- our new understanding """ # New memory = tanh(transform(what we see) + transform(what we remember)) # tanh squashes to [-1, 1], preventing the hidden state from growing unbounded h_new = torch.tanh(self.W_xh(x) + self.W_hh(h_prev)) return h_new# Demonstrate the recurrent connectioncell = SimpleRNNCell(input_size=10, hidden_size=20)# Process a sequence of length 5sequence = torch.randn(3, 5, 10) # (batch=3, seq_len=5, input=10)batch_size = sequence.size(0)seq_len = sequence.size(1)# Initialize hidden stateh = torch.zeros(batch_size, 20)# Process each time stephidden_states = []for t in range(seq_len): x_t = sequence[:, t, :] # Get input at time t h = cell(x_t, h) # Update hidden state hidden_states.append(h) print(f"Time {t}: h shape = {h.shape}")print(f"\nFinal hidden state captures the entire sequence!")
At each time step t:ht=tanh(Wxhxt+Whhht−1+bh)yt=Whyht+byWhere:
xt∈Rd is the input at time t
ht∈Rh is the hidden state
yt∈Ro is the output
Wxh∈Rh×d transforms input to hidden
Whh∈Rh×h transforms previous hidden to current
Why∈Ro×h transforms hidden to output
class VanillaRNN(nn.Module): """ Vanilla (Elman) RNN implemented from scratch. This is the simplest form of RNN, directly implementing the recurrence relation with tanh non-linearity. """ def __init__(self, input_size, hidden_size, output_size, num_layers=1): super().__init__() self.hidden_size = hidden_size self.num_layers = num_layers # Input to hidden weights self.W_xh = nn.ModuleList([ nn.Linear(input_size if i == 0 else hidden_size, hidden_size) for i in range(num_layers) ]) # Hidden to hidden weights self.W_hh = nn.ModuleList([ nn.Linear(hidden_size, hidden_size) for _ in range(num_layers) ]) # Hidden to output self.W_hy = nn.Linear(hidden_size, output_size) def forward(self, x, h_0=None): """ Process entire sequence. Args: x: Input sequence (batch, seq_len, input_size) h_0: Initial hidden state (num_layers, batch, hidden_size) Returns: outputs: Output at each time step (batch, seq_len, output_size) h_n: Final hidden state (num_layers, batch, hidden_size) """ batch_size, seq_len, _ = x.size() # Initialize hidden states if h_0 is None: h = [torch.zeros(batch_size, self.hidden_size, device=x.device) for _ in range(self.num_layers)] else: h = [h_0[i] for i in range(self.num_layers)] outputs = [] for t in range(seq_len): x_t = x[:, t, :] # Process through each layer for layer in range(self.num_layers): h[layer] = torch.tanh( self.W_xh[layer](x_t) + self.W_hh[layer](h[layer]) ) x_t = h[layer] # Output of this layer is input to next # Compute output y_t = self.W_hy(h[-1]) outputs.append(y_t) outputs = torch.stack(outputs, dim=1) h_n = torch.stack(h, dim=0) return outputs, h_n# Test our implementationrnn = VanillaRNN(input_size=10, hidden_size=32, output_size=5, num_layers=2)x = torch.randn(4, 20, 10) # batch=4, seq_len=20, input=10outputs, h_n = rnn(x)print(f"Input shape: {x.shape}")print(f"Output shape: {outputs.shape}") # (4, 20, 5)print(f"Final hidden shape: {h_n.shape}") # (2, 4, 32)
The gradient through time involves products of the recurrent weight matrix:∂h1∂hT=t=2∏T∂ht−1∂ht=t=2∏TWhhT⋅diag(tanh′(ht−1))Mathematical intuition: This product is the crux of the problem. The tanh derivative peaks at 0.25 (since tanh′(z)=1−tanh2(z) and its maximum is 1 at z=0, but typical values during training are 0.1-0.25). Multiply that by Whh at each step. If the largest singular value of Whh⋅diag(tanh′)<1 (which is almost always the case), the product shrinks exponentially: 0.2550≈10−30. That is effectively zero — the network cannot learn from information 50 steps ago because the gradient signal has been annihilated.Conversely, if the singular values exceed 1, the product explodes exponentially. This is why vanilla RNNs are caught in a double bind: they need large weights to maintain gradients, but large weights cause explosions. The only stable regime is the knife-edge where singular values equal exactly 1 — which is unrealistic to maintain during training.
def analyze_gradient_flow(): """ Analyze why gradients vanish or explode in RNNs. """ hidden_size = 100 seq_length = 100 # Different initialization scales scales = [0.5, 1.0, 1.5] fig, axes = plt.subplots(1, 3, figsize=(15, 4)) for ax, scale in zip(axes, scales): # Initialize W_hh with different scales W = torch.randn(hidden_size, hidden_size) * scale / np.sqrt(hidden_size) # Simulate gradient flow (simplified - ignoring tanh derivative) gradient = torch.eye(hidden_size) gradient_norms = [gradient.norm().item()] for t in range(seq_length): gradient = gradient @ W gradient_norms.append(gradient.norm().item()) ax.semilogy(gradient_norms) ax.set_xlabel('Time steps back') ax.set_ylabel('Gradient norm (log scale)') ax.set_title(f'Scale = {scale}') ax.axhline(y=1, color='r', linestyle='--', alpha=0.5) ax.set_ylim([1e-20, 1e20]) plt.tight_layout() plt.show()analyze_gradient_flow()
Very effective — eventually replaced RNNs entirely
Practical rule of thumb: If your sequence length is under 20 tokens, a vanilla RNN might work. For 20-200 tokens, use LSTM or GRU. For 200+ tokens, you almost certainly need attention or a transformer. The cutoffs are approximate, but the pattern holds: longer sequences need more sophisticated memory mechanisms.
def gradient_clipping_demo(): """Demonstrate gradient clipping to prevent explosion.""" model = nn.RNN(10, 32, batch_first=True) x = torch.randn(1, 100, 10, requires_grad=True) output, _ = model(x) loss = output.sum() loss.backward() # Check gradient norms before clipping total_norm = 0 for p in model.parameters(): if p.grad is not None: total_norm += p.grad.norm().item() ** 2 total_norm = total_norm ** 0.5 print(f"Gradient norm before clipping: {total_norm:.4f}") # Apply gradient clipping max_norm = 1.0 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) # Check after clipping total_norm_after = 0 for p in model.parameters(): if p.grad is not None: total_norm_after += p.grad.norm().item() ** 2 total_norm_after = total_norm_after ** 0.5 print(f"Gradient norm after clipping: {total_norm_after:.4f}")gradient_clipping_demo()
class CharRNN(nn.Module): """ Character-level language model. Given: "Hello Worl" Predict: "ello World" (next character at each position) """ def __init__(self, vocab_size, embed_size, hidden_size, num_layers): super().__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = nn.RNN(embed_size, hidden_size, num_layers, batch_first=True, dropout=0.2) self.fc = nn.Linear(hidden_size, vocab_size) def forward(self, x, hidden=None): # x: (batch, seq_len) character indices embedded = self.embedding(x) # (batch, seq_len, embed_size) output, hidden = self.rnn(embedded, hidden) # (batch, seq_len, hidden) logits = self.fc(output) # (batch, seq_len, vocab_size) return logits, hidden def generate(self, start_char, char_to_idx, idx_to_char, length=100, temperature=1.0): """Generate text character by character.""" self.eval() # Start with initial character current = torch.tensor([[char_to_idx[start_char]]]) hidden = None generated = [start_char] with torch.no_grad(): for _ in range(length): logits, hidden = self.forward(current, hidden) # Apply temperature probs = torch.softmax(logits[0, -1] / temperature, dim=0) # Sample from distribution next_idx = torch.multinomial(probs, 1).item() next_char = idx_to_char[next_idx] generated.append(next_char) current = torch.tensor([[next_idx]]) return ''.join(generated)def train_char_rnn(): """Train a character-level RNN on sample text.""" # Sample text (in practice, use a large corpus) text = """ To be, or not to be, that is the question: Whether 'tis nobler in the mind to suffer The slings and arrows of outrageous fortune, Or to take arms against a sea of troubles And by opposing end them. """ * 100 # Repeat for more training data # Build vocabulary chars = sorted(set(text)) char_to_idx = {c: i for i, c in enumerate(chars)} idx_to_char = {i: c for c, i in char_to_idx.items()} vocab_size = len(chars) print(f"Vocabulary size: {vocab_size}") print(f"Characters: {''.join(chars)}") # Prepare data seq_length = 50 def create_sequences(text, seq_length): inputs = [] targets = [] for i in range(0, len(text) - seq_length): inputs.append([char_to_idx[c] for c in text[i:i+seq_length]]) targets.append([char_to_idx[c] for c in text[i+1:i+seq_length+1]]) return torch.tensor(inputs), torch.tensor(targets) inputs, targets = create_sequences(text, seq_length) # Create model model = CharRNN(vocab_size, embed_size=64, hidden_size=128, num_layers=2) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.003) # Training loop batch_size = 64 num_epochs = 50 for epoch in range(num_epochs): # Random batch idx = torch.randint(0, len(inputs), (batch_size,)) x_batch = inputs[idx] y_batch = targets[idx] # Forward logits, _ = model(x_batch) loss = criterion(logits.view(-1, vocab_size), y_batch.view(-1)) # Backward optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) optimizer.step() if (epoch + 1) % 10 == 0: print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}") # Generate sample sample = model.generate('T', char_to_idx, idx_to_char, length=100, temperature=0.8) print(f"Sample: {sample[:80]}...") print() return model, char_to_idx, idx_to_char# Uncomment to train# model, c2i, i2c = train_char_rnn()
class MyRNN(nn.Module): def __init__(self, input_size, hidden_size): super().__init__() # Initialize W_xh, W_hh, bias # Implement forward pass with proper state management pass def forward(self, x, h_0=None): # Process sequence step by step # Return outputs and final hidden state pass
Verify it gives similar results to nn.RNN.
Exercise 2: Adding Problem
Implement the adding problem to test long-range dependencies:
Gradient clipping is not optional for RNNs. Unlike feedforward networks where gradient explosion is rare, RNNs will explode without clipping on any non-trivial sequence length. Always use torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) as a default. If training is unstable, lower to 1.0. If loss plateaus, you might be clipping too aggressively — try raising to 10.0.Hidden state detachment between batches: When training on long documents split into chunks, you must detach() the hidden state between chunks. Without this, PyTorch tries to backpropagate through the entire document, consuming unbounded memory. Use hidden = hidden.detach() or hidden = tuple(h.detach() for h in hidden) for LSTMs.Pack padded sequences for variable lengths: If your batch contains sequences of different lengths, padding them and ignoring padding in the loss is not enough — the RNN still processes the padding tokens, which corrupts the hidden state. Use nn.utils.rnn.pack_padded_sequence and pad_packed_sequence to make the RNN skip padding entirely. This is one of the most common performance bugs in RNN training.Loss is NaN after a few epochs: Almost always caused by exploding gradients. Check: (1) Are you clipping gradients? (2) Is your learning rate too high? (3) Are there any sequences with extreme values in the input? Start with lr=0.001 for Adam and lr=0.1 for SGD with RNNs.The hidden state initialization trap: Initializing hidden states to zero is standard, but for bidirectional RNNs the backward direction’s “initial” state is the state at the end of the sequence. If your sequences are padded, the backward RNN starts from a padding position. Use packed sequences to avoid this.
Backprop through unrolled network - gradients flow through time
Vanishing Gradient
Long sequences lead to gradients shrinking and early inputs being forgotten
Exploding Gradient
Gradients grow and training becomes unstable so use clipping
Bidirectional
Context from both past and future
Deep RNNs
Stack layers for hierarchical representations
Vanilla RNNs are rarely used in practice! The vanishing gradient problem makes them unable to learn long-range dependencies. In the next chapter, we’ll learn about LSTMs and GRUs - architectures specifically designed to solve this problem.
Why do vanilla RNNs fail on long sequences? Explain the vanishing gradient problem in the temporal context.
Strong Answer:
In an RNN, the hidden state update is ht=tanh(Whhht−1+Wxhxt+b). During BPTT, the gradient of loss with respect to early hidden states requires multiplying the Jacobian ∂ht/∂ht−1 across all time steps: ∏t=1T∂ht/∂ht−1.
Each Jacobian includes WhhT⋅diag(tanh′(zt)). The tanh derivative is at most 1 and typically 0.1-0.5. The product of T such terms approaches zero exponentially. After 50 time steps, gradients can be 10−10 or smaller.
The consequence is selective amnesia: the RNN can learn short-range patterns (2-5 steps) but cannot learn dependencies from 20+ steps ago. This is fundamentally different from the vanishing gradient in feedforward networks because the SAME weight matrix Whh is applied at every step, making the problem a function of the spectral radius of Whh.
If the largest singular value of Whh is less than 1, gradients vanish. If greater than 1, they explode. There is no stable middle ground for vanilla RNNs, which is why gated architectures (LSTM, GRU) were necessary.
Follow-up: Gradient clipping handles exploding gradients but not vanishing gradients. Why the asymmetry?Gradient clipping rescales the gradient norm when it exceeds a threshold, preventing catastrophically large updates. But for vanishing gradients, the gradients are not wrong in direction — they are just too small. You cannot amplify them because the gradient direction itself becomes unreliable when the signal-to-noise ratio is near zero. The solution must be architectural (LSTM gates, skip connections) rather than optimization-level (clipping). Clipping handles the exploding case; gating handles the vanishing case. They are complementary.
Explain BPTT and truncated BPTT. What are the trade-offs of truncation length?
Strong Answer:
BPTT unrolls the RNN through time and applies standard backpropagation to the resulting feedforward graph. An RNN processing length T is equivalent to a T-layer feedforward network with shared weights. Gradients are computed backward through all T steps.
Full BPTT has two problems: memory grows linearly with T (all activations cached), and gradients vanish or explode through hundreds of multiplications.
Truncated BPTT limits backpropagation to a window of k steps. Every k steps, the hidden state is detached from the computation graph. Gradients only flow backward through the most recent k steps.
Trade-off: the model can only LEARN dependencies up to k steps long. A truncation of 35 means a 50-step dependency is invisible to the optimizer. However, the model can still USE long-range information already encoded in the hidden state — it just cannot learn to encode it better. The practical sweet spot is k=35−256 for language modeling, balancing dependency learning against memory and stability.
Follow-up: How do transformers compare to truncated BPTT in terms of effective context?Transformers avoid gradient flow problems entirely — self-attention creates direct connections between any two positions, so the gradient path length is always 1 (through the attention weights). The limitation is computational: O(n2) attention scales poorly for very long contexts. But within the context window, every position has equal gradient access to every other position, unlike RNNs where gradient strength decays with distance. This is the fundamental reason transformers capture long-range dependencies more effectively.
When would you still choose an RNN/LSTM over a transformer today? Are there cases where recurrence is genuinely better?
Strong Answer:
Real-time streaming inference: RNNs process one token at a time with constant memory (O(1) per step), making them ideal for streaming audio transcription, real-time sensor processing, or any setting where you receive data one sample at a time and need immediate outputs. Transformers require buffering the entire context window before processing.
Extremely long sequences with limited compute: for sequences of 100,000+ steps (e.g., long-duration biosignals, multi-day time series), transformers’ O(n2) attention becomes prohibitive. RNNs process these in O(n) time with fixed memory, though they sacrifice long-range dependency quality.
Edge deployment with tight memory constraints: an LSTM cell has fixed-size state regardless of sequence length, making memory consumption predictable and small. A transformer’s KV-cache grows linearly with sequence length.
State Space Models (SSMs) like Mamba represent a modern compromise: they have RNN-like O(n) inference with transformer-like training parallelism, and they match transformer quality on many benchmarks. SSMs are increasingly the right answer for the “when would you use recurrence” question — they inherit the efficiency of RNNs without the gradient flow problems.
Follow-up: Why is Mamba considered a breakthrough for sequential modeling?Mamba (Gu and Dao, 2023) achieves linear-time inference like an RNN, parallel training like a transformer, and competitive quality on language modeling benchmarks. The key innovation is selective state spaces: the state transition matrices are input-dependent (selective), allowing the model to dynamically decide what to remember and forget — analogous to LSTM gating but formulated as a continuous-time system. This bridges the recurrence vs. attention divide by providing the efficiency of recurrence with the expressiveness of attention-like selection.