Skip to main content
Knowledge Distillation

Knowledge Distillation

The Teacher-Student Framework

Train a small “student” model to mimic a large “teacher” model:
AspectTeacherStudent
SizeLarge (billions of params)Small (millions)
AccuracyHighApproaches teacher
SpeedSlowFast
DeploymentExpensiveCheap
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Optional, Tuple, Dict, List, Callable
import numpy as np

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Core Distillation Methods

Soft Target Distillation

class SoftTargetDistillation:
    """
    Original Knowledge Distillation (Hinton et al., 2015).
    
    Key insight: Soft targets from teacher contain more information
    than hard labels. The "dark knowledge" in probability distribution
    reveals class similarities and uncertainty.
    
    Loss = α * KL(student || teacher) + (1-α) * CE(student, labels)
    """
    
    def __init__(
        self,
        temperature: float = 4.0,
        alpha: float = 0.7
    ):
        """
        Args:
            temperature: Higher T -> softer probabilities
            alpha: Weight for distillation loss vs hard label loss
        """
        self.temperature = temperature
        self.alpha = alpha
    
    def distillation_loss(
        self,
        student_logits: torch.Tensor,
        teacher_logits: torch.Tensor,
        labels: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Compute distillation loss.
        
        Args:
            student_logits: [batch, num_classes]
            teacher_logits: [batch, num_classes]
            labels: [batch] optional hard labels
        """
        T = self.temperature
        
        # Soft targets from teacher
        soft_teacher = F.softmax(teacher_logits / T, dim=-1)
        
        # Soft predictions from student
        soft_student = F.log_softmax(student_logits / T, dim=-1)
        
        # KL divergence (scaled by T^2 for gradient magnitude)
        distill_loss = F.kl_div(
            soft_student,
            soft_teacher,
            reduction='batchmean'
        ) * (T ** 2)
        
        if labels is not None:
            # Hard label loss
            hard_loss = F.cross_entropy(student_logits, labels)
            
            return self.alpha * distill_loss + (1 - self.alpha) * hard_loss
        
        return distill_loss


class DistillationTrainer:
    """
    Complete training loop for knowledge distillation.
    """
    
    def __init__(
        self,
        teacher: nn.Module,
        student: nn.Module,
        temperature: float = 4.0,
        alpha: float = 0.7
    ):
        self.teacher = teacher.eval()
        self.student = student
        self.distillation = SoftTargetDistillation(temperature, alpha)
        
        # Freeze teacher
        for param in self.teacher.parameters():
            param.requires_grad = False
    
    def train_step(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor,
        optimizer: optim.Optimizer
    ) -> Dict[str, float]:
        """Single training step."""
        self.student.train()
        
        optimizer.zero_grad()
        
        # Get teacher predictions
        with torch.no_grad():
            teacher_logits = self.teacher(inputs)
        
        # Get student predictions
        student_logits = self.student(inputs)
        
        # Compute loss
        loss = self.distillation.distillation_loss(
            student_logits, teacher_logits, labels
        )
        
        loss.backward()
        optimizer.step()
        
        # Metrics
        with torch.no_grad():
            student_acc = (student_logits.argmax(dim=-1) == labels).float().mean()
            teacher_acc = (teacher_logits.argmax(dim=-1) == labels).float().mean()
        
        return {
            'loss': loss.item(),
            'student_acc': student_acc.item(),
            'teacher_acc': teacher_acc.item()
        }
    
    def train(
        self,
        train_loader,
        optimizer: optim.Optimizer,
        epochs: int = 10
    ):
        """Full training loop."""
        for epoch in range(epochs):
            total_loss = 0
            total_acc = 0
            
            for batch_idx, (inputs, labels) in enumerate(train_loader):
                inputs, labels = inputs.to(device), labels.to(device)
                
                metrics = self.train_step(inputs, labels, optimizer)
                total_loss += metrics['loss']
                total_acc += metrics['student_acc']
            
            avg_loss = total_loss / len(train_loader)
            avg_acc = total_acc / len(train_loader)
            print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Acc={avg_acc:.4f}")


# Example usage
def create_example_models():
    """Create example teacher and student models."""
    
    # Large teacher
    teacher = nn.Sequential(
        nn.Linear(784, 1024),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, 10)
    )
    
    # Small student
    student = nn.Sequential(
        nn.Linear(784, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    )
    
    print(f"Teacher params: {sum(p.numel() for p in teacher.parameters()):,}")
    print(f"Student params: {sum(p.numel() for p in student.parameters()):,}")
    
    return teacher, student

create_example_models()

Feature-Based Distillation

class FeatureDistillation(nn.Module):
    """
    Distill intermediate feature representations.
    
    Methods:
    - FitNets: Match intermediate features directly
    - Attention Transfer: Match attention maps
    - Contrastive: Learn similar representations
    """
    
    def __init__(
        self,
        teacher_channels: List[int],
        student_channels: List[int],
        method: str = 'fitnet'  # 'fitnet', 'attention', 'contrastive'
    ):
        super().__init__()
        
        self.method = method
        
        # Projection layers to match dimensions
        self.projections = nn.ModuleList([
            nn.Conv2d(s_ch, t_ch, 1) if s_ch != t_ch else nn.Identity()
            for s_ch, t_ch in zip(student_channels, teacher_channels)
        ])
    
    def fitnet_loss(
        self,
        student_features: List[torch.Tensor],
        teacher_features: List[torch.Tensor]
    ) -> torch.Tensor:
        """
        FitNets: Train student hidden layers to mimic teacher.
        
        Uses L2 loss between intermediate representations.
        """
        total_loss = 0
        
        for i, (s_feat, t_feat) in enumerate(zip(student_features, teacher_features)):
            # Project student features to teacher dimension
            s_proj = self.projections[i](s_feat)
            
            # L2 loss (normalized)
            loss = F.mse_loss(
                F.normalize(s_proj, dim=1),
                F.normalize(t_feat, dim=1)
            )
            total_loss += loss
        
        return total_loss / len(student_features)
    
    def attention_transfer_loss(
        self,
        student_features: List[torch.Tensor],
        teacher_features: List[torch.Tensor]
    ) -> torch.Tensor:
        """
        Attention Transfer: Match spatial attention maps.
        
        Attention map = sum of squared activations across channels.
        """
        total_loss = 0
        
        for s_feat, t_feat in zip(student_features, teacher_features):
            # Compute attention maps
            s_attn = self._compute_attention_map(s_feat)
            t_attn = self._compute_attention_map(t_feat)
            
            # Match attention patterns
            loss = F.mse_loss(s_attn, t_attn)
            total_loss += loss
        
        return total_loss / len(student_features)
    
    def _compute_attention_map(self, feature: torch.Tensor) -> torch.Tensor:
        """
        Compute spatial attention map.
        
        Sum of squared activations, normalized.
        """
        # feature: [B, C, H, W]
        attn = feature.pow(2).sum(dim=1)  # [B, H, W]
        
        # Normalize
        attn = attn.view(attn.size(0), -1)
        attn = F.normalize(attn, dim=1)
        
        return attn
    
    def contrastive_loss(
        self,
        student_features: torch.Tensor,
        teacher_features: torch.Tensor,
        temperature: float = 0.5
    ) -> torch.Tensor:
        """
        Contrastive distillation: Learn similar representations.
        
        Pull together (student, teacher) pairs from same sample,
        push apart pairs from different samples.
        """
        # Global average pooling if spatial
        if student_features.dim() == 4:
            student_features = student_features.mean(dim=[2, 3])
        if teacher_features.dim() == 4:
            teacher_features = teacher_features.mean(dim=[2, 3])
        
        # Normalize
        s_norm = F.normalize(student_features, dim=1)
        t_norm = F.normalize(teacher_features, dim=1)
        
        # Similarity matrix
        batch_size = s_norm.size(0)
        sim_matrix = torch.matmul(s_norm, t_norm.T) / temperature
        
        # Labels: diagonal elements should be similar
        labels = torch.arange(batch_size).to(sim_matrix.device)
        
        # Cross entropy loss
        loss = F.cross_entropy(sim_matrix, labels)
        
        return loss
    
    def forward(
        self,
        student_features: List[torch.Tensor],
        teacher_features: List[torch.Tensor]
    ) -> torch.Tensor:
        """Compute feature distillation loss."""
        if self.method == 'fitnet':
            return self.fitnet_loss(student_features, teacher_features)
        elif self.method == 'attention':
            return self.attention_transfer_loss(student_features, teacher_features)
        elif self.method == 'contrastive':
            return self.contrastive_loss(student_features[-1], teacher_features[-1])


class FeatureExtractingModel(nn.Module):
    """
    Wrapper to extract intermediate features.
    """
    
    def __init__(self, model: nn.Module, layer_names: List[str]):
        super().__init__()
        self.model = model
        self.layer_names = layer_names
        self.features = {}
        
        # Register hooks
        for name, module in model.named_modules():
            if name in layer_names:
                module.register_forward_hook(self._make_hook(name))
    
    def _make_hook(self, name: str):
        def hook(module, input, output):
            self.features[name] = output
        return hook
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        output = self.model(x)
        return output, self.features.copy()

Relation-Based Distillation

class RelationDistillation:
    """
    Distill relationships between samples, not just outputs.
    
    Key insight: Structure of representation space matters.
    """
    
    @staticmethod
    def distance_wise_loss(
        student_features: torch.Tensor,
        teacher_features: torch.Tensor
    ) -> torch.Tensor:
        """
        Distance-wise distillation.
        
        Preserve pairwise distances between samples.
        """
        # Compute pairwise distances
        s_dist = torch.cdist(student_features, student_features)
        t_dist = torch.cdist(teacher_features, teacher_features)
        
        # Normalize
        s_dist = s_dist / s_dist.max()
        t_dist = t_dist / t_dist.max()
        
        # Huber loss for robustness
        loss = F.smooth_l1_loss(s_dist, t_dist)
        
        return loss
    
    @staticmethod
    def angle_wise_loss(
        student_features: torch.Tensor,
        teacher_features: torch.Tensor
    ) -> torch.Tensor:
        """
        Angle-wise distillation.
        
        Preserve angular relationships (correlations).
        """
        # Compute Gram matrices (correlations)
        s_gram = student_features @ student_features.T
        t_gram = teacher_features @ teacher_features.T
        
        # Normalize
        s_gram = s_gram / (s_gram.norm() + 1e-8)
        t_gram = t_gram / (t_gram.norm() + 1e-8)
        
        loss = F.mse_loss(s_gram, t_gram)
        
        return loss
    
    @staticmethod
    def similarity_preserving_loss(
        student_features: torch.Tensor,
        teacher_features: torch.Tensor
    ) -> torch.Tensor:
        """
        Similarity-Preserving distillation.
        
        Preserve which samples are similar/dissimilar.
        """
        # Cosine similarity matrices
        s_sim = F.cosine_similarity(
            student_features.unsqueeze(1),
            student_features.unsqueeze(0),
            dim=2
        )
        t_sim = F.cosine_similarity(
            teacher_features.unsqueeze(1),
            teacher_features.unsqueeze(0),
            dim=2
        )
        
        loss = F.mse_loss(s_sim, t_sim)
        
        return loss


class CRDLoss(nn.Module):
    """
    Contrastive Representation Distillation.
    
    Use contrastive learning objective to transfer knowledge.
    More effective than simple L2 matching.
    
    Reference: "Contrastive Representation Distillation" (Tian et al., 2020)
    """
    
    def __init__(
        self,
        student_dim: int,
        teacher_dim: int,
        feature_dim: int = 128,
        num_negatives: int = 16384,
        temperature: float = 0.07
    ):
        super().__init__()
        
        self.temperature = temperature
        self.num_negatives = num_negatives
        
        # Projection heads
        self.student_proj = nn.Sequential(
            nn.Linear(student_dim, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, feature_dim)
        )
        
        self.teacher_proj = nn.Sequential(
            nn.Linear(teacher_dim, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, feature_dim)
        )
        
        # Memory bank for negatives
        self.register_buffer(
            'memory_bank',
            torch.randn(num_negatives, feature_dim)
        )
        self.register_buffer('memory_ptr', torch.zeros(1, dtype=torch.long))
    
    def update_memory(self, features: torch.Tensor):
        """Update memory bank with new features."""
        batch_size = features.size(0)
        ptr = int(self.memory_ptr)
        
        # Replace old features
        if ptr + batch_size > self.num_negatives:
            ptr = 0
        
        self.memory_bank[ptr:ptr + batch_size] = features.detach()
        self.memory_ptr[0] = (ptr + batch_size) % self.num_negatives
    
    def forward(
        self,
        student_features: torch.Tensor,
        teacher_features: torch.Tensor
    ) -> torch.Tensor:
        """Compute CRD loss."""
        # Project features
        s_proj = F.normalize(self.student_proj(student_features), dim=1)
        t_proj = F.normalize(self.teacher_proj(teacher_features), dim=1)
        
        batch_size = s_proj.size(0)
        
        # Positive pairs: student-teacher from same sample
        pos_sim = torch.sum(s_proj * t_proj, dim=1, keepdim=True)
        
        # Negative pairs: student against memory bank
        neg_sim = s_proj @ self.memory_bank.T
        
        # InfoNCE loss
        logits = torch.cat([pos_sim, neg_sim], dim=1) / self.temperature
        labels = torch.zeros(batch_size, dtype=torch.long, device=logits.device)
        
        loss = F.cross_entropy(logits, labels)
        
        # Update memory bank
        self.update_memory(t_proj)
        
        return loss

Self-Distillation

class SelfDistillation(nn.Module):
    """
    Distill knowledge within the same network.
    
    Methods:
    - Deep supervision: Intermediate classifiers
    - Born-again networks: Train same architecture iteratively
    - Be Your Own Teacher: Ensemble of augmented views
    """
    
    def __init__(
        self,
        backbone: nn.Module,
        num_classes: int,
        hidden_dims: List[int]
    ):
        super().__init__()
        
        self.backbone = backbone
        
        # Auxiliary classifiers at intermediate layers
        self.aux_classifiers = nn.ModuleList([
            nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(dim, num_classes)
            )
            for dim in hidden_dims
        ])
        
        # Final classifier
        self.classifier = nn.Linear(hidden_dims[-1], num_classes)
    
    def forward(
        self,
        x: torch.Tensor,
        return_features: bool = False
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        Forward with auxiliary predictions.
        """
        features = []
        aux_logits = []
        
        # Get features at each layer (implement based on backbone)
        for layer, aux_clf in zip(self.backbone, self.aux_classifiers):
            x = layer(x)
            features.append(x)
            
            # Auxiliary prediction
            aux_out = aux_clf(x)
            aux_logits.append(aux_out)
        
        # Final prediction
        final_logits = self.classifier(x.flatten(1))
        
        if return_features:
            return final_logits, aux_logits, features
        
        return final_logits, aux_logits
    
    def self_distillation_loss(
        self,
        final_logits: torch.Tensor,
        aux_logits: List[torch.Tensor],
        labels: torch.Tensor,
        temperature: float = 4.0
    ) -> torch.Tensor:
        """
        Compute self-distillation loss.
        
        Deeper layers teach shallower layers.
        """
        # Final layer loss
        final_loss = F.cross_entropy(final_logits, labels)
        
        # Distillation losses
        distill_loss = 0
        soft_final = F.softmax(final_logits / temperature, dim=-1)
        
        for aux_logit in aux_logits:
            # Distill from final to aux
            soft_aux = F.log_softmax(aux_logit / temperature, dim=-1)
            kl_loss = F.kl_div(soft_aux, soft_final.detach(), reduction='batchmean')
            distill_loss += kl_loss * (temperature ** 2)
            
            # Also use hard labels for aux
            distill_loss += 0.5 * F.cross_entropy(aux_logit, labels)
        
        distill_loss /= len(aux_logits)
        
        return final_loss + 0.5 * distill_loss


class BornAgainNetworks:
    """
    Born-Again Networks: Iterative self-distillation.
    
    1. Train model from scratch (generation 1)
    2. Use model as teacher to train same architecture (generation 2)
    3. Repeat...
    
    Often improves over vanilla training!
    """
    
    def __init__(
        self,
        model_fn: Callable[[], nn.Module],
        temperature: float = 4.0
    ):
        self.model_fn = model_fn
        self.temperature = temperature
        self.generations = []
    
    def train_generation(
        self,
        train_loader,
        epochs: int,
        lr: float = 0.001,
        teacher: Optional[nn.Module] = None
    ) -> nn.Module:
        """Train one generation."""
        
        student = self.model_fn().to(device)
        optimizer = optim.Adam(student.parameters(), lr=lr)
        
        distiller = SoftTargetDistillation(self.temperature) if teacher else None
        
        if teacher:
            teacher.eval()
            for param in teacher.parameters():
                param.requires_grad = False
        
        for epoch in range(epochs):
            student.train()
            
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                optimizer.zero_grad()
                
                student_logits = student(inputs)
                
                if teacher:
                    with torch.no_grad():
                        teacher_logits = teacher(inputs)
                    loss = distiller.distillation_loss(
                        student_logits, teacher_logits, labels
                    )
                else:
                    loss = F.cross_entropy(student_logits, labels)
                
                loss.backward()
                optimizer.step()
        
        self.generations.append(student)
        return student
    
    def train(
        self,
        train_loader,
        num_generations: int = 3,
        epochs_per_gen: int = 10
    ) -> nn.Module:
        """Train multiple generations."""
        
        teacher = None
        
        for gen in range(num_generations):
            print(f"Training generation {gen + 1}...")
            teacher = self.train_generation(
                train_loader, epochs_per_gen, teacher=teacher
            )
        
        return teacher

Task-Specific Distillation

class DetectionDistillation:
    """
    Distillation for object detection models.
    
    Challenges:
    - Multi-task (classification + localization)
    - Variable number of objects
    - Feature pyramid networks
    """
    
    @staticmethod
    def classification_distillation(
        student_cls: torch.Tensor,
        teacher_cls: torch.Tensor,
        temperature: float = 2.0
    ) -> torch.Tensor:
        """Distill classification heads."""
        soft_teacher = F.sigmoid(teacher_cls / temperature)
        soft_student = F.sigmoid(student_cls / temperature)
        
        loss = F.binary_cross_entropy(
            soft_student, soft_teacher.detach()
        ) * (temperature ** 2)
        
        return loss
    
    @staticmethod
    def regression_distillation(
        student_box: torch.Tensor,
        teacher_box: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Distill bounding box regression."""
        loss = F.smooth_l1_loss(student_box, teacher_box, reduction='none')
        
        if mask is not None:
            loss = loss * mask.unsqueeze(-1)
        
        return loss.mean()
    
    @staticmethod
    def fpn_feature_distillation(
        student_features: Dict[str, torch.Tensor],
        teacher_features: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        """Distill FPN features at each scale."""
        total_loss = 0
        
        for level in student_features.keys():
            s_feat = student_features[level]
            t_feat = teacher_features[level]
            
            # Adapt student features if needed
            if s_feat.shape != t_feat.shape:
                # Add adaptation layer
                pass
            
            # L2 loss with attention masking
            importance = t_feat.abs().mean(dim=1, keepdim=True)
            importance = importance / importance.max()
            
            loss = F.mse_loss(s_feat * importance, t_feat * importance)
            total_loss += loss
        
        return total_loss / len(student_features)


class NLPDistillation:
    """
    Distillation for language models.
    
    Used in: DistilBERT, TinyBERT, MiniLM
    """
    
    @staticmethod
    def embedding_distillation(
        student_emb: torch.Tensor,
        teacher_emb: torch.Tensor
    ) -> torch.Tensor:
        """Distill word embeddings."""
        return F.mse_loss(student_emb, teacher_emb)
    
    @staticmethod
    def attention_distillation(
        student_attn: torch.Tensor,  # [B, H, L, L]
        teacher_attn: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Distill attention patterns.
        
        Student learns to attend like teacher.
        """
        # KL divergence on attention distributions
        s_attn = student_attn.log_softmax(dim=-1)
        t_attn = teacher_attn.softmax(dim=-1)
        
        loss = F.kl_div(s_attn, t_attn, reduction='none')
        
        if attention_mask is not None:
            # Mask padded positions
            mask = attention_mask.unsqueeze(1).unsqueeze(2)
            loss = loss * mask
        
        return loss.mean()
    
    @staticmethod
    def hidden_state_distillation(
        student_hidden: List[torch.Tensor],
        teacher_hidden: List[torch.Tensor],
        layer_mapping: Optional[Dict[int, int]] = None
    ) -> torch.Tensor:
        """
        Distill hidden states.
        
        Common mappings for unequal depths:
        - Uniform: student[i] <- teacher[i * T/S]
        - Skip: student[i] <- teacher[2*i] for half-depth
        """
        if layer_mapping is None:
            # Default: uniform mapping
            ratio = len(teacher_hidden) // len(student_hidden)
            layer_mapping = {i: i * ratio for i in range(len(student_hidden))}
        
        total_loss = 0
        
        for s_idx, t_idx in layer_mapping.items():
            s_h = student_hidden[s_idx]
            t_h = teacher_hidden[t_idx]
            
            loss = F.mse_loss(s_h, t_h)
            total_loss += loss
        
        return total_loss / len(layer_mapping)


class DistilBERTTrainer:
    """
    Complete DistilBERT-style training.
    
    Triple loss:
    1. Masked LM loss (task)
    2. Cosine embedding loss (hidden states)
    3. Distillation loss (soft targets)
    """
    
    def __init__(
        self,
        teacher: nn.Module,
        student: nn.Module,
        alpha_ce: float = 0.5,
        alpha_mlm: float = 0.5,
        alpha_cos: float = 0.0,
        temperature: float = 2.0
    ):
        self.teacher = teacher.eval()
        self.student = student
        
        self.alpha_ce = alpha_ce
        self.alpha_mlm = alpha_mlm
        self.alpha_cos = alpha_cos
        self.temperature = temperature
    
    def compute_loss(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor  # MLM labels
    ) -> torch.Tensor:
        """Compute DistilBERT loss."""
        
        # Teacher forward
        with torch.no_grad():
            teacher_out = self.teacher(
                input_ids, 
                attention_mask=attention_mask,
                output_hidden_states=True
            )
        
        # Student forward
        student_out = self.student(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        
        # 1. Soft target distillation
        soft_teacher = F.softmax(teacher_out.logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_out.logits / self.temperature, dim=-1)
        
        ce_loss = F.kl_div(
            soft_student, soft_teacher, reduction='batchmean'
        ) * (self.temperature ** 2)
        
        # 2. MLM loss
        mlm_loss = F.cross_entropy(
            student_out.logits.view(-1, student_out.logits.size(-1)),
            labels.view(-1),
            ignore_index=-100
        )
        
        # 3. Cosine embedding loss on last hidden
        cos_loss = 1 - F.cosine_similarity(
            student_out.hidden_states[-1],
            teacher_out.hidden_states[-1],
            dim=-1
        ).mean()
        
        total_loss = (
            self.alpha_ce * ce_loss +
            self.alpha_mlm * mlm_loss +
            self.alpha_cos * cos_loss
        )
        
        return total_loss

Advanced Distillation Techniques

class ProgressiveDistillation:
    """
    Gradually increase difficulty during distillation.
    
    Start with easy examples, progress to harder ones.
    """
    
    def __init__(self, num_stages: int = 5):
        self.num_stages = num_stages
        self.current_stage = 0
    
    def get_temperature(self) -> float:
        """Temperature decreases over stages."""
        max_temp = 10.0
        min_temp = 1.0
        
        return max_temp - (max_temp - min_temp) * (self.current_stage / self.num_stages)
    
    def get_alpha(self) -> float:
        """Distillation weight increases over stages."""
        return 0.5 + 0.5 * (self.current_stage / self.num_stages)
    
    def advance_stage(self):
        """Move to next stage."""
        if self.current_stage < self.num_stages:
            self.current_stage += 1


class OnlineDistillation(nn.Module):
    """
    Online mutual distillation.
    
    Multiple students teach each other during training.
    No pre-trained teacher needed!
    
    Reference: "Deep Mutual Learning" (Zhang et al., 2018)
    """
    
    def __init__(
        self,
        students: List[nn.Module],
        temperature: float = 4.0
    ):
        super().__init__()
        self.students = nn.ModuleList(students)
        self.temperature = temperature
    
    def forward(
        self,
        x: torch.Tensor,
        labels: torch.Tensor
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        Forward pass with mutual learning.
        
        Each student learns from:
        1. Ground truth labels
        2. Average of other students' predictions
        """
        # Get all predictions
        all_logits = [student(x) for student in self.students]
        
        losses = []
        
        for i, logits in enumerate(all_logits):
            # Task loss
            task_loss = F.cross_entropy(logits, labels)
            
            # Distillation from peers
            peer_logits = [all_logits[j] for j in range(len(all_logits)) if j != i]
            peer_avg = torch.stack(peer_logits).mean(dim=0)
            
            soft_peer = F.softmax(peer_avg / self.temperature, dim=-1)
            soft_self = F.log_softmax(logits / self.temperature, dim=-1)
            
            kl_loss = F.kl_div(
                soft_self, soft_peer.detach(), reduction='batchmean'
            ) * (self.temperature ** 2)
            
            loss = task_loss + kl_loss
            losses.append(loss)
        
        total_loss = sum(losses)
        
        return total_loss, all_logits


# Best practices summary
def distillation_best_practices():
    """Print distillation best practices."""
    
    tips = """
    ╔════════════════════════════════════════════════════════════════════╗
    ║             KNOWLEDGE DISTILLATION BEST PRACTICES                   ║
    ╠════════════════════════════════════════════════════════════════════╣
    ║                                                                     ║
    ║  TEMPERATURE SELECTION                                              ║
    ║  • T=1: Sharp distributions (less knowledge transfer)              ║
    ║  • T=4-8: Good balance for most tasks                              ║
    ║  • T>10: Very soft distributions (may lose discrimination)          ║
    ║                                                                     ║
    ║  LOSS BALANCING (α)                                                 ║
    ║  • α=0.5-0.7: Typical starting point                               ║
    ║  • Higher α: Rely more on teacher (good teacher)                   ║
    ║  • Lower α: Rely more on labels (noisy teacher)                    ║
    ║                                                                     ║
    ║  STUDENT SIZE                                                       ║
    ║  • 1/3 to 1/2 of teacher: Good compression with minimal loss       ║
    ║  • <1/10 of teacher: Significant accuracy drop expected            ║
    ║                                                                     ║
    ║  WHAT TO DISTILL                                                    ║
    ║  • Always: Output logits (soft targets)                            ║
    ║  • Often helpful: Intermediate features                            ║
    ║  • Sometimes helpful: Attention patterns                           ║
    ║                                                                     ║
    ║  TRAINING TIPS                                                      ║
    ║  • Use longer training for student than normal                     ║
    ║  • Data augmentation helps student generalize                      ║
    ║  • Try unlabeled data with teacher pseudo-labels                   ║
    ║                                                                     ║
    ╚════════════════════════════════════════════════════════════════════╝
    """
    print(tips)

distillation_best_practices()

Exercises

Train students with different temperatures (1, 2, 4, 8, 16). Plot accuracy vs temperature and find optimal value.
Combine feature distillation with soft target distillation. Compare to using each alone.
Implement Born-Again Networks:
  • Train for 3 generations
  • Compare each generation’s accuracy
  • Try different temperatures per generation

What’s Next?