Continual Learning
The Catastrophic Forgetting Problem
When neural networks learn new tasks, they forget old ones:| Scenario | Challenge |
|---|---|
| Sequential tasks | Model forgets task A when learning task B |
| Streaming data | Distribution shift over time |
| Class-incremental | New classes added continuously |
| Domain adaptation | Same task, new domain |
Copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Optional, Dict, List, Tuple, Callable
from dataclasses import dataclass
from copy import deepcopy
import numpy as np
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Measuring Forgetting
Copy
@dataclass
class ContinualMetrics:
"""
Metrics for continual learning evaluation.
"""
accuracy_matrix: np.ndarray # A[i,j] = acc on task j after learning task i
@property
def average_accuracy(self) -> float:
"""Average accuracy across all tasks after all training."""
return np.mean(self.accuracy_matrix[-1])
@property
def forgetting(self) -> float:
"""
Average forgetting measure.
For each task j: max accuracy on j - final accuracy on j
"""
n_tasks = self.accuracy_matrix.shape[1]
forgetting = []
for j in range(n_tasks - 1): # Exclude last task (no forgetting possible)
max_acc = np.max(self.accuracy_matrix[j:, j])
final_acc = self.accuracy_matrix[-1, j]
forgetting.append(max_acc - final_acc)
return np.mean(forgetting)
@property
def backward_transfer(self) -> float:
"""
How much learning new tasks affects old ones.
Negative = forgetting, Positive = improvement
"""
n_tasks = self.accuracy_matrix.shape[1]
bt = []
for j in range(n_tasks - 1):
# Accuracy right after learning j vs final accuracy
initial_acc = self.accuracy_matrix[j, j]
final_acc = self.accuracy_matrix[-1, j]
bt.append(final_acc - initial_acc)
return np.mean(bt)
@property
def forward_transfer(self) -> float:
"""
How much previous learning helps new tasks.
"""
n_tasks = self.accuracy_matrix.shape[1]
ft = []
for j in range(1, n_tasks):
# Performance on task j before learning it
before_acc = self.accuracy_matrix[j-1, j]
ft.append(before_acc)
return np.mean(ft)
class ContinualEvaluator:
"""
Evaluate continual learning methods.
"""
def __init__(self, task_datasets: List[Tuple]):
"""
Args:
task_datasets: List of (train_loader, test_loader) per task
"""
self.task_datasets = task_datasets
self.n_tasks = len(task_datasets)
self.accuracy_matrix = np.zeros((self.n_tasks, self.n_tasks))
def evaluate_all_tasks(self, model: nn.Module, current_task: int):
"""Evaluate model on all tasks seen so far."""
model.eval()
for task_id in range(current_task + 1):
_, test_loader = self.task_datasets[task_id]
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs, task_id=task_id)
_, predicted = outputs.max(1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
self.accuracy_matrix[current_task, task_id] = correct / total
def get_metrics(self) -> ContinualMetrics:
return ContinualMetrics(self.accuracy_matrix.copy())
Regularization-Based Methods
Elastic Weight Consolidation (EWC)
Copy
class EWC:
"""
Elastic Weight Consolidation.
Key idea: Penalize changes to weights that are important
for previous tasks.
Importance measured by Fisher Information Matrix.
Reference: "Overcoming catastrophic forgetting in neural networks"
(Kirkpatrick et al., 2017)
"""
def __init__(
self,
model: nn.Module,
lambda_ewc: float = 5000
):
self.model = model
self.lambda_ewc = lambda_ewc
# Storage for previous task parameters and importance
self.saved_params = {}
self.fisher = {}
def compute_fisher(
self,
data_loader,
n_samples: int = 200
):
"""
Compute Fisher Information Matrix (diagonal approximation).
Fisher[i] = E[gradient[i]^2]
High Fisher = parameter is important for current task
"""
self.model.eval()
# Initialize Fisher
fisher = {
name: torch.zeros_like(param)
for name, param in self.model.named_parameters()
if param.requires_grad
}
n_processed = 0
for inputs, labels in data_loader:
if n_processed >= n_samples:
break
inputs, labels = inputs.to(device), labels.to(device)
self.model.zero_grad()
outputs = self.model(inputs)
# Use log-likelihood (negative cross entropy)
# Sample from output distribution for proper Fisher
probs = F.softmax(outputs, dim=1)
sampled_labels = torch.multinomial(probs, 1).squeeze()
loss = F.cross_entropy(outputs, sampled_labels)
loss.backward()
# Accumulate squared gradients
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None:
fisher[name] += param.grad.data.pow(2)
n_processed += inputs.size(0)
# Normalize
for name in fisher:
fisher[name] /= n_processed
return fisher
def register_task(self, data_loader, task_id: int):
"""
Register completion of a task.
Saves parameters and computes Fisher information.
"""
# Compute Fisher for current task
fisher = self.compute_fisher(data_loader)
# Merge with previous Fisher (online EWC)
if len(self.fisher) > 0:
for name in fisher:
self.fisher[name] = (
self.fisher[name] + fisher[name]
)
else:
self.fisher = fisher
# Save current parameters
self.saved_params[task_id] = {
name: param.data.clone()
for name, param in self.model.named_parameters()
if param.requires_grad
}
def penalty(self) -> torch.Tensor:
"""
Compute EWC penalty.
penalty = sum_i F_i * (theta_i - theta_i^*)^2
"""
if len(self.saved_params) == 0:
return torch.tensor(0.0).to(device)
loss = torch.tensor(0.0).to(device)
for name, param in self.model.named_parameters():
if name in self.fisher:
# Average over all previous tasks
for task_id in self.saved_params:
old_param = self.saved_params[task_id][name]
loss += (
self.fisher[name] *
(param - old_param).pow(2)
).sum()
return self.lambda_ewc * loss / 2
class EWCTrainer:
"""Training loop with EWC."""
def __init__(
self,
model: nn.Module,
lambda_ewc: float = 5000
):
self.model = model
self.ewc = EWC(model, lambda_ewc)
def train_task(
self,
train_loader,
task_id: int,
epochs: int = 10,
lr: float = 0.001
):
"""Train on a single task with EWC regularization."""
optimizer = optim.Adam(self.model.parameters(), lr=lr)
self.model.train()
for epoch in range(epochs):
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = self.model(inputs)
# Task loss
task_loss = F.cross_entropy(outputs, labels)
# EWC penalty
ewc_loss = self.ewc.penalty()
total_loss = task_loss + ewc_loss
total_loss.backward()
optimizer.step()
# Register task completion
self.ewc.register_task(train_loader, task_id)
Synaptic Intelligence (SI)
Copy
class SynapticIntelligence:
"""
Synaptic Intelligence: Online importance estimation.
Key idea: Track contribution of each parameter to loss reduction
during training (path integral of gradients).
Reference: "Continual Learning Through Synaptic Intelligence"
(Zenke et al., 2017)
"""
def __init__(
self,
model: nn.Module,
lambda_si: float = 1.0,
epsilon: float = 1e-3
):
self.model = model
self.lambda_si = lambda_si
self.epsilon = epsilon
# Initialize tracking
self.omega = {} # Importance weights
self.old_params = {} # Parameters at task start
self.w = {} # Running importance estimate
for name, param in model.named_parameters():
if param.requires_grad:
self.omega[name] = torch.zeros_like(param)
self.old_params[name] = param.data.clone()
self.w[name] = torch.zeros_like(param)
def update_omega(self):
"""Update importance weights after task."""
for name, param in self.model.named_parameters():
if name in self.omega:
delta = param.data - self.old_params[name]
# Omega = accumulated gradient * parameter change / (delta^2 + eps)
self.omega[name] += self.w[name] / (delta.pow(2) + self.epsilon)
# Reset for next task
self.old_params[name] = param.data.clone()
self.w[name].zero_()
def update_w(self):
"""Update running importance during training."""
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None:
# Accumulate gradient * negative parameter change
# (gradient points toward loss increase)
self.w[name] -= param.grad.data * (
param.data - self.old_params[name]
)
def penalty(self) -> torch.Tensor:
"""Compute SI penalty."""
loss = torch.tensor(0.0).to(device)
for name, param in self.model.named_parameters():
if name in self.omega:
loss += (
self.omega[name] *
(param - self.old_params[name]).pow(2)
).sum()
return self.lambda_si * loss
Replay-Based Methods
Experience Replay
Copy
class ExperienceReplay:
"""
Store and replay examples from previous tasks.
Simple but effective baseline.
"""
def __init__(
self,
buffer_size: int = 5000,
samples_per_class: Optional[int] = None
):
self.buffer_size = buffer_size
self.samples_per_class = samples_per_class
self.buffer_x = []
self.buffer_y = []
self.buffer_task = []
def add(
self,
x: torch.Tensor,
y: torch.Tensor,
task_id: int
):
"""Add samples to buffer."""
for i in range(x.size(0)):
if len(self.buffer_x) < self.buffer_size:
self.buffer_x.append(x[i].cpu())
self.buffer_y.append(y[i].cpu())
self.buffer_task.append(task_id)
else:
# Reservoir sampling
idx = np.random.randint(0, len(self.buffer_x) + 1)
if idx < self.buffer_size:
self.buffer_x[idx] = x[i].cpu()
self.buffer_y[idx] = y[i].cpu()
self.buffer_task[idx] = task_id
def sample(
self,
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
"""Sample a batch from buffer."""
if len(self.buffer_x) == 0:
return None, None, None
indices = np.random.choice(
len(self.buffer_x),
min(batch_size, len(self.buffer_x)),
replace=False
)
x = torch.stack([self.buffer_x[i] for i in indices])
y = torch.tensor([self.buffer_y[i] for i in indices])
tasks = [self.buffer_task[i] for i in indices]
return x.to(device), y.to(device), tasks
class GradientEpisodicMemory:
"""
Gradient Episodic Memory (GEM).
Key idea: Project gradients to not increase loss on previous tasks.
Reference: "Gradient Episodic Memory for Continual Learning"
(Lopez-Paz & Ranzato, 2017)
"""
def __init__(
self,
model: nn.Module,
memory_size: int = 256
):
self.model = model
self.memory_size = memory_size
self.memory = {} # {task_id: (x, y)}
self.memory_gradients = {}
def store_task_memory(
self,
data_loader,
task_id: int
):
"""Store examples from completed task."""
x_mem, y_mem = [], []
for x, y in data_loader:
x_mem.append(x)
y_mem.append(y)
if sum(len(batch) for batch in x_mem) >= self.memory_size:
break
x_mem = torch.cat(x_mem)[:self.memory_size]
y_mem = torch.cat(y_mem)[:self.memory_size]
self.memory[task_id] = (x_mem.to(device), y_mem.to(device))
def compute_memory_gradients(self):
"""Compute gradients for all memory samples."""
self.memory_gradients = {}
for task_id, (x, y) in self.memory.items():
self.model.zero_grad()
outputs = self.model(x)
loss = F.cross_entropy(outputs, y)
loss.backward()
# Flatten gradients
grad = torch.cat([
p.grad.view(-1)
for p in self.model.parameters()
if p.grad is not None
])
self.memory_gradients[task_id] = grad.clone()
def project_gradient(self, current_grad: torch.Tensor) -> torch.Tensor:
"""
Project current gradient to satisfy constraints.
Constraint: g · g_memory >= 0 for all memory gradients
"""
if len(self.memory_gradients) == 0:
return current_grad
mem_grads = torch.stack(list(self.memory_gradients.values()))
# Check constraint violations
dotprods = torch.mv(mem_grads, current_grad)
if (dotprods >= 0).all():
return current_grad
# Project using quadratic programming
# Simplified: project onto each violated constraint
projected = current_grad.clone()
for i, dotprod in enumerate(dotprods):
if dotprod < 0:
# Remove component along memory gradient
mem_grad = mem_grads[i]
projected -= (dotprod / (mem_grad.norm().pow(2) + 1e-8)) * mem_grad
return projected
class AveragedGEM:
"""
A-GEM: Efficient version of GEM.
Uses average of memory gradients instead of all constraints.
"""
def __init__(
self,
model: nn.Module,
memory_size: int = 256
):
self.model = model
self.memory = ExperienceReplay(memory_size)
def project_gradient(self):
"""Project current gradient using A-GEM."""
if len(self.memory.buffer_x) == 0:
return
# Get current gradient
current_grad = torch.cat([
p.grad.view(-1)
for p in self.model.parameters()
if p.grad is not None
])
# Compute reference gradient on memory
x_mem, y_mem, _ = self.memory.sample(batch_size=256)
self.model.zero_grad()
outputs = self.model(x_mem)
loss = F.cross_entropy(outputs, y_mem)
loss.backward()
ref_grad = torch.cat([
p.grad.view(-1)
for p in self.model.parameters()
if p.grad is not None
])
# Check if constraint is violated
dotprod = torch.dot(current_grad, ref_grad)
if dotprod < 0:
# Project
projected = current_grad - (
dotprod / (ref_grad.norm().pow(2) + 1e-8)
) * ref_grad
# Apply projected gradient
offset = 0
for p in self.model.parameters():
if p.grad is not None:
numel = p.numel()
p.grad.data.copy_(
projected[offset:offset + numel].view_as(p)
)
offset += numel
Generative Replay
Copy
class GenerativeReplay(nn.Module):
"""
Generate pseudo-examples from previous tasks.
Uses a generative model (VAE, GAN) to create samples
that represent previous task distributions.
No memory storage needed!
"""
def __init__(
self,
input_dim: int,
hidden_dim: int = 256,
latent_dim: int = 32
):
super().__init__()
# VAE for generating replay samples
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_var = nn.Linear(hidden_dim, latent_dim)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
self.input_dim = input_dim
self.latent_dim = latent_dim
def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
h = self.encoder(x.view(-1, self.input_dim))
return self.fc_mu(h), self.fc_var(h)
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z: torch.Tensor) -> torch.Tensor:
return self.decoder(z)
def generate(self, n_samples: int) -> torch.Tensor:
"""Generate replay samples."""
z = torch.randn(n_samples, self.latent_dim).to(device)
return self.decode(z)
def vae_loss(self, x: torch.Tensor) -> torch.Tensor:
"""VAE loss for training generator."""
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
recon = self.decode(z)
recon_loss = F.mse_loss(recon, x.view(-1, self.input_dim))
kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + 0.1 * kl_loss
class DGRTrainer:
"""
Deep Generative Replay training.
Train classifier + generator together.
Generate old task data using generator.
"""
def __init__(
self,
classifier: nn.Module,
generator: GenerativeReplay,
replay_ratio: float = 0.5
):
self.classifier = classifier
self.generator = generator
self.replay_ratio = replay_ratio
# Store old classifier for pseudo-labeling
self.old_classifier = None
def train_task(
self,
train_loader,
task_id: int,
epochs: int = 10,
lr: float = 0.001
):
"""Train on task with generative replay."""
clf_optimizer = optim.Adam(self.classifier.parameters(), lr=lr)
gen_optimizer = optim.Adam(self.generator.parameters(), lr=lr)
for epoch in range(epochs):
self.classifier.train()
self.generator.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
batch_size = x.size(0)
# Generate replay samples
if task_id > 0 and self.old_classifier is not None:
n_replay = int(batch_size * self.replay_ratio)
with torch.no_grad():
x_replay = self.generator.generate(n_replay)
# Pseudo-labels from old classifier
y_replay = self.old_classifier(x_replay).argmax(dim=1)
# Combine current and replay data
x = torch.cat([x.view(batch_size, -1), x_replay])
y = torch.cat([y, y_replay])
# Train classifier
clf_optimizer.zero_grad()
outputs = self.classifier(x.view(x.size(0), -1))
clf_loss = F.cross_entropy(outputs, y)
clf_loss.backward()
clf_optimizer.step()
# Train generator on current task data
gen_optimizer.zero_grad()
gen_loss = self.generator.vae_loss(x[:batch_size])
gen_loss.backward()
gen_optimizer.step()
# Store classifier for next task
self.old_classifier = deepcopy(self.classifier)
self.old_classifier.eval()
Architecture-Based Methods
Progressive Neural Networks
Copy
class ProgressiveNeuralNetwork(nn.Module):
"""
Progressive Neural Networks.
Key idea: Add new columns for new tasks,
connect laterally to previous columns.
Zero forgetting (old columns frozen).
Reference: "Progressive Neural Networks" (Rusu et al., 2016)
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int
):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
# Columns for each task
self.columns = nn.ModuleList()
# Lateral connections
self.lateral = nn.ModuleList()
def add_column(self):
"""Add a new column for a new task."""
n_cols = len(self.columns)
# New column
column = nn.ModuleDict({
'layer1': nn.Linear(self.input_dim, self.hidden_dim),
'layer2': nn.Linear(self.hidden_dim, self.hidden_dim),
'head': nn.Linear(self.hidden_dim, self.output_dim)
})
self.columns.append(column)
# Lateral connections from previous columns
if n_cols > 0:
lateral = nn.ModuleDict({
'lateral1': nn.Linear(n_cols * self.hidden_dim, self.hidden_dim),
'lateral2': nn.Linear(n_cols * self.hidden_dim, self.hidden_dim)
})
self.lateral.append(lateral)
# Freeze all previous columns
for i, col in enumerate(self.columns[:-1]):
for param in col.parameters():
param.requires_grad = False
def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor:
"""Forward through column for specific task."""
# Collect all hidden activations
h1_all = []
h2_all = []
for col_id in range(task_id + 1):
col = self.columns[col_id]
# First layer
h1 = F.relu(col['layer1'](x))
# Add lateral connections
if col_id > 0:
lateral_input = torch.cat(h1_all, dim=1)
h1 = h1 + self.lateral[col_id - 1]['lateral1'](lateral_input)
h1_all.append(h1)
# Second layer
h2 = F.relu(col['layer2'](h1))
if col_id > 0:
lateral_input = torch.cat(h2_all, dim=1)
h2 = h2 + self.lateral[col_id - 1]['lateral2'](lateral_input)
h2_all.append(h2)
# Output from target task column
output = self.columns[task_id]['head'](h2_all[task_id])
return output
class PackNet:
"""
PackNet: Prune and reuse for continual learning.
Key idea:
1. Train on task
2. Prune unimportant weights
3. Freeze pruned network
4. Retrain remaining weights for new task
Reference: "PackNet: Adding Multiple Tasks to a Single Network"
(Mallya & Lazebnik, 2018)
"""
def __init__(
self,
model: nn.Module,
prune_ratio: float = 0.75
):
self.model = model
self.prune_ratio = prune_ratio
# Track which weights are used by which tasks
self.masks = {} # {task_id: {param_name: mask}}
self.available_mask = {} # Which weights are still available
# Initialize available mask
for name, param in model.named_parameters():
self.available_mask[name] = torch.ones_like(param, dtype=torch.bool)
def compute_importance(self) -> Dict[str, torch.Tensor]:
"""Compute weight importance (magnitude-based)."""
importance = {}
for name, param in self.model.named_parameters():
if param.requires_grad:
importance[name] = param.data.abs()
return importance
def prune_for_task(self, task_id: int):
"""Prune weights after training on task."""
importance = self.compute_importance()
task_mask = {}
for name, imp in importance.items():
available = self.available_mask[name]
# Only consider available weights
imp_available = imp * available.float()
# Find threshold for pruning
n_keep = int((1 - self.prune_ratio) * available.sum().item())
if n_keep > 0:
threshold = torch.kthvalue(
imp_available[available].flatten(),
len(imp_available[available].flatten()) - n_keep
)[0]
# Keep weights above threshold
mask = imp_available >= threshold
else:
mask = torch.zeros_like(available)
task_mask[name] = mask
# Update available weights
self.available_mask[name] = available & ~mask
self.masks[task_id] = task_mask
def apply_mask(self, task_id: int):
"""Apply task mask to model."""
mask = self.masks[task_id]
for name, param in self.model.named_parameters():
if name in mask:
param.data *= mask[name].float()
Advanced Methods
Dark Experience Replay
Copy
class DarkExperienceReplay:
"""
Dark Experience Replay (DER).
Store logits along with inputs for better replay.
Combines ideas from knowledge distillation with replay.
Reference: "Dark Experience for General Continual Learning"
(Buzzega et al., 2020)
"""
def __init__(
self,
buffer_size: int = 5000,
alpha: float = 0.5
):
self.buffer_size = buffer_size
self.alpha = alpha
self.buffer_x = []
self.buffer_y = []
self.buffer_logits = []
def add(
self,
x: torch.Tensor,
y: torch.Tensor,
logits: torch.Tensor
):
"""Add samples with their logits."""
for i in range(x.size(0)):
if len(self.buffer_x) < self.buffer_size:
self.buffer_x.append(x[i].cpu())
self.buffer_y.append(y[i].cpu())
self.buffer_logits.append(logits[i].detach().cpu())
else:
idx = np.random.randint(0, len(self.buffer_x))
self.buffer_x[idx] = x[i].cpu()
self.buffer_y[idx] = y[i].cpu()
self.buffer_logits[idx] = logits[i].detach().cpu()
def compute_loss(
self,
model: nn.Module,
batch_size: int
) -> torch.Tensor:
"""Compute DER loss on buffer samples."""
if len(self.buffer_x) == 0:
return torch.tensor(0.0).to(device)
indices = np.random.choice(
len(self.buffer_x),
min(batch_size, len(self.buffer_x)),
replace=False
)
x = torch.stack([self.buffer_x[i] for i in indices]).to(device)
y = torch.tensor([self.buffer_y[i] for i in indices]).to(device)
old_logits = torch.stack([self.buffer_logits[i] for i in indices]).to(device)
# Current predictions
new_logits = model(x)
# Combined loss
ce_loss = F.cross_entropy(new_logits, y)
mse_loss = F.mse_loss(new_logits, old_logits)
return self.alpha * ce_loss + (1 - self.alpha) * mse_loss
class OnlineMetaLearning:
"""
Online Meta-Learning for Continual Learning.
Use meta-learning to adapt quickly to new tasks
while maintaining performance on old ones.
"""
def __init__(
self,
model: nn.Module,
inner_lr: float = 0.01,
outer_lr: float = 0.001
):
self.model = model
self.inner_lr = inner_lr
self.outer_lr = outer_lr
self.outer_optimizer = optim.Adam(model.parameters(), lr=outer_lr)
self.memory = ExperienceReplay(buffer_size=1000)
def train_step(
self,
x_current: torch.Tensor,
y_current: torch.Tensor
):
"""MAML-style training step."""
# Inner loop: adapt to current batch
adapted_model = deepcopy(self.model)
for _ in range(5): # Inner steps
outputs = adapted_model(x_current)
loss = F.cross_entropy(outputs, y_current)
grads = torch.autograd.grad(
loss, adapted_model.parameters()
)
for param, grad in zip(adapted_model.parameters(), grads):
param.data -= self.inner_lr * grad
# Outer loop: update on both current and memory
self.outer_optimizer.zero_grad()
# Loss on current task (using adapted model)
current_loss = F.cross_entropy(
adapted_model(x_current), y_current
)
# Loss on memory (using original model)
x_mem, y_mem, _ = self.memory.sample(x_current.size(0))
if x_mem is not None:
memory_loss = F.cross_entropy(self.model(x_mem), y_mem)
total_loss = current_loss + memory_loss
else:
total_loss = current_loss
total_loss.backward()
self.outer_optimizer.step()
# Update memory
self.memory.add(x_current, y_current, 0)
Best Practices
Copy
def continual_learning_summary():
"""Print continual learning methods summary."""
summary = """
╔════════════════════════════════════════════════════════════════════╗
║ CONTINUAL LEARNING METHODS SUMMARY ║
╠════════════════════════════════════════════════════════════════════╣
║ ║
║ REGULARIZATION-BASED ║
║ • EWC: Penalize changes to important weights ║
║ • SI: Online importance estimation ║
║ + No memory storage needed ║
║ - May not scale to many tasks ║
║ ║
║ REPLAY-BASED ║
║ • Experience Replay: Store raw examples ║
║ • GEM/A-GEM: Constrain gradients ║
║ • Generative Replay: Generate pseudo-examples ║
║ + Very effective ║
║ - Memory or compute overhead ║
║ ║
║ ARCHITECTURE-BASED ║
║ • Progressive Nets: Add new columns ║
║ • PackNet: Prune and reuse ║
║ + No forgetting (frozen weights) ║
║ - Model grows with tasks ║
║ ║
╠════════════════════════════════════════════════════════════════════╣
║ RECOMMENDATIONS ║
╠════════════════════════════════════════════════════════════════════╣
║ ║
║ Few tasks, limited memory: EWC or SI ║
║ Many tasks, some memory: Experience Replay + regularization ║
║ Critical no-forgetting: Architecture-based methods ║
║ Large models: Generative replay or knowledge distillation ║
║ ║
║ GENERAL TIPS: ║
║ • Always evaluate on ALL previous tasks ║
║ • Measure both accuracy and forgetting ║
║ • Consider task boundaries (known vs unknown) ║
║ • Test with different task orderings ║
║ ║
╚════════════════════════════════════════════════════════════════════╝
"""
print(summary)
continual_learning_summary()
Exercises
Exercise 1: Compare Methods
Exercise 1: Compare Methods
Implement and compare on Split-MNIST:
- Fine-tuning (baseline)
- EWC
- Experience Replay
- A-GEM
Exercise 2: Task Order Sensitivity
Exercise 2: Task Order Sensitivity
Test EWC on different task orderings:
- Easy to hard
- Hard to easy
- Random
Exercise 3: Hybrid Approach
Exercise 3: Hybrid Approach
Combine EWC with Experience Replay:
- What’s the optimal combination?
- Does it beat either method alone?