Skip to main content
Interpretability

Model Interpretability: Opening the Black Box

Why Interpretability Matters

Deep learning models are often “black boxes” - they work, but we don’t know why. This is problematic for:
  • Trust: Should we trust this diagnosis?
  • Debugging: Why did the model fail?
  • Fairness: Is the model using protected attributes?
  • Compliance: Regulations require explanations (GDPR)
  • Science: What has the model learned?
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional, Callable

torch.manual_seed(42)

Gradient-Based Methods

Vanilla Gradients

The simplest approach: compute gradient of output w.r.t. input.
class VanillaGradients:
    """Compute input gradients for interpretability."""
    
    def __init__(self, model: nn.Module):
        self.model = model
        self.model.eval()
    
    def compute_gradients(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None
    ) -> torch.Tensor:
        """
        Compute gradient of target class w.r.t. input.
        
        Args:
            input_tensor: [1, C, H, W] input image
            target_class: Class to explain (default: predicted class)
        
        Returns:
            gradients: [C, H, W] gradient map
        """
        input_tensor = input_tensor.clone().requires_grad_(True)
        
        # Forward pass
        output = self.model(input_tensor)
        
        # Get target class
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        # Backward pass
        self.model.zero_grad()
        output[0, target_class].backward()
        
        # Get gradients
        gradients = input_tensor.grad.data[0]
        
        return gradients
    
    def visualize(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None
    ) -> np.ndarray:
        """Create visualization of gradients."""
        
        gradients = self.compute_gradients(input_tensor, target_class)
        
        # Take absolute value and sum across channels
        saliency = gradients.abs().sum(dim=0).cpu().numpy()
        
        # Normalize
        saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-8)
        
        return saliency


# Example usage
# model = load_pretrained_model()
# explainer = VanillaGradients(model)
# saliency = explainer.visualize(input_image, target_class=281)  # cat

Gradient × Input

Multiply gradients by input for sharper attributions:
class GradientTimesInput(VanillaGradients):
    """Gradient × Input attribution."""
    
    def compute_attribution(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None
    ) -> torch.Tensor:
        
        gradients = self.compute_gradients(input_tensor, target_class)
        
        # Multiply by input
        attribution = gradients * input_tensor[0]
        
        return attribution

Integrated Gradients

Accumulate gradients along path from baseline to input: IGi(x)=(xixi)×01F(x+α(xx))xidα\text{IG}_i(x) = (x_i - x'_i) \times \int_0^1 \frac{\partial F(x' + \alpha(x - x'))}{\partial x_i} d\alpha
class IntegratedGradients:
    """Integrated Gradients attribution method."""
    
    def __init__(self, model: nn.Module):
        self.model = model
        self.model.eval()
    
    def compute_integrated_gradients(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None,
        baseline: Optional[torch.Tensor] = None,
        steps: int = 50
    ) -> torch.Tensor:
        """
        Compute Integrated Gradients.
        
        Args:
            input_tensor: [1, C, H, W] input image
            target_class: Class to explain
            baseline: Reference input (default: zeros)
            steps: Number of integration steps
        
        Returns:
            attributions: [C, H, W] attribution map
        """
        if baseline is None:
            baseline = torch.zeros_like(input_tensor)
        
        # Get target class
        if target_class is None:
            with torch.no_grad():
                output = self.model(input_tensor)
                target_class = output.argmax(dim=1).item()
        
        # Generate interpolated inputs
        scaled_inputs = []
        for i in range(steps + 1):
            alpha = i / steps
            scaled_input = baseline + alpha * (input_tensor - baseline)
            scaled_inputs.append(scaled_input)
        
        scaled_inputs = torch.cat(scaled_inputs, dim=0)
        scaled_inputs.requires_grad_(True)
        
        # Forward pass
        outputs = self.model(scaled_inputs)
        
        # Backward pass
        self.model.zero_grad()
        target_outputs = outputs[:, target_class]
        grads = torch.autograd.grad(
            outputs=target_outputs.sum(),
            inputs=scaled_inputs,
            create_graph=False
        )[0]
        
        # Average gradients
        avg_grads = grads.mean(dim=0)
        
        # Compute integrated gradients
        integrated_grads = (input_tensor[0] - baseline[0]) * avg_grads
        
        return integrated_grads
    
    def visualize(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None,
        steps: int = 50
    ) -> np.ndarray:
        """Create visualization."""
        
        attributions = self.compute_integrated_gradients(
            input_tensor, target_class, steps=steps
        )
        
        # Sum across channels and take absolute value
        attr_map = attributions.abs().sum(dim=0).cpu().numpy()
        
        # Normalize
        attr_map = (attr_map - attr_map.min()) / (attr_map.max() - attr_map.min() + 1e-8)
        
        return attr_map

SmoothGrad

Average gradients over noisy versions of input:
class SmoothGrad(VanillaGradients):
    """SmoothGrad: Average gradients over noisy samples."""
    
    def compute_smooth_gradients(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None,
        n_samples: int = 50,
        noise_std: float = 0.1
    ) -> torch.Tensor:
        """
        Args:
            n_samples: Number of noisy samples
            noise_std: Standard deviation of noise
        """
        all_gradients = []
        
        for _ in range(n_samples):
            # Add noise
            noise = torch.randn_like(input_tensor) * noise_std
            noisy_input = input_tensor + noise
            
            # Compute gradients
            grads = self.compute_gradients(noisy_input, target_class)
            all_gradients.append(grads)
        
        # Average
        smooth_grads = torch.stack(all_gradients).mean(dim=0)
        
        return smooth_grads

Class Activation Mapping (CAM)

Grad-CAM

Gradient-weighted Class Activation Mapping:
class GradCAM:
    """
    Grad-CAM: Visual Explanations from Deep Networks.
    
    Uses gradients flowing into the final convolutional layer to produce
    a coarse localization map highlighting important regions.
    """
    
    def __init__(self, model: nn.Module, target_layer: nn.Module):
        self.model = model
        self.target_layer = target_layer
        self.model.eval()
        
        self.gradients = None
        self.activations = None
        
        # Register hooks
        self._register_hooks()
    
    def _register_hooks(self):
        """Register forward and backward hooks."""
        
        def forward_hook(module, input, output):
            self.activations = output.detach()
        
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()
        
        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_full_backward_hook(backward_hook)
    
    def generate_cam(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None
    ) -> np.ndarray:
        """
        Generate Grad-CAM heatmap.
        
        Args:
            input_tensor: [1, C, H, W] input image
            target_class: Class to visualize
        
        Returns:
            cam: [H, W] heatmap (same size as input)
        """
        # Forward pass
        output = self.model(input_tensor)
        
        # Get target class
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        # Backward pass
        self.model.zero_grad()
        output[0, target_class].backward()
        
        # Get weights: global average pooling of gradients
        weights = self.gradients.mean(dim=(2, 3), keepdim=True)  # [1, C, 1, 1]
        
        # Weighted combination of activation maps
        cam = (weights * self.activations).sum(dim=1, keepdim=True)  # [1, 1, H', W']
        
        # ReLU to keep only positive contributions
        cam = F.relu(cam)
        
        # Upsample to input size
        cam = F.interpolate(
            cam,
            size=input_tensor.shape[2:],
            mode='bilinear',
            align_corners=False
        )
        
        # Normalize
        cam = cam[0, 0].cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        
        return cam
    
    def visualize(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None,
        alpha: float = 0.5
    ) -> np.ndarray:
        """Create overlay visualization."""
        
        cam = self.generate_cam(input_tensor, target_class)
        
        # Get original image
        image = input_tensor[0].permute(1, 2, 0).cpu().numpy()
        image = (image - image.min()) / (image.max() - image.min())
        
        # Apply colormap to CAM
        import matplotlib.cm as cm
        heatmap = cm.jet(cam)[:, :, :3]
        
        # Overlay
        overlay = alpha * heatmap + (1 - alpha) * image
        overlay = np.clip(overlay, 0, 1)
        
        return overlay


# Example usage with ResNet
class GradCAMExample:
    """Example of using Grad-CAM with ResNet."""
    
    @staticmethod
    def get_gradcam_for_resnet(model):
        # For ResNet, use the last conv layer before average pooling
        target_layer = model.layer4[-1].conv2
        return GradCAM(model, target_layer)

Grad-CAM++

Improved version with better localization:
class GradCAMPlusPlus(GradCAM):
    """Grad-CAM++ with improved weighting."""
    
    def generate_cam(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None
    ) -> np.ndarray:
        
        # Forward pass
        output = self.model(input_tensor)
        
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        # Backward pass
        self.model.zero_grad()
        output[0, target_class].backward()
        
        # Grad-CAM++ weighting
        gradients = self.gradients[0]  # [C, H, W]
        activations = self.activations[0]  # [C, H, W]
        
        # Compute alpha (importance weights)
        grad_2 = gradients ** 2
        grad_3 = gradients ** 3
        
        sum_activations = activations.sum(dim=(1, 2), keepdim=True)
        alpha_num = grad_2
        alpha_denom = 2 * grad_2 + sum_activations * grad_3 + 1e-8
        alpha = alpha_num / alpha_denom
        
        # Weighted gradients
        weights = (alpha * F.relu(gradients)).sum(dim=(1, 2))  # [C]
        
        # Weighted combination
        cam = (weights.view(-1, 1, 1) * activations).sum(dim=0)
        cam = F.relu(cam)
        
        # Upsample and normalize
        cam = F.interpolate(
            cam.unsqueeze(0).unsqueeze(0),
            size=input_tensor.shape[2:],
            mode='bilinear',
            align_corners=False
        )[0, 0]
        
        cam = cam.cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        
        return cam

Attention Visualization

Visualizing Transformer Attention

class AttentionVisualizer:
    """Visualize attention patterns in transformers."""
    
    def __init__(self, model: nn.Module):
        self.model = model
        self.attention_maps = {}
        
        # Register hooks to capture attention
        self._register_hooks()
    
    def _register_hooks(self):
        """Hook into attention layers."""
        
        def make_hook(name):
            def hook(module, input, output):
                # Assuming attention returns (output, attention_weights)
                if isinstance(output, tuple) and len(output) == 2:
                    self.attention_maps[name] = output[1].detach()
            return hook
        
        for name, module in self.model.named_modules():
            if 'attention' in name.lower() or 'attn' in name.lower():
                module.register_forward_hook(make_hook(name))
    
    def get_attention_maps(
        self,
        input_tensor: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """Forward pass and collect attention maps."""
        
        self.attention_maps = {}
        
        with torch.no_grad():
            _ = self.model(input_tensor)
        
        return self.attention_maps
    
    def visualize_attention(
        self,
        attention: torch.Tensor,
        tokens: List[str],
        head: int = 0
    ) -> None:
        """Visualize attention as heatmap."""
        
        # attention: [batch, heads, seq_len, seq_len]
        attn = attention[0, head].cpu().numpy()
        
        plt.figure(figsize=(10, 8))
        plt.imshow(attn, cmap='viridis')
        
        plt.xticks(range(len(tokens)), tokens, rotation=45, ha='right')
        plt.yticks(range(len(tokens)), tokens)
        
        plt.colorbar()
        plt.title(f'Attention Head {head}')
        plt.tight_layout()
        plt.show()
    
    @staticmethod
    def aggregate_attention(
        attention_maps: Dict[str, torch.Tensor],
        method: str = 'mean'
    ) -> torch.Tensor:
        """Aggregate attention across layers and heads."""
        
        all_attention = list(attention_maps.values())
        stacked = torch.stack(all_attention, dim=0)  # [layers, batch, heads, seq, seq]
        
        if method == 'mean':
            return stacked.mean(dim=(0, 2))  # [batch, seq, seq]
        elif method == 'max':
            return stacked.max(dim=0)[0].max(dim=1)[0]
        elif method == 'rollout':
            # Attention rollout
            result = torch.eye(stacked.shape[-1]).unsqueeze(0)
            for layer_attn in stacked:
                # Average over heads
                attn = layer_attn.mean(dim=1)
                # Add residual connection
                attn = 0.5 * attn + 0.5 * torch.eye(attn.shape[-1])
                # Matrix multiply
                result = torch.bmm(result, attn)
            return result
        else:
            raise ValueError(f"Unknown method: {method}")


class ViTAttentionRollout:
    """Attention rollout for Vision Transformer."""
    
    def __init__(self, model):
        self.model = model
        self.attentions = []
        
        self._register_hooks()
    
    def _register_hooks(self):
        for block in self.model.blocks:
            block.attn.register_forward_hook(self._save_attention)
    
    def _save_attention(self, module, input, output):
        self.attentions.append(output[1])  # Save attention weights
    
    def compute_rollout(
        self,
        input_tensor: torch.Tensor,
        discard_ratio: float = 0.9
    ) -> np.ndarray:
        """Compute attention rollout visualization."""
        
        self.attentions = []
        
        with torch.no_grad():
            _ = self.model(input_tensor)
        
        # Stack all layer attentions
        attention = torch.stack(self.attentions)  # [layers, batch, heads, tokens, tokens]
        
        # Average over heads
        attention = attention.mean(dim=2)
        
        # Rollout
        result = torch.eye(attention.shape[-1])
        
        for layer_attn in attention[:, 0]:  # First batch
            # Discard lowest attentions
            flat = layer_attn.flatten()
            threshold = flat.quantile(discard_ratio)
            layer_attn = layer_attn * (layer_attn > threshold)
            
            # Renormalize
            layer_attn = layer_attn / layer_attn.sum(dim=-1, keepdim=True)
            
            # Add identity (residual)
            layer_attn = 0.5 * layer_attn + 0.5 * torch.eye(layer_attn.shape[-1])
            
            result = torch.mm(result, layer_attn)
        
        # Get attention from CLS token to patches
        cls_attention = result[0, 1:]  # Exclude CLS token itself
        
        # Reshape to 2D
        num_patches = int(np.sqrt(len(cls_attention)))
        attention_map = cls_attention.reshape(num_patches, num_patches).numpy()
        
        return attention_map

SHAP (SHapley Additive exPlanations)

class DeepSHAP:
    """
    Deep SHAP for neural network explanations.
    Based on Shapley values from game theory.
    """
    
    def __init__(self, model: nn.Module, background_data: torch.Tensor):
        self.model = model
        self.model.eval()
        self.background = background_data
    
    def shap_values(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None,
        n_samples: int = 100
    ) -> torch.Tensor:
        """
        Compute SHAP values using sampling approximation.
        
        Args:
            input_tensor: [1, C, H, W] input to explain
            target_class: Class to explain
            n_samples: Number of coalition samples
        """
        if target_class is None:
            with torch.no_grad():
                output = self.model(input_tensor)
                target_class = output.argmax(dim=1).item()
        
        # Flatten input for feature-level attribution
        input_flat = input_tensor.flatten()
        n_features = input_flat.shape[0]
        
        # Sample coalitions and compute marginal contributions
        shap_values = torch.zeros_like(input_flat)
        
        for _ in range(n_samples):
            # Random permutation
            perm = torch.randperm(n_features)
            
            # Sample background
            bg_idx = np.random.randint(len(self.background))
            background = self.background[bg_idx:bg_idx+1].flatten()
            
            # Compute marginal contributions
            current = background.clone()
            prev_output = self._evaluate(current.reshape(input_tensor.shape), target_class)
            
            for idx in perm:
                current[idx] = input_flat[idx]
                new_output = self._evaluate(current.reshape(input_tensor.shape), target_class)
                
                shap_values[idx] += (new_output - prev_output)
                prev_output = new_output
        
        shap_values /= n_samples
        
        return shap_values.reshape(input_tensor.shape)
    
    def _evaluate(self, x: torch.Tensor, target_class: int) -> float:
        with torch.no_grad():
            output = self.model(x)
            return output[0, target_class].item()


class GradientSHAP:
    """
    GradientSHAP: Faster SHAP approximation using gradients.
    """
    
    def __init__(self, model: nn.Module, background_data: torch.Tensor):
        self.model = model
        self.background = background_data
    
    def shap_values(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None,
        n_samples: int = 50
    ) -> torch.Tensor:
        """Compute GradientSHAP values."""
        
        if target_class is None:
            with torch.no_grad():
                output = self.model(input_tensor)
                target_class = output.argmax(dim=1).item()
        
        all_grads = []
        
        for _ in range(n_samples):
            # Sample background
            bg_idx = np.random.randint(len(self.background))
            baseline = self.background[bg_idx:bg_idx+1]
            
            # Random interpolation point
            alpha = np.random.uniform()
            interpolated = baseline + alpha * (input_tensor - baseline)
            interpolated.requires_grad_(True)
            
            # Compute gradient
            output = self.model(interpolated)
            self.model.zero_grad()
            output[0, target_class].backward()
            
            grad = interpolated.grad.data
            all_grads.append(grad)
        
        # Average gradients
        avg_grad = torch.stack(all_grads).mean(dim=0)
        
        # SHAP values = gradient × (input - expected_baseline)
        expected_baseline = self.background.mean(dim=0)
        shap_values = avg_grad * (input_tensor - expected_baseline)
        
        return shap_values

Feature Visualization

Activation Maximization

class ActivationMaximization:
    """
    Generate images that maximally activate a neuron.
    Useful for understanding what features a neuron detects.
    """
    
    def __init__(self, model: nn.Module):
        self.model = model
        self.model.eval()
    
    def visualize_filter(
        self,
        layer: nn.Module,
        filter_idx: int,
        image_size: int = 224,
        iterations: int = 200,
        lr: float = 0.1
    ) -> torch.Tensor:
        """
        Generate image that maximizes activation of a filter.
        
        Args:
            layer: Target layer
            filter_idx: Which filter to visualize
            image_size: Size of generated image
            iterations: Optimization steps
            lr: Learning rate
        """
        # Start with random noise
        image = torch.randn(1, 3, image_size, image_size, requires_grad=True)
        
        optimizer = torch.optim.Adam([image], lr=lr)
        
        # Hook to capture activation
        activation = None
        
        def hook(module, input, output):
            nonlocal activation
            activation = output
        
        handle = layer.register_forward_hook(hook)
        
        for i in range(iterations):
            optimizer.zero_grad()
            
            # Forward pass
            _ = self.model(image)
            
            # Loss: negative activation (we want to maximize)
            loss = -activation[0, filter_idx].mean()
            
            # Add regularization
            loss += 0.01 * torch.norm(image)  # L2 regularization
            loss += 0.001 * self._total_variation(image)  # Smoothness
            
            loss.backward()
            optimizer.step()
            
            # Clip to valid range
            with torch.no_grad():
                image.clamp_(-2, 2)
        
        handle.remove()
        
        # Normalize for visualization
        image = image.detach()
        image = (image - image.min()) / (image.max() - image.min())
        
        return image
    
    def _total_variation(self, x: torch.Tensor) -> torch.Tensor:
        """Total variation loss for smoothness."""
        diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
        diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
        return diff_h.mean() + diff_w.mean()


class DeepDream:
    """
    DeepDream: Maximize activations to create psychedelic images.
    """
    
    def __init__(self, model: nn.Module, target_layers: List[nn.Module]):
        self.model = model
        self.target_layers = target_layers
        self.activations = {}
    
    def dream(
        self,
        image: torch.Tensor,
        iterations: int = 100,
        lr: float = 0.01,
        octave_scale: float = 1.4,
        n_octaves: int = 4
    ) -> torch.Tensor:
        """
        Apply DeepDream to an image.
        
        Args:
            image: [1, 3, H, W] input image
            iterations: Steps per octave
            lr: Learning rate
            octave_scale: Scale factor between octaves
            n_octaves: Number of octaves (scales)
        """
        original_size = image.shape[2:]
        
        # Process at multiple scales (octaves)
        for octave in range(n_octaves):
            if octave > 0:
                # Upscale
                new_size = [int(s * octave_scale) for s in image.shape[2:]]
                image = F.interpolate(image, size=new_size, mode='bilinear')
            
            image = image.detach().requires_grad_(True)
            
            for _ in range(iterations):
                # Forward pass
                _ = self.model(image)
                
                # Maximize activations
                loss = sum(
                    self.activations[layer].norm()
                    for layer in self.target_layers
                )
                
                loss.backward()
                
                # Gradient ascent
                with torch.no_grad():
                    image += lr * image.grad / (image.grad.norm() + 1e-8)
                    image.grad.zero_()
        
        # Resize back
        image = F.interpolate(image, size=original_size, mode='bilinear')
        
        return image.detach()

Concept-Based Explanations

TCAV (Testing with Concept Activation Vectors)

class TCAV:
    """
    Testing with Concept Activation Vectors.
    Explain models using human-interpretable concepts.
    """
    
    def __init__(self, model: nn.Module, bottleneck_layer: nn.Module):
        self.model = model
        self.bottleneck = bottleneck_layer
        self.activations = None
        
        # Hook to capture activations
        self.bottleneck.register_forward_hook(self._save_activation)
    
    def _save_activation(self, module, input, output):
        self.activations = output.detach()
    
    def get_activations(self, images: torch.Tensor) -> torch.Tensor:
        """Get bottleneck activations for images."""
        with torch.no_grad():
            _ = self.model(images)
        return self.activations
    
    def train_cav(
        self,
        concept_images: torch.Tensor,
        random_images: torch.Tensor
    ) -> torch.Tensor:
        """
        Train Concept Activation Vector.
        
        Args:
            concept_images: Images with the concept
            random_images: Random images without the concept
        
        Returns:
            cav: Concept activation vector
        """
        # Get activations
        concept_acts = self.get_activations(concept_images)
        random_acts = self.get_activations(random_images)
        
        # Flatten activations
        concept_acts = concept_acts.flatten(start_dim=1)
        random_acts = random_acts.flatten(start_dim=1)
        
        # Create labels
        X = torch.cat([concept_acts, random_acts], dim=0)
        y = torch.cat([
            torch.ones(len(concept_acts)),
            torch.zeros(len(random_acts))
        ])
        
        # Train linear classifier
        from sklearn.linear_model import LogisticRegression
        
        clf = LogisticRegression(max_iter=1000)
        clf.fit(X.numpy(), y.numpy())
        
        # CAV is the weight vector
        cav = torch.tensor(clf.coef_[0], dtype=torch.float32)
        cav = cav / cav.norm()
        
        return cav
    
    def compute_tcav_score(
        self,
        test_images: torch.Tensor,
        target_class: int,
        cav: torch.Tensor
    ) -> float:
        """
        Compute TCAV score: fraction of test images where 
        the concept positively influences the prediction.
        """
        positive_count = 0
        
        for img in test_images:
            img = img.unsqueeze(0).requires_grad_(True)
            
            # Forward pass
            output = self.model(img)
            
            # Backward pass
            self.model.zero_grad()
            output[0, target_class].backward()
            
            # Get gradient of activations
            act_grad = self.activations.grad
            
            if act_grad is not None:
                # Directional derivative along CAV
                act_grad_flat = act_grad.flatten()
                directional_deriv = torch.dot(act_grad_flat, cav)
                
                if directional_deriv > 0:
                    positive_count += 1
        
        tcav_score = positive_count / len(test_images)
        return tcav_score

Practical Considerations

def interpretability_best_practices():
    """Best practices for model interpretability."""
    
    tips = """
    ╔════════════════════════════════════════════════════════════════╗
    ║              INTERPRETABILITY BEST PRACTICES                   ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                ║
    ║  1. CHOOSE THE RIGHT METHOD                                    ║
    ║     • Grad-CAM: Quick visual localization                      ║
    ║     • Integrated Gradients: Theoretically grounded             ║
    ║     • SHAP: Feature importance with consistency                ║
    ║     • TCAV: Concept-level understanding                        ║
    ║                                                                ║
    ║  2. VALIDATE EXPLANATIONS                                      ║
    ║     • Sanity checks: Random model should give random explain   ║
    ║     • Faithfulness: Removing important features hurts perf     ║
    ║     • Human evaluation: Do explanations make sense?            ║
    ║                                                                ║
    ║  3. BE AWARE OF LIMITATIONS                                    ║
    ║     • Gradient-based methods can be noisy                      ║
    ║     • Explanations may not reflect actual reasoning            ║
    ║     • Different methods can give different explanations        ║
    ║                                                                ║
    ║  4. COMBINE MULTIPLE METHODS                                   ║
    ║     • Cross-validate with different techniques                 ║
    ║     • Use local (per-sample) and global explanations           ║
    ║     • Consider both input and concept-level analysis           ║
    ║                                                                ║
    ║  5. COMMUNICATE CAREFULLY                                      ║
    ║     • Explanations ≠ justifications                            ║
    ║     • Don't over-interpret visualizations                      ║
    ║     • Consider your audience's expertise                       ║
    ║                                                                ║
    ╚════════════════════════════════════════════════════════════════╝
    """
    print(tips)

interpretability_best_practices()

Exercises

Compare Grad-CAM, Integrated Gradients, and SHAP on the same images:
def compare_attributions(model, image, target_class):
    gradcam = GradCAM(model, model.layer4[-1])
    ig = IntegratedGradients(model)
    shap = GradientSHAP(model, background_data)
    
    # Compare visualizations
    # Compute correlation between methods
Implement sanity checks for explanation methods:
def sanity_check(model, explainer, image):
    # 1. Randomize top layer weights
    # 2. Check if explanations change significantly
    # If they don't change, the method may not be faithful
Create CAVs for your own concepts:
# Collect images with/without your concept
# Train CAV
# Test which classes are sensitive to your concept

What’s Next?