> ## Documentation Index
> Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt
> Use this file to discover all available pages before exploring further.

# LSTMs & GRUs

> Master gated recurrent architectures - how LSTM and GRU cells solve the vanishing gradient problem with memory gates

<Frame>
  <img src="https://mintcdn.com/devweeekends/0kwJwOL2KCwg2YYu/images/courses/deep-learning-mastery/lstm-gru-concept.svg?fit=max&auto=format&n=0kwJwOL2KCwg2YYu&q=85&s=f27df9d264060bbba3d2cc87535656af" alt="LSTM and GRU Architectures" width="1080" height="1080" data-path="images/courses/deep-learning-mastery/lstm-gru-concept.svg" />
</Frame>

# LSTMs & GRUs: Gated Recurrent Networks

## The Memory Problem Revisited

Vanilla RNNs suffer from a fundamental flaw: they can't remember things for long. The vanishing gradient problem means information from early time steps gets "washed out" as it passes through many layers of tanh activations.

**Real-world consequence**: An RNN reading a book can't remember what happened in Chapter 1 when it reaches Chapter 10.

```python theme={null}
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

def demonstrate_memory_problem():
    """
    The Long-Range Dependency Problem:
    
    "The cat, which had been sleeping on the warm sunny 
    windowsill for the entire lazy afternoon while the 
    rain pattered gently against the glass, finally ___."
    
    To predict "woke up" or "stretched", we need to remember
    "cat" and "sleeping" from 25+ words ago!
    """
    
    print("Long-range dependency example:")
    print("-" * 60)
    print("Sentence: 'The CAT, which had been SLEEPING on the warm...")
    print("         sunny windowsill for the entire lazy afternoon...")
    print("         while the rain pattered against the glass, finally ___'")
    print()
    print("To predict the blank, we need to remember:")
    print("  - Subject: 'cat' (25 words back)")
    print("  - State: 'sleeping' (23 words back)")
    print()
    print("Vanilla RNNs fail because gradients vanish over 25 steps!")

demonstrate_memory_problem()
```

<Note>
  **The Solution**: Instead of trying to force information through a single path, create **multiple pathways** for information flow, some of which can pass information unchanged. This is the key insight behind LSTMs and GRUs.
</Note>

***

## Long Short-Term Memory (LSTM)

### The Big Idea: A Memory Cell with Gates

An LSTM maintains two types of state:

1. **Cell State ($C_t$)**: The "long-term memory" -- a conveyor belt for information that flows through with minimal modification
2. **Hidden State ($h_t$)**: The "working memory" -- what the network is currently thinking about

Three gates control information flow:

1. **Forget Gate**: What to erase from long-term memory ("the subject changed, forget the old topic")
2. **Input Gate**: What new information to write to long-term memory ("this is a new character, store their name")
3. **Output Gate**: What to surface from long-term memory for the current decision ("for predicting the next word, I need the subject, not the setting")

The analogy: imagine a student taking notes during a lecture. The cell state is their notebook. The forget gate is crossing out old notes that are no longer relevant. The input gate is writing new notes. The output gate is deciding which notes to glance at to answer a question. The notebook persists across the entire lecture -- that is the key difference from a vanilla RNN, which is like a student trying to remember everything in their head without writing anything down.

<Frame>
  <img src="https://mintlify.s3.us-west-1.amazonaws.com/devweeekends/images/courses/deep-learning-mastery/lstm-cell.svg" alt="LSTM Cell Diagram" />
</Frame>

### LSTM Equations

$$
\begin{aligned}
f_t &= \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) & \text{(Forget gate)} \\
i_t &= \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) & \text{(Input gate)} \\
\tilde{C}_t &= \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) & \text{(Candidate values)} \\
C_t &= f_t \odot C_{t-1} + i_t \odot \tilde{C}_t & \text{(Cell state update)} \\
o_t &= \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) & \text{(Output gate)} \\
h_t &= o_t \odot \tanh(C_t) & \text{(Hidden state)}
\end{aligned}
$$

Where:

* $\sigma$ is the sigmoid function (outputs 0-1 for gating)
* $\odot$ is element-wise multiplication
* $[h_{t-1}, x_t]$ is concatenation of previous hidden and current input

```python theme={null}
class LSTMCell(nn.Module):
    """
    LSTM Cell implemented from scratch.
    
    The key insight: Cell state C_t flows through with only
    element-wise operations (multiply by forget, add input).
    This creates a "gradient highway" - gradients can flow
    backward through time without vanishing!
    """
    
    def __init__(self, input_size, hidden_size):
        super().__init__()
        
        self.hidden_size = hidden_size
        
        # All gates in one matrix for efficiency
        # Order: input, forget, cell, output (i, f, g, o)
        self.gates = nn.Linear(input_size + hidden_size, 4 * hidden_size)
    
    def forward(self, x, hidden=None):
        """
        Args:
            x: Input (batch, input_size)
            hidden: Tuple of (h, c), each (batch, hidden_size)
        
        Returns:
            h_new: New hidden state
            c_new: New cell state
        """
        batch_size = x.size(0)
        
        # Initialize hidden states if not provided
        if hidden is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
            c = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            h, c = hidden
        
        # Concatenate input and hidden state
        combined = torch.cat([x, h], dim=1)
        
        # Compute all gates at once
        gates = self.gates(combined)
        
        # Split into individual gates
        i, f, g, o = gates.chunk(4, dim=1)
        
        # Apply activations -- sigmoid for gates (0-1 range = "how much to allow through")
        i = torch.sigmoid(i)  # Input gate: what fraction of new info to write
        f = torch.sigmoid(f)  # Forget gate: what fraction of old memory to keep
        g = torch.tanh(g)     # Candidate values: what to potentially write (tanh = [-1,1] range)
        o = torch.sigmoid(o)  # Output gate: what fraction of memory to reveal
        
        # Update cell state: selectively forget old + selectively write new
        # This is the "gradient highway" -- gradients flow through f * c with minimal decay
        c_new = f * c + i * g
        
        # Compute hidden state: selectively read from cell state
        h_new = o * torch.tanh(c_new)  # tanh squashes cell state before gating
        
        return h_new, c_new


# Test our LSTM cell
cell = LSTMCell(input_size=32, hidden_size=64)

x = torch.randn(8, 32)  # Batch of 8, input size 32
h, c = cell(x)

print(f"Input: {x.shape}")
print(f"Hidden state: {h.shape}")
print(f"Cell state: {c.shape}")
```

### Complete LSTM Layer

```python theme={null}
class LSTM(nn.Module):
    """
    Complete LSTM layer that processes sequences.
    """
    
    def __init__(self, input_size, hidden_size, num_layers=1, 
                 batch_first=True, dropout=0.0, bidirectional=False):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1
        
        # Stack of LSTM cells
        self.cells = nn.ModuleList()
        for layer in range(num_layers):
            for direction in range(self.num_directions):
                layer_input_size = input_size if layer == 0 else hidden_size * self.num_directions
                self.cells.append(LSTMCell(layer_input_size, hidden_size))
        
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None
    
    def forward(self, x, hidden=None):
        """
        Process sequence through LSTM.
        
        Args:
            x: (batch, seq_len, input) if batch_first else (seq_len, batch, input)
            hidden: Tuple of (h_0, c_0), each (num_layers * directions, batch, hidden)
        
        Returns:
            output: Hidden states at each time step
            (h_n, c_n): Final hidden and cell states
        """
        if self.batch_first:
            batch_size, seq_len, _ = x.size()
        else:
            seq_len, batch_size, _ = x.size()
            x = x.transpose(0, 1)
        
        # Initialize hidden states
        if hidden is None:
            h = [torch.zeros(batch_size, self.hidden_size, device=x.device)
                 for _ in range(self.num_layers * self.num_directions)]
            c = [torch.zeros(batch_size, self.hidden_size, device=x.device)
                 for _ in range(self.num_layers * self.num_directions)]
        else:
            h = [hidden[0][i] for i in range(self.num_layers * self.num_directions)]
            c = [hidden[1][i] for i in range(self.num_layers * self.num_directions)]
        
        # Process each layer
        layer_input = x
        
        for layer in range(self.num_layers):
            outputs_forward = []
            outputs_backward = []
            
            # Forward direction
            cell_idx = layer * self.num_directions
            h_forward, c_forward = h[cell_idx], c[cell_idx]
            
            for t in range(seq_len):
                h_forward, c_forward = self.cells[cell_idx](
                    layer_input[:, t, :], (h_forward, c_forward)
                )
                outputs_forward.append(h_forward)
            
            h[cell_idx], c[cell_idx] = h_forward, c_forward
            
            # Backward direction (if bidirectional)
            if self.bidirectional:
                cell_idx = layer * self.num_directions + 1
                h_backward, c_backward = h[cell_idx], c[cell_idx]
                
                for t in range(seq_len - 1, -1, -1):
                    h_backward, c_backward = self.cells[cell_idx](
                        layer_input[:, t, :], (h_backward, c_backward)
                    )
                    outputs_backward.insert(0, h_backward)
                
                h[cell_idx], c[cell_idx] = h_backward, c_backward
            
            # Combine outputs
            outputs_forward = torch.stack(outputs_forward, dim=1)
            if self.bidirectional:
                outputs_backward = torch.stack(outputs_backward, dim=1)
                layer_output = torch.cat([outputs_forward, outputs_backward], dim=-1)
            else:
                layer_output = outputs_forward
            
            # Apply dropout between layers (not on last layer)
            if self.dropout and layer < self.num_layers - 1:
                layer_output = self.dropout(layer_output)
            
            layer_input = layer_output
        
        # Stack hidden states
        h_n = torch.stack(h, dim=0)
        c_n = torch.stack(c, dim=0)
        
        if not self.batch_first:
            layer_output = layer_output.transpose(0, 1)
        
        return layer_output, (h_n, c_n)


# Compare with PyTorch implementation
our_lstm = LSTM(input_size=32, hidden_size=64, num_layers=2, batch_first=True)
pytorch_lstm = nn.LSTM(input_size=32, hidden_size=64, num_layers=2, batch_first=True)

x = torch.randn(8, 20, 32)  # batch=8, seq_len=20, input=32

our_output, (our_h, our_c) = our_lstm(x)
pt_output, (pt_h, pt_c) = pytorch_lstm(x)

print(f"Our output shape:     {our_output.shape}")
print(f"PyTorch output shape: {pt_output.shape}")
print(f"Our hidden shape:     {our_h.shape}")
print(f"PyTorch hidden shape: {pt_h.shape}")
```

***

## Understanding the Gates

### The Forget Gate: Learning What to Ignore

```python theme={null}
def visualize_forget_gate():
    """
    The forget gate decides what information from the cell state
    should be thrown away.
    
    f_t = σ(W_f · [h_{t-1}, x_t] + b_f)
    
    - f_t ≈ 1: Keep this information
    - f_t ≈ 0: Forget this information
    """
    
    # Example: Language model processing "John lives in Paris. He speaks ___"
    # When we see "He", the forget gate might:
    # - Keep: Information that subject is male ("John")
    # - Forget: Specific details about Paris that aren't relevant
    
    print("Forget Gate Example:")
    print("-" * 50)
    print("Context: 'John lives in Paris. He speaks ___'")
    print()
    print("Cell state before 'He':")
    print("  [subject=John, location=Paris, job=unknown, ...]")
    print()
    print("Forget gate values (learned):")
    print("  [subject: 0.95, location: 0.3, job: 0.1, ...]")
    print()
    print("Cell state after forget gate:")
    print("  [subject=John(kept), location=faded, job=forgotten]")
    
visualize_forget_gate()


def forget_gate_experiment():
    """
    Demonstrate how the forget gate responds to different inputs.
    """
    lstm = nn.LSTM(10, 20, batch_first=True)
    
    # Create two sequences
    # Sequence 1: Normal input
    x1 = torch.randn(1, 50, 10)
    
    # Sequence 2: Input with a "reset signal" at position 25
    x2 = x1.clone()
    x2[0, 25, :] = 10.0  # Strong signal
    
    # Get intermediate states using hooks
    forget_gates = []
    
    def hook(module, input, output):
        # LSTM internal states
        pass
    
    # Process and compare cell states
    _, (_, c1) = lstm(x1)
    _, (_, c2) = lstm(x2)
    
    print("Cell state comparison:")
    print(f"Normal sequence - cell state norm: {c1.norm().item():.4f}")
    print(f"With reset signal - cell state norm: {c2.norm().item():.4f}")
    
forget_gate_experiment()
```

### The Input Gate: Learning What to Remember

```python theme={null}
def visualize_input_gate():
    """
    The input gate decides what new information to store.
    
    i_t = σ(W_i · [h_{t-1}, x_t] + b_i)     # What to update
    g_t = tanh(W_g · [h_{t-1}, x_t] + b_g)  # Candidate values
    
    New information = i_t * g_t
    """
    
    print("Input Gate Example:")
    print("-" * 50)
    print("Processing: 'Marie Curie won the Nobel Prize in ___ and ___'")
    print()
    print("When we see 'Nobel Prize':")
    print("  Candidate values (tanh): [award_type: +0.9, person: -0.1, ...]")
    print("  Input gate (sigmoid):    [award_type: 0.95, person: 0.2, ...]")
    print()
    print("New cell state += [strong award signal, weak person signal]")
    
visualize_input_gate()
```

### The Output Gate: Learning What to Reveal

```python theme={null}
def visualize_output_gate():
    """
    The output gate decides what parts of the cell state to output.
    
    o_t = σ(W_o · [h_{t-1}, x_t] + b_o)
    h_t = o_t * tanh(C_t)
    
    The cell state might store information that isn't immediately
    relevant but will be needed later.
    """
    
    print("Output Gate Example:")
    print("-" * 50)
    print("Cell state: [subject=John, visited_places=[Paris,London], mood=happy]")
    print()
    print("Current word: 'traveled' → Need location info")
    print("  Output gate: [subject: 0.3, places: 0.9, mood: 0.1]")
    print("  Hidden output emphasizes: places")
    print()
    print("Current word: 'smiled' → Need mood info")  
    print("  Output gate: [subject: 0.2, places: 0.1, mood: 0.95]")
    print("  Hidden output emphasizes: mood")

visualize_output_gate()
```

<Frame>
  <img src="https://mintlify.s3.us-west-1.amazonaws.com/devweeekends/images/courses/deep-learning-mastery/lstm-gates-flow.svg" alt="LSTM Gates Information Flow" />
</Frame>

***

## Why LSTM Solves Vanishing Gradients

### The Gradient Highway

```python theme={null}
def explain_gradient_highway():
    """
    The key to LSTM's success: The cell state update
    
    C_t = f_t * C_{t-1} + i_t * g_t
    
    The gradient of C_t w.r.t. C_{t-1} is simply f_t (forget gate).
    
    If f_t ≈ 1, gradients flow through unchanged!
    This is like a "gradient highway" - information and gradients
    can travel long distances without being squashed.
    """
    
    print("Gradient Flow Comparison:")
    print("-" * 60)
    print()
    print("Vanilla RNN gradient through T steps:")
    print("  ∂h_T/∂h_1 = ∏(W_hh * diag(tanh')) → vanishes!")
    print()
    print("LSTM gradient through T steps:")
    print("  ∂C_T/∂C_1 = ∏(f_t) → can stay close to 1!")
    print()
    print("Key insight: Forget gate learns to keep gradients alive")
    print("for important long-range dependencies.")
    
explain_gradient_highway()


def compare_gradient_flow():
    """Compare gradient flow in RNN vs LSTM."""
    
    seq_len = 100
    hidden_size = 64
    
    # Vanilla RNN
    rnn = nn.RNN(32, hidden_size, batch_first=True)
    
    # LSTM
    lstm = nn.LSTM(32, hidden_size, batch_first=True)
    
    x = torch.randn(1, seq_len, 32, requires_grad=True)
    
    # RNN gradient flow
    rnn_out, _ = rnn(x)
    rnn_loss = rnn_out[:, -1, :].sum()
    rnn_loss.backward()
    rnn_grad = x.grad[:, 0, :].norm().item()
    
    x.grad.zero_()
    
    # LSTM gradient flow
    lstm_out, _ = lstm(x)
    lstm_loss = lstm_out[:, -1, :].sum()
    lstm_loss.backward()
    lstm_grad = x.grad[:, 0, :].norm().item()
    
    print(f"\nGradient at first time step (sequence length {seq_len}):")
    print(f"  RNN:  {rnn_grad:.6f}")
    print(f"  LSTM: {lstm_grad:.6f}")
    print(f"  Ratio (LSTM/RNN): {lstm_grad/rnn_grad:.2f}x")

compare_gradient_flow()
```

***

## Gated Recurrent Unit (GRU)

### A Simpler Alternative

GRU simplifies LSTM by:

1. Combining forget and input gates into an "update gate"
2. Merging cell state and hidden state

The key insight is that forgetting and remembering are two sides of the same coin. In an LSTM, the forget gate and input gate are independent -- you could forget everything *and* write nothing new (losing information), or forget nothing *and* write a lot (accumulating unboundedly). GRU enforces a conservation law: the update gate $z_t$ controls a smooth interpolation between the old state and the new candidate. When $z_t = 1$, you fully adopt the new candidate. When $z_t = 0$, you keep the old state unchanged. There is no way to simultaneously forget and fail to replace -- which makes the GRU more constrained but also more stable and easier to train.

Think of it like a thermostat dial. LSTM gives you two separate dials (heating and cooling), which is more flexible but also lets you accidentally run both at once. GRU gives you a single dial that smoothly blends between "keep the old temperature" and "adopt the new temperature." Fewer controls, but harder to misconfigure.

<Frame>
  <img src="https://mintlify.s3.us-west-1.amazonaws.com/devweeekends/images/courses/deep-learning-mastery/gru-cell.svg" alt="GRU Cell Diagram" />
</Frame>

### GRU Equations

$$
\begin{aligned}
z_t &= \sigma(W_z \cdot [h_{t-1}, x_t]) & \text{(Update gate)} \\
r_t &= \sigma(W_r \cdot [h_{t-1}, x_t]) & \text{(Reset gate)} \\
\tilde{h}_t &= \tanh(W_h \cdot [r_t \odot h_{t-1}, x_t]) & \text{(Candidate)} \\
h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t & \text{(Final state)}
\end{aligned}
$$

```python theme={null}
class GRUCell(nn.Module):
    """
    GRU Cell implemented from scratch.
    
    Compared to LSTM:
    - 2 gates instead of 3
    - No separate cell state
    - Fewer parameters
    - Often performs similarly
    """
    
    def __init__(self, input_size, hidden_size):
        super().__init__()
        
        self.hidden_size = hidden_size
        
        # Update and reset gates
        self.W_z = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_r = nn.Linear(input_size + hidden_size, hidden_size)
        
        # Candidate hidden state
        self.W_h = nn.Linear(input_size + hidden_size, hidden_size)
    
    def forward(self, x, h=None):
        """
        Args:
            x: Input (batch, input_size)
            h: Hidden state (batch, hidden_size)
        
        Returns:
            h_new: Updated hidden state
        """
        batch_size = x.size(0)
        
        if h is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
        
        # Concatenate input and hidden
        combined = torch.cat([x, h], dim=1)
        
        # Update gate: how much of new state to use
        z = torch.sigmoid(self.W_z(combined))
        
        # Reset gate: how much of old state to forget when computing candidate
        r = torch.sigmoid(self.W_r(combined))
        
        # Candidate hidden state
        combined_reset = torch.cat([x, r * h], dim=1)
        h_candidate = torch.tanh(self.W_h(combined_reset))
        
        # Final hidden state: interpolate between old and candidate
        h_new = (1 - z) * h + z * h_candidate
        
        return h_new


class GRU(nn.Module):
    """Complete GRU layer."""
    
    def __init__(self, input_size, hidden_size, num_layers=1, batch_first=True):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        
        self.cells = nn.ModuleList([
            GRUCell(input_size if i == 0 else hidden_size, hidden_size)
            for i in range(num_layers)
        ])
    
    def forward(self, x, h_0=None):
        if self.batch_first:
            batch_size, seq_len, _ = x.size()
        else:
            seq_len, batch_size, _ = x.size()
            x = x.transpose(0, 1)
        
        # Initialize hidden states
        if h_0 is None:
            h = [torch.zeros(batch_size, self.hidden_size, device=x.device)
                 for _ in range(self.num_layers)]
        else:
            h = [h_0[i] for i in range(self.num_layers)]
        
        outputs = []
        
        for t in range(seq_len):
            x_t = x[:, t, :]
            
            for layer, cell in enumerate(self.cells):
                h[layer] = cell(x_t, h[layer])
                x_t = h[layer]
            
            outputs.append(h[-1])
        
        outputs = torch.stack(outputs, dim=1)
        h_n = torch.stack(h, dim=0)
        
        if not self.batch_first:
            outputs = outputs.transpose(0, 1)
        
        return outputs, h_n


# Test GRU
gru = GRU(input_size=32, hidden_size=64, num_layers=2, batch_first=True)
x = torch.randn(8, 20, 32)
output, h_n = gru(x)

print(f"Input:  {x.shape}")
print(f"Output: {output.shape}")
print(f"Hidden: {h_n.shape}")
```

### LSTM vs GRU Comparison

| Aspect          | LSTM                                   | GRU                                   |
| --------------- | -------------------------------------- | ------------------------------------- |
| **Gates**       | 3 (forget, input, output)              | 2 (update, reset)                     |
| **States**      | 2 (hidden + cell)                      | 1 (hidden only)                       |
| **Parameters**  | 4 × hidden²                            | 3 × hidden²                           |
| **Training**    | Slightly slower                        | Faster                                |
| **Performance** | Often slightly better on complex tasks | Often comparable                      |
| **When to use** | Long sequences, complex dependencies   | Faster training needed, simpler tasks |

```python theme={null}
def compare_lstm_gru():
    """Compare LSTM and GRU architectures."""
    
    input_size = 128
    hidden_size = 256
    
    lstm = nn.LSTM(input_size, hidden_size, num_layers=2, batch_first=True)
    gru = nn.GRU(input_size, hidden_size, num_layers=2, batch_first=True)
    
    lstm_params = sum(p.numel() for p in lstm.parameters())
    gru_params = sum(p.numel() for p in gru.parameters())
    
    print("Parameter Comparison:")
    print(f"  LSTM: {lstm_params:,} parameters")
    print(f"  GRU:  {gru_params:,} parameters")
    print(f"  GRU is {100*(1 - gru_params/lstm_params):.1f}% smaller")
    
    # Speed comparison
    import time
    
    x = torch.randn(32, 100, input_size)
    
    # Warm up
    _ = lstm(x)
    _ = gru(x)
    
    # Time LSTM
    start = time.time()
    for _ in range(100):
        _ = lstm(x)
    lstm_time = time.time() - start
    
    # Time GRU
    start = time.time()
    for _ in range(100):
        _ = gru(x)
    gru_time = time.time() - start
    
    print(f"\nSpeed Comparison (100 forward passes):")
    print(f"  LSTM: {lstm_time:.3f}s")
    print(f"  GRU:  {gru_time:.3f}s")
    print(f"  GRU is {100*(1 - gru_time/lstm_time):.1f}% faster")

compare_lstm_gru()
```

***

## Practical Applications

### Sentiment Analysis with LSTM

```python theme={null}
class SentimentLSTM(nn.Module):
    """
    Sentiment analysis using bidirectional LSTM.
    """
    
    def __init__(self, vocab_size, embed_dim, hidden_dim, 
                 num_layers, num_classes, dropout=0.5):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        self.lstm = nn.LSTM(
            embed_dim, 
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        self.dropout = nn.Dropout(dropout)
        
        # Bidirectional → hidden_dim * 2
        self.fc = nn.Linear(hidden_dim * 2, num_classes)
    
    def forward(self, x, lengths=None):
        """
        Args:
            x: Token indices (batch, seq_len)
            lengths: Actual sequence lengths (for packing)
        """
        # Embed tokens
        embedded = self.dropout(self.embedding(x))
        
        # Pack sequences if lengths provided
        if lengths is not None:
            packed = nn.utils.rnn.pack_padded_sequence(
                embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            output, (hidden, cell) = self.lstm(packed)
        else:
            output, (hidden, cell) = self.lstm(embedded)
        
        # Concatenate final forward and backward hidden states
        # hidden shape: (num_layers * 2, batch, hidden_dim)
        hidden_forward = hidden[-2, :, :]  # Last layer, forward
        hidden_backward = hidden[-1, :, :]  # Last layer, backward
        hidden_cat = torch.cat([hidden_forward, hidden_backward], dim=1)
        
        # Classify
        output = self.fc(self.dropout(hidden_cat))
        
        return output


# Example usage
model = SentimentLSTM(
    vocab_size=10000,
    embed_dim=128,
    hidden_dim=256,
    num_layers=2,
    num_classes=2
)

# Simulated batch
batch = torch.randint(1, 10000, (16, 100))  # 16 sequences, length 100
predictions = model(batch)

print(f"Input: {batch.shape}")
print(f"Predictions: {predictions.shape}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
```

### Language Modeling with LSTM

```python theme={null}
class LanguageModelLSTM(nn.Module):
    """
    Language model: Predict next word given previous words.
    """
    
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, 
                 dropout=0.5, tie_weights=True):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
        # Weight tying: share embedding and output weights
        if tie_weights and embed_dim == hidden_dim:
            self.fc.weight = self.embedding.weight
    
    def forward(self, x, hidden=None):
        """
        Args:
            x: Token indices (batch, seq_len)
            hidden: (h, c) tuple for stateful processing
        
        Returns:
            logits: (batch, seq_len, vocab_size)
            hidden: Updated hidden state
        """
        embedded = self.dropout(self.embedding(x))
        output, hidden = self.lstm(embedded, hidden)
        output = self.dropout(output)
        logits = self.fc(output)
        
        return logits, hidden
    
    def init_hidden(self, batch_size, device='cpu'):
        """Initialize hidden state."""
        h = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)
        c = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)
        return (h, c)
    
    def generate(self, start_tokens, max_length=100, temperature=1.0):
        """Generate text continuation."""
        self.eval()
        
        batch_size = start_tokens.size(0)
        device = start_tokens.device
        
        hidden = self.init_hidden(batch_size, device)
        generated = start_tokens.tolist()
        
        current = start_tokens
        
        with torch.no_grad():
            for _ in range(max_length):
                logits, hidden = self.forward(current, hidden)
                
                # Get last token prediction
                logits = logits[:, -1, :] / temperature
                probs = torch.softmax(logits, dim=-1)
                
                # Sample
                next_token = torch.multinomial(probs, 1)
                
                for i in range(batch_size):
                    generated[i].append(next_token[i].item())
                
                current = next_token
        
        return generated


# Create model
lm = LanguageModelLSTM(
    vocab_size=50000,
    embed_dim=512,
    hidden_dim=512,
    num_layers=3,
    dropout=0.3
)

# Training example
x = torch.randint(0, 50000, (8, 35))  # batch=8, seq_len=35
y = torch.randint(0, 50000, (8, 35))  # targets (shifted by 1)

logits, _ = lm(x)
loss = nn.CrossEntropyLoss()(logits.view(-1, 50000), y.view(-1))

print(f"Input: {x.shape}")
print(f"Output: {logits.shape}")
print(f"Loss: {loss.item():.4f}")
print(f"Perplexity: {torch.exp(loss).item():.2f}")
```

### Sequence-to-Sequence Translation

```python theme={null}
class Encoder(nn.Module):
    """LSTM Encoder for seq2seq."""
    
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim, hidden_dim, num_layers,
            batch_first=True, dropout=dropout, bidirectional=True
        )
        self.dropout = nn.Dropout(dropout)
        
        # Project bidirectional to unidirectional for decoder
        self.fc_h = nn.Linear(hidden_dim * 2, hidden_dim)
        self.fc_c = nn.Linear(hidden_dim * 2, hidden_dim)
    
    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, (hidden, cell) = self.lstm(embedded)
        
        # Combine bidirectional states
        # hidden: (num_layers * 2, batch, hidden_dim)
        hidden = hidden.view(self.lstm.num_layers, 2, -1, self.lstm.hidden_size)
        hidden = torch.cat([hidden[:, 0], hidden[:, 1]], dim=-1)
        hidden = torch.tanh(self.fc_h(hidden))
        
        cell = cell.view(self.lstm.num_layers, 2, -1, self.lstm.hidden_size)
        cell = torch.cat([cell[:, 0], cell[:, 1]], dim=-1)
        cell = torch.tanh(self.fc_c(cell))
        
        return outputs, (hidden, cell)


class Decoder(nn.Module):
    """LSTM Decoder for seq2seq."""
    
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout):
        super().__init__()
        
        self.vocab_size = vocab_size
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim, hidden_dim, num_layers,
            batch_first=True, dropout=dropout
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, tgt, hidden):
        embedded = self.dropout(self.embedding(tgt))
        output, hidden = self.lstm(embedded, hidden)
        prediction = self.fc(output)
        return prediction, hidden


class Seq2Seq(nn.Module):
    """Complete sequence-to-sequence model."""
    
    def __init__(self, encoder, decoder):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        """
        Args:
            src: Source sequence (batch, src_len)
            tgt: Target sequence (batch, tgt_len)
            teacher_forcing_ratio: Probability of using ground truth
        """
        batch_size = src.size(0)
        tgt_len = tgt.size(1)
        tgt_vocab_size = self.decoder.vocab_size
        
        # Store outputs
        outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size, device=src.device)
        
        # Encode source
        encoder_outputs, hidden = self.encoder(src)
        
        # First decoder input is <sos> token
        decoder_input = tgt[:, 0:1]
        
        for t in range(1, tgt_len):
            output, hidden = self.decoder(decoder_input, hidden)
            outputs[:, t] = output.squeeze(1)
            
            # Teacher forcing
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            
            if teacher_force:
                decoder_input = tgt[:, t:t+1]
            else:
                decoder_input = output.argmax(-1)
        
        return outputs


# Create seq2seq model
encoder = Encoder(vocab_size=10000, embed_dim=256, hidden_dim=512, 
                  num_layers=2, dropout=0.3)
decoder = Decoder(vocab_size=8000, embed_dim=256, hidden_dim=512,
                  num_layers=2, dropout=0.3)

model = Seq2Seq(encoder, decoder)

# Example
src = torch.randint(0, 10000, (16, 30))  # Source: 16 sentences, max 30 tokens
tgt = torch.randint(0, 8000, (16, 25))   # Target: 16 sentences, max 25 tokens

output = model(src, tgt)
print(f"Source: {src.shape}")
print(f"Target: {tgt.shape}")
print(f"Output: {output.shape}")
```

***

## Advanced LSTM Variants

### Peephole Connections

```python theme={null}
class PeepholeLSTMCell(nn.Module):
    """
    LSTM with peephole connections.
    
    Gates can directly look at the cell state, not just hidden state.
    This can help with precise timing and counting.
    """
    
    def __init__(self, input_size, hidden_size):
        super().__init__()
        
        self.hidden_size = hidden_size
        
        # Standard weights
        self.W_i = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_f = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_o = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_c = nn.Linear(input_size + hidden_size, hidden_size)
        
        # Peephole weights (diagonal, so just vectors)
        self.p_i = nn.Parameter(torch.randn(hidden_size))
        self.p_f = nn.Parameter(torch.randn(hidden_size))
        self.p_o = nn.Parameter(torch.randn(hidden_size))
    
    def forward(self, x, hidden=None):
        batch_size = x.size(0)
        
        if hidden is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
            c = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            h, c = hidden
        
        combined = torch.cat([x, h], dim=1)
        
        # Gates with peephole connections to cell state
        f = torch.sigmoid(self.W_f(combined) + self.p_f * c)
        i = torch.sigmoid(self.W_i(combined) + self.p_i * c)
        
        c_candidate = torch.tanh(self.W_c(combined))
        c_new = f * c + i * c_candidate
        
        o = torch.sigmoid(self.W_o(combined) + self.p_o * c_new)
        h_new = o * torch.tanh(c_new)
        
        return h_new, c_new
```

### Layer Normalization in LSTM

```python theme={null}
class LayerNormLSTMCell(nn.Module):
    """
    LSTM with layer normalization for better training stability.
    """
    
    def __init__(self, input_size, hidden_size):
        super().__init__()
        
        self.hidden_size = hidden_size
        
        self.W_ih = nn.Linear(input_size, 4 * hidden_size, bias=False)
        self.W_hh = nn.Linear(hidden_size, 4 * hidden_size, bias=False)
        
        # Layer normalization
        self.ln_ih = nn.LayerNorm(4 * hidden_size)
        self.ln_hh = nn.LayerNorm(4 * hidden_size)
        self.ln_cell = nn.LayerNorm(hidden_size)
    
    def forward(self, x, hidden=None):
        batch_size = x.size(0)
        
        if hidden is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
            c = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            h, c = hidden
        
        # Apply layer norm to gate inputs
        gates = self.ln_ih(self.W_ih(x)) + self.ln_hh(self.W_hh(h))
        
        i, f, g, o = gates.chunk(4, dim=1)
        
        i = torch.sigmoid(i)
        f = torch.sigmoid(f)
        g = torch.tanh(g)
        o = torch.sigmoid(o)
        
        c_new = f * c + i * g
        
        # Layer norm on cell state before output
        h_new = o * torch.tanh(self.ln_cell(c_new))
        
        return h_new, c_new
```

***

## Best Practices and Tips

<Tip>
  **Training LSTM/GRU Models:**

  1. **Gradient clipping**: Always use `torch.nn.utils.clip_grad_norm_` with max\_norm=1.0-5.0
  2. **Learning rate**: Start with 0.001 for Adam, 0.1-1.0 for SGD
  3. **Dropout**: Use 0.2-0.5 between layers, not within cells
  4. **Initialization**: PyTorch defaults are usually fine
  5. **Bidirectional**: Use for tasks where you have full sequence (classification, tagging)
</Tip>

<Warning>
  **Common Mistakes:**

  1. **Not detaching hidden state**: For stateful training, detach hidden state between batches:
     ```python theme={null}
     hidden = (hidden[0].detach(), hidden[1].detach())
     ```

  2. **Ignoring sequence lengths**: Use `pack_padded_sequence` for variable-length sequences

  3. **Wrong hidden state indexing**: For bidirectional, hidden\[-2:] contains last layer

  4. **Too many layers**: 2-3 layers usually sufficient; more can hurt
</Warning>

<Tip>
  **Advanced Debugging Hints for LSTM/GRU Training:**

  **Forget gate bias initialization**: A well-known trick from the original LSTM paper: initialize the forget gate bias to 1.0 (or even 2.0). This biases the gate toward "remembering" at initialization, which helps gradient flow in early training. In PyTorch, after creating the LSTM: `for name, param in lstm.named_parameters(): if 'bias' in name: n = param.size(0); param.data[n//4:n//2].fill_(1.0)`. The `[n//4:n//2]` slice targets the forget gate bias specifically because PyTorch packs gates in order (input, forget, cell, output).

  **Cell state explosion**: If your loss becomes NaN but gradients look normal, check the cell state magnitude. Unlike the hidden state (bounded by tanh), the cell state $C_t$ is unbounded -- it can grow arbitrarily large if the forget gate stays near 1 and the input gate keeps adding. Monitor `cell_state.abs().max()` during training. If it exceeds 100, you likely need gradient clipping or a lower learning rate.

  **Teacher forcing ratio scheduling**: For seq2seq models, starting with 100% teacher forcing and abruptly switching to 0% at inference causes a train/test mismatch. Schedule the ratio from 1.0 down to 0.0 over the course of training (linear or exponential decay). This is called "scheduled sampling" and significantly improves generation quality.

  **LSTM vs GRU selection heuristic**: Use LSTM as your default. Switch to GRU if: (a) you need faster training and your sequences are under 200 tokens, (b) you are memory-constrained (GRU uses 25% fewer parameters), or (c) your ablation shows no accuracy difference. For most NLP tasks with sequences over 100 tokens, LSTM has a slight edge; for shorter sequences, they are typically indistinguishable.
</Tip>

***

## Exercises

<AccordionGroup>
  <Accordion title="Exercise 1: Implement LSTM from Scratch">
    Implement a complete LSTM without using `nn.LSTM`:

    1. Implement `LSTMCell` with all gates
    2. Stack cells into `LSTM` layer
    3. Add bidirectional support
    4. Verify outputs match `nn.LSTM`

    Test on a simple sequence classification task.
  </Accordion>

  <Accordion title="Exercise 2: Gradient Flow Analysis">
    Create a visualization of gradient flow through LSTM:

    1. Process sequences of length 10, 50, 100, 200
    2. Track gradient magnitude at each time step
    3. Compare with vanilla RNN
    4. Plot the results

    What do you observe about the forget gate values in trained models?
  </Accordion>

  <Accordion title="Exercise 3: Named Entity Recognition">
    Build an NER tagger using BiLSTM:

    1. Load CoNLL-2003 dataset
    2. Implement BiLSTM-CRF model
    3. Train with proper evaluation (F1 score)
    4. Analyze errors by entity type

    Compare BiLSTM with BiGRU.
  </Accordion>

  <Accordion title="Exercise 4: Music Generation">
    Train an LSTM to generate music:

    1. Download MIDI files and convert to sequences
    2. Train character-level LSTM on ABC notation
    3. Generate new melodies
    4. Convert back to MIDI and listen

    Experiment with different temperatures.
  </Accordion>

  <Accordion title="Exercise 5: Time Series Forecasting">
    Build a multivariate time series forecaster:

    1. Use a dataset like air quality or stock prices
    2. Implement encoder-decoder with LSTM
    3. Add attention (preview of next chapter!)
    4. Compare with simple baselines

    Evaluate with proper time series cross-validation.
  </Accordion>
</AccordionGroup>

***

## Key Takeaways

| Concept              | Key Insight                                                    |
| -------------------- | -------------------------------------------------------------- |
| **Cell State**       | Long-term memory highway - gradients flow unimpeded            |
| **Forget Gate**      | Learns what to erase - enables "forgetting" of irrelevant info |
| **Input Gate**       | Learns what to remember - filters new information              |
| **Output Gate**      | Learns what to reveal - controls hidden state exposure         |
| **GRU**              | Simpler, fewer parameters, often similar performance           |
| **Gradient Highway** | Cell state update is additive → gradients don't vanish         |

***

## What's Next

<CardGroup cols={1}>
  <Card title="Module 10: Attention Mechanism" icon="eye" href="/courses/deep-learning-mastery/10-attention">
    Go beyond sequential processing — learn how attention allows models to focus on relevant parts of the input, enabling breakthrough performance on translation, summarization, and more.
  </Card>
</CardGroup>

***

## Interview Deep-Dive

<AccordionGroup>
  <Accordion title="Walk through the LSTM cell equations. For each gate, explain what it does and why that specific activation function was chosen.">
    **Strong Answer:**

    * **Forget gate**: $f_t = \sigma(W_f [h_{t-1}, x_t] + b_f)$. Sigmoid outputs in (0,1) act as a dimmer switch on each cell state element. Values near 1 keep information, near 0 erase it. Sigmoid is chosen because we need a smooth differentiable gate in \[0,1].
    * **Input gate + candidate**: $i_t = \sigma(W_i [h_{t-1}, x_t] + b_i)$ controls how much new information to write. Candidate values $\tilde{C}_t = \tanh(W_C [h_{t-1}, x_t] + b_C)$ use tanh because its \[-1, 1] range allows both additive and subtractive modifications to the cell state.
    * **Cell state update**: $C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t$. This additive update is the key -- $\partial C_t / \partial C_{t-1} = f_t$ can stay near 1, preserving gradients across hundreds of steps. Compare to vanilla RNNs where gradients decay exponentially.
    * **Output gate**: $o_t = \sigma(W_o [h_{t-1}, x_t] + b_o)$, $h_t = o_t \odot \tanh(C_t)$. The cell state may store information not needed for the current prediction. The output gate selectively exposes relevant parts.
    * The architectural insight: cell state is the memory bus, gates are learned read/write controllers. This separation of storage from computation enables long-term memory.

    **Follow-up: The forget gate bias is typically initialized to 1.0 instead of 0.0. Why?**

    With bias=0, sigmoid outputs 0.5, so the LSTM forgets half the cell state every step. Over 100 steps, only $(0.5)^{100} \approx 10^{-30}$ survives. Initializing to 1 gives sigmoid(1)=0.73, keeping the gradient highway open by default. The network can then learn which dimensions to close rather than having to first learn to keep them open. This trick (Jozefowicz et al., 2015) is considered essential practice for tasks with long dependencies.
  </Accordion>

  <Accordion title="Compare LSTM and GRU architecturally. When would you choose one over the other?">
    **Strong Answer:**

    * GRU merges forget and input gates into a single update gate and combines cell state with hidden state, reducing parameters by roughly 25% and improving training speed by 15-20%.
    * LSTM has a separate cell state providing a cleaner gradient highway, and separate forget/input gates give more fine-grained memory control.
    * **Choose GRU**: small datasets (fewer parameters reduce overfitting), speed-critical applications, moderate-length dependencies (under 200 steps). GRU performs comparably to LSTM on most benchmarks with shorter sequences.
    * **Choose LSTM**: very long dependencies (500+ steps), ample data to support extra parameters, or when you need the explicit cell state for inspection/interpretability.
    * In practice, the performance gap is usually 1-2%, and both have been largely superseded by transformers. The choice between them matters less than the choice between recurrence and attention.

    **Follow-up: GRU's reset gate has no direct equivalent in LSTM. What does it do?**

    The reset gate $r_t = \sigma(W_r [h_{t-1}, x_t])$ controls how much of the previous hidden state to expose when computing the candidate update. When $r_t \approx 0$, the candidate ignores previous state, allowing the model to write completely fresh information. LSTM achieves a similar effect through a low forget gate combined with a high input gate, but GRU's mechanism is more direct and parameter-efficient.
  </Accordion>

  <Accordion title="Why is the LSTM cell state update additive rather than multiplicative? Connect this to gradient flow and ResNet skip connections.">
    **Strong Answer:**

    * The cell state update $C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t$ is additive: new state = weighted old state + weighted new candidate. Vanilla RNNs use $h_t = \tanh(W_{hh} h_{t-1} + ...)$, a nonlinear (multiplicative) transformation.
    * Gradient consequence: $\partial C_T / \partial C_1 = \prod_{t=2}^{T} f_t$. With forget gates near 1, this product stays close to 1 across hundreds of steps. Vanilla RNNs have $\prod W_{hh}^T \cdot \text{diag}(\tanh')$, which decays exponentially.
    * This is the exact same principle as ResNet: $y = F(x) + x$ gives gradient $\partial y / \partial x = \partial F / \partial x + 1$. The additive identity provides a gradient highway. LSTM's forget gate modulates this highway ($f$ instead of fixed 1), but when $f \approx 1$, the effect is identical.
    * Both LSTM (1997) and ResNet (2015) independently discovered that additive shortcuts solve the gradient degradation problem in deep/long computation chains. The underlying math is the same: addition distributes gradients without attenuation.

    **Follow-up: Can the forget gate learn to be exactly 0 for some dimensions and exactly 1 for others simultaneously?**

    Yes, and this is exactly what happens in practice. Different dimensions of the cell state specialize: some maintain $f \approx 1$ for hundreds of steps (long-term registers storing sentence subjects or global context), while others cycle between 0 and 1 rapidly (short-term buffers for recent token information). Visualizing forget gate values across dimensions and time steps reveals this specialization clearly -- it is one of the most interpretable aspects of LSTM internals.
  </Accordion>
</AccordionGroup>
