Skip to main content
LSTM and GRU Architectures

LSTMs & GRUs: Gated Recurrent Networks

The Memory Problem Revisited

Vanilla RNNs suffer from a fundamental flaw: they can’t remember things for long. The vanishing gradient problem means information from early time steps gets “washed out” as it passes through many layers of tanh activations. Real-world consequence: An RNN reading a book can’t remember what happened in Chapter 1 when it reaches Chapter 10.
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

def demonstrate_memory_problem():
    """
    The Long-Range Dependency Problem:
    
    "The cat, which had been sleeping on the warm sunny 
    windowsill for the entire lazy afternoon while the 
    rain pattered gently against the glass, finally ___."
    
    To predict "woke up" or "stretched", we need to remember
    "cat" and "sleeping" from 25+ words ago!
    """
    
    print("Long-range dependency example:")
    print("-" * 60)
    print("Sentence: 'The CAT, which had been SLEEPING on the warm...")
    print("         sunny windowsill for the entire lazy afternoon...")
    print("         while the rain pattered against the glass, finally ___'")
    print()
    print("To predict the blank, we need to remember:")
    print("  - Subject: 'cat' (25 words back)")
    print("  - State: 'sleeping' (23 words back)")
    print()
    print("Vanilla RNNs fail because gradients vanish over 25 steps!")

demonstrate_memory_problem()
The Solution: Instead of trying to force information through a single path, create multiple pathways for information flow, some of which can pass information unchanged. This is the key insight behind LSTMs and GRUs.

Long Short-Term Memory (LSTM)

The Big Idea: A Memory Cell with Gates

An LSTM maintains two types of state:
  1. Cell State (CtC_t): The “long-term memory” - a highway for information
  2. Hidden State (hth_t): The “working memory” - current output
Three gates control information flow:
  1. Forget Gate: What to erase from cell state
  2. Input Gate: What new information to add
  3. Output Gate: What to output based on cell state
LSTM Cell Diagram

LSTM Equations

ft=σ(Wf[ht1,xt]+bf)(Forget gate)it=σ(Wi[ht1,xt]+bi)(Input gate)C~t=tanh(WC[ht1,xt]+bC)(Candidate values)Ct=ftCt1+itC~t(Cell state update)ot=σ(Wo[ht1,xt]+bo)(Output gate)ht=ottanh(Ct)(Hidden state)\begin{aligned} f_t &= \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) & \text{(Forget gate)} \\ i_t &= \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) & \text{(Input gate)} \\ \tilde{C}_t &= \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) & \text{(Candidate values)} \\ C_t &= f_t \odot C_{t-1} + i_t \odot \tilde{C}_t & \text{(Cell state update)} \\ o_t &= \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) & \text{(Output gate)} \\ h_t &= o_t \odot \tanh(C_t) & \text{(Hidden state)} \end{aligned} Where:
  • σ\sigma is the sigmoid function (outputs 0-1 for gating)
  • \odot is element-wise multiplication
  • [ht1,xt][h_{t-1}, x_t] is concatenation of previous hidden and current input
class LSTMCell(nn.Module):
    """
    LSTM Cell implemented from scratch.
    
    The key insight: Cell state C_t flows through with only
    element-wise operations (multiply by forget, add input).
    This creates a "gradient highway" - gradients can flow
    backward through time without vanishing!
    """
    
    def __init__(self, input_size, hidden_size):
        super().__init__()
        
        self.hidden_size = hidden_size
        
        # All gates in one matrix for efficiency
        # Order: input, forget, cell, output (i, f, g, o)
        self.gates = nn.Linear(input_size + hidden_size, 4 * hidden_size)
    
    def forward(self, x, hidden=None):
        """
        Args:
            x: Input (batch, input_size)
            hidden: Tuple of (h, c), each (batch, hidden_size)
        
        Returns:
            h_new: New hidden state
            c_new: New cell state
        """
        batch_size = x.size(0)
        
        # Initialize hidden states if not provided
        if hidden is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
            c = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            h, c = hidden
        
        # Concatenate input and hidden state
        combined = torch.cat([x, h], dim=1)
        
        # Compute all gates at once
        gates = self.gates(combined)
        
        # Split into individual gates
        i, f, g, o = gates.chunk(4, dim=1)
        
        # Apply activations
        i = torch.sigmoid(i)  # Input gate
        f = torch.sigmoid(f)  # Forget gate
        g = torch.tanh(g)     # Candidate values
        o = torch.sigmoid(o)  # Output gate
        
        # Update cell state
        c_new = f * c + i * g
        
        # Compute hidden state
        h_new = o * torch.tanh(c_new)
        
        return h_new, c_new


# Test our LSTM cell
cell = LSTMCell(input_size=32, hidden_size=64)

x = torch.randn(8, 32)  # Batch of 8, input size 32
h, c = cell(x)

print(f"Input: {x.shape}")
print(f"Hidden state: {h.shape}")
print(f"Cell state: {c.shape}")

Complete LSTM Layer

class LSTM(nn.Module):
    """
    Complete LSTM layer that processes sequences.
    """
    
    def __init__(self, input_size, hidden_size, num_layers=1, 
                 batch_first=True, dropout=0.0, bidirectional=False):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1
        
        # Stack of LSTM cells
        self.cells = nn.ModuleList()
        for layer in range(num_layers):
            for direction in range(self.num_directions):
                layer_input_size = input_size if layer == 0 else hidden_size * self.num_directions
                self.cells.append(LSTMCell(layer_input_size, hidden_size))
        
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None
    
    def forward(self, x, hidden=None):
        """
        Process sequence through LSTM.
        
        Args:
            x: (batch, seq_len, input) if batch_first else (seq_len, batch, input)
            hidden: Tuple of (h_0, c_0), each (num_layers * directions, batch, hidden)
        
        Returns:
            output: Hidden states at each time step
            (h_n, c_n): Final hidden and cell states
        """
        if self.batch_first:
            batch_size, seq_len, _ = x.size()
        else:
            seq_len, batch_size, _ = x.size()
            x = x.transpose(0, 1)
        
        # Initialize hidden states
        if hidden is None:
            h = [torch.zeros(batch_size, self.hidden_size, device=x.device)
                 for _ in range(self.num_layers * self.num_directions)]
            c = [torch.zeros(batch_size, self.hidden_size, device=x.device)
                 for _ in range(self.num_layers * self.num_directions)]
        else:
            h = [hidden[0][i] for i in range(self.num_layers * self.num_directions)]
            c = [hidden[1][i] for i in range(self.num_layers * self.num_directions)]
        
        # Process each layer
        layer_input = x
        
        for layer in range(self.num_layers):
            outputs_forward = []
            outputs_backward = []
            
            # Forward direction
            cell_idx = layer * self.num_directions
            h_forward, c_forward = h[cell_idx], c[cell_idx]
            
            for t in range(seq_len):
                h_forward, c_forward = self.cells[cell_idx](
                    layer_input[:, t, :], (h_forward, c_forward)
                )
                outputs_forward.append(h_forward)
            
            h[cell_idx], c[cell_idx] = h_forward, c_forward
            
            # Backward direction (if bidirectional)
            if self.bidirectional:
                cell_idx = layer * self.num_directions + 1
                h_backward, c_backward = h[cell_idx], c[cell_idx]
                
                for t in range(seq_len - 1, -1, -1):
                    h_backward, c_backward = self.cells[cell_idx](
                        layer_input[:, t, :], (h_backward, c_backward)
                    )
                    outputs_backward.insert(0, h_backward)
                
                h[cell_idx], c[cell_idx] = h_backward, c_backward
            
            # Combine outputs
            outputs_forward = torch.stack(outputs_forward, dim=1)
            if self.bidirectional:
                outputs_backward = torch.stack(outputs_backward, dim=1)
                layer_output = torch.cat([outputs_forward, outputs_backward], dim=-1)
            else:
                layer_output = outputs_forward
            
            # Apply dropout between layers (not on last layer)
            if self.dropout and layer < self.num_layers - 1:
                layer_output = self.dropout(layer_output)
            
            layer_input = layer_output
        
        # Stack hidden states
        h_n = torch.stack(h, dim=0)
        c_n = torch.stack(c, dim=0)
        
        if not self.batch_first:
            layer_output = layer_output.transpose(0, 1)
        
        return layer_output, (h_n, c_n)


# Compare with PyTorch implementation
our_lstm = LSTM(input_size=32, hidden_size=64, num_layers=2, batch_first=True)
pytorch_lstm = nn.LSTM(input_size=32, hidden_size=64, num_layers=2, batch_first=True)

x = torch.randn(8, 20, 32)  # batch=8, seq_len=20, input=32

our_output, (our_h, our_c) = our_lstm(x)
pt_output, (pt_h, pt_c) = pytorch_lstm(x)

print(f"Our output shape:     {our_output.shape}")
print(f"PyTorch output shape: {pt_output.shape}")
print(f"Our hidden shape:     {our_h.shape}")
print(f"PyTorch hidden shape: {pt_h.shape}")

Understanding the Gates

The Forget Gate: Learning What to Ignore

def visualize_forget_gate():
    """
    The forget gate decides what information from the cell state
    should be thrown away.
    
    f_t = σ(W_f · [h_{t-1}, x_t] + b_f)
    
    - f_t ≈ 1: Keep this information
    - f_t ≈ 0: Forget this information
    """
    
    # Example: Language model processing "John lives in Paris. He speaks ___"
    # When we see "He", the forget gate might:
    # - Keep: Information that subject is male ("John")
    # - Forget: Specific details about Paris that aren't relevant
    
    print("Forget Gate Example:")
    print("-" * 50)
    print("Context: 'John lives in Paris. He speaks ___'")
    print()
    print("Cell state before 'He':")
    print("  [subject=John, location=Paris, job=unknown, ...]")
    print()
    print("Forget gate values (learned):")
    print("  [subject: 0.95, location: 0.3, job: 0.1, ...]")
    print()
    print("Cell state after forget gate:")
    print("  [subject=John(kept), location=faded, job=forgotten]")
    
visualize_forget_gate()


def forget_gate_experiment():
    """
    Demonstrate how the forget gate responds to different inputs.
    """
    lstm = nn.LSTM(10, 20, batch_first=True)
    
    # Create two sequences
    # Sequence 1: Normal input
    x1 = torch.randn(1, 50, 10)
    
    # Sequence 2: Input with a "reset signal" at position 25
    x2 = x1.clone()
    x2[0, 25, :] = 10.0  # Strong signal
    
    # Get intermediate states using hooks
    forget_gates = []
    
    def hook(module, input, output):
        # LSTM internal states
        pass
    
    # Process and compare cell states
    _, (_, c1) = lstm(x1)
    _, (_, c2) = lstm(x2)
    
    print("Cell state comparison:")
    print(f"Normal sequence - cell state norm: {c1.norm().item():.4f}")
    print(f"With reset signal - cell state norm: {c2.norm().item():.4f}")
    
forget_gate_experiment()

The Input Gate: Learning What to Remember

def visualize_input_gate():
    """
    The input gate decides what new information to store.
    
    i_t = σ(W_i · [h_{t-1}, x_t] + b_i)     # What to update
    g_t = tanh(W_g · [h_{t-1}, x_t] + b_g)  # Candidate values
    
    New information = i_t * g_t
    """
    
    print("Input Gate Example:")
    print("-" * 50)
    print("Processing: 'Marie Curie won the Nobel Prize in ___ and ___'")
    print()
    print("When we see 'Nobel Prize':")
    print("  Candidate values (tanh): [award_type: +0.9, person: -0.1, ...]")
    print("  Input gate (sigmoid):    [award_type: 0.95, person: 0.2, ...]")
    print()
    print("New cell state += [strong award signal, weak person signal]")
    
visualize_input_gate()

The Output Gate: Learning What to Reveal

def visualize_output_gate():
    """
    The output gate decides what parts of the cell state to output.
    
    o_t = σ(W_o · [h_{t-1}, x_t] + b_o)
    h_t = o_t * tanh(C_t)
    
    The cell state might store information that isn't immediately
    relevant but will be needed later.
    """
    
    print("Output Gate Example:")
    print("-" * 50)
    print("Cell state: [subject=John, visited_places=[Paris,London], mood=happy]")
    print()
    print("Current word: 'traveled' → Need location info")
    print("  Output gate: [subject: 0.3, places: 0.9, mood: 0.1]")
    print("  Hidden output emphasizes: places")
    print()
    print("Current word: 'smiled' → Need mood info")  
    print("  Output gate: [subject: 0.2, places: 0.1, mood: 0.95]")
    print("  Hidden output emphasizes: mood")

visualize_output_gate()
LSTM Gates Information Flow

Why LSTM Solves Vanishing Gradients

The Gradient Highway

def explain_gradient_highway():
    """
    The key to LSTM's success: The cell state update
    
    C_t = f_t * C_{t-1} + i_t * g_t
    
    The gradient of C_t w.r.t. C_{t-1} is simply f_t (forget gate).
    
    If f_t ≈ 1, gradients flow through unchanged!
    This is like a "gradient highway" - information and gradients
    can travel long distances without being squashed.
    """
    
    print("Gradient Flow Comparison:")
    print("-" * 60)
    print()
    print("Vanilla RNN gradient through T steps:")
    print("  ∂h_T/∂h_1 = ∏(W_hh * diag(tanh')) → vanishes!")
    print()
    print("LSTM gradient through T steps:")
    print("  ∂C_T/∂C_1 = ∏(f_t) → can stay close to 1!")
    print()
    print("Key insight: Forget gate learns to keep gradients alive")
    print("for important long-range dependencies.")
    
explain_gradient_highway()


def compare_gradient_flow():
    """Compare gradient flow in RNN vs LSTM."""
    
    seq_len = 100
    hidden_size = 64
    
    # Vanilla RNN
    rnn = nn.RNN(32, hidden_size, batch_first=True)
    
    # LSTM
    lstm = nn.LSTM(32, hidden_size, batch_first=True)
    
    x = torch.randn(1, seq_len, 32, requires_grad=True)
    
    # RNN gradient flow
    rnn_out, _ = rnn(x)
    rnn_loss = rnn_out[:, -1, :].sum()
    rnn_loss.backward()
    rnn_grad = x.grad[:, 0, :].norm().item()
    
    x.grad.zero_()
    
    # LSTM gradient flow
    lstm_out, _ = lstm(x)
    lstm_loss = lstm_out[:, -1, :].sum()
    lstm_loss.backward()
    lstm_grad = x.grad[:, 0, :].norm().item()
    
    print(f"\nGradient at first time step (sequence length {seq_len}):")
    print(f"  RNN:  {rnn_grad:.6f}")
    print(f"  LSTM: {lstm_grad:.6f}")
    print(f"  Ratio (LSTM/RNN): {lstm_grad/rnn_grad:.2f}x")

compare_gradient_flow()

Gated Recurrent Unit (GRU)

A Simpler Alternative

GRU simplifies LSTM by:
  1. Combining forget and input gates into an “update gate”
  2. Merging cell state and hidden state
GRU Cell Diagram

GRU Equations

zt=σ(Wz[ht1,xt])(Update gate)rt=σ(Wr[ht1,xt])(Reset gate)h~t=tanh(Wh[rtht1,xt])(Candidate)ht=(1zt)ht1+zth~t(Final state)\begin{aligned} z_t &= \sigma(W_z \cdot [h_{t-1}, x_t]) & \text{(Update gate)} \\ r_t &= \sigma(W_r \cdot [h_{t-1}, x_t]) & \text{(Reset gate)} \\ \tilde{h}_t &= \tanh(W_h \cdot [r_t \odot h_{t-1}, x_t]) & \text{(Candidate)} \\ h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t & \text{(Final state)} \end{aligned}
class GRUCell(nn.Module):
    """
    GRU Cell implemented from scratch.
    
    Compared to LSTM:
    - 2 gates instead of 3
    - No separate cell state
    - Fewer parameters
    - Often performs similarly
    """
    
    def __init__(self, input_size, hidden_size):
        super().__init__()
        
        self.hidden_size = hidden_size
        
        # Update and reset gates
        self.W_z = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_r = nn.Linear(input_size + hidden_size, hidden_size)
        
        # Candidate hidden state
        self.W_h = nn.Linear(input_size + hidden_size, hidden_size)
    
    def forward(self, x, h=None):
        """
        Args:
            x: Input (batch, input_size)
            h: Hidden state (batch, hidden_size)
        
        Returns:
            h_new: Updated hidden state
        """
        batch_size = x.size(0)
        
        if h is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
        
        # Concatenate input and hidden
        combined = torch.cat([x, h], dim=1)
        
        # Update gate: how much of new state to use
        z = torch.sigmoid(self.W_z(combined))
        
        # Reset gate: how much of old state to forget when computing candidate
        r = torch.sigmoid(self.W_r(combined))
        
        # Candidate hidden state
        combined_reset = torch.cat([x, r * h], dim=1)
        h_candidate = torch.tanh(self.W_h(combined_reset))
        
        # Final hidden state: interpolate between old and candidate
        h_new = (1 - z) * h + z * h_candidate
        
        return h_new


class GRU(nn.Module):
    """Complete GRU layer."""
    
    def __init__(self, input_size, hidden_size, num_layers=1, batch_first=True):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        
        self.cells = nn.ModuleList([
            GRUCell(input_size if i == 0 else hidden_size, hidden_size)
            for i in range(num_layers)
        ])
    
    def forward(self, x, h_0=None):
        if self.batch_first:
            batch_size, seq_len, _ = x.size()
        else:
            seq_len, batch_size, _ = x.size()
            x = x.transpose(0, 1)
        
        # Initialize hidden states
        if h_0 is None:
            h = [torch.zeros(batch_size, self.hidden_size, device=x.device)
                 for _ in range(self.num_layers)]
        else:
            h = [h_0[i] for i in range(self.num_layers)]
        
        outputs = []
        
        for t in range(seq_len):
            x_t = x[:, t, :]
            
            for layer, cell in enumerate(self.cells):
                h[layer] = cell(x_t, h[layer])
                x_t = h[layer]
            
            outputs.append(h[-1])
        
        outputs = torch.stack(outputs, dim=1)
        h_n = torch.stack(h, dim=0)
        
        if not self.batch_first:
            outputs = outputs.transpose(0, 1)
        
        return outputs, h_n


# Test GRU
gru = GRU(input_size=32, hidden_size=64, num_layers=2, batch_first=True)
x = torch.randn(8, 20, 32)
output, h_n = gru(x)

print(f"Input:  {x.shape}")
print(f"Output: {output.shape}")
print(f"Hidden: {h_n.shape}")

LSTM vs GRU Comparison

AspectLSTMGRU
Gates3 (forget, input, output)2 (update, reset)
States2 (hidden + cell)1 (hidden only)
Parameters4 × hidden²3 × hidden²
TrainingSlightly slowerFaster
PerformanceOften slightly better on complex tasksOften comparable
When to useLong sequences, complex dependenciesFaster training needed, simpler tasks
def compare_lstm_gru():
    """Compare LSTM and GRU architectures."""
    
    input_size = 128
    hidden_size = 256
    
    lstm = nn.LSTM(input_size, hidden_size, num_layers=2, batch_first=True)
    gru = nn.GRU(input_size, hidden_size, num_layers=2, batch_first=True)
    
    lstm_params = sum(p.numel() for p in lstm.parameters())
    gru_params = sum(p.numel() for p in gru.parameters())
    
    print("Parameter Comparison:")
    print(f"  LSTM: {lstm_params:,} parameters")
    print(f"  GRU:  {gru_params:,} parameters")
    print(f"  GRU is {100*(1 - gru_params/lstm_params):.1f}% smaller")
    
    # Speed comparison
    import time
    
    x = torch.randn(32, 100, input_size)
    
    # Warm up
    _ = lstm(x)
    _ = gru(x)
    
    # Time LSTM
    start = time.time()
    for _ in range(100):
        _ = lstm(x)
    lstm_time = time.time() - start
    
    # Time GRU
    start = time.time()
    for _ in range(100):
        _ = gru(x)
    gru_time = time.time() - start
    
    print(f"\nSpeed Comparison (100 forward passes):")
    print(f"  LSTM: {lstm_time:.3f}s")
    print(f"  GRU:  {gru_time:.3f}s")
    print(f"  GRU is {100*(1 - gru_time/lstm_time):.1f}% faster")

compare_lstm_gru()

Practical Applications

Sentiment Analysis with LSTM

class SentimentLSTM(nn.Module):
    """
    Sentiment analysis using bidirectional LSTM.
    """
    
    def __init__(self, vocab_size, embed_dim, hidden_dim, 
                 num_layers, num_classes, dropout=0.5):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        self.lstm = nn.LSTM(
            embed_dim, 
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        self.dropout = nn.Dropout(dropout)
        
        # Bidirectional → hidden_dim * 2
        self.fc = nn.Linear(hidden_dim * 2, num_classes)
    
    def forward(self, x, lengths=None):
        """
        Args:
            x: Token indices (batch, seq_len)
            lengths: Actual sequence lengths (for packing)
        """
        # Embed tokens
        embedded = self.dropout(self.embedding(x))
        
        # Pack sequences if lengths provided
        if lengths is not None:
            packed = nn.utils.rnn.pack_padded_sequence(
                embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            output, (hidden, cell) = self.lstm(packed)
        else:
            output, (hidden, cell) = self.lstm(embedded)
        
        # Concatenate final forward and backward hidden states
        # hidden shape: (num_layers * 2, batch, hidden_dim)
        hidden_forward = hidden[-2, :, :]  # Last layer, forward
        hidden_backward = hidden[-1, :, :]  # Last layer, backward
        hidden_cat = torch.cat([hidden_forward, hidden_backward], dim=1)
        
        # Classify
        output = self.fc(self.dropout(hidden_cat))
        
        return output


# Example usage
model = SentimentLSTM(
    vocab_size=10000,
    embed_dim=128,
    hidden_dim=256,
    num_layers=2,
    num_classes=2
)

# Simulated batch
batch = torch.randint(1, 10000, (16, 100))  # 16 sequences, length 100
predictions = model(batch)

print(f"Input: {batch.shape}")
print(f"Predictions: {predictions.shape}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Language Modeling with LSTM

class LanguageModelLSTM(nn.Module):
    """
    Language model: Predict next word given previous words.
    """
    
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, 
                 dropout=0.5, tie_weights=True):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
        # Weight tying: share embedding and output weights
        if tie_weights and embed_dim == hidden_dim:
            self.fc.weight = self.embedding.weight
    
    def forward(self, x, hidden=None):
        """
        Args:
            x: Token indices (batch, seq_len)
            hidden: (h, c) tuple for stateful processing
        
        Returns:
            logits: (batch, seq_len, vocab_size)
            hidden: Updated hidden state
        """
        embedded = self.dropout(self.embedding(x))
        output, hidden = self.lstm(embedded, hidden)
        output = self.dropout(output)
        logits = self.fc(output)
        
        return logits, hidden
    
    def init_hidden(self, batch_size, device='cpu'):
        """Initialize hidden state."""
        h = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)
        c = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)
        return (h, c)
    
    def generate(self, start_tokens, max_length=100, temperature=1.0):
        """Generate text continuation."""
        self.eval()
        
        batch_size = start_tokens.size(0)
        device = start_tokens.device
        
        hidden = self.init_hidden(batch_size, device)
        generated = start_tokens.tolist()
        
        current = start_tokens
        
        with torch.no_grad():
            for _ in range(max_length):
                logits, hidden = self.forward(current, hidden)
                
                # Get last token prediction
                logits = logits[:, -1, :] / temperature
                probs = torch.softmax(logits, dim=-1)
                
                # Sample
                next_token = torch.multinomial(probs, 1)
                
                for i in range(batch_size):
                    generated[i].append(next_token[i].item())
                
                current = next_token
        
        return generated


# Create model
lm = LanguageModelLSTM(
    vocab_size=50000,
    embed_dim=512,
    hidden_dim=512,
    num_layers=3,
    dropout=0.3
)

# Training example
x = torch.randint(0, 50000, (8, 35))  # batch=8, seq_len=35
y = torch.randint(0, 50000, (8, 35))  # targets (shifted by 1)

logits, _ = lm(x)
loss = nn.CrossEntropyLoss()(logits.view(-1, 50000), y.view(-1))

print(f"Input: {x.shape}")
print(f"Output: {logits.shape}")
print(f"Loss: {loss.item():.4f}")
print(f"Perplexity: {torch.exp(loss).item():.2f}")

Sequence-to-Sequence Translation

class Encoder(nn.Module):
    """LSTM Encoder for seq2seq."""
    
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim, hidden_dim, num_layers,
            batch_first=True, dropout=dropout, bidirectional=True
        )
        self.dropout = nn.Dropout(dropout)
        
        # Project bidirectional to unidirectional for decoder
        self.fc_h = nn.Linear(hidden_dim * 2, hidden_dim)
        self.fc_c = nn.Linear(hidden_dim * 2, hidden_dim)
    
    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, (hidden, cell) = self.lstm(embedded)
        
        # Combine bidirectional states
        # hidden: (num_layers * 2, batch, hidden_dim)
        hidden = hidden.view(self.lstm.num_layers, 2, -1, self.lstm.hidden_size)
        hidden = torch.cat([hidden[:, 0], hidden[:, 1]], dim=-1)
        hidden = torch.tanh(self.fc_h(hidden))
        
        cell = cell.view(self.lstm.num_layers, 2, -1, self.lstm.hidden_size)
        cell = torch.cat([cell[:, 0], cell[:, 1]], dim=-1)
        cell = torch.tanh(self.fc_c(cell))
        
        return outputs, (hidden, cell)


class Decoder(nn.Module):
    """LSTM Decoder for seq2seq."""
    
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout):
        super().__init__()
        
        self.vocab_size = vocab_size
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim, hidden_dim, num_layers,
            batch_first=True, dropout=dropout
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, tgt, hidden):
        embedded = self.dropout(self.embedding(tgt))
        output, hidden = self.lstm(embedded, hidden)
        prediction = self.fc(output)
        return prediction, hidden


class Seq2Seq(nn.Module):
    """Complete sequence-to-sequence model."""
    
    def __init__(self, encoder, decoder):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        """
        Args:
            src: Source sequence (batch, src_len)
            tgt: Target sequence (batch, tgt_len)
            teacher_forcing_ratio: Probability of using ground truth
        """
        batch_size = src.size(0)
        tgt_len = tgt.size(1)
        tgt_vocab_size = self.decoder.vocab_size
        
        # Store outputs
        outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size, device=src.device)
        
        # Encode source
        encoder_outputs, hidden = self.encoder(src)
        
        # First decoder input is <sos> token
        decoder_input = tgt[:, 0:1]
        
        for t in range(1, tgt_len):
            output, hidden = self.decoder(decoder_input, hidden)
            outputs[:, t] = output.squeeze(1)
            
            # Teacher forcing
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            
            if teacher_force:
                decoder_input = tgt[:, t:t+1]
            else:
                decoder_input = output.argmax(-1)
        
        return outputs


# Create seq2seq model
encoder = Encoder(vocab_size=10000, embed_dim=256, hidden_dim=512, 
                  num_layers=2, dropout=0.3)
decoder = Decoder(vocab_size=8000, embed_dim=256, hidden_dim=512,
                  num_layers=2, dropout=0.3)

model = Seq2Seq(encoder, decoder)

# Example
src = torch.randint(0, 10000, (16, 30))  # Source: 16 sentences, max 30 tokens
tgt = torch.randint(0, 8000, (16, 25))   # Target: 16 sentences, max 25 tokens

output = model(src, tgt)
print(f"Source: {src.shape}")
print(f"Target: {tgt.shape}")
print(f"Output: {output.shape}")

Advanced LSTM Variants

Peephole Connections

class PeepholeLSTMCell(nn.Module):
    """
    LSTM with peephole connections.
    
    Gates can directly look at the cell state, not just hidden state.
    This can help with precise timing and counting.
    """
    
    def __init__(self, input_size, hidden_size):
        super().__init__()
        
        self.hidden_size = hidden_size
        
        # Standard weights
        self.W_i = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_f = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_o = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_c = nn.Linear(input_size + hidden_size, hidden_size)
        
        # Peephole weights (diagonal, so just vectors)
        self.p_i = nn.Parameter(torch.randn(hidden_size))
        self.p_f = nn.Parameter(torch.randn(hidden_size))
        self.p_o = nn.Parameter(torch.randn(hidden_size))
    
    def forward(self, x, hidden=None):
        batch_size = x.size(0)
        
        if hidden is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
            c = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            h, c = hidden
        
        combined = torch.cat([x, h], dim=1)
        
        # Gates with peephole connections to cell state
        f = torch.sigmoid(self.W_f(combined) + self.p_f * c)
        i = torch.sigmoid(self.W_i(combined) + self.p_i * c)
        
        c_candidate = torch.tanh(self.W_c(combined))
        c_new = f * c + i * c_candidate
        
        o = torch.sigmoid(self.W_o(combined) + self.p_o * c_new)
        h_new = o * torch.tanh(c_new)
        
        return h_new, c_new

Layer Normalization in LSTM

class LayerNormLSTMCell(nn.Module):
    """
    LSTM with layer normalization for better training stability.
    """
    
    def __init__(self, input_size, hidden_size):
        super().__init__()
        
        self.hidden_size = hidden_size
        
        self.W_ih = nn.Linear(input_size, 4 * hidden_size, bias=False)
        self.W_hh = nn.Linear(hidden_size, 4 * hidden_size, bias=False)
        
        # Layer normalization
        self.ln_ih = nn.LayerNorm(4 * hidden_size)
        self.ln_hh = nn.LayerNorm(4 * hidden_size)
        self.ln_cell = nn.LayerNorm(hidden_size)
    
    def forward(self, x, hidden=None):
        batch_size = x.size(0)
        
        if hidden is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
            c = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            h, c = hidden
        
        # Apply layer norm to gate inputs
        gates = self.ln_ih(self.W_ih(x)) + self.ln_hh(self.W_hh(h))
        
        i, f, g, o = gates.chunk(4, dim=1)
        
        i = torch.sigmoid(i)
        f = torch.sigmoid(f)
        g = torch.tanh(g)
        o = torch.sigmoid(o)
        
        c_new = f * c + i * g
        
        # Layer norm on cell state before output
        h_new = o * torch.tanh(self.ln_cell(c_new))
        
        return h_new, c_new

Best Practices and Tips

Training LSTM/GRU Models:
  1. Gradient clipping: Always use torch.nn.utils.clip_grad_norm_ with max_norm=1.0-5.0
  2. Learning rate: Start with 0.001 for Adam, 0.1-1.0 for SGD
  3. Dropout: Use 0.2-0.5 between layers, not within cells
  4. Initialization: PyTorch defaults are usually fine
  5. Bidirectional: Use for tasks where you have full sequence (classification, tagging)
Common Mistakes:
  1. Not detaching hidden state: For stateful training, detach hidden state between batches:
    hidden = (hidden[0].detach(), hidden[1].detach())
    
  2. Ignoring sequence lengths: Use pack_padded_sequence for variable-length sequences
  3. Wrong hidden state indexing: For bidirectional, hidden[-2:] contains last layer
  4. Too many layers: 2-3 layers usually sufficient; more can hurt

Exercises

Implement a complete LSTM without using nn.LSTM:
  1. Implement LSTMCell with all gates
  2. Stack cells into LSTM layer
  3. Add bidirectional support
  4. Verify outputs match nn.LSTM
Test on a simple sequence classification task.
Create a visualization of gradient flow through LSTM:
  1. Process sequences of length 10, 50, 100, 200
  2. Track gradient magnitude at each time step
  3. Compare with vanilla RNN
  4. Plot the results
What do you observe about the forget gate values in trained models?
Build an NER tagger using BiLSTM:
  1. Load CoNLL-2003 dataset
  2. Implement BiLSTM-CRF model
  3. Train with proper evaluation (F1 score)
  4. Analyze errors by entity type
Compare BiLSTM with BiGRU.
Train an LSTM to generate music:
  1. Download MIDI files and convert to sequences
  2. Train character-level LSTM on ABC notation
  3. Generate new melodies
  4. Convert back to MIDI and listen
Experiment with different temperatures.
Build a multivariate time series forecaster:
  1. Use a dataset like air quality or stock prices
  2. Implement encoder-decoder with LSTM
  3. Add attention (preview of next chapter!)
  4. Compare with simple baselines
Evaluate with proper time series cross-validation.

Key Takeaways

ConceptKey Insight
Cell StateLong-term memory highway - gradients flow unimpeded
Forget GateLearns what to erase - enables “forgetting” of irrelevant info
Input GateLearns what to remember - filters new information
Output GateLearns what to reveal - controls hidden state exposure
GRUSimpler, fewer parameters, often similar performance
Gradient HighwayCell state update is additive → gradients don’t vanish

What’s Next