Skip to main content

Documentation Index

Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt

Use this file to discover all available pages before exploring further.

Continual Learning

Continual Learning

The Catastrophic Forgetting Problem

Imagine you spent a year becoming fluent in French, then spent a year learning Mandarin, and when you tried to speak French again you could barely string a sentence together. That is catastrophic forgetting — and it is exactly what happens to neural networks. Unlike humans, who can (mostly) retain old skills while acquiring new ones, standard neural networks overwrite old knowledge when optimized for new data. The weights that encoded “how to recognize cats” get repurposed for “how to recognize medical X-rays,” and the cat knowledge evaporates. This is not an edge case. It is the fundamental obstacle to deploying ML models in the real world, where data distributions shift, new classes appear, and retraining from scratch every time is prohibitively expensive.
ScenarioChallengeReal-world Example
Sequential tasksModel forgets task A when learning task BA spam filter trained on new attack patterns forgets old ones
Streaming dataDistribution shift over timeA recommendation model degrades as user preferences evolve
Class-incrementalNew classes added continuouslyAn object detector that must recognize new product categories
Domain adaptationSame task, new domainA medical model moving from one hospital’s imaging equipment to another
A common misconception is that fine-tuning with a small learning rate avoids catastrophic forgetting. It does not — it merely slows it down. Even with a learning rate 100x smaller, a model fine-tuned on task B will eventually lose task A performance. The learning rate controls the speed of forgetting, not whether it occurs.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Optional, Dict, List, Tuple, Callable
from dataclasses import dataclass
from copy import deepcopy
import numpy as np

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Measuring Forgetting

@dataclass
class ContinualMetrics:
    """
    Metrics for continual learning evaluation.
    """
    accuracy_matrix: np.ndarray  # A[i,j] = acc on task j after learning task i
    
    @property
    def average_accuracy(self) -> float:
        """Average accuracy across all tasks after all training."""
        return np.mean(self.accuracy_matrix[-1])
    
    @property
    def forgetting(self) -> float:
        """
        Average forgetting measure.
        
        For each task j: max accuracy on j - final accuracy on j
        """
        n_tasks = self.accuracy_matrix.shape[1]
        forgetting = []
        
        for j in range(n_tasks - 1):  # Exclude last task (no forgetting possible)
            max_acc = np.max(self.accuracy_matrix[j:, j])
            final_acc = self.accuracy_matrix[-1, j]
            forgetting.append(max_acc - final_acc)
        
        return np.mean(forgetting)
    
    @property
    def backward_transfer(self) -> float:
        """
        How much learning new tasks affects old ones.
        
        Negative = forgetting, Positive = improvement
        """
        n_tasks = self.accuracy_matrix.shape[1]
        bt = []
        
        for j in range(n_tasks - 1):
            # Accuracy right after learning j vs final accuracy
            initial_acc = self.accuracy_matrix[j, j]
            final_acc = self.accuracy_matrix[-1, j]
            bt.append(final_acc - initial_acc)
        
        return np.mean(bt)
    
    @property
    def forward_transfer(self) -> float:
        """
        How much previous learning helps new tasks.
        """
        n_tasks = self.accuracy_matrix.shape[1]
        ft = []
        
        for j in range(1, n_tasks):
            # Performance on task j before learning it
            before_acc = self.accuracy_matrix[j-1, j]
            ft.append(before_acc)
        
        return np.mean(ft)


class ContinualEvaluator:
    """
    Evaluate continual learning methods.
    """
    
    def __init__(self, task_datasets: List[Tuple]):
        """
        Args:
            task_datasets: List of (train_loader, test_loader) per task
        """
        self.task_datasets = task_datasets
        self.n_tasks = len(task_datasets)
        self.accuracy_matrix = np.zeros((self.n_tasks, self.n_tasks))
    
    def evaluate_all_tasks(self, model: nn.Module, current_task: int):
        """Evaluate model on all tasks seen so far."""
        model.eval()
        
        for task_id in range(current_task + 1):
            _, test_loader = self.task_datasets[task_id]
            
            correct = 0
            total = 0
            
            with torch.no_grad():
                for inputs, labels in test_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs, task_id=task_id)
                    _, predicted = outputs.max(1)
                    
                    correct += (predicted == labels).sum().item()
                    total += labels.size(0)
            
            self.accuracy_matrix[current_task, task_id] = correct / total
    
    def get_metrics(self) -> ContinualMetrics:
        return ContinualMetrics(self.accuracy_matrix.copy())

Regularization-Based Methods

Elastic Weight Consolidation (EWC)

class EWC:
    """
    Elastic Weight Consolidation.
    
    Key idea: Not all weights are equally important for a given task.
    Some weights are critical (the model's output changes dramatically
    if they move), while others are nearly irrelevant. EWC measures
    each weight's importance using the Fisher Information Matrix, then
    adds a penalty during training on new tasks that is proportional
    to how much each weight deviates from its "old task" value, weighted
    by its importance.
    
    Analogy: Imagine you are rearranging furniture (weights) in a room
    to fit a new desk (new task). EWC tells you: "the couch and TV
    (high Fisher) are load-bearing for the old layout -- try not to move
    them. The side table and lamp (low Fisher) can move freely."
    
    Importance measured by Fisher Information Matrix (diagonal approximation).
    
    Reference: "Overcoming catastrophic forgetting in neural networks"
               (Kirkpatrick et al., 2017)
    """
    
    def __init__(
        self,
        model: nn.Module,
        lambda_ewc: float = 5000
    ):
        self.model = model
        self.lambda_ewc = lambda_ewc
        
        # Storage for previous task parameters and importance
        self.saved_params = {}
        self.fisher = {}
    
    def compute_fisher(
        self,
        data_loader,
        n_samples: int = 200
    ):
        """
        Compute Fisher Information Matrix (diagonal approximation).
        
        Fisher[i] = E[gradient[i]^2]
        
        High Fisher = parameter is important for current task
        """
        self.model.eval()
        
        # Initialize Fisher
        fisher = {
            name: torch.zeros_like(param)
            for name, param in self.model.named_parameters()
            if param.requires_grad
        }
        
        n_processed = 0
        
        for inputs, labels in data_loader:
            if n_processed >= n_samples:
                break
            
            inputs, labels = inputs.to(device), labels.to(device)
            
            self.model.zero_grad()
            outputs = self.model(inputs)
            
            # Use log-likelihood (negative cross entropy)
            # Sample from output distribution for proper Fisher
            probs = F.softmax(outputs, dim=1)
            sampled_labels = torch.multinomial(probs, 1).squeeze()
            loss = F.cross_entropy(outputs, sampled_labels)
            loss.backward()
            
            # Accumulate squared gradients
            for name, param in self.model.named_parameters():
                if param.requires_grad and param.grad is not None:
                    fisher[name] += param.grad.data.pow(2)
            
            n_processed += inputs.size(0)
        
        # Normalize
        for name in fisher:
            fisher[name] /= n_processed
        
        return fisher
    
    def register_task(self, data_loader, task_id: int):
        """
        Register completion of a task.
        
        Saves parameters and computes Fisher information.
        """
        # Compute Fisher for current task
        fisher = self.compute_fisher(data_loader)
        
        # Merge with previous Fisher (online EWC)
        if len(self.fisher) > 0:
            for name in fisher:
                self.fisher[name] = (
                    self.fisher[name] + fisher[name]
                )
        else:
            self.fisher = fisher
        
        # Save current parameters
        self.saved_params[task_id] = {
            name: param.data.clone()
            for name, param in self.model.named_parameters()
            if param.requires_grad
        }
    
    def penalty(self) -> torch.Tensor:
        """
        Compute EWC penalty.
        
        penalty = sum_i F_i * (theta_i - theta_i^*)^2
        """
        if len(self.saved_params) == 0:
            return torch.tensor(0.0).to(device)
        
        loss = torch.tensor(0.0).to(device)
        
        for name, param in self.model.named_parameters():
            if name in self.fisher:
                # Average over all previous tasks
                for task_id in self.saved_params:
                    old_param = self.saved_params[task_id][name]
                    
                    loss += (
                        self.fisher[name] * 
                        (param - old_param).pow(2)
                    ).sum()
        
        return self.lambda_ewc * loss / 2


class EWCTrainer:
    """Training loop with EWC."""
    
    def __init__(
        self,
        model: nn.Module,
        lambda_ewc: float = 5000
    ):
        self.model = model
        self.ewc = EWC(model, lambda_ewc)
    
    def train_task(
        self,
        train_loader,
        task_id: int,
        epochs: int = 10,
        lr: float = 0.001
    ):
        """Train on a single task with EWC regularization."""
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        self.model.train()
        
        for epoch in range(epochs):
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                optimizer.zero_grad()
                
                outputs = self.model(inputs)
                
                # Task loss
                task_loss = F.cross_entropy(outputs, labels)
                
                # EWC penalty
                ewc_loss = self.ewc.penalty()
                
                total_loss = task_loss + ewc_loss
                
                total_loss.backward()
                optimizer.step()
        
        # Register task completion
        self.ewc.register_task(train_loader, task_id)

Synaptic Intelligence (SI)

class SynapticIntelligence:
    """
    Synaptic Intelligence: Online importance estimation.
    
    Key idea: Track contribution of each parameter to loss reduction
    during training (path integral of gradients).
    
    Reference: "Continual Learning Through Synaptic Intelligence"
               (Zenke et al., 2017)
    """
    
    def __init__(
        self,
        model: nn.Module,
        lambda_si: float = 1.0,
        epsilon: float = 1e-3
    ):
        self.model = model
        self.lambda_si = lambda_si
        self.epsilon = epsilon
        
        # Initialize tracking
        self.omega = {}  # Importance weights
        self.old_params = {}  # Parameters at task start
        self.w = {}  # Running importance estimate
        
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.omega[name] = torch.zeros_like(param)
                self.old_params[name] = param.data.clone()
                self.w[name] = torch.zeros_like(param)
    
    def update_omega(self):
        """Update importance weights after task."""
        for name, param in self.model.named_parameters():
            if name in self.omega:
                delta = param.data - self.old_params[name]
                
                # Omega = accumulated gradient * parameter change / (delta^2 + eps)
                self.omega[name] += self.w[name] / (delta.pow(2) + self.epsilon)
                
                # Reset for next task
                self.old_params[name] = param.data.clone()
                self.w[name].zero_()
    
    def update_w(self):
        """Update running importance during training."""
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None:
                # Accumulate gradient * negative parameter change
                # (gradient points toward loss increase)
                self.w[name] -= param.grad.data * (
                    param.data - self.old_params[name]
                )
    
    def penalty(self) -> torch.Tensor:
        """Compute SI penalty."""
        loss = torch.tensor(0.0).to(device)
        
        for name, param in self.model.named_parameters():
            if name in self.omega:
                loss += (
                    self.omega[name] * 
                    (param - self.old_params[name]).pow(2)
                ).sum()
        
        return self.lambda_si * loss

Replay-Based Methods

Replay methods take a completely different approach from regularization: instead of constraining how the model changes, they maintain a small memory of past examples and mix them into training data for new tasks. This is directly inspired by the neuroscience theory of memory consolidation, where the hippocampus “replays” past experiences during sleep to transfer them to long-term cortical storage. The simplest version — experience replay — is often the strongest baseline. Before reaching for sophisticated methods like EWC or PackNet, always compare against replay with a reasonable buffer size. It is embarrassingly effective.
When using experience replay, how you select which examples to store matters more than the buffer size. Random selection is a decent default, but “herding” (storing examples closest to the class mean in feature space) or “k-center coreset” selection (maximizing coverage of the feature space) can improve results by 2-5% with the same buffer size. Also, balancing the buffer equally across classes prevents the model from becoming biased toward recently seen classes.

Experience Replay

class ExperienceReplay:
    """
    Store and replay examples from previous tasks.
    
    Simple but surprisingly effective baseline. In many benchmarks,
    experience replay with a buffer of just 500-1000 examples per task
    outperforms more sophisticated regularization methods.
    
    Uses reservoir sampling to maintain a representative subset when
    the total number of observed examples exceeds buffer capacity.
    """
    
    def __init__(
        self,
        buffer_size: int = 5000,
        samples_per_class: Optional[int] = None
    ):
        self.buffer_size = buffer_size
        self.samples_per_class = samples_per_class
        
        self.buffer_x = []
        self.buffer_y = []
        self.buffer_task = []
    
    def add(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        task_id: int
    ):
        """Add samples to buffer."""
        for i in range(x.size(0)):
            if len(self.buffer_x) < self.buffer_size:
                self.buffer_x.append(x[i].cpu())
                self.buffer_y.append(y[i].cpu())
                self.buffer_task.append(task_id)
            else:
                # Reservoir sampling
                idx = np.random.randint(0, len(self.buffer_x) + 1)
                if idx < self.buffer_size:
                    self.buffer_x[idx] = x[i].cpu()
                    self.buffer_y[idx] = y[i].cpu()
                    self.buffer_task[idx] = task_id
    
    def sample(
        self,
        batch_size: int
    ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
        """Sample a batch from buffer."""
        if len(self.buffer_x) == 0:
            return None, None, None
        
        indices = np.random.choice(
            len(self.buffer_x),
            min(batch_size, len(self.buffer_x)),
            replace=False
        )
        
        x = torch.stack([self.buffer_x[i] for i in indices])
        y = torch.tensor([self.buffer_y[i] for i in indices])
        tasks = [self.buffer_task[i] for i in indices]
        
        return x.to(device), y.to(device), tasks


class GradientEpisodicMemory:
    """
    Gradient Episodic Memory (GEM).
    
    Key idea: Project gradients to not increase loss on previous tasks.
    
    Reference: "Gradient Episodic Memory for Continual Learning"
               (Lopez-Paz & Ranzato, 2017)
    """
    
    def __init__(
        self,
        model: nn.Module,
        memory_size: int = 256
    ):
        self.model = model
        self.memory_size = memory_size
        
        self.memory = {}  # {task_id: (x, y)}
        self.memory_gradients = {}
    
    def store_task_memory(
        self,
        data_loader,
        task_id: int
    ):
        """Store examples from completed task."""
        x_mem, y_mem = [], []
        
        for x, y in data_loader:
            x_mem.append(x)
            y_mem.append(y)
            
            if sum(len(batch) for batch in x_mem) >= self.memory_size:
                break
        
        x_mem = torch.cat(x_mem)[:self.memory_size]
        y_mem = torch.cat(y_mem)[:self.memory_size]
        
        self.memory[task_id] = (x_mem.to(device), y_mem.to(device))
    
    def compute_memory_gradients(self):
        """Compute gradients for all memory samples."""
        self.memory_gradients = {}
        
        for task_id, (x, y) in self.memory.items():
            self.model.zero_grad()
            
            outputs = self.model(x)
            loss = F.cross_entropy(outputs, y)
            loss.backward()
            
            # Flatten gradients
            grad = torch.cat([
                p.grad.view(-1) 
                for p in self.model.parameters() 
                if p.grad is not None
            ])
            
            self.memory_gradients[task_id] = grad.clone()
    
    def project_gradient(self, current_grad: torch.Tensor) -> torch.Tensor:
        """
        Project current gradient to satisfy constraints.
        
        Constraint: g · g_memory >= 0 for all memory gradients
        """
        if len(self.memory_gradients) == 0:
            return current_grad
        
        mem_grads = torch.stack(list(self.memory_gradients.values()))
        
        # Check constraint violations
        dotprods = torch.mv(mem_grads, current_grad)
        
        if (dotprods >= 0).all():
            return current_grad
        
        # Project using quadratic programming
        # Simplified: project onto each violated constraint
        projected = current_grad.clone()
        
        for i, dotprod in enumerate(dotprods):
            if dotprod < 0:
                # Remove component along memory gradient
                mem_grad = mem_grads[i]
                projected -= (dotprod / (mem_grad.norm().pow(2) + 1e-8)) * mem_grad
        
        return projected


class AveragedGEM:
    """
    A-GEM: Efficient version of GEM.
    
    Uses average of memory gradients instead of all constraints.
    """
    
    def __init__(
        self,
        model: nn.Module,
        memory_size: int = 256
    ):
        self.model = model
        self.memory = ExperienceReplay(memory_size)
    
    def project_gradient(self):
        """Project current gradient using A-GEM."""
        if len(self.memory.buffer_x) == 0:
            return
        
        # Get current gradient
        current_grad = torch.cat([
            p.grad.view(-1) 
            for p in self.model.parameters() 
            if p.grad is not None
        ])
        
        # Compute reference gradient on memory
        x_mem, y_mem, _ = self.memory.sample(batch_size=256)
        
        self.model.zero_grad()
        outputs = self.model(x_mem)
        loss = F.cross_entropy(outputs, y_mem)
        loss.backward()
        
        ref_grad = torch.cat([
            p.grad.view(-1) 
            for p in self.model.parameters() 
            if p.grad is not None
        ])
        
        # Check if constraint is violated
        dotprod = torch.dot(current_grad, ref_grad)
        
        if dotprod < 0:
            # Project
            projected = current_grad - (
                dotprod / (ref_grad.norm().pow(2) + 1e-8)
            ) * ref_grad
            
            # Apply projected gradient
            offset = 0
            for p in self.model.parameters():
                if p.grad is not None:
                    numel = p.numel()
                    p.grad.data.copy_(
                        projected[offset:offset + numel].view_as(p)
                    )
                    offset += numel

Generative Replay

Generative replay is an elegant alternative to storing real examples: instead of maintaining a buffer, train a generative model (VAE, GAN, diffusion model) alongside the classifier. When learning a new task, the generator produces synthetic examples from previous tasks that are mixed into training. The generator itself must also be trained continually, so you are solving two continual learning problems at once — but the payoff is zero storage of real data, which matters in privacy-sensitive domains like healthcare.
Generative replay sounds appealing in theory, but in practice the quality of the generator matters enormously. If the generator produces low-fidelity samples (blurry images, mode collapse), the classifier trained on them will degrade. For simple datasets (MNIST, Fashion-MNIST), generative replay works well. For complex datasets (CIFAR-100, ImageNet), the generator itself becomes the bottleneck. Modern diffusion models have improved this significantly, but the compute cost of maintaining a diffusion model alongside your classifier may exceed the cost of simply storing a replay buffer.
class GenerativeReplay(nn.Module):
    """
    Generate pseudo-examples from previous tasks using a learned generator.
    
    Architecture: VAE (or GAN/diffusion model) that learns to generate
    samples representing the distribution of each task seen so far.
    
    Trade-offs:
    - Pro: No real data storage needed (privacy-friendly)
    - Con: Generator quality limits classifier performance
    - Con: Training two models (generator + classifier) continually is harder
    """
    
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int = 256,
        latent_dim: int = 32
    ):
        super().__init__()
        
        # VAE for generating replay samples
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_var = nn.Linear(hidden_dim, latent_dim)
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
        
        self.input_dim = input_dim
        self.latent_dim = latent_dim
    
    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        h = self.encoder(x.view(-1, self.input_dim))
        return self.fc_mu(h), self.fc_var(h)
    
    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z: torch.Tensor) -> torch.Tensor:
        return self.decoder(z)
    
    def generate(self, n_samples: int) -> torch.Tensor:
        """Generate replay samples."""
        z = torch.randn(n_samples, self.latent_dim).to(device)
        return self.decode(z)
    
    def vae_loss(self, x: torch.Tensor) -> torch.Tensor:
        """VAE loss for training generator."""
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        
        recon_loss = F.mse_loss(recon, x.view(-1, self.input_dim))
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        
        return recon_loss + 0.1 * kl_loss


class DGRTrainer:
    """
    Deep Generative Replay training.
    
    Train classifier + generator together.
    Generate old task data using generator.
    """
    
    def __init__(
        self,
        classifier: nn.Module,
        generator: GenerativeReplay,
        replay_ratio: float = 0.5
    ):
        self.classifier = classifier
        self.generator = generator
        self.replay_ratio = replay_ratio
        
        # Store old classifier for pseudo-labeling
        self.old_classifier = None
    
    def train_task(
        self,
        train_loader,
        task_id: int,
        epochs: int = 10,
        lr: float = 0.001
    ):
        """Train on task with generative replay."""
        
        clf_optimizer = optim.Adam(self.classifier.parameters(), lr=lr)
        gen_optimizer = optim.Adam(self.generator.parameters(), lr=lr)
        
        for epoch in range(epochs):
            self.classifier.train()
            self.generator.train()
            
            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                batch_size = x.size(0)
                
                # Generate replay samples
                if task_id > 0 and self.old_classifier is not None:
                    n_replay = int(batch_size * self.replay_ratio)
                    
                    with torch.no_grad():
                        x_replay = self.generator.generate(n_replay)
                        # Pseudo-labels from old classifier
                        y_replay = self.old_classifier(x_replay).argmax(dim=1)
                    
                    # Combine current and replay data
                    x = torch.cat([x.view(batch_size, -1), x_replay])
                    y = torch.cat([y, y_replay])
                
                # Train classifier
                clf_optimizer.zero_grad()
                outputs = self.classifier(x.view(x.size(0), -1))
                clf_loss = F.cross_entropy(outputs, y)
                clf_loss.backward()
                clf_optimizer.step()
                
                # Train generator on current task data
                gen_optimizer.zero_grad()
                gen_loss = self.generator.vae_loss(x[:batch_size])
                gen_loss.backward()
                gen_optimizer.step()
        
        # Store classifier for next task
        self.old_classifier = deepcopy(self.classifier)
        self.old_classifier.eval()

Architecture-Based Methods

Architecture-based methods take the most direct approach to preventing forgetting: give each task its own parameters. The old parameters are literally frozen — they cannot be modified, so they cannot be forgotten. The challenge shifts from “how to avoid forgetting” to “how to share knowledge across tasks without running out of capacity.” Think of it like a library: regularization methods try to write new books without erasing old ones (hard). Replay methods keep photocopies of old books (memory cost). Architecture methods add new bookshelves for new topics, while allowing readers to reference old shelves (capacity cost). Each approach trades a different resource.

Progressive Neural Networks

class ProgressiveNeuralNetwork(nn.Module):
    """
    Progressive Neural Networks.
    
    Key idea: Add new columns (subnetworks) for new tasks,
    connect laterally to previous columns so new tasks can
    leverage features learned by old tasks.
    
    Zero forgetting (old columns are frozen -- literally impossible
    to forget because old weights never change).
    
    Trade-off: Model size grows linearly with the number of tasks.
    After 100 tasks, the model is 100x its original size.
    
    Reference: "Progressive Neural Networks" (Rusu et al., 2016)
    """
    
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int
    ):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        # Columns for each task
        self.columns = nn.ModuleList()
        
        # Lateral connections
        self.lateral = nn.ModuleList()
    
    def add_column(self):
        """Add a new column for a new task."""
        n_cols = len(self.columns)
        
        # New column
        column = nn.ModuleDict({
            'layer1': nn.Linear(self.input_dim, self.hidden_dim),
            'layer2': nn.Linear(self.hidden_dim, self.hidden_dim),
            'head': nn.Linear(self.hidden_dim, self.output_dim)
        })
        
        self.columns.append(column)
        
        # Lateral connections from previous columns
        if n_cols > 0:
            lateral = nn.ModuleDict({
                'lateral1': nn.Linear(n_cols * self.hidden_dim, self.hidden_dim),
                'lateral2': nn.Linear(n_cols * self.hidden_dim, self.hidden_dim)
            })
            self.lateral.append(lateral)
        
        # Freeze all previous columns
        for i, col in enumerate(self.columns[:-1]):
            for param in col.parameters():
                param.requires_grad = False
    
    def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor:
        """Forward through column for specific task."""
        
        # Collect all hidden activations
        h1_all = []
        h2_all = []
        
        for col_id in range(task_id + 1):
            col = self.columns[col_id]
            
            # First layer
            h1 = F.relu(col['layer1'](x))
            
            # Add lateral connections
            if col_id > 0:
                lateral_input = torch.cat(h1_all, dim=1)
                h1 = h1 + self.lateral[col_id - 1]['lateral1'](lateral_input)
            
            h1_all.append(h1)
            
            # Second layer
            h2 = F.relu(col['layer2'](h1))
            
            if col_id > 0:
                lateral_input = torch.cat(h2_all, dim=1)
                h2 = h2 + self.lateral[col_id - 1]['lateral2'](lateral_input)
            
            h2_all.append(h2)
        
        # Output from target task column
        output = self.columns[task_id]['head'](h2_all[task_id])
        
        return output


class PackNet:
    """
    PackNet: Prune and reuse for continual learning.
    
    Key idea: 
    1. Train on task
    2. Prune unimportant weights
    3. Freeze pruned network
    4. Retrain remaining weights for new task
    
    Reference: "PackNet: Adding Multiple Tasks to a Single Network"
               (Mallya & Lazebnik, 2018)
    """
    
    def __init__(
        self,
        model: nn.Module,
        prune_ratio: float = 0.75
    ):
        self.model = model
        self.prune_ratio = prune_ratio
        
        # Track which weights are used by which tasks
        self.masks = {}  # {task_id: {param_name: mask}}
        self.available_mask = {}  # Which weights are still available
        
        # Initialize available mask
        for name, param in model.named_parameters():
            self.available_mask[name] = torch.ones_like(param, dtype=torch.bool)
    
    def compute_importance(self) -> Dict[str, torch.Tensor]:
        """Compute weight importance (magnitude-based)."""
        importance = {}
        
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                importance[name] = param.data.abs()
        
        return importance
    
    def prune_for_task(self, task_id: int):
        """Prune weights after training on task."""
        importance = self.compute_importance()
        
        task_mask = {}
        
        for name, imp in importance.items():
            available = self.available_mask[name]
            
            # Only consider available weights
            imp_available = imp * available.float()
            
            # Find threshold for pruning
            n_keep = int((1 - self.prune_ratio) * available.sum().item())
            if n_keep > 0:
                threshold = torch.kthvalue(
                    imp_available[available].flatten(),
                    len(imp_available[available].flatten()) - n_keep
                )[0]
                
                # Keep weights above threshold
                mask = imp_available >= threshold
            else:
                mask = torch.zeros_like(available)
            
            task_mask[name] = mask
            
            # Update available weights
            self.available_mask[name] = available & ~mask
        
        self.masks[task_id] = task_mask
    
    def apply_mask(self, task_id: int):
        """Apply task mask to model."""
        mask = self.masks[task_id]
        
        for name, param in self.model.named_parameters():
            if name in mask:
                param.data *= mask[name].float()

Advanced Methods

Dark Experience Replay

class DarkExperienceReplay:
    """
    Dark Experience Replay (DER).
    
    Store logits along with inputs for better replay.
    Combines ideas from knowledge distillation with replay.
    
    Reference: "Dark Experience for General Continual Learning"
               (Buzzega et al., 2020)
    """
    
    def __init__(
        self,
        buffer_size: int = 5000,
        alpha: float = 0.5
    ):
        self.buffer_size = buffer_size
        self.alpha = alpha
        
        self.buffer_x = []
        self.buffer_y = []
        self.buffer_logits = []
    
    def add(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        logits: torch.Tensor
    ):
        """Add samples with their logits."""
        for i in range(x.size(0)):
            if len(self.buffer_x) < self.buffer_size:
                self.buffer_x.append(x[i].cpu())
                self.buffer_y.append(y[i].cpu())
                self.buffer_logits.append(logits[i].detach().cpu())
            else:
                idx = np.random.randint(0, len(self.buffer_x))
                self.buffer_x[idx] = x[i].cpu()
                self.buffer_y[idx] = y[i].cpu()
                self.buffer_logits[idx] = logits[i].detach().cpu()
    
    def compute_loss(
        self,
        model: nn.Module,
        batch_size: int
    ) -> torch.Tensor:
        """Compute DER loss on buffer samples."""
        if len(self.buffer_x) == 0:
            return torch.tensor(0.0).to(device)
        
        indices = np.random.choice(
            len(self.buffer_x),
            min(batch_size, len(self.buffer_x)),
            replace=False
        )
        
        x = torch.stack([self.buffer_x[i] for i in indices]).to(device)
        y = torch.tensor([self.buffer_y[i] for i in indices]).to(device)
        old_logits = torch.stack([self.buffer_logits[i] for i in indices]).to(device)
        
        # Current predictions
        new_logits = model(x)
        
        # Combined loss
        ce_loss = F.cross_entropy(new_logits, y)
        mse_loss = F.mse_loss(new_logits, old_logits)
        
        return self.alpha * ce_loss + (1 - self.alpha) * mse_loss


class OnlineMetaLearning:
    """
    Online Meta-Learning for Continual Learning.
    
    Use meta-learning to adapt quickly to new tasks
    while maintaining performance on old ones.
    """
    
    def __init__(
        self,
        model: nn.Module,
        inner_lr: float = 0.01,
        outer_lr: float = 0.001
    ):
        self.model = model
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        
        self.outer_optimizer = optim.Adam(model.parameters(), lr=outer_lr)
        self.memory = ExperienceReplay(buffer_size=1000)
    
    def train_step(
        self,
        x_current: torch.Tensor,
        y_current: torch.Tensor
    ):
        """MAML-style training step."""
        
        # Inner loop: adapt to current batch
        adapted_model = deepcopy(self.model)
        
        for _ in range(5):  # Inner steps
            outputs = adapted_model(x_current)
            loss = F.cross_entropy(outputs, y_current)
            
            grads = torch.autograd.grad(
                loss, adapted_model.parameters()
            )
            
            for param, grad in zip(adapted_model.parameters(), grads):
                param.data -= self.inner_lr * grad
        
        # Outer loop: update on both current and memory
        self.outer_optimizer.zero_grad()
        
        # Loss on current task (using adapted model)
        current_loss = F.cross_entropy(
            adapted_model(x_current), y_current
        )
        
        # Loss on memory (using original model)
        x_mem, y_mem, _ = self.memory.sample(x_current.size(0))
        if x_mem is not None:
            memory_loss = F.cross_entropy(self.model(x_mem), y_mem)
            total_loss = current_loss + memory_loss
        else:
            total_loss = current_loss
        
        total_loss.backward()
        self.outer_optimizer.step()
        
        # Update memory
        self.memory.add(x_current, y_current, 0)

Best Practices

def continual_learning_summary():
    """Print continual learning methods summary."""
    
    summary = """
    ╔════════════════════════════════════════════════════════════════════╗
    ║              CONTINUAL LEARNING METHODS SUMMARY                     ║
    ╠════════════════════════════════════════════════════════════════════╣
    ║                                                                     ║
    ║  REGULARIZATION-BASED                                               ║
    ║  • EWC: Penalize changes to important weights                       ║
    ║  • SI: Online importance estimation                                 ║
    ║  + No memory storage needed                                         ║
    ║  - May not scale to many tasks                                      ║
    ║                                                                     ║
    ║  REPLAY-BASED                                                       ║
    ║  • Experience Replay: Store raw examples                            ║
    ║  • GEM/A-GEM: Constrain gradients                                   ║
    ║  • Generative Replay: Generate pseudo-examples                      ║
    ║  + Very effective                                                   ║
    ║  - Memory or compute overhead                                       ║
    ║                                                                     ║
    ║  ARCHITECTURE-BASED                                                 ║
    ║  • Progressive Nets: Add new columns                                ║
    ║  • PackNet: Prune and reuse                                         ║
    ║  + No forgetting (frozen weights)                                   ║
    ║  - Model grows with tasks                                           ║
    ║                                                                     ║
    ╠════════════════════════════════════════════════════════════════════╣
    ║                     RECOMMENDATIONS                                 ║
    ╠════════════════════════════════════════════════════════════════════╣
    ║                                                                     ║
    ║  Few tasks, limited memory: EWC or SI                               ║
    ║  Many tasks, some memory: Experience Replay + regularization        ║
    ║  Critical no-forgetting: Architecture-based methods                 ║
    ║  Large models: Generative replay or knowledge distillation          ║
    ║                                                                     ║
    ║  GENERAL TIPS:                                                      ║
    ║  • Always evaluate on ALL previous tasks                            ║
    ║  • Measure both accuracy and forgetting                             ║
    ║  • Consider task boundaries (known vs unknown)                      ║
    ║  • Test with different task orderings                               ║
    ║                                                                     ║
    ╚════════════════════════════════════════════════════════════════════╝
    """
    print(summary)

continual_learning_summary()

Exercises

Implement and compare on Split-MNIST:
  • Fine-tuning (baseline)
  • EWC
  • Experience Replay
  • A-GEM
Plot accuracy matrix and compute all metrics.
Test EWC on different task orderings:
  • Easy to hard
  • Hard to easy
  • Random
How does order affect final performance?
Combine EWC with Experience Replay:
  • What’s the optimal combination?
  • Does it beat either method alone?

What’s Next?

Model Compression

Quantization and pruning techniques

Capstone Project

Apply all techniques in a complete project