Documentation Index
Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt
Use this file to discover all available pages before exploring further.
Continual Learning
The Catastrophic Forgetting Problem
Imagine you spent a year becoming fluent in French, then spent a year learning Mandarin, and when you tried to speak French again you could barely string a sentence together. That is catastrophic forgetting — and it is exactly what happens to neural networks. Unlike humans, who can (mostly) retain old skills while acquiring new ones, standard neural networks overwrite old knowledge when optimized for new data. The weights that encoded “how to recognize cats” get repurposed for “how to recognize medical X-rays,” and the cat knowledge evaporates. This is not an edge case. It is the fundamental obstacle to deploying ML models in the real world, where data distributions shift, new classes appear, and retraining from scratch every time is prohibitively expensive.| Scenario | Challenge | Real-world Example |
|---|---|---|
| Sequential tasks | Model forgets task A when learning task B | A spam filter trained on new attack patterns forgets old ones |
| Streaming data | Distribution shift over time | A recommendation model degrades as user preferences evolve |
| Class-incremental | New classes added continuously | An object detector that must recognize new product categories |
| Domain adaptation | Same task, new domain | A medical model moving from one hospital’s imaging equipment to another |
A common misconception is that fine-tuning with a small learning rate avoids catastrophic forgetting. It does not — it merely slows it down. Even with a learning rate 100x smaller, a model fine-tuned on task B will eventually lose task A performance. The learning rate controls the speed of forgetting, not whether it occurs.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Optional, Dict, List, Tuple, Callable
from dataclasses import dataclass
from copy import deepcopy
import numpy as np
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Measuring Forgetting
@dataclass
class ContinualMetrics:
"""
Metrics for continual learning evaluation.
"""
accuracy_matrix: np.ndarray # A[i,j] = acc on task j after learning task i
@property
def average_accuracy(self) -> float:
"""Average accuracy across all tasks after all training."""
return np.mean(self.accuracy_matrix[-1])
@property
def forgetting(self) -> float:
"""
Average forgetting measure.
For each task j: max accuracy on j - final accuracy on j
"""
n_tasks = self.accuracy_matrix.shape[1]
forgetting = []
for j in range(n_tasks - 1): # Exclude last task (no forgetting possible)
max_acc = np.max(self.accuracy_matrix[j:, j])
final_acc = self.accuracy_matrix[-1, j]
forgetting.append(max_acc - final_acc)
return np.mean(forgetting)
@property
def backward_transfer(self) -> float:
"""
How much learning new tasks affects old ones.
Negative = forgetting, Positive = improvement
"""
n_tasks = self.accuracy_matrix.shape[1]
bt = []
for j in range(n_tasks - 1):
# Accuracy right after learning j vs final accuracy
initial_acc = self.accuracy_matrix[j, j]
final_acc = self.accuracy_matrix[-1, j]
bt.append(final_acc - initial_acc)
return np.mean(bt)
@property
def forward_transfer(self) -> float:
"""
How much previous learning helps new tasks.
"""
n_tasks = self.accuracy_matrix.shape[1]
ft = []
for j in range(1, n_tasks):
# Performance on task j before learning it
before_acc = self.accuracy_matrix[j-1, j]
ft.append(before_acc)
return np.mean(ft)
class ContinualEvaluator:
"""
Evaluate continual learning methods.
"""
def __init__(self, task_datasets: List[Tuple]):
"""
Args:
task_datasets: List of (train_loader, test_loader) per task
"""
self.task_datasets = task_datasets
self.n_tasks = len(task_datasets)
self.accuracy_matrix = np.zeros((self.n_tasks, self.n_tasks))
def evaluate_all_tasks(self, model: nn.Module, current_task: int):
"""Evaluate model on all tasks seen so far."""
model.eval()
for task_id in range(current_task + 1):
_, test_loader = self.task_datasets[task_id]
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs, task_id=task_id)
_, predicted = outputs.max(1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
self.accuracy_matrix[current_task, task_id] = correct / total
def get_metrics(self) -> ContinualMetrics:
return ContinualMetrics(self.accuracy_matrix.copy())
Regularization-Based Methods
Elastic Weight Consolidation (EWC)
class EWC:
"""
Elastic Weight Consolidation.
Key idea: Not all weights are equally important for a given task.
Some weights are critical (the model's output changes dramatically
if they move), while others are nearly irrelevant. EWC measures
each weight's importance using the Fisher Information Matrix, then
adds a penalty during training on new tasks that is proportional
to how much each weight deviates from its "old task" value, weighted
by its importance.
Analogy: Imagine you are rearranging furniture (weights) in a room
to fit a new desk (new task). EWC tells you: "the couch and TV
(high Fisher) are load-bearing for the old layout -- try not to move
them. The side table and lamp (low Fisher) can move freely."
Importance measured by Fisher Information Matrix (diagonal approximation).
Reference: "Overcoming catastrophic forgetting in neural networks"
(Kirkpatrick et al., 2017)
"""
def __init__(
self,
model: nn.Module,
lambda_ewc: float = 5000
):
self.model = model
self.lambda_ewc = lambda_ewc
# Storage for previous task parameters and importance
self.saved_params = {}
self.fisher = {}
def compute_fisher(
self,
data_loader,
n_samples: int = 200
):
"""
Compute Fisher Information Matrix (diagonal approximation).
Fisher[i] = E[gradient[i]^2]
High Fisher = parameter is important for current task
"""
self.model.eval()
# Initialize Fisher
fisher = {
name: torch.zeros_like(param)
for name, param in self.model.named_parameters()
if param.requires_grad
}
n_processed = 0
for inputs, labels in data_loader:
if n_processed >= n_samples:
break
inputs, labels = inputs.to(device), labels.to(device)
self.model.zero_grad()
outputs = self.model(inputs)
# Use log-likelihood (negative cross entropy)
# Sample from output distribution for proper Fisher
probs = F.softmax(outputs, dim=1)
sampled_labels = torch.multinomial(probs, 1).squeeze()
loss = F.cross_entropy(outputs, sampled_labels)
loss.backward()
# Accumulate squared gradients
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None:
fisher[name] += param.grad.data.pow(2)
n_processed += inputs.size(0)
# Normalize
for name in fisher:
fisher[name] /= n_processed
return fisher
def register_task(self, data_loader, task_id: int):
"""
Register completion of a task.
Saves parameters and computes Fisher information.
"""
# Compute Fisher for current task
fisher = self.compute_fisher(data_loader)
# Merge with previous Fisher (online EWC)
if len(self.fisher) > 0:
for name in fisher:
self.fisher[name] = (
self.fisher[name] + fisher[name]
)
else:
self.fisher = fisher
# Save current parameters
self.saved_params[task_id] = {
name: param.data.clone()
for name, param in self.model.named_parameters()
if param.requires_grad
}
def penalty(self) -> torch.Tensor:
"""
Compute EWC penalty.
penalty = sum_i F_i * (theta_i - theta_i^*)^2
"""
if len(self.saved_params) == 0:
return torch.tensor(0.0).to(device)
loss = torch.tensor(0.0).to(device)
for name, param in self.model.named_parameters():
if name in self.fisher:
# Average over all previous tasks
for task_id in self.saved_params:
old_param = self.saved_params[task_id][name]
loss += (
self.fisher[name] *
(param - old_param).pow(2)
).sum()
return self.lambda_ewc * loss / 2
class EWCTrainer:
"""Training loop with EWC."""
def __init__(
self,
model: nn.Module,
lambda_ewc: float = 5000
):
self.model = model
self.ewc = EWC(model, lambda_ewc)
def train_task(
self,
train_loader,
task_id: int,
epochs: int = 10,
lr: float = 0.001
):
"""Train on a single task with EWC regularization."""
optimizer = optim.Adam(self.model.parameters(), lr=lr)
self.model.train()
for epoch in range(epochs):
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = self.model(inputs)
# Task loss
task_loss = F.cross_entropy(outputs, labels)
# EWC penalty
ewc_loss = self.ewc.penalty()
total_loss = task_loss + ewc_loss
total_loss.backward()
optimizer.step()
# Register task completion
self.ewc.register_task(train_loader, task_id)
Synaptic Intelligence (SI)
class SynapticIntelligence:
"""
Synaptic Intelligence: Online importance estimation.
Key idea: Track contribution of each parameter to loss reduction
during training (path integral of gradients).
Reference: "Continual Learning Through Synaptic Intelligence"
(Zenke et al., 2017)
"""
def __init__(
self,
model: nn.Module,
lambda_si: float = 1.0,
epsilon: float = 1e-3
):
self.model = model
self.lambda_si = lambda_si
self.epsilon = epsilon
# Initialize tracking
self.omega = {} # Importance weights
self.old_params = {} # Parameters at task start
self.w = {} # Running importance estimate
for name, param in model.named_parameters():
if param.requires_grad:
self.omega[name] = torch.zeros_like(param)
self.old_params[name] = param.data.clone()
self.w[name] = torch.zeros_like(param)
def update_omega(self):
"""Update importance weights after task."""
for name, param in self.model.named_parameters():
if name in self.omega:
delta = param.data - self.old_params[name]
# Omega = accumulated gradient * parameter change / (delta^2 + eps)
self.omega[name] += self.w[name] / (delta.pow(2) + self.epsilon)
# Reset for next task
self.old_params[name] = param.data.clone()
self.w[name].zero_()
def update_w(self):
"""Update running importance during training."""
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None:
# Accumulate gradient * negative parameter change
# (gradient points toward loss increase)
self.w[name] -= param.grad.data * (
param.data - self.old_params[name]
)
def penalty(self) -> torch.Tensor:
"""Compute SI penalty."""
loss = torch.tensor(0.0).to(device)
for name, param in self.model.named_parameters():
if name in self.omega:
loss += (
self.omega[name] *
(param - self.old_params[name]).pow(2)
).sum()
return self.lambda_si * loss
Replay-Based Methods
Replay methods take a completely different approach from regularization: instead of constraining how the model changes, they maintain a small memory of past examples and mix them into training data for new tasks. This is directly inspired by the neuroscience theory of memory consolidation, where the hippocampus “replays” past experiences during sleep to transfer them to long-term cortical storage. The simplest version — experience replay — is often the strongest baseline. Before reaching for sophisticated methods like EWC or PackNet, always compare against replay with a reasonable buffer size. It is embarrassingly effective.When using experience replay, how you select which examples to store matters more than the buffer size. Random selection is a decent default, but “herding” (storing examples closest to the class mean in feature space) or “k-center coreset” selection (maximizing coverage of the feature space) can improve results by 2-5% with the same buffer size. Also, balancing the buffer equally across classes prevents the model from becoming biased toward recently seen classes.
Experience Replay
class ExperienceReplay:
"""
Store and replay examples from previous tasks.
Simple but surprisingly effective baseline. In many benchmarks,
experience replay with a buffer of just 500-1000 examples per task
outperforms more sophisticated regularization methods.
Uses reservoir sampling to maintain a representative subset when
the total number of observed examples exceeds buffer capacity.
"""
def __init__(
self,
buffer_size: int = 5000,
samples_per_class: Optional[int] = None
):
self.buffer_size = buffer_size
self.samples_per_class = samples_per_class
self.buffer_x = []
self.buffer_y = []
self.buffer_task = []
def add(
self,
x: torch.Tensor,
y: torch.Tensor,
task_id: int
):
"""Add samples to buffer."""
for i in range(x.size(0)):
if len(self.buffer_x) < self.buffer_size:
self.buffer_x.append(x[i].cpu())
self.buffer_y.append(y[i].cpu())
self.buffer_task.append(task_id)
else:
# Reservoir sampling
idx = np.random.randint(0, len(self.buffer_x) + 1)
if idx < self.buffer_size:
self.buffer_x[idx] = x[i].cpu()
self.buffer_y[idx] = y[i].cpu()
self.buffer_task[idx] = task_id
def sample(
self,
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
"""Sample a batch from buffer."""
if len(self.buffer_x) == 0:
return None, None, None
indices = np.random.choice(
len(self.buffer_x),
min(batch_size, len(self.buffer_x)),
replace=False
)
x = torch.stack([self.buffer_x[i] for i in indices])
y = torch.tensor([self.buffer_y[i] for i in indices])
tasks = [self.buffer_task[i] for i in indices]
return x.to(device), y.to(device), tasks
class GradientEpisodicMemory:
"""
Gradient Episodic Memory (GEM).
Key idea: Project gradients to not increase loss on previous tasks.
Reference: "Gradient Episodic Memory for Continual Learning"
(Lopez-Paz & Ranzato, 2017)
"""
def __init__(
self,
model: nn.Module,
memory_size: int = 256
):
self.model = model
self.memory_size = memory_size
self.memory = {} # {task_id: (x, y)}
self.memory_gradients = {}
def store_task_memory(
self,
data_loader,
task_id: int
):
"""Store examples from completed task."""
x_mem, y_mem = [], []
for x, y in data_loader:
x_mem.append(x)
y_mem.append(y)
if sum(len(batch) for batch in x_mem) >= self.memory_size:
break
x_mem = torch.cat(x_mem)[:self.memory_size]
y_mem = torch.cat(y_mem)[:self.memory_size]
self.memory[task_id] = (x_mem.to(device), y_mem.to(device))
def compute_memory_gradients(self):
"""Compute gradients for all memory samples."""
self.memory_gradients = {}
for task_id, (x, y) in self.memory.items():
self.model.zero_grad()
outputs = self.model(x)
loss = F.cross_entropy(outputs, y)
loss.backward()
# Flatten gradients
grad = torch.cat([
p.grad.view(-1)
for p in self.model.parameters()
if p.grad is not None
])
self.memory_gradients[task_id] = grad.clone()
def project_gradient(self, current_grad: torch.Tensor) -> torch.Tensor:
"""
Project current gradient to satisfy constraints.
Constraint: g · g_memory >= 0 for all memory gradients
"""
if len(self.memory_gradients) == 0:
return current_grad
mem_grads = torch.stack(list(self.memory_gradients.values()))
# Check constraint violations
dotprods = torch.mv(mem_grads, current_grad)
if (dotprods >= 0).all():
return current_grad
# Project using quadratic programming
# Simplified: project onto each violated constraint
projected = current_grad.clone()
for i, dotprod in enumerate(dotprods):
if dotprod < 0:
# Remove component along memory gradient
mem_grad = mem_grads[i]
projected -= (dotprod / (mem_grad.norm().pow(2) + 1e-8)) * mem_grad
return projected
class AveragedGEM:
"""
A-GEM: Efficient version of GEM.
Uses average of memory gradients instead of all constraints.
"""
def __init__(
self,
model: nn.Module,
memory_size: int = 256
):
self.model = model
self.memory = ExperienceReplay(memory_size)
def project_gradient(self):
"""Project current gradient using A-GEM."""
if len(self.memory.buffer_x) == 0:
return
# Get current gradient
current_grad = torch.cat([
p.grad.view(-1)
for p in self.model.parameters()
if p.grad is not None
])
# Compute reference gradient on memory
x_mem, y_mem, _ = self.memory.sample(batch_size=256)
self.model.zero_grad()
outputs = self.model(x_mem)
loss = F.cross_entropy(outputs, y_mem)
loss.backward()
ref_grad = torch.cat([
p.grad.view(-1)
for p in self.model.parameters()
if p.grad is not None
])
# Check if constraint is violated
dotprod = torch.dot(current_grad, ref_grad)
if dotprod < 0:
# Project
projected = current_grad - (
dotprod / (ref_grad.norm().pow(2) + 1e-8)
) * ref_grad
# Apply projected gradient
offset = 0
for p in self.model.parameters():
if p.grad is not None:
numel = p.numel()
p.grad.data.copy_(
projected[offset:offset + numel].view_as(p)
)
offset += numel
Generative Replay
Generative replay is an elegant alternative to storing real examples: instead of maintaining a buffer, train a generative model (VAE, GAN, diffusion model) alongside the classifier. When learning a new task, the generator produces synthetic examples from previous tasks that are mixed into training. The generator itself must also be trained continually, so you are solving two continual learning problems at once — but the payoff is zero storage of real data, which matters in privacy-sensitive domains like healthcare.Generative replay sounds appealing in theory, but in practice the quality of the generator matters enormously. If the generator produces low-fidelity samples (blurry images, mode collapse), the classifier trained on them will degrade. For simple datasets (MNIST, Fashion-MNIST), generative replay works well. For complex datasets (CIFAR-100, ImageNet), the generator itself becomes the bottleneck. Modern diffusion models have improved this significantly, but the compute cost of maintaining a diffusion model alongside your classifier may exceed the cost of simply storing a replay buffer.
class GenerativeReplay(nn.Module):
"""
Generate pseudo-examples from previous tasks using a learned generator.
Architecture: VAE (or GAN/diffusion model) that learns to generate
samples representing the distribution of each task seen so far.
Trade-offs:
- Pro: No real data storage needed (privacy-friendly)
- Con: Generator quality limits classifier performance
- Con: Training two models (generator + classifier) continually is harder
"""
def __init__(
self,
input_dim: int,
hidden_dim: int = 256,
latent_dim: int = 32
):
super().__init__()
# VAE for generating replay samples
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_var = nn.Linear(hidden_dim, latent_dim)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
self.input_dim = input_dim
self.latent_dim = latent_dim
def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
h = self.encoder(x.view(-1, self.input_dim))
return self.fc_mu(h), self.fc_var(h)
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z: torch.Tensor) -> torch.Tensor:
return self.decoder(z)
def generate(self, n_samples: int) -> torch.Tensor:
"""Generate replay samples."""
z = torch.randn(n_samples, self.latent_dim).to(device)
return self.decode(z)
def vae_loss(self, x: torch.Tensor) -> torch.Tensor:
"""VAE loss for training generator."""
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
recon = self.decode(z)
recon_loss = F.mse_loss(recon, x.view(-1, self.input_dim))
kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + 0.1 * kl_loss
class DGRTrainer:
"""
Deep Generative Replay training.
Train classifier + generator together.
Generate old task data using generator.
"""
def __init__(
self,
classifier: nn.Module,
generator: GenerativeReplay,
replay_ratio: float = 0.5
):
self.classifier = classifier
self.generator = generator
self.replay_ratio = replay_ratio
# Store old classifier for pseudo-labeling
self.old_classifier = None
def train_task(
self,
train_loader,
task_id: int,
epochs: int = 10,
lr: float = 0.001
):
"""Train on task with generative replay."""
clf_optimizer = optim.Adam(self.classifier.parameters(), lr=lr)
gen_optimizer = optim.Adam(self.generator.parameters(), lr=lr)
for epoch in range(epochs):
self.classifier.train()
self.generator.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
batch_size = x.size(0)
# Generate replay samples
if task_id > 0 and self.old_classifier is not None:
n_replay = int(batch_size * self.replay_ratio)
with torch.no_grad():
x_replay = self.generator.generate(n_replay)
# Pseudo-labels from old classifier
y_replay = self.old_classifier(x_replay).argmax(dim=1)
# Combine current and replay data
x = torch.cat([x.view(batch_size, -1), x_replay])
y = torch.cat([y, y_replay])
# Train classifier
clf_optimizer.zero_grad()
outputs = self.classifier(x.view(x.size(0), -1))
clf_loss = F.cross_entropy(outputs, y)
clf_loss.backward()
clf_optimizer.step()
# Train generator on current task data
gen_optimizer.zero_grad()
gen_loss = self.generator.vae_loss(x[:batch_size])
gen_loss.backward()
gen_optimizer.step()
# Store classifier for next task
self.old_classifier = deepcopy(self.classifier)
self.old_classifier.eval()
Architecture-Based Methods
Architecture-based methods take the most direct approach to preventing forgetting: give each task its own parameters. The old parameters are literally frozen — they cannot be modified, so they cannot be forgotten. The challenge shifts from “how to avoid forgetting” to “how to share knowledge across tasks without running out of capacity.” Think of it like a library: regularization methods try to write new books without erasing old ones (hard). Replay methods keep photocopies of old books (memory cost). Architecture methods add new bookshelves for new topics, while allowing readers to reference old shelves (capacity cost). Each approach trades a different resource.Progressive Neural Networks
class ProgressiveNeuralNetwork(nn.Module):
"""
Progressive Neural Networks.
Key idea: Add new columns (subnetworks) for new tasks,
connect laterally to previous columns so new tasks can
leverage features learned by old tasks.
Zero forgetting (old columns are frozen -- literally impossible
to forget because old weights never change).
Trade-off: Model size grows linearly with the number of tasks.
After 100 tasks, the model is 100x its original size.
Reference: "Progressive Neural Networks" (Rusu et al., 2016)
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int
):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
# Columns for each task
self.columns = nn.ModuleList()
# Lateral connections
self.lateral = nn.ModuleList()
def add_column(self):
"""Add a new column for a new task."""
n_cols = len(self.columns)
# New column
column = nn.ModuleDict({
'layer1': nn.Linear(self.input_dim, self.hidden_dim),
'layer2': nn.Linear(self.hidden_dim, self.hidden_dim),
'head': nn.Linear(self.hidden_dim, self.output_dim)
})
self.columns.append(column)
# Lateral connections from previous columns
if n_cols > 0:
lateral = nn.ModuleDict({
'lateral1': nn.Linear(n_cols * self.hidden_dim, self.hidden_dim),
'lateral2': nn.Linear(n_cols * self.hidden_dim, self.hidden_dim)
})
self.lateral.append(lateral)
# Freeze all previous columns
for i, col in enumerate(self.columns[:-1]):
for param in col.parameters():
param.requires_grad = False
def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor:
"""Forward through column for specific task."""
# Collect all hidden activations
h1_all = []
h2_all = []
for col_id in range(task_id + 1):
col = self.columns[col_id]
# First layer
h1 = F.relu(col['layer1'](x))
# Add lateral connections
if col_id > 0:
lateral_input = torch.cat(h1_all, dim=1)
h1 = h1 + self.lateral[col_id - 1]['lateral1'](lateral_input)
h1_all.append(h1)
# Second layer
h2 = F.relu(col['layer2'](h1))
if col_id > 0:
lateral_input = torch.cat(h2_all, dim=1)
h2 = h2 + self.lateral[col_id - 1]['lateral2'](lateral_input)
h2_all.append(h2)
# Output from target task column
output = self.columns[task_id]['head'](h2_all[task_id])
return output
class PackNet:
"""
PackNet: Prune and reuse for continual learning.
Key idea:
1. Train on task
2. Prune unimportant weights
3. Freeze pruned network
4. Retrain remaining weights for new task
Reference: "PackNet: Adding Multiple Tasks to a Single Network"
(Mallya & Lazebnik, 2018)
"""
def __init__(
self,
model: nn.Module,
prune_ratio: float = 0.75
):
self.model = model
self.prune_ratio = prune_ratio
# Track which weights are used by which tasks
self.masks = {} # {task_id: {param_name: mask}}
self.available_mask = {} # Which weights are still available
# Initialize available mask
for name, param in model.named_parameters():
self.available_mask[name] = torch.ones_like(param, dtype=torch.bool)
def compute_importance(self) -> Dict[str, torch.Tensor]:
"""Compute weight importance (magnitude-based)."""
importance = {}
for name, param in self.model.named_parameters():
if param.requires_grad:
importance[name] = param.data.abs()
return importance
def prune_for_task(self, task_id: int):
"""Prune weights after training on task."""
importance = self.compute_importance()
task_mask = {}
for name, imp in importance.items():
available = self.available_mask[name]
# Only consider available weights
imp_available = imp * available.float()
# Find threshold for pruning
n_keep = int((1 - self.prune_ratio) * available.sum().item())
if n_keep > 0:
threshold = torch.kthvalue(
imp_available[available].flatten(),
len(imp_available[available].flatten()) - n_keep
)[0]
# Keep weights above threshold
mask = imp_available >= threshold
else:
mask = torch.zeros_like(available)
task_mask[name] = mask
# Update available weights
self.available_mask[name] = available & ~mask
self.masks[task_id] = task_mask
def apply_mask(self, task_id: int):
"""Apply task mask to model."""
mask = self.masks[task_id]
for name, param in self.model.named_parameters():
if name in mask:
param.data *= mask[name].float()
Advanced Methods
Dark Experience Replay
class DarkExperienceReplay:
"""
Dark Experience Replay (DER).
Store logits along with inputs for better replay.
Combines ideas from knowledge distillation with replay.
Reference: "Dark Experience for General Continual Learning"
(Buzzega et al., 2020)
"""
def __init__(
self,
buffer_size: int = 5000,
alpha: float = 0.5
):
self.buffer_size = buffer_size
self.alpha = alpha
self.buffer_x = []
self.buffer_y = []
self.buffer_logits = []
def add(
self,
x: torch.Tensor,
y: torch.Tensor,
logits: torch.Tensor
):
"""Add samples with their logits."""
for i in range(x.size(0)):
if len(self.buffer_x) < self.buffer_size:
self.buffer_x.append(x[i].cpu())
self.buffer_y.append(y[i].cpu())
self.buffer_logits.append(logits[i].detach().cpu())
else:
idx = np.random.randint(0, len(self.buffer_x))
self.buffer_x[idx] = x[i].cpu()
self.buffer_y[idx] = y[i].cpu()
self.buffer_logits[idx] = logits[i].detach().cpu()
def compute_loss(
self,
model: nn.Module,
batch_size: int
) -> torch.Tensor:
"""Compute DER loss on buffer samples."""
if len(self.buffer_x) == 0:
return torch.tensor(0.0).to(device)
indices = np.random.choice(
len(self.buffer_x),
min(batch_size, len(self.buffer_x)),
replace=False
)
x = torch.stack([self.buffer_x[i] for i in indices]).to(device)
y = torch.tensor([self.buffer_y[i] for i in indices]).to(device)
old_logits = torch.stack([self.buffer_logits[i] for i in indices]).to(device)
# Current predictions
new_logits = model(x)
# Combined loss
ce_loss = F.cross_entropy(new_logits, y)
mse_loss = F.mse_loss(new_logits, old_logits)
return self.alpha * ce_loss + (1 - self.alpha) * mse_loss
class OnlineMetaLearning:
"""
Online Meta-Learning for Continual Learning.
Use meta-learning to adapt quickly to new tasks
while maintaining performance on old ones.
"""
def __init__(
self,
model: nn.Module,
inner_lr: float = 0.01,
outer_lr: float = 0.001
):
self.model = model
self.inner_lr = inner_lr
self.outer_lr = outer_lr
self.outer_optimizer = optim.Adam(model.parameters(), lr=outer_lr)
self.memory = ExperienceReplay(buffer_size=1000)
def train_step(
self,
x_current: torch.Tensor,
y_current: torch.Tensor
):
"""MAML-style training step."""
# Inner loop: adapt to current batch
adapted_model = deepcopy(self.model)
for _ in range(5): # Inner steps
outputs = adapted_model(x_current)
loss = F.cross_entropy(outputs, y_current)
grads = torch.autograd.grad(
loss, adapted_model.parameters()
)
for param, grad in zip(adapted_model.parameters(), grads):
param.data -= self.inner_lr * grad
# Outer loop: update on both current and memory
self.outer_optimizer.zero_grad()
# Loss on current task (using adapted model)
current_loss = F.cross_entropy(
adapted_model(x_current), y_current
)
# Loss on memory (using original model)
x_mem, y_mem, _ = self.memory.sample(x_current.size(0))
if x_mem is not None:
memory_loss = F.cross_entropy(self.model(x_mem), y_mem)
total_loss = current_loss + memory_loss
else:
total_loss = current_loss
total_loss.backward()
self.outer_optimizer.step()
# Update memory
self.memory.add(x_current, y_current, 0)
Best Practices
def continual_learning_summary():
"""Print continual learning methods summary."""
summary = """
╔════════════════════════════════════════════════════════════════════╗
║ CONTINUAL LEARNING METHODS SUMMARY ║
╠════════════════════════════════════════════════════════════════════╣
║ ║
║ REGULARIZATION-BASED ║
║ • EWC: Penalize changes to important weights ║
║ • SI: Online importance estimation ║
║ + No memory storage needed ║
║ - May not scale to many tasks ║
║ ║
║ REPLAY-BASED ║
║ • Experience Replay: Store raw examples ║
║ • GEM/A-GEM: Constrain gradients ║
║ • Generative Replay: Generate pseudo-examples ║
║ + Very effective ║
║ - Memory or compute overhead ║
║ ║
║ ARCHITECTURE-BASED ║
║ • Progressive Nets: Add new columns ║
║ • PackNet: Prune and reuse ║
║ + No forgetting (frozen weights) ║
║ - Model grows with tasks ║
║ ║
╠════════════════════════════════════════════════════════════════════╣
║ RECOMMENDATIONS ║
╠════════════════════════════════════════════════════════════════════╣
║ ║
║ Few tasks, limited memory: EWC or SI ║
║ Many tasks, some memory: Experience Replay + regularization ║
║ Critical no-forgetting: Architecture-based methods ║
║ Large models: Generative replay or knowledge distillation ║
║ ║
║ GENERAL TIPS: ║
║ • Always evaluate on ALL previous tasks ║
║ • Measure both accuracy and forgetting ║
║ • Consider task boundaries (known vs unknown) ║
║ • Test with different task orderings ║
║ ║
╚════════════════════════════════════════════════════════════════════╝
"""
print(summary)
continual_learning_summary()
Exercises
Exercise 1: Compare Methods
Exercise 1: Compare Methods
Implement and compare on Split-MNIST:
- Fine-tuning (baseline)
- EWC
- Experience Replay
- A-GEM
Exercise 2: Task Order Sensitivity
Exercise 2: Task Order Sensitivity
Test EWC on different task orderings:
- Easy to hard
- Hard to easy
- Random
Exercise 3: Hybrid Approach
Exercise 3: Hybrid Approach
Combine EWC with Experience Replay:
- What’s the optimal combination?
- Does it beat either method alone?
What’s Next?
Model Compression
Quantization and pruning techniques
Capstone Project
Apply all techniques in a complete project