Skip to main content
Continual Learning

Continual Learning

The Catastrophic Forgetting Problem

When neural networks learn new tasks, they forget old ones:
ScenarioChallenge
Sequential tasksModel forgets task A when learning task B
Streaming dataDistribution shift over time
Class-incrementalNew classes added continuously
Domain adaptationSame task, new domain
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Optional, Dict, List, Tuple, Callable
from dataclasses import dataclass
from copy import deepcopy
import numpy as np

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Measuring Forgetting

@dataclass
class ContinualMetrics:
    """
    Metrics for continual learning evaluation.
    """
    accuracy_matrix: np.ndarray  # A[i,j] = acc on task j after learning task i
    
    @property
    def average_accuracy(self) -> float:
        """Average accuracy across all tasks after all training."""
        return np.mean(self.accuracy_matrix[-1])
    
    @property
    def forgetting(self) -> float:
        """
        Average forgetting measure.
        
        For each task j: max accuracy on j - final accuracy on j
        """
        n_tasks = self.accuracy_matrix.shape[1]
        forgetting = []
        
        for j in range(n_tasks - 1):  # Exclude last task (no forgetting possible)
            max_acc = np.max(self.accuracy_matrix[j:, j])
            final_acc = self.accuracy_matrix[-1, j]
            forgetting.append(max_acc - final_acc)
        
        return np.mean(forgetting)
    
    @property
    def backward_transfer(self) -> float:
        """
        How much learning new tasks affects old ones.
        
        Negative = forgetting, Positive = improvement
        """
        n_tasks = self.accuracy_matrix.shape[1]
        bt = []
        
        for j in range(n_tasks - 1):
            # Accuracy right after learning j vs final accuracy
            initial_acc = self.accuracy_matrix[j, j]
            final_acc = self.accuracy_matrix[-1, j]
            bt.append(final_acc - initial_acc)
        
        return np.mean(bt)
    
    @property
    def forward_transfer(self) -> float:
        """
        How much previous learning helps new tasks.
        """
        n_tasks = self.accuracy_matrix.shape[1]
        ft = []
        
        for j in range(1, n_tasks):
            # Performance on task j before learning it
            before_acc = self.accuracy_matrix[j-1, j]
            ft.append(before_acc)
        
        return np.mean(ft)


class ContinualEvaluator:
    """
    Evaluate continual learning methods.
    """
    
    def __init__(self, task_datasets: List[Tuple]):
        """
        Args:
            task_datasets: List of (train_loader, test_loader) per task
        """
        self.task_datasets = task_datasets
        self.n_tasks = len(task_datasets)
        self.accuracy_matrix = np.zeros((self.n_tasks, self.n_tasks))
    
    def evaluate_all_tasks(self, model: nn.Module, current_task: int):
        """Evaluate model on all tasks seen so far."""
        model.eval()
        
        for task_id in range(current_task + 1):
            _, test_loader = self.task_datasets[task_id]
            
            correct = 0
            total = 0
            
            with torch.no_grad():
                for inputs, labels in test_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs, task_id=task_id)
                    _, predicted = outputs.max(1)
                    
                    correct += (predicted == labels).sum().item()
                    total += labels.size(0)
            
            self.accuracy_matrix[current_task, task_id] = correct / total
    
    def get_metrics(self) -> ContinualMetrics:
        return ContinualMetrics(self.accuracy_matrix.copy())

Regularization-Based Methods

Elastic Weight Consolidation (EWC)

class EWC:
    """
    Elastic Weight Consolidation.
    
    Key idea: Penalize changes to weights that are important
    for previous tasks.
    
    Importance measured by Fisher Information Matrix.
    
    Reference: "Overcoming catastrophic forgetting in neural networks"
               (Kirkpatrick et al., 2017)
    """
    
    def __init__(
        self,
        model: nn.Module,
        lambda_ewc: float = 5000
    ):
        self.model = model
        self.lambda_ewc = lambda_ewc
        
        # Storage for previous task parameters and importance
        self.saved_params = {}
        self.fisher = {}
    
    def compute_fisher(
        self,
        data_loader,
        n_samples: int = 200
    ):
        """
        Compute Fisher Information Matrix (diagonal approximation).
        
        Fisher[i] = E[gradient[i]^2]
        
        High Fisher = parameter is important for current task
        """
        self.model.eval()
        
        # Initialize Fisher
        fisher = {
            name: torch.zeros_like(param)
            for name, param in self.model.named_parameters()
            if param.requires_grad
        }
        
        n_processed = 0
        
        for inputs, labels in data_loader:
            if n_processed >= n_samples:
                break
            
            inputs, labels = inputs.to(device), labels.to(device)
            
            self.model.zero_grad()
            outputs = self.model(inputs)
            
            # Use log-likelihood (negative cross entropy)
            # Sample from output distribution for proper Fisher
            probs = F.softmax(outputs, dim=1)
            sampled_labels = torch.multinomial(probs, 1).squeeze()
            loss = F.cross_entropy(outputs, sampled_labels)
            loss.backward()
            
            # Accumulate squared gradients
            for name, param in self.model.named_parameters():
                if param.requires_grad and param.grad is not None:
                    fisher[name] += param.grad.data.pow(2)
            
            n_processed += inputs.size(0)
        
        # Normalize
        for name in fisher:
            fisher[name] /= n_processed
        
        return fisher
    
    def register_task(self, data_loader, task_id: int):
        """
        Register completion of a task.
        
        Saves parameters and computes Fisher information.
        """
        # Compute Fisher for current task
        fisher = self.compute_fisher(data_loader)
        
        # Merge with previous Fisher (online EWC)
        if len(self.fisher) > 0:
            for name in fisher:
                self.fisher[name] = (
                    self.fisher[name] + fisher[name]
                )
        else:
            self.fisher = fisher
        
        # Save current parameters
        self.saved_params[task_id] = {
            name: param.data.clone()
            for name, param in self.model.named_parameters()
            if param.requires_grad
        }
    
    def penalty(self) -> torch.Tensor:
        """
        Compute EWC penalty.
        
        penalty = sum_i F_i * (theta_i - theta_i^*)^2
        """
        if len(self.saved_params) == 0:
            return torch.tensor(0.0).to(device)
        
        loss = torch.tensor(0.0).to(device)
        
        for name, param in self.model.named_parameters():
            if name in self.fisher:
                # Average over all previous tasks
                for task_id in self.saved_params:
                    old_param = self.saved_params[task_id][name]
                    
                    loss += (
                        self.fisher[name] * 
                        (param - old_param).pow(2)
                    ).sum()
        
        return self.lambda_ewc * loss / 2


class EWCTrainer:
    """Training loop with EWC."""
    
    def __init__(
        self,
        model: nn.Module,
        lambda_ewc: float = 5000
    ):
        self.model = model
        self.ewc = EWC(model, lambda_ewc)
    
    def train_task(
        self,
        train_loader,
        task_id: int,
        epochs: int = 10,
        lr: float = 0.001
    ):
        """Train on a single task with EWC regularization."""
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        self.model.train()
        
        for epoch in range(epochs):
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                optimizer.zero_grad()
                
                outputs = self.model(inputs)
                
                # Task loss
                task_loss = F.cross_entropy(outputs, labels)
                
                # EWC penalty
                ewc_loss = self.ewc.penalty()
                
                total_loss = task_loss + ewc_loss
                
                total_loss.backward()
                optimizer.step()
        
        # Register task completion
        self.ewc.register_task(train_loader, task_id)

Synaptic Intelligence (SI)

class SynapticIntelligence:
    """
    Synaptic Intelligence: Online importance estimation.
    
    Key idea: Track contribution of each parameter to loss reduction
    during training (path integral of gradients).
    
    Reference: "Continual Learning Through Synaptic Intelligence"
               (Zenke et al., 2017)
    """
    
    def __init__(
        self,
        model: nn.Module,
        lambda_si: float = 1.0,
        epsilon: float = 1e-3
    ):
        self.model = model
        self.lambda_si = lambda_si
        self.epsilon = epsilon
        
        # Initialize tracking
        self.omega = {}  # Importance weights
        self.old_params = {}  # Parameters at task start
        self.w = {}  # Running importance estimate
        
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.omega[name] = torch.zeros_like(param)
                self.old_params[name] = param.data.clone()
                self.w[name] = torch.zeros_like(param)
    
    def update_omega(self):
        """Update importance weights after task."""
        for name, param in self.model.named_parameters():
            if name in self.omega:
                delta = param.data - self.old_params[name]
                
                # Omega = accumulated gradient * parameter change / (delta^2 + eps)
                self.omega[name] += self.w[name] / (delta.pow(2) + self.epsilon)
                
                # Reset for next task
                self.old_params[name] = param.data.clone()
                self.w[name].zero_()
    
    def update_w(self):
        """Update running importance during training."""
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None:
                # Accumulate gradient * negative parameter change
                # (gradient points toward loss increase)
                self.w[name] -= param.grad.data * (
                    param.data - self.old_params[name]
                )
    
    def penalty(self) -> torch.Tensor:
        """Compute SI penalty."""
        loss = torch.tensor(0.0).to(device)
        
        for name, param in self.model.named_parameters():
            if name in self.omega:
                loss += (
                    self.omega[name] * 
                    (param - self.old_params[name]).pow(2)
                ).sum()
        
        return self.lambda_si * loss

Replay-Based Methods

Experience Replay

class ExperienceReplay:
    """
    Store and replay examples from previous tasks.
    
    Simple but effective baseline.
    """
    
    def __init__(
        self,
        buffer_size: int = 5000,
        samples_per_class: Optional[int] = None
    ):
        self.buffer_size = buffer_size
        self.samples_per_class = samples_per_class
        
        self.buffer_x = []
        self.buffer_y = []
        self.buffer_task = []
    
    def add(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        task_id: int
    ):
        """Add samples to buffer."""
        for i in range(x.size(0)):
            if len(self.buffer_x) < self.buffer_size:
                self.buffer_x.append(x[i].cpu())
                self.buffer_y.append(y[i].cpu())
                self.buffer_task.append(task_id)
            else:
                # Reservoir sampling
                idx = np.random.randint(0, len(self.buffer_x) + 1)
                if idx < self.buffer_size:
                    self.buffer_x[idx] = x[i].cpu()
                    self.buffer_y[idx] = y[i].cpu()
                    self.buffer_task[idx] = task_id
    
    def sample(
        self,
        batch_size: int
    ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
        """Sample a batch from buffer."""
        if len(self.buffer_x) == 0:
            return None, None, None
        
        indices = np.random.choice(
            len(self.buffer_x),
            min(batch_size, len(self.buffer_x)),
            replace=False
        )
        
        x = torch.stack([self.buffer_x[i] for i in indices])
        y = torch.tensor([self.buffer_y[i] for i in indices])
        tasks = [self.buffer_task[i] for i in indices]
        
        return x.to(device), y.to(device), tasks


class GradientEpisodicMemory:
    """
    Gradient Episodic Memory (GEM).
    
    Key idea: Project gradients to not increase loss on previous tasks.
    
    Reference: "Gradient Episodic Memory for Continual Learning"
               (Lopez-Paz & Ranzato, 2017)
    """
    
    def __init__(
        self,
        model: nn.Module,
        memory_size: int = 256
    ):
        self.model = model
        self.memory_size = memory_size
        
        self.memory = {}  # {task_id: (x, y)}
        self.memory_gradients = {}
    
    def store_task_memory(
        self,
        data_loader,
        task_id: int
    ):
        """Store examples from completed task."""
        x_mem, y_mem = [], []
        
        for x, y in data_loader:
            x_mem.append(x)
            y_mem.append(y)
            
            if sum(len(batch) for batch in x_mem) >= self.memory_size:
                break
        
        x_mem = torch.cat(x_mem)[:self.memory_size]
        y_mem = torch.cat(y_mem)[:self.memory_size]
        
        self.memory[task_id] = (x_mem.to(device), y_mem.to(device))
    
    def compute_memory_gradients(self):
        """Compute gradients for all memory samples."""
        self.memory_gradients = {}
        
        for task_id, (x, y) in self.memory.items():
            self.model.zero_grad()
            
            outputs = self.model(x)
            loss = F.cross_entropy(outputs, y)
            loss.backward()
            
            # Flatten gradients
            grad = torch.cat([
                p.grad.view(-1) 
                for p in self.model.parameters() 
                if p.grad is not None
            ])
            
            self.memory_gradients[task_id] = grad.clone()
    
    def project_gradient(self, current_grad: torch.Tensor) -> torch.Tensor:
        """
        Project current gradient to satisfy constraints.
        
        Constraint: g · g_memory >= 0 for all memory gradients
        """
        if len(self.memory_gradients) == 0:
            return current_grad
        
        mem_grads = torch.stack(list(self.memory_gradients.values()))
        
        # Check constraint violations
        dotprods = torch.mv(mem_grads, current_grad)
        
        if (dotprods >= 0).all():
            return current_grad
        
        # Project using quadratic programming
        # Simplified: project onto each violated constraint
        projected = current_grad.clone()
        
        for i, dotprod in enumerate(dotprods):
            if dotprod < 0:
                # Remove component along memory gradient
                mem_grad = mem_grads[i]
                projected -= (dotprod / (mem_grad.norm().pow(2) + 1e-8)) * mem_grad
        
        return projected


class AveragedGEM:
    """
    A-GEM: Efficient version of GEM.
    
    Uses average of memory gradients instead of all constraints.
    """
    
    def __init__(
        self,
        model: nn.Module,
        memory_size: int = 256
    ):
        self.model = model
        self.memory = ExperienceReplay(memory_size)
    
    def project_gradient(self):
        """Project current gradient using A-GEM."""
        if len(self.memory.buffer_x) == 0:
            return
        
        # Get current gradient
        current_grad = torch.cat([
            p.grad.view(-1) 
            for p in self.model.parameters() 
            if p.grad is not None
        ])
        
        # Compute reference gradient on memory
        x_mem, y_mem, _ = self.memory.sample(batch_size=256)
        
        self.model.zero_grad()
        outputs = self.model(x_mem)
        loss = F.cross_entropy(outputs, y_mem)
        loss.backward()
        
        ref_grad = torch.cat([
            p.grad.view(-1) 
            for p in self.model.parameters() 
            if p.grad is not None
        ])
        
        # Check if constraint is violated
        dotprod = torch.dot(current_grad, ref_grad)
        
        if dotprod < 0:
            # Project
            projected = current_grad - (
                dotprod / (ref_grad.norm().pow(2) + 1e-8)
            ) * ref_grad
            
            # Apply projected gradient
            offset = 0
            for p in self.model.parameters():
                if p.grad is not None:
                    numel = p.numel()
                    p.grad.data.copy_(
                        projected[offset:offset + numel].view_as(p)
                    )
                    offset += numel

Generative Replay

class GenerativeReplay(nn.Module):
    """
    Generate pseudo-examples from previous tasks.
    
    Uses a generative model (VAE, GAN) to create samples
    that represent previous task distributions.
    
    No memory storage needed!
    """
    
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int = 256,
        latent_dim: int = 32
    ):
        super().__init__()
        
        # VAE for generating replay samples
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_var = nn.Linear(hidden_dim, latent_dim)
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
        
        self.input_dim = input_dim
        self.latent_dim = latent_dim
    
    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        h = self.encoder(x.view(-1, self.input_dim))
        return self.fc_mu(h), self.fc_var(h)
    
    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z: torch.Tensor) -> torch.Tensor:
        return self.decoder(z)
    
    def generate(self, n_samples: int) -> torch.Tensor:
        """Generate replay samples."""
        z = torch.randn(n_samples, self.latent_dim).to(device)
        return self.decode(z)
    
    def vae_loss(self, x: torch.Tensor) -> torch.Tensor:
        """VAE loss for training generator."""
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        
        recon_loss = F.mse_loss(recon, x.view(-1, self.input_dim))
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        
        return recon_loss + 0.1 * kl_loss


class DGRTrainer:
    """
    Deep Generative Replay training.
    
    Train classifier + generator together.
    Generate old task data using generator.
    """
    
    def __init__(
        self,
        classifier: nn.Module,
        generator: GenerativeReplay,
        replay_ratio: float = 0.5
    ):
        self.classifier = classifier
        self.generator = generator
        self.replay_ratio = replay_ratio
        
        # Store old classifier for pseudo-labeling
        self.old_classifier = None
    
    def train_task(
        self,
        train_loader,
        task_id: int,
        epochs: int = 10,
        lr: float = 0.001
    ):
        """Train on task with generative replay."""
        
        clf_optimizer = optim.Adam(self.classifier.parameters(), lr=lr)
        gen_optimizer = optim.Adam(self.generator.parameters(), lr=lr)
        
        for epoch in range(epochs):
            self.classifier.train()
            self.generator.train()
            
            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                batch_size = x.size(0)
                
                # Generate replay samples
                if task_id > 0 and self.old_classifier is not None:
                    n_replay = int(batch_size * self.replay_ratio)
                    
                    with torch.no_grad():
                        x_replay = self.generator.generate(n_replay)
                        # Pseudo-labels from old classifier
                        y_replay = self.old_classifier(x_replay).argmax(dim=1)
                    
                    # Combine current and replay data
                    x = torch.cat([x.view(batch_size, -1), x_replay])
                    y = torch.cat([y, y_replay])
                
                # Train classifier
                clf_optimizer.zero_grad()
                outputs = self.classifier(x.view(x.size(0), -1))
                clf_loss = F.cross_entropy(outputs, y)
                clf_loss.backward()
                clf_optimizer.step()
                
                # Train generator on current task data
                gen_optimizer.zero_grad()
                gen_loss = self.generator.vae_loss(x[:batch_size])
                gen_loss.backward()
                gen_optimizer.step()
        
        # Store classifier for next task
        self.old_classifier = deepcopy(self.classifier)
        self.old_classifier.eval()

Architecture-Based Methods

Progressive Neural Networks

class ProgressiveNeuralNetwork(nn.Module):
    """
    Progressive Neural Networks.
    
    Key idea: Add new columns for new tasks,
    connect laterally to previous columns.
    
    Zero forgetting (old columns frozen).
    
    Reference: "Progressive Neural Networks" (Rusu et al., 2016)
    """
    
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int
    ):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        # Columns for each task
        self.columns = nn.ModuleList()
        
        # Lateral connections
        self.lateral = nn.ModuleList()
    
    def add_column(self):
        """Add a new column for a new task."""
        n_cols = len(self.columns)
        
        # New column
        column = nn.ModuleDict({
            'layer1': nn.Linear(self.input_dim, self.hidden_dim),
            'layer2': nn.Linear(self.hidden_dim, self.hidden_dim),
            'head': nn.Linear(self.hidden_dim, self.output_dim)
        })
        
        self.columns.append(column)
        
        # Lateral connections from previous columns
        if n_cols > 0:
            lateral = nn.ModuleDict({
                'lateral1': nn.Linear(n_cols * self.hidden_dim, self.hidden_dim),
                'lateral2': nn.Linear(n_cols * self.hidden_dim, self.hidden_dim)
            })
            self.lateral.append(lateral)
        
        # Freeze all previous columns
        for i, col in enumerate(self.columns[:-1]):
            for param in col.parameters():
                param.requires_grad = False
    
    def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor:
        """Forward through column for specific task."""
        
        # Collect all hidden activations
        h1_all = []
        h2_all = []
        
        for col_id in range(task_id + 1):
            col = self.columns[col_id]
            
            # First layer
            h1 = F.relu(col['layer1'](x))
            
            # Add lateral connections
            if col_id > 0:
                lateral_input = torch.cat(h1_all, dim=1)
                h1 = h1 + self.lateral[col_id - 1]['lateral1'](lateral_input)
            
            h1_all.append(h1)
            
            # Second layer
            h2 = F.relu(col['layer2'](h1))
            
            if col_id > 0:
                lateral_input = torch.cat(h2_all, dim=1)
                h2 = h2 + self.lateral[col_id - 1]['lateral2'](lateral_input)
            
            h2_all.append(h2)
        
        # Output from target task column
        output = self.columns[task_id]['head'](h2_all[task_id])
        
        return output


class PackNet:
    """
    PackNet: Prune and reuse for continual learning.
    
    Key idea: 
    1. Train on task
    2. Prune unimportant weights
    3. Freeze pruned network
    4. Retrain remaining weights for new task
    
    Reference: "PackNet: Adding Multiple Tasks to a Single Network"
               (Mallya & Lazebnik, 2018)
    """
    
    def __init__(
        self,
        model: nn.Module,
        prune_ratio: float = 0.75
    ):
        self.model = model
        self.prune_ratio = prune_ratio
        
        # Track which weights are used by which tasks
        self.masks = {}  # {task_id: {param_name: mask}}
        self.available_mask = {}  # Which weights are still available
        
        # Initialize available mask
        for name, param in model.named_parameters():
            self.available_mask[name] = torch.ones_like(param, dtype=torch.bool)
    
    def compute_importance(self) -> Dict[str, torch.Tensor]:
        """Compute weight importance (magnitude-based)."""
        importance = {}
        
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                importance[name] = param.data.abs()
        
        return importance
    
    def prune_for_task(self, task_id: int):
        """Prune weights after training on task."""
        importance = self.compute_importance()
        
        task_mask = {}
        
        for name, imp in importance.items():
            available = self.available_mask[name]
            
            # Only consider available weights
            imp_available = imp * available.float()
            
            # Find threshold for pruning
            n_keep = int((1 - self.prune_ratio) * available.sum().item())
            if n_keep > 0:
                threshold = torch.kthvalue(
                    imp_available[available].flatten(),
                    len(imp_available[available].flatten()) - n_keep
                )[0]
                
                # Keep weights above threshold
                mask = imp_available >= threshold
            else:
                mask = torch.zeros_like(available)
            
            task_mask[name] = mask
            
            # Update available weights
            self.available_mask[name] = available & ~mask
        
        self.masks[task_id] = task_mask
    
    def apply_mask(self, task_id: int):
        """Apply task mask to model."""
        mask = self.masks[task_id]
        
        for name, param in self.model.named_parameters():
            if name in mask:
                param.data *= mask[name].float()

Advanced Methods

Dark Experience Replay

class DarkExperienceReplay:
    """
    Dark Experience Replay (DER).
    
    Store logits along with inputs for better replay.
    Combines ideas from knowledge distillation with replay.
    
    Reference: "Dark Experience for General Continual Learning"
               (Buzzega et al., 2020)
    """
    
    def __init__(
        self,
        buffer_size: int = 5000,
        alpha: float = 0.5
    ):
        self.buffer_size = buffer_size
        self.alpha = alpha
        
        self.buffer_x = []
        self.buffer_y = []
        self.buffer_logits = []
    
    def add(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        logits: torch.Tensor
    ):
        """Add samples with their logits."""
        for i in range(x.size(0)):
            if len(self.buffer_x) < self.buffer_size:
                self.buffer_x.append(x[i].cpu())
                self.buffer_y.append(y[i].cpu())
                self.buffer_logits.append(logits[i].detach().cpu())
            else:
                idx = np.random.randint(0, len(self.buffer_x))
                self.buffer_x[idx] = x[i].cpu()
                self.buffer_y[idx] = y[i].cpu()
                self.buffer_logits[idx] = logits[i].detach().cpu()
    
    def compute_loss(
        self,
        model: nn.Module,
        batch_size: int
    ) -> torch.Tensor:
        """Compute DER loss on buffer samples."""
        if len(self.buffer_x) == 0:
            return torch.tensor(0.0).to(device)
        
        indices = np.random.choice(
            len(self.buffer_x),
            min(batch_size, len(self.buffer_x)),
            replace=False
        )
        
        x = torch.stack([self.buffer_x[i] for i in indices]).to(device)
        y = torch.tensor([self.buffer_y[i] for i in indices]).to(device)
        old_logits = torch.stack([self.buffer_logits[i] for i in indices]).to(device)
        
        # Current predictions
        new_logits = model(x)
        
        # Combined loss
        ce_loss = F.cross_entropy(new_logits, y)
        mse_loss = F.mse_loss(new_logits, old_logits)
        
        return self.alpha * ce_loss + (1 - self.alpha) * mse_loss


class OnlineMetaLearning:
    """
    Online Meta-Learning for Continual Learning.
    
    Use meta-learning to adapt quickly to new tasks
    while maintaining performance on old ones.
    """
    
    def __init__(
        self,
        model: nn.Module,
        inner_lr: float = 0.01,
        outer_lr: float = 0.001
    ):
        self.model = model
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        
        self.outer_optimizer = optim.Adam(model.parameters(), lr=outer_lr)
        self.memory = ExperienceReplay(buffer_size=1000)
    
    def train_step(
        self,
        x_current: torch.Tensor,
        y_current: torch.Tensor
    ):
        """MAML-style training step."""
        
        # Inner loop: adapt to current batch
        adapted_model = deepcopy(self.model)
        
        for _ in range(5):  # Inner steps
            outputs = adapted_model(x_current)
            loss = F.cross_entropy(outputs, y_current)
            
            grads = torch.autograd.grad(
                loss, adapted_model.parameters()
            )
            
            for param, grad in zip(adapted_model.parameters(), grads):
                param.data -= self.inner_lr * grad
        
        # Outer loop: update on both current and memory
        self.outer_optimizer.zero_grad()
        
        # Loss on current task (using adapted model)
        current_loss = F.cross_entropy(
            adapted_model(x_current), y_current
        )
        
        # Loss on memory (using original model)
        x_mem, y_mem, _ = self.memory.sample(x_current.size(0))
        if x_mem is not None:
            memory_loss = F.cross_entropy(self.model(x_mem), y_mem)
            total_loss = current_loss + memory_loss
        else:
            total_loss = current_loss
        
        total_loss.backward()
        self.outer_optimizer.step()
        
        # Update memory
        self.memory.add(x_current, y_current, 0)

Best Practices

def continual_learning_summary():
    """Print continual learning methods summary."""
    
    summary = """
    ╔════════════════════════════════════════════════════════════════════╗
    ║              CONTINUAL LEARNING METHODS SUMMARY                     ║
    ╠════════════════════════════════════════════════════════════════════╣
    ║                                                                     ║
    ║  REGULARIZATION-BASED                                               ║
    ║  • EWC: Penalize changes to important weights                       ║
    ║  • SI: Online importance estimation                                 ║
    ║  + No memory storage needed                                         ║
    ║  - May not scale to many tasks                                      ║
    ║                                                                     ║
    ║  REPLAY-BASED                                                       ║
    ║  • Experience Replay: Store raw examples                            ║
    ║  • GEM/A-GEM: Constrain gradients                                   ║
    ║  • Generative Replay: Generate pseudo-examples                      ║
    ║  + Very effective                                                   ║
    ║  - Memory or compute overhead                                       ║
    ║                                                                     ║
    ║  ARCHITECTURE-BASED                                                 ║
    ║  • Progressive Nets: Add new columns                                ║
    ║  • PackNet: Prune and reuse                                         ║
    ║  + No forgetting (frozen weights)                                   ║
    ║  - Model grows with tasks                                           ║
    ║                                                                     ║
    ╠════════════════════════════════════════════════════════════════════╣
    ║                     RECOMMENDATIONS                                 ║
    ╠════════════════════════════════════════════════════════════════════╣
    ║                                                                     ║
    ║  Few tasks, limited memory: EWC or SI                               ║
    ║  Many tasks, some memory: Experience Replay + regularization        ║
    ║  Critical no-forgetting: Architecture-based methods                 ║
    ║  Large models: Generative replay or knowledge distillation          ║
    ║                                                                     ║
    ║  GENERAL TIPS:                                                      ║
    ║  • Always evaluate on ALL previous tasks                            ║
    ║  • Measure both accuracy and forgetting                             ║
    ║  • Consider task boundaries (known vs unknown)                      ║
    ║  • Test with different task orderings                               ║
    ║                                                                     ║
    ╚════════════════════════════════════════════════════════════════════╝
    """
    print(summary)

continual_learning_summary()

Exercises

Implement and compare on Split-MNIST:
  • Fine-tuning (baseline)
  • EWC
  • Experience Replay
  • A-GEM
Plot accuracy matrix and compute all metrics.
Test EWC on different task orderings:
  • Easy to hard
  • Hard to easy
  • Random
How does order affect final performance?
Combine EWC with Experience Replay:
  • What’s the optimal combination?
  • Does it beat either method alone?

What’s Next?