Documentation Index
Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt
Use this file to discover all available pages before exploring further.
Knowledge Distillation
The Teacher-Student Framework
Knowledge distillation is based on a surprisingly human analogy: a master chef (the teacher) has spent years developing intuition about flavor combinations, but they can train an apprentice (the student) far faster than the apprentice could learn from raw ingredients alone. The apprentice does not need to replicate every experience the master had — they just need to absorb the master’s refined judgment. In ML terms: a large, expensive model has learned a rich understanding of the data, and we can transfer that understanding to a smaller, deployable model more efficiently than training the small model from scratch on raw labels. The key insight from Hinton’s original 2015 paper: when a teacher classifies a cat image, its “soft” output might say “90% cat, 8% dog, 2% tiger.” Those secondary probabilities (the “dark knowledge”) carry information that hard labels (“cat”) miss entirely. They tell the student that cats and dogs are more similar than cats and cars — relationships that would take the student many more examples to discover on its own.| Aspect | Teacher | Student |
|---|---|---|
| Size | Large (billions of params) | Small (millions) |
| Accuracy | High | Approaches teacher |
| Speed | Slow | Fast |
| Deployment | Expensive | Cheap |
Distillation is not just for model compression. It is also used to: (1) transfer knowledge from an ensemble to a single model, (2) train models on private data by distilling from a model trained with access, and (3) improve training stability by providing smoother gradient signals than hard labels.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Optional, Tuple, Dict, List, Callable
import numpy as np
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Core Distillation Methods
Soft Target Distillation
class SoftTargetDistillation:
"""
Original Knowledge Distillation (Hinton et al., 2015).
Key insight: Soft targets from teacher contain more information
than hard labels. The "dark knowledge" in probability distribution
reveals class similarities and uncertainty.
Loss = α * KL(student || teacher) + (1-α) * CE(student, labels)
"""
def __init__(
self,
temperature: float = 4.0,
alpha: float = 0.7
):
"""
Args:
temperature: Higher T -> softer probabilities
alpha: Weight for distillation loss vs hard label loss
"""
self.temperature = temperature
self.alpha = alpha
def distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute distillation loss.
Args:
student_logits: [batch, num_classes]
teacher_logits: [batch, num_classes]
labels: [batch] optional hard labels
"""
T = self.temperature
# Soft targets from teacher
soft_teacher = F.softmax(teacher_logits / T, dim=-1)
# Soft predictions from student
soft_student = F.log_softmax(student_logits / T, dim=-1)
# KL divergence (scaled by T^2 for gradient magnitude)
distill_loss = F.kl_div(
soft_student,
soft_teacher,
reduction='batchmean'
) * (T ** 2)
if labels is not None:
# Hard label loss
hard_loss = F.cross_entropy(student_logits, labels)
return self.alpha * distill_loss + (1 - self.alpha) * hard_loss
return distill_loss
class DistillationTrainer:
"""
Complete training loop for knowledge distillation.
"""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
temperature: float = 4.0,
alpha: float = 0.7
):
self.teacher = teacher.eval()
self.student = student
self.distillation = SoftTargetDistillation(temperature, alpha)
# Freeze teacher
for param in self.teacher.parameters():
param.requires_grad = False
def train_step(
self,
inputs: torch.Tensor,
labels: torch.Tensor,
optimizer: optim.Optimizer
) -> Dict[str, float]:
"""Single training step."""
self.student.train()
optimizer.zero_grad()
# Get teacher predictions
with torch.no_grad():
teacher_logits = self.teacher(inputs)
# Get student predictions
student_logits = self.student(inputs)
# Compute loss
loss = self.distillation.distillation_loss(
student_logits, teacher_logits, labels
)
loss.backward()
optimizer.step()
# Metrics
with torch.no_grad():
student_acc = (student_logits.argmax(dim=-1) == labels).float().mean()
teacher_acc = (teacher_logits.argmax(dim=-1) == labels).float().mean()
return {
'loss': loss.item(),
'student_acc': student_acc.item(),
'teacher_acc': teacher_acc.item()
}
def train(
self,
train_loader,
optimizer: optim.Optimizer,
epochs: int = 10
):
"""Full training loop."""
for epoch in range(epochs):
total_loss = 0
total_acc = 0
for batch_idx, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.to(device), labels.to(device)
metrics = self.train_step(inputs, labels, optimizer)
total_loss += metrics['loss']
total_acc += metrics['student_acc']
avg_loss = total_loss / len(train_loader)
avg_acc = total_acc / len(train_loader)
print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Acc={avg_acc:.4f}")
# Example usage
def create_example_models():
"""Create example teacher and student models."""
# Large teacher
teacher = nn.Sequential(
nn.Linear(784, 1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
# Small student
student = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
print(f"Teacher params: {sum(p.numel() for p in teacher.parameters()):,}")
print(f"Student params: {sum(p.numel() for p in student.parameters()):,}")
return teacher, student
create_example_models()
Feature-Based Distillation
Soft-target distillation only uses the final output layer. But the teacher’s intermediate representations also carry valuable information — the way the teacher structures its internal representations (which features activate together, which spatial regions get attention) is knowledge the student can absorb. Feature-based distillation forces the student’s hidden layers to mimic the teacher’s, providing a much richer training signal. Think of it this way: soft-target distillation is like learning from a teacher’s final answers on an exam. Feature-based distillation is like also seeing the teacher’s scratch work — the intermediate reasoning steps that led to those answers.class FeatureDistillation(nn.Module):
"""
Distill intermediate feature representations.
Methods:
- FitNets: Match intermediate features directly (L2 loss between layers)
- Attention Transfer: Match spatial attention maps (which regions matter)
- Contrastive: Learn similar representation geometry (pull same-sample
student/teacher pairs together, push different-sample pairs apart)
Practical tip: Start with attention transfer -- it is the most forgiving
because it does not require teacher and student layers to have the same
dimension (attention maps are always spatial). FitNets requires a projection
layer whenever dimensions mismatch.
"""
def __init__(
self,
teacher_channels: List[int],
student_channels: List[int],
method: str = 'fitnet' # 'fitnet', 'attention', 'contrastive'
):
super().__init__()
self.method = method
# Projection layers to match dimensions
self.projections = nn.ModuleList([
nn.Conv2d(s_ch, t_ch, 1) if s_ch != t_ch else nn.Identity()
for s_ch, t_ch in zip(student_channels, teacher_channels)
])
def fitnet_loss(
self,
student_features: List[torch.Tensor],
teacher_features: List[torch.Tensor]
) -> torch.Tensor:
"""
FitNets: Train student hidden layers to mimic teacher.
Uses L2 loss between intermediate representations.
"""
total_loss = 0
for i, (s_feat, t_feat) in enumerate(zip(student_features, teacher_features)):
# Project student features to teacher dimension
s_proj = self.projections[i](s_feat)
# L2 loss (normalized)
loss = F.mse_loss(
F.normalize(s_proj, dim=1),
F.normalize(t_feat, dim=1)
)
total_loss += loss
return total_loss / len(student_features)
def attention_transfer_loss(
self,
student_features: List[torch.Tensor],
teacher_features: List[torch.Tensor]
) -> torch.Tensor:
"""
Attention Transfer: Match spatial attention maps.
Attention map = sum of squared activations across channels.
"""
total_loss = 0
for s_feat, t_feat in zip(student_features, teacher_features):
# Compute attention maps
s_attn = self._compute_attention_map(s_feat)
t_attn = self._compute_attention_map(t_feat)
# Match attention patterns
loss = F.mse_loss(s_attn, t_attn)
total_loss += loss
return total_loss / len(student_features)
def _compute_attention_map(self, feature: torch.Tensor) -> torch.Tensor:
"""
Compute spatial attention map.
Sum of squared activations, normalized.
"""
# feature: [B, C, H, W]
attn = feature.pow(2).sum(dim=1) # [B, H, W]
# Normalize
attn = attn.view(attn.size(0), -1)
attn = F.normalize(attn, dim=1)
return attn
def contrastive_loss(
self,
student_features: torch.Tensor,
teacher_features: torch.Tensor,
temperature: float = 0.5
) -> torch.Tensor:
"""
Contrastive distillation: Learn similar representations.
Pull together (student, teacher) pairs from same sample,
push apart pairs from different samples.
"""
# Global average pooling if spatial
if student_features.dim() == 4:
student_features = student_features.mean(dim=[2, 3])
if teacher_features.dim() == 4:
teacher_features = teacher_features.mean(dim=[2, 3])
# Normalize
s_norm = F.normalize(student_features, dim=1)
t_norm = F.normalize(teacher_features, dim=1)
# Similarity matrix
batch_size = s_norm.size(0)
sim_matrix = torch.matmul(s_norm, t_norm.T) / temperature
# Labels: diagonal elements should be similar
labels = torch.arange(batch_size).to(sim_matrix.device)
# Cross entropy loss
loss = F.cross_entropy(sim_matrix, labels)
return loss
def forward(
self,
student_features: List[torch.Tensor],
teacher_features: List[torch.Tensor]
) -> torch.Tensor:
"""Compute feature distillation loss."""
if self.method == 'fitnet':
return self.fitnet_loss(student_features, teacher_features)
elif self.method == 'attention':
return self.attention_transfer_loss(student_features, teacher_features)
elif self.method == 'contrastive':
return self.contrastive_loss(student_features[-1], teacher_features[-1])
class FeatureExtractingModel(nn.Module):
"""
Wrapper to extract intermediate features.
"""
def __init__(self, model: nn.Module, layer_names: List[str]):
super().__init__()
self.model = model
self.layer_names = layer_names
self.features = {}
# Register hooks
for name, module in model.named_modules():
if name in layer_names:
module.register_forward_hook(self._make_hook(name))
def _make_hook(self, name: str):
def hook(module, input, output):
self.features[name] = output
return hook
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
output = self.model(x)
return output, self.features.copy()
Relation-Based Distillation
class RelationDistillation:
"""
Distill relationships between samples, not just outputs.
Key insight: Structure of representation space matters.
"""
@staticmethod
def distance_wise_loss(
student_features: torch.Tensor,
teacher_features: torch.Tensor
) -> torch.Tensor:
"""
Distance-wise distillation.
Preserve pairwise distances between samples.
"""
# Compute pairwise distances
s_dist = torch.cdist(student_features, student_features)
t_dist = torch.cdist(teacher_features, teacher_features)
# Normalize
s_dist = s_dist / s_dist.max()
t_dist = t_dist / t_dist.max()
# Huber loss for robustness
loss = F.smooth_l1_loss(s_dist, t_dist)
return loss
@staticmethod
def angle_wise_loss(
student_features: torch.Tensor,
teacher_features: torch.Tensor
) -> torch.Tensor:
"""
Angle-wise distillation.
Preserve angular relationships (correlations).
"""
# Compute Gram matrices (correlations)
s_gram = student_features @ student_features.T
t_gram = teacher_features @ teacher_features.T
# Normalize
s_gram = s_gram / (s_gram.norm() + 1e-8)
t_gram = t_gram / (t_gram.norm() + 1e-8)
loss = F.mse_loss(s_gram, t_gram)
return loss
@staticmethod
def similarity_preserving_loss(
student_features: torch.Tensor,
teacher_features: torch.Tensor
) -> torch.Tensor:
"""
Similarity-Preserving distillation.
Preserve which samples are similar/dissimilar.
"""
# Cosine similarity matrices
s_sim = F.cosine_similarity(
student_features.unsqueeze(1),
student_features.unsqueeze(0),
dim=2
)
t_sim = F.cosine_similarity(
teacher_features.unsqueeze(1),
teacher_features.unsqueeze(0),
dim=2
)
loss = F.mse_loss(s_sim, t_sim)
return loss
class CRDLoss(nn.Module):
"""
Contrastive Representation Distillation.
Use contrastive learning objective to transfer knowledge.
More effective than simple L2 matching.
Reference: "Contrastive Representation Distillation" (Tian et al., 2020)
"""
def __init__(
self,
student_dim: int,
teacher_dim: int,
feature_dim: int = 128,
num_negatives: int = 16384,
temperature: float = 0.07
):
super().__init__()
self.temperature = temperature
self.num_negatives = num_negatives
# Projection heads
self.student_proj = nn.Sequential(
nn.Linear(student_dim, feature_dim),
nn.ReLU(),
nn.Linear(feature_dim, feature_dim)
)
self.teacher_proj = nn.Sequential(
nn.Linear(teacher_dim, feature_dim),
nn.ReLU(),
nn.Linear(feature_dim, feature_dim)
)
# Memory bank for negatives
self.register_buffer(
'memory_bank',
torch.randn(num_negatives, feature_dim)
)
self.register_buffer('memory_ptr', torch.zeros(1, dtype=torch.long))
def update_memory(self, features: torch.Tensor):
"""Update memory bank with new features."""
batch_size = features.size(0)
ptr = int(self.memory_ptr)
# Replace old features
if ptr + batch_size > self.num_negatives:
ptr = 0
self.memory_bank[ptr:ptr + batch_size] = features.detach()
self.memory_ptr[0] = (ptr + batch_size) % self.num_negatives
def forward(
self,
student_features: torch.Tensor,
teacher_features: torch.Tensor
) -> torch.Tensor:
"""Compute CRD loss."""
# Project features
s_proj = F.normalize(self.student_proj(student_features), dim=1)
t_proj = F.normalize(self.teacher_proj(teacher_features), dim=1)
batch_size = s_proj.size(0)
# Positive pairs: student-teacher from same sample
pos_sim = torch.sum(s_proj * t_proj, dim=1, keepdim=True)
# Negative pairs: student against memory bank
neg_sim = s_proj @ self.memory_bank.T
# InfoNCE loss
logits = torch.cat([pos_sim, neg_sim], dim=1) / self.temperature
labels = torch.zeros(batch_size, dtype=torch.long, device=logits.device)
loss = F.cross_entropy(logits, labels)
# Update memory bank
self.update_memory(t_proj)
return loss
Self-Distillation
Self-distillation is one of the most surprising results in knowledge distillation: you can improve a model by distilling it from itself. There is no separate, larger teacher. The model’s own deeper layers teach its shallower layers, or a previously trained copy of the same architecture serves as the teacher. The fact that this works at all suggests that the “dark knowledge” in soft targets provides a regularization benefit that goes beyond the information in hard labels.Born-Again Networks (self-distillation through iterative retraining) typically gives diminishing returns after 2-3 generations. The first self-distillation step usually provides the most improvement (0.5-1% accuracy). Beyond generation 3, the student often plateaus or slightly degrades. Do not assume more generations equals more improvement.
class SelfDistillation(nn.Module):
"""
Distill knowledge within the same network.
Methods:
- Deep supervision: Attach auxiliary classifiers at intermediate layers;
the deepest layer's soft targets train the shallower classifiers
- Born-again networks: Train the same architecture iteratively, using
the previous generation as teacher
- Be Your Own Teacher: Use an ensemble of augmented views of the same
input as a form of implicit distillation
"""
def __init__(
self,
backbone: nn.Module,
num_classes: int,
hidden_dims: List[int]
):
super().__init__()
self.backbone = backbone
# Auxiliary classifiers at intermediate layers
self.aux_classifiers = nn.ModuleList([
nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(dim, num_classes)
)
for dim in hidden_dims
])
# Final classifier
self.classifier = nn.Linear(hidden_dims[-1], num_classes)
def forward(
self,
x: torch.Tensor,
return_features: bool = False
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Forward with auxiliary predictions.
"""
features = []
aux_logits = []
# Get features at each layer (implement based on backbone)
for layer, aux_clf in zip(self.backbone, self.aux_classifiers):
x = layer(x)
features.append(x)
# Auxiliary prediction
aux_out = aux_clf(x)
aux_logits.append(aux_out)
# Final prediction
final_logits = self.classifier(x.flatten(1))
if return_features:
return final_logits, aux_logits, features
return final_logits, aux_logits
def self_distillation_loss(
self,
final_logits: torch.Tensor,
aux_logits: List[torch.Tensor],
labels: torch.Tensor,
temperature: float = 4.0
) -> torch.Tensor:
"""
Compute self-distillation loss.
Deeper layers teach shallower layers.
"""
# Final layer loss
final_loss = F.cross_entropy(final_logits, labels)
# Distillation losses
distill_loss = 0
soft_final = F.softmax(final_logits / temperature, dim=-1)
for aux_logit in aux_logits:
# Distill from final to aux
soft_aux = F.log_softmax(aux_logit / temperature, dim=-1)
kl_loss = F.kl_div(soft_aux, soft_final.detach(), reduction='batchmean')
distill_loss += kl_loss * (temperature ** 2)
# Also use hard labels for aux
distill_loss += 0.5 * F.cross_entropy(aux_logit, labels)
distill_loss /= len(aux_logits)
return final_loss + 0.5 * distill_loss
class BornAgainNetworks:
"""
Born-Again Networks: Iterative self-distillation.
1. Train model from scratch (generation 1)
2. Use model as teacher to train same architecture (generation 2)
3. Repeat...
Often improves over vanilla training!
"""
def __init__(
self,
model_fn: Callable[[], nn.Module],
temperature: float = 4.0
):
self.model_fn = model_fn
self.temperature = temperature
self.generations = []
def train_generation(
self,
train_loader,
epochs: int,
lr: float = 0.001,
teacher: Optional[nn.Module] = None
) -> nn.Module:
"""Train one generation."""
student = self.model_fn().to(device)
optimizer = optim.Adam(student.parameters(), lr=lr)
distiller = SoftTargetDistillation(self.temperature) if teacher else None
if teacher:
teacher.eval()
for param in teacher.parameters():
param.requires_grad = False
for epoch in range(epochs):
student.train()
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
student_logits = student(inputs)
if teacher:
with torch.no_grad():
teacher_logits = teacher(inputs)
loss = distiller.distillation_loss(
student_logits, teacher_logits, labels
)
else:
loss = F.cross_entropy(student_logits, labels)
loss.backward()
optimizer.step()
self.generations.append(student)
return student
def train(
self,
train_loader,
num_generations: int = 3,
epochs_per_gen: int = 10
) -> nn.Module:
"""Train multiple generations."""
teacher = None
for gen in range(num_generations):
print(f"Training generation {gen + 1}...")
teacher = self.train_generation(
train_loader, epochs_per_gen, teacher=teacher
)
return teacher
Task-Specific Distillation
The methods above are general-purpose, but real-world distillation often requires task-specific adaptations. Object detection models have multiple heads (classification + regression) and multi-scale feature pyramids. Language models have attention patterns and layered hidden states. Simply applying vanilla KD to the final output of these models leaves significant knowledge on the table. The following recipes show how to distill knowledge from each component of complex architectures.For NLP model distillation, the most impactful technique is usually attention pattern matching (used in TinyBERT and MiniLM). The attention matrices encode which tokens the model considers related, and this structural knowledge transfers very effectively. For vision models, FPN feature distillation at all scales tends to give the biggest gains because different scales capture different-sized objects.
class DetectionDistillation:
"""
Distillation for object detection models.
Challenges unique to detection:
- Multi-task output (classification + bounding box regression)
- Variable number of objects per image (cannot simply softmax over all)
- Feature Pyramid Networks require distillation at multiple spatial scales
- Foreground/background imbalance means most predictions are "no object"
"""
@staticmethod
def classification_distillation(
student_cls: torch.Tensor,
teacher_cls: torch.Tensor,
temperature: float = 2.0
) -> torch.Tensor:
"""Distill classification heads."""
soft_teacher = F.sigmoid(teacher_cls / temperature)
soft_student = F.sigmoid(student_cls / temperature)
loss = F.binary_cross_entropy(
soft_student, soft_teacher.detach()
) * (temperature ** 2)
return loss
@staticmethod
def regression_distillation(
student_box: torch.Tensor,
teacher_box: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Distill bounding box regression."""
loss = F.smooth_l1_loss(student_box, teacher_box, reduction='none')
if mask is not None:
loss = loss * mask.unsqueeze(-1)
return loss.mean()
@staticmethod
def fpn_feature_distillation(
student_features: Dict[str, torch.Tensor],
teacher_features: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""Distill FPN features at each scale."""
total_loss = 0
for level in student_features.keys():
s_feat = student_features[level]
t_feat = teacher_features[level]
# Adapt student features if needed
if s_feat.shape != t_feat.shape:
# Add adaptation layer
pass
# L2 loss with attention masking
importance = t_feat.abs().mean(dim=1, keepdim=True)
importance = importance / importance.max()
loss = F.mse_loss(s_feat * importance, t_feat * importance)
total_loss += loss
return total_loss / len(student_features)
class NLPDistillation:
"""
Distillation for language models.
Used in: DistilBERT, TinyBERT, MiniLM
"""
@staticmethod
def embedding_distillation(
student_emb: torch.Tensor,
teacher_emb: torch.Tensor
) -> torch.Tensor:
"""Distill word embeddings."""
return F.mse_loss(student_emb, teacher_emb)
@staticmethod
def attention_distillation(
student_attn: torch.Tensor, # [B, H, L, L]
teacher_attn: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Distill attention patterns.
Student learns to attend like teacher.
"""
# KL divergence on attention distributions
s_attn = student_attn.log_softmax(dim=-1)
t_attn = teacher_attn.softmax(dim=-1)
loss = F.kl_div(s_attn, t_attn, reduction='none')
if attention_mask is not None:
# Mask padded positions
mask = attention_mask.unsqueeze(1).unsqueeze(2)
loss = loss * mask
return loss.mean()
@staticmethod
def hidden_state_distillation(
student_hidden: List[torch.Tensor],
teacher_hidden: List[torch.Tensor],
layer_mapping: Optional[Dict[int, int]] = None
) -> torch.Tensor:
"""
Distill hidden states.
Common mappings for unequal depths:
- Uniform: student[i] <- teacher[i * T/S]
- Skip: student[i] <- teacher[2*i] for half-depth
"""
if layer_mapping is None:
# Default: uniform mapping
ratio = len(teacher_hidden) // len(student_hidden)
layer_mapping = {i: i * ratio for i in range(len(student_hidden))}
total_loss = 0
for s_idx, t_idx in layer_mapping.items():
s_h = student_hidden[s_idx]
t_h = teacher_hidden[t_idx]
loss = F.mse_loss(s_h, t_h)
total_loss += loss
return total_loss / len(layer_mapping)
class DistilBERTTrainer:
"""
Complete DistilBERT-style training.
Triple loss:
1. Masked LM loss (task)
2. Cosine embedding loss (hidden states)
3. Distillation loss (soft targets)
"""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
alpha_ce: float = 0.5,
alpha_mlm: float = 0.5,
alpha_cos: float = 0.0,
temperature: float = 2.0
):
self.teacher = teacher.eval()
self.student = student
self.alpha_ce = alpha_ce
self.alpha_mlm = alpha_mlm
self.alpha_cos = alpha_cos
self.temperature = temperature
def compute_loss(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor # MLM labels
) -> torch.Tensor:
"""Compute DistilBERT loss."""
# Teacher forward
with torch.no_grad():
teacher_out = self.teacher(
input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)
# Student forward
student_out = self.student(
input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)
# 1. Soft target distillation
soft_teacher = F.softmax(teacher_out.logits / self.temperature, dim=-1)
soft_student = F.log_softmax(student_out.logits / self.temperature, dim=-1)
ce_loss = F.kl_div(
soft_student, soft_teacher, reduction='batchmean'
) * (self.temperature ** 2)
# 2. MLM loss
mlm_loss = F.cross_entropy(
student_out.logits.view(-1, student_out.logits.size(-1)),
labels.view(-1),
ignore_index=-100
)
# 3. Cosine embedding loss on last hidden
cos_loss = 1 - F.cosine_similarity(
student_out.hidden_states[-1],
teacher_out.hidden_states[-1],
dim=-1
).mean()
total_loss = (
self.alpha_ce * ce_loss +
self.alpha_mlm * mlm_loss +
self.alpha_cos * cos_loss
)
return total_loss
Advanced Distillation Techniques
class ProgressiveDistillation:
"""
Gradually increase difficulty during distillation.
Start with easy examples, progress to harder ones.
"""
def __init__(self, num_stages: int = 5):
self.num_stages = num_stages
self.current_stage = 0
def get_temperature(self) -> float:
"""Temperature decreases over stages."""
max_temp = 10.0
min_temp = 1.0
return max_temp - (max_temp - min_temp) * (self.current_stage / self.num_stages)
def get_alpha(self) -> float:
"""Distillation weight increases over stages."""
return 0.5 + 0.5 * (self.current_stage / self.num_stages)
def advance_stage(self):
"""Move to next stage."""
if self.current_stage < self.num_stages:
self.current_stage += 1
class OnlineDistillation(nn.Module):
"""
Online mutual distillation.
Multiple students teach each other during training.
No pre-trained teacher needed!
Reference: "Deep Mutual Learning" (Zhang et al., 2018)
"""
def __init__(
self,
students: List[nn.Module],
temperature: float = 4.0
):
super().__init__()
self.students = nn.ModuleList(students)
self.temperature = temperature
def forward(
self,
x: torch.Tensor,
labels: torch.Tensor
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Forward pass with mutual learning.
Each student learns from:
1. Ground truth labels
2. Average of other students' predictions
"""
# Get all predictions
all_logits = [student(x) for student in self.students]
losses = []
for i, logits in enumerate(all_logits):
# Task loss
task_loss = F.cross_entropy(logits, labels)
# Distillation from peers
peer_logits = [all_logits[j] for j in range(len(all_logits)) if j != i]
peer_avg = torch.stack(peer_logits).mean(dim=0)
soft_peer = F.softmax(peer_avg / self.temperature, dim=-1)
soft_self = F.log_softmax(logits / self.temperature, dim=-1)
kl_loss = F.kl_div(
soft_self, soft_peer.detach(), reduction='batchmean'
) * (self.temperature ** 2)
loss = task_loss + kl_loss
losses.append(loss)
total_loss = sum(losses)
return total_loss, all_logits
# Best practices summary
def distillation_best_practices():
"""Print distillation best practices."""
tips = """
╔════════════════════════════════════════════════════════════════════╗
║ KNOWLEDGE DISTILLATION BEST PRACTICES ║
╠════════════════════════════════════════════════════════════════════╣
║ ║
║ TEMPERATURE SELECTION ║
║ • T=1: Sharp distributions (less knowledge transfer) ║
║ • T=4-8: Good balance for most tasks ║
║ • T>10: Very soft distributions (may lose discrimination) ║
║ ║
║ LOSS BALANCING (α) ║
║ • α=0.5-0.7: Typical starting point ║
║ • Higher α: Rely more on teacher (good teacher) ║
║ • Lower α: Rely more on labels (noisy teacher) ║
║ ║
║ STUDENT SIZE ║
║ • 1/3 to 1/2 of teacher: Good compression with minimal loss ║
║ • <1/10 of teacher: Significant accuracy drop expected ║
║ ║
║ WHAT TO DISTILL ║
║ • Always: Output logits (soft targets) ║
║ • Often helpful: Intermediate features ║
║ • Sometimes helpful: Attention patterns ║
║ ║
║ TRAINING TIPS ║
║ • Use longer training for student than normal ║
║ • Data augmentation helps student generalize ║
║ • Try unlabeled data with teacher pseudo-labels ║
║ ║
╚════════════════════════════════════════════════════════════════════╝
"""
print(tips)
distillation_best_practices()
Exercises
Exercise 1: Temperature Sweep
Exercise 1: Temperature Sweep
Train students with different temperatures (1, 2, 4, 8, 16).
Plot accuracy vs temperature and find optimal value.
Exercise 2: Feature + Logit Distillation
Exercise 2: Feature + Logit Distillation
Combine feature distillation with soft target distillation.
Compare to using each alone.
Exercise 3: Self-Distillation Experiment
Exercise 3: Self-Distillation Experiment
Implement Born-Again Networks:
- Train for 3 generations
- Compare each generation’s accuracy
- Try different temperatures per generation
What’s Next?
Continual Learning
Learn new tasks without forgetting old ones
Quantization
Model compression with quantization