Skip to main content

Documentation Index

Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt

Use this file to discover all available pages before exploring further.

Interpretability

Model Interpretability: Opening the Black Box

Why Interpretability Matters

Deep learning models are often “black boxes” — they work, but we do not know why. A model that classifies skin lesions with 95% accuracy is useless in a clinic if the dermatologist cannot understand what features it is using. Is it looking at the lesion’s border irregularity (good) or the presence of a ruler in the image (bad — common in training datasets of confirmed melanomas)? This is not a theoretical concern. It is problematic for:
  • Trust: A radiologist will not act on a model’s prediction without understanding its reasoning
  • Debugging: When a model fails on production data, you need to know why to fix it
  • Fairness: You must verify the model is not using protected attributes like race or gender as proxies
  • Compliance: GDPR’s “right to explanation” and the EU AI Act explicitly require algorithmic transparency
  • Science: Understanding what a model has learned can reveal genuine scientific insights about the data
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional, Callable

torch.manual_seed(42)

Gradient-Based Methods

Vanilla Gradients

The simplest approach: compute the gradient of the output class score with respect to the input pixels. The intuition is direct — pixels with large gradients are pixels where a small change would most affect the model’s prediction. If the gradient at pixel (100, 200) is large for the “cat” class, then changing that pixel significantly affects whether the model thinks the image contains a cat.
class VanillaGradients:
    """Compute input gradients for interpretability."""
    
    def __init__(self, model: nn.Module):
        self.model = model
        self.model.eval()
    
    def compute_gradients(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None
    ) -> torch.Tensor:
        """
        Compute gradient of target class w.r.t. input.
        
        Args:
            input_tensor: [1, C, H, W] input image
            target_class: Class to explain (default: predicted class)
        
        Returns:
            gradients: [C, H, W] gradient map
        """
        input_tensor = input_tensor.clone().requires_grad_(True)
        
        # Forward pass
        output = self.model(input_tensor)
        
        # Get target class
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        # Backward pass
        self.model.zero_grad()
        output[0, target_class].backward()
        
        # Get gradients
        gradients = input_tensor.grad.data[0]
        
        return gradients
    
    def visualize(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None
    ) -> np.ndarray:
        """Create visualization of gradients."""
        
        gradients = self.compute_gradients(input_tensor, target_class)
        
        # Take absolute value and sum across channels
        saliency = gradients.abs().sum(dim=0).cpu().numpy()
        
        # Normalize
        saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-8)
        
        return saliency


# Example usage
# model = load_pretrained_model()
# explainer = VanillaGradients(model)
# saliency = explainer.visualize(input_image, target_class=281)  # cat

Gradient × Input

Multiply gradients by input for sharper attributions:
class GradientTimesInput(VanillaGradients):
    """Gradient × Input attribution."""
    
    def compute_attribution(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None
    ) -> torch.Tensor:
        
        gradients = self.compute_gradients(input_tensor, target_class)
        
        # Multiply by input
        attribution = gradients * input_tensor[0]
        
        return attribution

Integrated Gradients

Vanilla gradients have a fundamental flaw: they only capture local sensitivity, not total attribution. A pixel might have a small gradient (locally flat) but still be critically important for the prediction. Integrated Gradients fixes this by accumulating gradients along a path from a baseline (typically a black image) to the actual input. This satisfies important theoretical properties like “completeness” — the attributions sum to the difference between the model’s output at the input and at the baseline. Accumulate gradients along the path from baseline to input: IGi(x)=(xixi)×01F(x+α(xx))xidα\text{IG}_i(x) = (x_i - x'_i) \times \int_0^1 \frac{\partial F(x' + \alpha(x - x'))}{\partial x_i} d\alpha
class IntegratedGradients:
    """Integrated Gradients attribution method."""
    
    def __init__(self, model: nn.Module):
        self.model = model
        self.model.eval()
    
    def compute_integrated_gradients(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None,
        baseline: Optional[torch.Tensor] = None,
        steps: int = 50
    ) -> torch.Tensor:
        """
        Compute Integrated Gradients.
        
        Args:
            input_tensor: [1, C, H, W] input image
            target_class: Class to explain
            baseline: Reference input (default: zeros)
            steps: Number of integration steps
        
        Returns:
            attributions: [C, H, W] attribution map
        """
        if baseline is None:
            baseline = torch.zeros_like(input_tensor)
        
        # Get target class
        if target_class is None:
            with torch.no_grad():
                output = self.model(input_tensor)
                target_class = output.argmax(dim=1).item()
        
        # Generate interpolated inputs
        scaled_inputs = []
        for i in range(steps + 1):
            alpha = i / steps
            scaled_input = baseline + alpha * (input_tensor - baseline)
            scaled_inputs.append(scaled_input)
        
        scaled_inputs = torch.cat(scaled_inputs, dim=0)
        scaled_inputs.requires_grad_(True)
        
        # Forward pass
        outputs = self.model(scaled_inputs)
        
        # Backward pass
        self.model.zero_grad()
        target_outputs = outputs[:, target_class]
        grads = torch.autograd.grad(
            outputs=target_outputs.sum(),
            inputs=scaled_inputs,
            create_graph=False
        )[0]
        
        # Average gradients
        avg_grads = grads.mean(dim=0)
        
        # Compute integrated gradients
        integrated_grads = (input_tensor[0] - baseline[0]) * avg_grads
        
        return integrated_grads
    
    def visualize(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None,
        steps: int = 50
    ) -> np.ndarray:
        """Create visualization."""
        
        attributions = self.compute_integrated_gradients(
            input_tensor, target_class, steps=steps
        )
        
        # Sum across channels and take absolute value
        attr_map = attributions.abs().sum(dim=0).cpu().numpy()
        
        # Normalize
        attr_map = (attr_map - attr_map.min()) / (attr_map.max() - attr_map.min() + 1e-8)
        
        return attr_map

SmoothGrad

Average gradients over noisy versions of input:
class SmoothGrad(VanillaGradients):
    """SmoothGrad: Average gradients over noisy samples."""
    
    def compute_smooth_gradients(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None,
        n_samples: int = 50,
        noise_std: float = 0.1
    ) -> torch.Tensor:
        """
        Args:
            n_samples: Number of noisy samples
            noise_std: Standard deviation of noise
        """
        all_gradients = []
        
        for _ in range(n_samples):
            # Add noise
            noise = torch.randn_like(input_tensor) * noise_std
            noisy_input = input_tensor + noise
            
            # Compute gradients
            grads = self.compute_gradients(noisy_input, target_class)
            all_gradients.append(grads)
        
        # Average
        smooth_grads = torch.stack(all_gradients).mean(dim=0)
        
        return smooth_grads

Class Activation Mapping (CAM)

Grad-CAM

Grad-CAM is the most widely-used interpretability method in practice, and for good reason: it produces intuitive, coarse-grained heatmaps that show which regions of the image were most important for a particular class prediction. Unlike gradient-based methods that operate at the pixel level (producing noisy, hard-to-interpret saliency maps), Grad-CAM works at the feature map level, producing smooth heatmaps that align with human-understandable regions. Gradient-weighted Class Activation Mapping:
class GradCAM:
    """
    Grad-CAM: Visual Explanations from Deep Networks.
    
    Uses gradients flowing into the final convolutional layer to produce
    a coarse localization map highlighting important regions.
    """
    
    def __init__(self, model: nn.Module, target_layer: nn.Module):
        self.model = model
        self.target_layer = target_layer
        self.model.eval()
        
        self.gradients = None
        self.activations = None
        
        # Register hooks
        self._register_hooks()
    
    def _register_hooks(self):
        """Register forward and backward hooks."""
        
        def forward_hook(module, input, output):
            self.activations = output.detach()
        
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()
        
        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_full_backward_hook(backward_hook)
    
    def generate_cam(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None
    ) -> np.ndarray:
        """
        Generate Grad-CAM heatmap.
        
        Args:
            input_tensor: [1, C, H, W] input image
            target_class: Class to visualize
        
        Returns:
            cam: [H, W] heatmap (same size as input)
        """
        # Forward pass
        output = self.model(input_tensor)
        
        # Get target class
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        # Backward pass
        self.model.zero_grad()
        output[0, target_class].backward()
        
        # Get weights: global average pooling of gradients
        weights = self.gradients.mean(dim=(2, 3), keepdim=True)  # [1, C, 1, 1]
        
        # Weighted combination of activation maps
        cam = (weights * self.activations).sum(dim=1, keepdim=True)  # [1, 1, H', W']
        
        # ReLU to keep only positive contributions
        cam = F.relu(cam)
        
        # Upsample to input size
        cam = F.interpolate(
            cam,
            size=input_tensor.shape[2:],
            mode='bilinear',
            align_corners=False
        )
        
        # Normalize
        cam = cam[0, 0].cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        
        return cam
    
    def visualize(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None,
        alpha: float = 0.5
    ) -> np.ndarray:
        """Create overlay visualization."""
        
        cam = self.generate_cam(input_tensor, target_class)
        
        # Get original image
        image = input_tensor[0].permute(1, 2, 0).cpu().numpy()
        image = (image - image.min()) / (image.max() - image.min())
        
        # Apply colormap to CAM
        import matplotlib.cm as cm
        heatmap = cm.jet(cam)[:, :, :3]
        
        # Overlay
        overlay = alpha * heatmap + (1 - alpha) * image
        overlay = np.clip(overlay, 0, 1)
        
        return overlay


# Example usage with ResNet
class GradCAMExample:
    """Example of using Grad-CAM with ResNet."""
    
    @staticmethod
    def get_gradcam_for_resnet(model):
        # For ResNet, use the last conv layer before average pooling
        target_layer = model.layer4[-1].conv2
        return GradCAM(model, target_layer)

Grad-CAM++

Improved version with better localization:
class GradCAMPlusPlus(GradCAM):
    """Grad-CAM++ with improved weighting."""
    
    def generate_cam(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None
    ) -> np.ndarray:
        
        # Forward pass
        output = self.model(input_tensor)
        
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        # Backward pass
        self.model.zero_grad()
        output[0, target_class].backward()
        
        # Grad-CAM++ weighting
        gradients = self.gradients[0]  # [C, H, W]
        activations = self.activations[0]  # [C, H, W]
        
        # Compute alpha (importance weights)
        grad_2 = gradients ** 2
        grad_3 = gradients ** 3
        
        sum_activations = activations.sum(dim=(1, 2), keepdim=True)
        alpha_num = grad_2
        alpha_denom = 2 * grad_2 + sum_activations * grad_3 + 1e-8
        alpha = alpha_num / alpha_denom
        
        # Weighted gradients
        weights = (alpha * F.relu(gradients)).sum(dim=(1, 2))  # [C]
        
        # Weighted combination
        cam = (weights.view(-1, 1, 1) * activations).sum(dim=0)
        cam = F.relu(cam)
        
        # Upsample and normalize
        cam = F.interpolate(
            cam.unsqueeze(0).unsqueeze(0),
            size=input_tensor.shape[2:],
            mode='bilinear',
            align_corners=False
        )[0, 0]
        
        cam = cam.cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        
        return cam

Attention Visualization

Visualizing Transformer Attention

class AttentionVisualizer:
    """Visualize attention patterns in transformers."""
    
    def __init__(self, model: nn.Module):
        self.model = model
        self.attention_maps = {}
        
        # Register hooks to capture attention
        self._register_hooks()
    
    def _register_hooks(self):
        """Hook into attention layers."""
        
        def make_hook(name):
            def hook(module, input, output):
                # Assuming attention returns (output, attention_weights)
                if isinstance(output, tuple) and len(output) == 2:
                    self.attention_maps[name] = output[1].detach()
            return hook
        
        for name, module in self.model.named_modules():
            if 'attention' in name.lower() or 'attn' in name.lower():
                module.register_forward_hook(make_hook(name))
    
    def get_attention_maps(
        self,
        input_tensor: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """Forward pass and collect attention maps."""
        
        self.attention_maps = {}
        
        with torch.no_grad():
            _ = self.model(input_tensor)
        
        return self.attention_maps
    
    def visualize_attention(
        self,
        attention: torch.Tensor,
        tokens: List[str],
        head: int = 0
    ) -> None:
        """Visualize attention as heatmap."""
        
        # attention: [batch, heads, seq_len, seq_len]
        attn = attention[0, head].cpu().numpy()
        
        plt.figure(figsize=(10, 8))
        plt.imshow(attn, cmap='viridis')
        
        plt.xticks(range(len(tokens)), tokens, rotation=45, ha='right')
        plt.yticks(range(len(tokens)), tokens)
        
        plt.colorbar()
        plt.title(f'Attention Head {head}')
        plt.tight_layout()
        plt.show()
    
    @staticmethod
    def aggregate_attention(
        attention_maps: Dict[str, torch.Tensor],
        method: str = 'mean'
    ) -> torch.Tensor:
        """Aggregate attention across layers and heads."""
        
        all_attention = list(attention_maps.values())
        stacked = torch.stack(all_attention, dim=0)  # [layers, batch, heads, seq, seq]
        
        if method == 'mean':
            return stacked.mean(dim=(0, 2))  # [batch, seq, seq]
        elif method == 'max':
            return stacked.max(dim=0)[0].max(dim=1)[0]
        elif method == 'rollout':
            # Attention rollout
            result = torch.eye(stacked.shape[-1]).unsqueeze(0)
            for layer_attn in stacked:
                # Average over heads
                attn = layer_attn.mean(dim=1)
                # Add residual connection
                attn = 0.5 * attn + 0.5 * torch.eye(attn.shape[-1])
                # Matrix multiply
                result = torch.bmm(result, attn)
            return result
        else:
            raise ValueError(f"Unknown method: {method}")


class ViTAttentionRollout:
    """Attention rollout for Vision Transformer."""
    
    def __init__(self, model):
        self.model = model
        self.attentions = []
        
        self._register_hooks()
    
    def _register_hooks(self):
        for block in self.model.blocks:
            block.attn.register_forward_hook(self._save_attention)
    
    def _save_attention(self, module, input, output):
        self.attentions.append(output[1])  # Save attention weights
    
    def compute_rollout(
        self,
        input_tensor: torch.Tensor,
        discard_ratio: float = 0.9
    ) -> np.ndarray:
        """Compute attention rollout visualization."""
        
        self.attentions = []
        
        with torch.no_grad():
            _ = self.model(input_tensor)
        
        # Stack all layer attentions
        attention = torch.stack(self.attentions)  # [layers, batch, heads, tokens, tokens]
        
        # Average over heads
        attention = attention.mean(dim=2)
        
        # Rollout
        result = torch.eye(attention.shape[-1])
        
        for layer_attn in attention[:, 0]:  # First batch
            # Discard lowest attentions
            flat = layer_attn.flatten()
            threshold = flat.quantile(discard_ratio)
            layer_attn = layer_attn * (layer_attn > threshold)
            
            # Renormalize
            layer_attn = layer_attn / layer_attn.sum(dim=-1, keepdim=True)
            
            # Add identity (residual)
            layer_attn = 0.5 * layer_attn + 0.5 * torch.eye(layer_attn.shape[-1])
            
            result = torch.mm(result, layer_attn)
        
        # Get attention from CLS token to patches
        cls_attention = result[0, 1:]  # Exclude CLS token itself
        
        # Reshape to 2D
        num_patches = int(np.sqrt(len(cls_attention)))
        attention_map = cls_attention.reshape(num_patches, num_patches).numpy()
        
        return attention_map

SHAP (SHapley Additive exPlanations)

SHAP brings rigorous game theory to model interpretability. The idea comes from cooperative game theory: if a group of players (features) collectively achieve some payoff (the model’s prediction), how do you fairly distribute credit? Shapley values provide the unique attribution method that satisfies several desirable fairness axioms: symmetry (equal features get equal credit), efficiency (attributions sum to the total prediction), and monotonicity (a feature that always helps never gets negative credit). The catch: computing exact Shapley values requires evaluating the model on all possible subsets of features, which is exponential. DeepSHAP and GradientSHAP provide efficient approximations that leverage the network’s structure.
class DeepSHAP:
    """
    Deep SHAP for neural network explanations.
    Based on Shapley values from game theory.
    Provides theoretically grounded feature attributions.
    """
    
    def __init__(self, model: nn.Module, background_data: torch.Tensor):
        self.model = model
        self.model.eval()
        self.background = background_data
    
    def shap_values(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None,
        n_samples: int = 100
    ) -> torch.Tensor:
        """
        Compute SHAP values using sampling approximation.
        
        Args:
            input_tensor: [1, C, H, W] input to explain
            target_class: Class to explain
            n_samples: Number of coalition samples
        """
        if target_class is None:
            with torch.no_grad():
                output = self.model(input_tensor)
                target_class = output.argmax(dim=1).item()
        
        # Flatten input for feature-level attribution
        input_flat = input_tensor.flatten()
        n_features = input_flat.shape[0]
        
        # Sample coalitions and compute marginal contributions
        shap_values = torch.zeros_like(input_flat)
        
        for _ in range(n_samples):
            # Random permutation
            perm = torch.randperm(n_features)
            
            # Sample background
            bg_idx = np.random.randint(len(self.background))
            background = self.background[bg_idx:bg_idx+1].flatten()
            
            # Compute marginal contributions
            current = background.clone()
            prev_output = self._evaluate(current.reshape(input_tensor.shape), target_class)
            
            for idx in perm:
                current[idx] = input_flat[idx]
                new_output = self._evaluate(current.reshape(input_tensor.shape), target_class)
                
                shap_values[idx] += (new_output - prev_output)
                prev_output = new_output
        
        shap_values /= n_samples
        
        return shap_values.reshape(input_tensor.shape)
    
    def _evaluate(self, x: torch.Tensor, target_class: int) -> float:
        with torch.no_grad():
            output = self.model(x)
            return output[0, target_class].item()


class GradientSHAP:
    """
    GradientSHAP: Faster SHAP approximation using gradients.
    """
    
    def __init__(self, model: nn.Module, background_data: torch.Tensor):
        self.model = model
        self.background = background_data
    
    def shap_values(
        self,
        input_tensor: torch.Tensor,
        target_class: Optional[int] = None,
        n_samples: int = 50
    ) -> torch.Tensor:
        """Compute GradientSHAP values."""
        
        if target_class is None:
            with torch.no_grad():
                output = self.model(input_tensor)
                target_class = output.argmax(dim=1).item()
        
        all_grads = []
        
        for _ in range(n_samples):
            # Sample background
            bg_idx = np.random.randint(len(self.background))
            baseline = self.background[bg_idx:bg_idx+1]
            
            # Random interpolation point
            alpha = np.random.uniform()
            interpolated = baseline + alpha * (input_tensor - baseline)
            interpolated.requires_grad_(True)
            
            # Compute gradient
            output = self.model(interpolated)
            self.model.zero_grad()
            output[0, target_class].backward()
            
            grad = interpolated.grad.data
            all_grads.append(grad)
        
        # Average gradients
        avg_grad = torch.stack(all_grads).mean(dim=0)
        
        # SHAP values = gradient × (input - expected_baseline)
        expected_baseline = self.background.mean(dim=0)
        shap_values = avg_grad * (input_tensor - expected_baseline)
        
        return shap_values

Feature Visualization

Activation Maximization

While attribution methods answer “which inputs matter?”, feature visualization answers “what does this neuron detect?” By starting with random noise and iteratively modifying the image to maximize a specific neuron’s activation, we can literally see the platonic ideal of what that neuron is looking for. Early layers reveal edge and texture detectors; deeper layers reveal complex pattern detectors like dog faces, wheels, or building facades. This technique, popularized by Google’s Distill publication, is both scientifically illuminating and occasionally surreal.
class ActivationMaximization:
    """
    Generate images that maximally activate a neuron.
    Useful for understanding what features a neuron detects.
    
    Regularization is critical: without it, the optimized image
    will be adversarial noise that maximizes activation but is
    visually meaningless to humans.
    """
    
    def __init__(self, model: nn.Module):
        self.model = model
        self.model.eval()
    
    def visualize_filter(
        self,
        layer: nn.Module,
        filter_idx: int,
        image_size: int = 224,
        iterations: int = 200,
        lr: float = 0.1
    ) -> torch.Tensor:
        """
        Generate image that maximizes activation of a filter.
        
        Args:
            layer: Target layer
            filter_idx: Which filter to visualize
            image_size: Size of generated image
            iterations: Optimization steps
            lr: Learning rate
        """
        # Start with random noise
        image = torch.randn(1, 3, image_size, image_size, requires_grad=True)
        
        optimizer = torch.optim.Adam([image], lr=lr)
        
        # Hook to capture activation
        activation = None
        
        def hook(module, input, output):
            nonlocal activation
            activation = output
        
        handle = layer.register_forward_hook(hook)
        
        for i in range(iterations):
            optimizer.zero_grad()
            
            # Forward pass
            _ = self.model(image)
            
            # Loss: negative activation (we want to maximize)
            loss = -activation[0, filter_idx].mean()
            
            # Add regularization
            loss += 0.01 * torch.norm(image)  # L2 regularization
            loss += 0.001 * self._total_variation(image)  # Smoothness
            
            loss.backward()
            optimizer.step()
            
            # Clip to valid range
            with torch.no_grad():
                image.clamp_(-2, 2)
        
        handle.remove()
        
        # Normalize for visualization
        image = image.detach()
        image = (image - image.min()) / (image.max() - image.min())
        
        return image
    
    def _total_variation(self, x: torch.Tensor) -> torch.Tensor:
        """Total variation loss for smoothness."""
        diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
        diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
        return diff_h.mean() + diff_w.mean()


class DeepDream:
    """
    DeepDream: Maximize activations to create psychedelic images.
    """
    
    def __init__(self, model: nn.Module, target_layers: List[nn.Module]):
        self.model = model
        self.target_layers = target_layers
        self.activations = {}
    
    def dream(
        self,
        image: torch.Tensor,
        iterations: int = 100,
        lr: float = 0.01,
        octave_scale: float = 1.4,
        n_octaves: int = 4
    ) -> torch.Tensor:
        """
        Apply DeepDream to an image.
        
        Args:
            image: [1, 3, H, W] input image
            iterations: Steps per octave
            lr: Learning rate
            octave_scale: Scale factor between octaves
            n_octaves: Number of octaves (scales)
        """
        original_size = image.shape[2:]
        
        # Process at multiple scales (octaves)
        for octave in range(n_octaves):
            if octave > 0:
                # Upscale
                new_size = [int(s * octave_scale) for s in image.shape[2:]]
                image = F.interpolate(image, size=new_size, mode='bilinear')
            
            image = image.detach().requires_grad_(True)
            
            for _ in range(iterations):
                # Forward pass
                _ = self.model(image)
                
                # Maximize activations
                loss = sum(
                    self.activations[layer].norm()
                    for layer in self.target_layers
                )
                
                loss.backward()
                
                # Gradient ascent
                with torch.no_grad():
                    image += lr * image.grad / (image.grad.norm() + 1e-8)
                    image.grad.zero_()
        
        # Resize back
        image = F.interpolate(image, size=original_size, mode='bilinear')
        
        return image.detach()

Concept-Based Explanations

TCAV (Testing with Concept Activation Vectors)

TCAV moves beyond pixel-level explanations to concept-level ones, which is what humans actually care about. Instead of asking “which pixels matter?”, TCAV asks “does the concept of stripes influence the model’s prediction of zebra?” You define a concept by providing example images (images with stripes vs. random images), TCAV learns a direction in the model’s activation space that corresponds to that concept, and then measures how much pushing the representation along that direction affects the prediction. This approach speaks the language of domain experts — a doctor can ask “does the presence of calcification influence the model’s diagnosis?” without needing to understand gradients or attention maps.
class TCAV:
    """
    Testing with Concept Activation Vectors.
    Explain models using human-interpretable concepts.
    
    The key advantage over saliency methods: TCAV explanations
    are expressed in terms that domain experts already understand.
    """
    
    def __init__(self, model: nn.Module, bottleneck_layer: nn.Module):
        self.model = model
        self.bottleneck = bottleneck_layer
        self.activations = None
        
        # Hook to capture activations
        self.bottleneck.register_forward_hook(self._save_activation)
    
    def _save_activation(self, module, input, output):
        self.activations = output.detach()
    
    def get_activations(self, images: torch.Tensor) -> torch.Tensor:
        """Get bottleneck activations for images."""
        with torch.no_grad():
            _ = self.model(images)
        return self.activations
    
    def train_cav(
        self,
        concept_images: torch.Tensor,
        random_images: torch.Tensor
    ) -> torch.Tensor:
        """
        Train Concept Activation Vector.
        
        Args:
            concept_images: Images with the concept
            random_images: Random images without the concept
        
        Returns:
            cav: Concept activation vector
        """
        # Get activations
        concept_acts = self.get_activations(concept_images)
        random_acts = self.get_activations(random_images)
        
        # Flatten activations
        concept_acts = concept_acts.flatten(start_dim=1)
        random_acts = random_acts.flatten(start_dim=1)
        
        # Create labels
        X = torch.cat([concept_acts, random_acts], dim=0)
        y = torch.cat([
            torch.ones(len(concept_acts)),
            torch.zeros(len(random_acts))
        ])
        
        # Train linear classifier
        from sklearn.linear_model import LogisticRegression
        
        clf = LogisticRegression(max_iter=1000)
        clf.fit(X.numpy(), y.numpy())
        
        # CAV is the weight vector
        cav = torch.tensor(clf.coef_[0], dtype=torch.float32)
        cav = cav / cav.norm()
        
        return cav
    
    def compute_tcav_score(
        self,
        test_images: torch.Tensor,
        target_class: int,
        cav: torch.Tensor
    ) -> float:
        """
        Compute TCAV score: fraction of test images where 
        the concept positively influences the prediction.
        """
        positive_count = 0
        
        for img in test_images:
            img = img.unsqueeze(0).requires_grad_(True)
            
            # Forward pass
            output = self.model(img)
            
            # Backward pass
            self.model.zero_grad()
            output[0, target_class].backward()
            
            # Get gradient of activations
            act_grad = self.activations.grad
            
            if act_grad is not None:
                # Directional derivative along CAV
                act_grad_flat = act_grad.flatten()
                directional_deriv = torch.dot(act_grad_flat, cav)
                
                if directional_deriv > 0:
                    positive_count += 1
        
        tcav_score = positive_count / len(test_images)
        return tcav_score

Practical Considerations

def interpretability_best_practices():
    """Best practices for model interpretability."""
    
    tips = """
    ╔════════════════════════════════════════════════════════════════╗
    ║              INTERPRETABILITY BEST PRACTICES                   ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                ║
    ║  1. CHOOSE THE RIGHT METHOD                                    ║
    ║     • Grad-CAM: Quick visual localization                      ║
    ║     • Integrated Gradients: Theoretically grounded             ║
    ║     • SHAP: Feature importance with consistency                ║
    ║     • TCAV: Concept-level understanding                        ║
    ║                                                                ║
    ║  2. VALIDATE EXPLANATIONS                                      ║
    ║     • Sanity checks: Random model should give random explain   ║
    ║     • Faithfulness: Removing important features hurts perf     ║
    ║     • Human evaluation: Do explanations make sense?            ║
    ║                                                                ║
    ║  3. BE AWARE OF LIMITATIONS                                    ║
    ║     • Gradient-based methods can be noisy                      ║
    ║     • Explanations may not reflect actual reasoning            ║
    ║     • Different methods can give different explanations        ║
    ║                                                                ║
    ║  4. COMBINE MULTIPLE METHODS                                   ║
    ║     • Cross-validate with different techniques                 ║
    ║     • Use local (per-sample) and global explanations           ║
    ║     • Consider both input and concept-level analysis           ║
    ║                                                                ║
    ║  5. COMMUNICATE CAREFULLY                                      ║
    ║     • Explanations ≠ justifications                            ║
    ║     • Don't over-interpret visualizations                      ║
    ║     • Consider your audience's expertise                       ║
    ║                                                                ║
    ╚════════════════════════════════════════════════════════════════╝
    """
    print(tips)

interpretability_best_practices()

Exercises

Compare Grad-CAM, Integrated Gradients, and SHAP on the same images:
def compare_attributions(model, image, target_class):
    gradcam = GradCAM(model, model.layer4[-1])
    ig = IntegratedGradients(model)
    shap = GradientSHAP(model, background_data)
    
    # Compare visualizations
    # Compute correlation between methods
Implement sanity checks for explanation methods:
def sanity_check(model, explainer, image):
    # 1. Randomize top layer weights
    # 2. Check if explanations change significantly
    # If they don't change, the method may not be faithful
Create CAVs for your own concepts:
# Collect images with/without your concept
# Train CAV
# Test which classes are sensitive to your concept

What’s Next?

Adversarial Robustness

Defend against adversarial attacks

Knowledge Distillation

Transfer knowledge between models