Knowledge Distillation
The Teacher-Student Framework
Train a small “student” model to mimic a large “teacher” model:| Aspect | Teacher | Student |
|---|---|---|
| Size | Large (billions of params) | Small (millions) |
| Accuracy | High | Approaches teacher |
| Speed | Slow | Fast |
| Deployment | Expensive | Cheap |
Copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Optional, Tuple, Dict, List, Callable
import numpy as np
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Core Distillation Methods
Soft Target Distillation
Copy
class SoftTargetDistillation:
"""
Original Knowledge Distillation (Hinton et al., 2015).
Key insight: Soft targets from teacher contain more information
than hard labels. The "dark knowledge" in probability distribution
reveals class similarities and uncertainty.
Loss = α * KL(student || teacher) + (1-α) * CE(student, labels)
"""
def __init__(
self,
temperature: float = 4.0,
alpha: float = 0.7
):
"""
Args:
temperature: Higher T -> softer probabilities
alpha: Weight for distillation loss vs hard label loss
"""
self.temperature = temperature
self.alpha = alpha
def distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute distillation loss.
Args:
student_logits: [batch, num_classes]
teacher_logits: [batch, num_classes]
labels: [batch] optional hard labels
"""
T = self.temperature
# Soft targets from teacher
soft_teacher = F.softmax(teacher_logits / T, dim=-1)
# Soft predictions from student
soft_student = F.log_softmax(student_logits / T, dim=-1)
# KL divergence (scaled by T^2 for gradient magnitude)
distill_loss = F.kl_div(
soft_student,
soft_teacher,
reduction='batchmean'
) * (T ** 2)
if labels is not None:
# Hard label loss
hard_loss = F.cross_entropy(student_logits, labels)
return self.alpha * distill_loss + (1 - self.alpha) * hard_loss
return distill_loss
class DistillationTrainer:
"""
Complete training loop for knowledge distillation.
"""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
temperature: float = 4.0,
alpha: float = 0.7
):
self.teacher = teacher.eval()
self.student = student
self.distillation = SoftTargetDistillation(temperature, alpha)
# Freeze teacher
for param in self.teacher.parameters():
param.requires_grad = False
def train_step(
self,
inputs: torch.Tensor,
labels: torch.Tensor,
optimizer: optim.Optimizer
) -> Dict[str, float]:
"""Single training step."""
self.student.train()
optimizer.zero_grad()
# Get teacher predictions
with torch.no_grad():
teacher_logits = self.teacher(inputs)
# Get student predictions
student_logits = self.student(inputs)
# Compute loss
loss = self.distillation.distillation_loss(
student_logits, teacher_logits, labels
)
loss.backward()
optimizer.step()
# Metrics
with torch.no_grad():
student_acc = (student_logits.argmax(dim=-1) == labels).float().mean()
teacher_acc = (teacher_logits.argmax(dim=-1) == labels).float().mean()
return {
'loss': loss.item(),
'student_acc': student_acc.item(),
'teacher_acc': teacher_acc.item()
}
def train(
self,
train_loader,
optimizer: optim.Optimizer,
epochs: int = 10
):
"""Full training loop."""
for epoch in range(epochs):
total_loss = 0
total_acc = 0
for batch_idx, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.to(device), labels.to(device)
metrics = self.train_step(inputs, labels, optimizer)
total_loss += metrics['loss']
total_acc += metrics['student_acc']
avg_loss = total_loss / len(train_loader)
avg_acc = total_acc / len(train_loader)
print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Acc={avg_acc:.4f}")
# Example usage
def create_example_models():
"""Create example teacher and student models."""
# Large teacher
teacher = nn.Sequential(
nn.Linear(784, 1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
# Small student
student = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
print(f"Teacher params: {sum(p.numel() for p in teacher.parameters()):,}")
print(f"Student params: {sum(p.numel() for p in student.parameters()):,}")
return teacher, student
create_example_models()
Feature-Based Distillation
Copy
class FeatureDistillation(nn.Module):
"""
Distill intermediate feature representations.
Methods:
- FitNets: Match intermediate features directly
- Attention Transfer: Match attention maps
- Contrastive: Learn similar representations
"""
def __init__(
self,
teacher_channels: List[int],
student_channels: List[int],
method: str = 'fitnet' # 'fitnet', 'attention', 'contrastive'
):
super().__init__()
self.method = method
# Projection layers to match dimensions
self.projections = nn.ModuleList([
nn.Conv2d(s_ch, t_ch, 1) if s_ch != t_ch else nn.Identity()
for s_ch, t_ch in zip(student_channels, teacher_channels)
])
def fitnet_loss(
self,
student_features: List[torch.Tensor],
teacher_features: List[torch.Tensor]
) -> torch.Tensor:
"""
FitNets: Train student hidden layers to mimic teacher.
Uses L2 loss between intermediate representations.
"""
total_loss = 0
for i, (s_feat, t_feat) in enumerate(zip(student_features, teacher_features)):
# Project student features to teacher dimension
s_proj = self.projections[i](s_feat)
# L2 loss (normalized)
loss = F.mse_loss(
F.normalize(s_proj, dim=1),
F.normalize(t_feat, dim=1)
)
total_loss += loss
return total_loss / len(student_features)
def attention_transfer_loss(
self,
student_features: List[torch.Tensor],
teacher_features: List[torch.Tensor]
) -> torch.Tensor:
"""
Attention Transfer: Match spatial attention maps.
Attention map = sum of squared activations across channels.
"""
total_loss = 0
for s_feat, t_feat in zip(student_features, teacher_features):
# Compute attention maps
s_attn = self._compute_attention_map(s_feat)
t_attn = self._compute_attention_map(t_feat)
# Match attention patterns
loss = F.mse_loss(s_attn, t_attn)
total_loss += loss
return total_loss / len(student_features)
def _compute_attention_map(self, feature: torch.Tensor) -> torch.Tensor:
"""
Compute spatial attention map.
Sum of squared activations, normalized.
"""
# feature: [B, C, H, W]
attn = feature.pow(2).sum(dim=1) # [B, H, W]
# Normalize
attn = attn.view(attn.size(0), -1)
attn = F.normalize(attn, dim=1)
return attn
def contrastive_loss(
self,
student_features: torch.Tensor,
teacher_features: torch.Tensor,
temperature: float = 0.5
) -> torch.Tensor:
"""
Contrastive distillation: Learn similar representations.
Pull together (student, teacher) pairs from same sample,
push apart pairs from different samples.
"""
# Global average pooling if spatial
if student_features.dim() == 4:
student_features = student_features.mean(dim=[2, 3])
if teacher_features.dim() == 4:
teacher_features = teacher_features.mean(dim=[2, 3])
# Normalize
s_norm = F.normalize(student_features, dim=1)
t_norm = F.normalize(teacher_features, dim=1)
# Similarity matrix
batch_size = s_norm.size(0)
sim_matrix = torch.matmul(s_norm, t_norm.T) / temperature
# Labels: diagonal elements should be similar
labels = torch.arange(batch_size).to(sim_matrix.device)
# Cross entropy loss
loss = F.cross_entropy(sim_matrix, labels)
return loss
def forward(
self,
student_features: List[torch.Tensor],
teacher_features: List[torch.Tensor]
) -> torch.Tensor:
"""Compute feature distillation loss."""
if self.method == 'fitnet':
return self.fitnet_loss(student_features, teacher_features)
elif self.method == 'attention':
return self.attention_transfer_loss(student_features, teacher_features)
elif self.method == 'contrastive':
return self.contrastive_loss(student_features[-1], teacher_features[-1])
class FeatureExtractingModel(nn.Module):
"""
Wrapper to extract intermediate features.
"""
def __init__(self, model: nn.Module, layer_names: List[str]):
super().__init__()
self.model = model
self.layer_names = layer_names
self.features = {}
# Register hooks
for name, module in model.named_modules():
if name in layer_names:
module.register_forward_hook(self._make_hook(name))
def _make_hook(self, name: str):
def hook(module, input, output):
self.features[name] = output
return hook
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
output = self.model(x)
return output, self.features.copy()
Relation-Based Distillation
Copy
class RelationDistillation:
"""
Distill relationships between samples, not just outputs.
Key insight: Structure of representation space matters.
"""
@staticmethod
def distance_wise_loss(
student_features: torch.Tensor,
teacher_features: torch.Tensor
) -> torch.Tensor:
"""
Distance-wise distillation.
Preserve pairwise distances between samples.
"""
# Compute pairwise distances
s_dist = torch.cdist(student_features, student_features)
t_dist = torch.cdist(teacher_features, teacher_features)
# Normalize
s_dist = s_dist / s_dist.max()
t_dist = t_dist / t_dist.max()
# Huber loss for robustness
loss = F.smooth_l1_loss(s_dist, t_dist)
return loss
@staticmethod
def angle_wise_loss(
student_features: torch.Tensor,
teacher_features: torch.Tensor
) -> torch.Tensor:
"""
Angle-wise distillation.
Preserve angular relationships (correlations).
"""
# Compute Gram matrices (correlations)
s_gram = student_features @ student_features.T
t_gram = teacher_features @ teacher_features.T
# Normalize
s_gram = s_gram / (s_gram.norm() + 1e-8)
t_gram = t_gram / (t_gram.norm() + 1e-8)
loss = F.mse_loss(s_gram, t_gram)
return loss
@staticmethod
def similarity_preserving_loss(
student_features: torch.Tensor,
teacher_features: torch.Tensor
) -> torch.Tensor:
"""
Similarity-Preserving distillation.
Preserve which samples are similar/dissimilar.
"""
# Cosine similarity matrices
s_sim = F.cosine_similarity(
student_features.unsqueeze(1),
student_features.unsqueeze(0),
dim=2
)
t_sim = F.cosine_similarity(
teacher_features.unsqueeze(1),
teacher_features.unsqueeze(0),
dim=2
)
loss = F.mse_loss(s_sim, t_sim)
return loss
class CRDLoss(nn.Module):
"""
Contrastive Representation Distillation.
Use contrastive learning objective to transfer knowledge.
More effective than simple L2 matching.
Reference: "Contrastive Representation Distillation" (Tian et al., 2020)
"""
def __init__(
self,
student_dim: int,
teacher_dim: int,
feature_dim: int = 128,
num_negatives: int = 16384,
temperature: float = 0.07
):
super().__init__()
self.temperature = temperature
self.num_negatives = num_negatives
# Projection heads
self.student_proj = nn.Sequential(
nn.Linear(student_dim, feature_dim),
nn.ReLU(),
nn.Linear(feature_dim, feature_dim)
)
self.teacher_proj = nn.Sequential(
nn.Linear(teacher_dim, feature_dim),
nn.ReLU(),
nn.Linear(feature_dim, feature_dim)
)
# Memory bank for negatives
self.register_buffer(
'memory_bank',
torch.randn(num_negatives, feature_dim)
)
self.register_buffer('memory_ptr', torch.zeros(1, dtype=torch.long))
def update_memory(self, features: torch.Tensor):
"""Update memory bank with new features."""
batch_size = features.size(0)
ptr = int(self.memory_ptr)
# Replace old features
if ptr + batch_size > self.num_negatives:
ptr = 0
self.memory_bank[ptr:ptr + batch_size] = features.detach()
self.memory_ptr[0] = (ptr + batch_size) % self.num_negatives
def forward(
self,
student_features: torch.Tensor,
teacher_features: torch.Tensor
) -> torch.Tensor:
"""Compute CRD loss."""
# Project features
s_proj = F.normalize(self.student_proj(student_features), dim=1)
t_proj = F.normalize(self.teacher_proj(teacher_features), dim=1)
batch_size = s_proj.size(0)
# Positive pairs: student-teacher from same sample
pos_sim = torch.sum(s_proj * t_proj, dim=1, keepdim=True)
# Negative pairs: student against memory bank
neg_sim = s_proj @ self.memory_bank.T
# InfoNCE loss
logits = torch.cat([pos_sim, neg_sim], dim=1) / self.temperature
labels = torch.zeros(batch_size, dtype=torch.long, device=logits.device)
loss = F.cross_entropy(logits, labels)
# Update memory bank
self.update_memory(t_proj)
return loss
Self-Distillation
Copy
class SelfDistillation(nn.Module):
"""
Distill knowledge within the same network.
Methods:
- Deep supervision: Intermediate classifiers
- Born-again networks: Train same architecture iteratively
- Be Your Own Teacher: Ensemble of augmented views
"""
def __init__(
self,
backbone: nn.Module,
num_classes: int,
hidden_dims: List[int]
):
super().__init__()
self.backbone = backbone
# Auxiliary classifiers at intermediate layers
self.aux_classifiers = nn.ModuleList([
nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(dim, num_classes)
)
for dim in hidden_dims
])
# Final classifier
self.classifier = nn.Linear(hidden_dims[-1], num_classes)
def forward(
self,
x: torch.Tensor,
return_features: bool = False
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Forward with auxiliary predictions.
"""
features = []
aux_logits = []
# Get features at each layer (implement based on backbone)
for layer, aux_clf in zip(self.backbone, self.aux_classifiers):
x = layer(x)
features.append(x)
# Auxiliary prediction
aux_out = aux_clf(x)
aux_logits.append(aux_out)
# Final prediction
final_logits = self.classifier(x.flatten(1))
if return_features:
return final_logits, aux_logits, features
return final_logits, aux_logits
def self_distillation_loss(
self,
final_logits: torch.Tensor,
aux_logits: List[torch.Tensor],
labels: torch.Tensor,
temperature: float = 4.0
) -> torch.Tensor:
"""
Compute self-distillation loss.
Deeper layers teach shallower layers.
"""
# Final layer loss
final_loss = F.cross_entropy(final_logits, labels)
# Distillation losses
distill_loss = 0
soft_final = F.softmax(final_logits / temperature, dim=-1)
for aux_logit in aux_logits:
# Distill from final to aux
soft_aux = F.log_softmax(aux_logit / temperature, dim=-1)
kl_loss = F.kl_div(soft_aux, soft_final.detach(), reduction='batchmean')
distill_loss += kl_loss * (temperature ** 2)
# Also use hard labels for aux
distill_loss += 0.5 * F.cross_entropy(aux_logit, labels)
distill_loss /= len(aux_logits)
return final_loss + 0.5 * distill_loss
class BornAgainNetworks:
"""
Born-Again Networks: Iterative self-distillation.
1. Train model from scratch (generation 1)
2. Use model as teacher to train same architecture (generation 2)
3. Repeat...
Often improves over vanilla training!
"""
def __init__(
self,
model_fn: Callable[[], nn.Module],
temperature: float = 4.0
):
self.model_fn = model_fn
self.temperature = temperature
self.generations = []
def train_generation(
self,
train_loader,
epochs: int,
lr: float = 0.001,
teacher: Optional[nn.Module] = None
) -> nn.Module:
"""Train one generation."""
student = self.model_fn().to(device)
optimizer = optim.Adam(student.parameters(), lr=lr)
distiller = SoftTargetDistillation(self.temperature) if teacher else None
if teacher:
teacher.eval()
for param in teacher.parameters():
param.requires_grad = False
for epoch in range(epochs):
student.train()
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
student_logits = student(inputs)
if teacher:
with torch.no_grad():
teacher_logits = teacher(inputs)
loss = distiller.distillation_loss(
student_logits, teacher_logits, labels
)
else:
loss = F.cross_entropy(student_logits, labels)
loss.backward()
optimizer.step()
self.generations.append(student)
return student
def train(
self,
train_loader,
num_generations: int = 3,
epochs_per_gen: int = 10
) -> nn.Module:
"""Train multiple generations."""
teacher = None
for gen in range(num_generations):
print(f"Training generation {gen + 1}...")
teacher = self.train_generation(
train_loader, epochs_per_gen, teacher=teacher
)
return teacher
Task-Specific Distillation
Copy
class DetectionDistillation:
"""
Distillation for object detection models.
Challenges:
- Multi-task (classification + localization)
- Variable number of objects
- Feature pyramid networks
"""
@staticmethod
def classification_distillation(
student_cls: torch.Tensor,
teacher_cls: torch.Tensor,
temperature: float = 2.0
) -> torch.Tensor:
"""Distill classification heads."""
soft_teacher = F.sigmoid(teacher_cls / temperature)
soft_student = F.sigmoid(student_cls / temperature)
loss = F.binary_cross_entropy(
soft_student, soft_teacher.detach()
) * (temperature ** 2)
return loss
@staticmethod
def regression_distillation(
student_box: torch.Tensor,
teacher_box: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Distill bounding box regression."""
loss = F.smooth_l1_loss(student_box, teacher_box, reduction='none')
if mask is not None:
loss = loss * mask.unsqueeze(-1)
return loss.mean()
@staticmethod
def fpn_feature_distillation(
student_features: Dict[str, torch.Tensor],
teacher_features: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""Distill FPN features at each scale."""
total_loss = 0
for level in student_features.keys():
s_feat = student_features[level]
t_feat = teacher_features[level]
# Adapt student features if needed
if s_feat.shape != t_feat.shape:
# Add adaptation layer
pass
# L2 loss with attention masking
importance = t_feat.abs().mean(dim=1, keepdim=True)
importance = importance / importance.max()
loss = F.mse_loss(s_feat * importance, t_feat * importance)
total_loss += loss
return total_loss / len(student_features)
class NLPDistillation:
"""
Distillation for language models.
Used in: DistilBERT, TinyBERT, MiniLM
"""
@staticmethod
def embedding_distillation(
student_emb: torch.Tensor,
teacher_emb: torch.Tensor
) -> torch.Tensor:
"""Distill word embeddings."""
return F.mse_loss(student_emb, teacher_emb)
@staticmethod
def attention_distillation(
student_attn: torch.Tensor, # [B, H, L, L]
teacher_attn: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Distill attention patterns.
Student learns to attend like teacher.
"""
# KL divergence on attention distributions
s_attn = student_attn.log_softmax(dim=-1)
t_attn = teacher_attn.softmax(dim=-1)
loss = F.kl_div(s_attn, t_attn, reduction='none')
if attention_mask is not None:
# Mask padded positions
mask = attention_mask.unsqueeze(1).unsqueeze(2)
loss = loss * mask
return loss.mean()
@staticmethod
def hidden_state_distillation(
student_hidden: List[torch.Tensor],
teacher_hidden: List[torch.Tensor],
layer_mapping: Optional[Dict[int, int]] = None
) -> torch.Tensor:
"""
Distill hidden states.
Common mappings for unequal depths:
- Uniform: student[i] <- teacher[i * T/S]
- Skip: student[i] <- teacher[2*i] for half-depth
"""
if layer_mapping is None:
# Default: uniform mapping
ratio = len(teacher_hidden) // len(student_hidden)
layer_mapping = {i: i * ratio for i in range(len(student_hidden))}
total_loss = 0
for s_idx, t_idx in layer_mapping.items():
s_h = student_hidden[s_idx]
t_h = teacher_hidden[t_idx]
loss = F.mse_loss(s_h, t_h)
total_loss += loss
return total_loss / len(layer_mapping)
class DistilBERTTrainer:
"""
Complete DistilBERT-style training.
Triple loss:
1. Masked LM loss (task)
2. Cosine embedding loss (hidden states)
3. Distillation loss (soft targets)
"""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
alpha_ce: float = 0.5,
alpha_mlm: float = 0.5,
alpha_cos: float = 0.0,
temperature: float = 2.0
):
self.teacher = teacher.eval()
self.student = student
self.alpha_ce = alpha_ce
self.alpha_mlm = alpha_mlm
self.alpha_cos = alpha_cos
self.temperature = temperature
def compute_loss(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor # MLM labels
) -> torch.Tensor:
"""Compute DistilBERT loss."""
# Teacher forward
with torch.no_grad():
teacher_out = self.teacher(
input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)
# Student forward
student_out = self.student(
input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)
# 1. Soft target distillation
soft_teacher = F.softmax(teacher_out.logits / self.temperature, dim=-1)
soft_student = F.log_softmax(student_out.logits / self.temperature, dim=-1)
ce_loss = F.kl_div(
soft_student, soft_teacher, reduction='batchmean'
) * (self.temperature ** 2)
# 2. MLM loss
mlm_loss = F.cross_entropy(
student_out.logits.view(-1, student_out.logits.size(-1)),
labels.view(-1),
ignore_index=-100
)
# 3. Cosine embedding loss on last hidden
cos_loss = 1 - F.cosine_similarity(
student_out.hidden_states[-1],
teacher_out.hidden_states[-1],
dim=-1
).mean()
total_loss = (
self.alpha_ce * ce_loss +
self.alpha_mlm * mlm_loss +
self.alpha_cos * cos_loss
)
return total_loss
Advanced Distillation Techniques
Copy
class ProgressiveDistillation:
"""
Gradually increase difficulty during distillation.
Start with easy examples, progress to harder ones.
"""
def __init__(self, num_stages: int = 5):
self.num_stages = num_stages
self.current_stage = 0
def get_temperature(self) -> float:
"""Temperature decreases over stages."""
max_temp = 10.0
min_temp = 1.0
return max_temp - (max_temp - min_temp) * (self.current_stage / self.num_stages)
def get_alpha(self) -> float:
"""Distillation weight increases over stages."""
return 0.5 + 0.5 * (self.current_stage / self.num_stages)
def advance_stage(self):
"""Move to next stage."""
if self.current_stage < self.num_stages:
self.current_stage += 1
class OnlineDistillation(nn.Module):
"""
Online mutual distillation.
Multiple students teach each other during training.
No pre-trained teacher needed!
Reference: "Deep Mutual Learning" (Zhang et al., 2018)
"""
def __init__(
self,
students: List[nn.Module],
temperature: float = 4.0
):
super().__init__()
self.students = nn.ModuleList(students)
self.temperature = temperature
def forward(
self,
x: torch.Tensor,
labels: torch.Tensor
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Forward pass with mutual learning.
Each student learns from:
1. Ground truth labels
2. Average of other students' predictions
"""
# Get all predictions
all_logits = [student(x) for student in self.students]
losses = []
for i, logits in enumerate(all_logits):
# Task loss
task_loss = F.cross_entropy(logits, labels)
# Distillation from peers
peer_logits = [all_logits[j] for j in range(len(all_logits)) if j != i]
peer_avg = torch.stack(peer_logits).mean(dim=0)
soft_peer = F.softmax(peer_avg / self.temperature, dim=-1)
soft_self = F.log_softmax(logits / self.temperature, dim=-1)
kl_loss = F.kl_div(
soft_self, soft_peer.detach(), reduction='batchmean'
) * (self.temperature ** 2)
loss = task_loss + kl_loss
losses.append(loss)
total_loss = sum(losses)
return total_loss, all_logits
# Best practices summary
def distillation_best_practices():
"""Print distillation best practices."""
tips = """
╔════════════════════════════════════════════════════════════════════╗
║ KNOWLEDGE DISTILLATION BEST PRACTICES ║
╠════════════════════════════════════════════════════════════════════╣
║ ║
║ TEMPERATURE SELECTION ║
║ • T=1: Sharp distributions (less knowledge transfer) ║
║ • T=4-8: Good balance for most tasks ║
║ • T>10: Very soft distributions (may lose discrimination) ║
║ ║
║ LOSS BALANCING (α) ║
║ • α=0.5-0.7: Typical starting point ║
║ • Higher α: Rely more on teacher (good teacher) ║
║ • Lower α: Rely more on labels (noisy teacher) ║
║ ║
║ STUDENT SIZE ║
║ • 1/3 to 1/2 of teacher: Good compression with minimal loss ║
║ • <1/10 of teacher: Significant accuracy drop expected ║
║ ║
║ WHAT TO DISTILL ║
║ • Always: Output logits (soft targets) ║
║ • Often helpful: Intermediate features ║
║ • Sometimes helpful: Attention patterns ║
║ ║
║ TRAINING TIPS ║
║ • Use longer training for student than normal ║
║ • Data augmentation helps student generalize ║
║ • Try unlabeled data with teacher pseudo-labels ║
║ ║
╚════════════════════════════════════════════════════════════════════╝
"""
print(tips)
distillation_best_practices()
Exercises
Exercise 1: Temperature Sweep
Exercise 1: Temperature Sweep
Train students with different temperatures (1, 2, 4, 8, 16).
Plot accuracy vs temperature and find optimal value.
Exercise 2: Feature + Logit Distillation
Exercise 2: Feature + Logit Distillation
Combine feature distillation with soft target distillation.
Compare to using each alone.
Exercise 3: Self-Distillation Experiment
Exercise 3: Self-Distillation Experiment
Implement Born-Again Networks:
- Train for 3 generations
- Compare each generation’s accuracy
- Try different temperatures per generation