Deep learning models are often “black boxes” — they work, but we do not know why. A model that classifies skin lesions with 95% accuracy is useless in a clinic if the dermatologist cannot understand what features it is using. Is it looking at the lesion’s border irregularity (good) or the presence of a ruler in the image (bad — common in training datasets of confirmed melanomas)?This is not a theoretical concern. It is problematic for:
Trust: A radiologist will not act on a model’s prediction without understanding its reasoning
Debugging: When a model fails on production data, you need to know why to fix it
Fairness: You must verify the model is not using protected attributes like race or gender as proxies
Compliance: GDPR’s “right to explanation” and the EU AI Act explicitly require algorithmic transparency
Science: Understanding what a model has learned can reveal genuine scientific insights about the data
import torchimport torch.nn as nnimport torch.nn.functional as Fimport numpy as npimport matplotlib.pyplot as pltfrom typing import List, Dict, Tuple, Optional, Callabletorch.manual_seed(42)
The simplest approach: compute the gradient of the output class score with respect to the input pixels. The intuition is direct — pixels with large gradients are pixels where a small change would most affect the model’s prediction. If the gradient at pixel (100, 200) is large for the “cat” class, then changing that pixel significantly affects whether the model thinks the image contains a cat.
class VanillaGradients: """Compute input gradients for interpretability.""" def __init__(self, model: nn.Module): self.model = model self.model.eval() def compute_gradients( self, input_tensor: torch.Tensor, target_class: Optional[int] = None ) -> torch.Tensor: """ Compute gradient of target class w.r.t. input. Args: input_tensor: [1, C, H, W] input image target_class: Class to explain (default: predicted class) Returns: gradients: [C, H, W] gradient map """ input_tensor = input_tensor.clone().requires_grad_(True) # Forward pass output = self.model(input_tensor) # Get target class if target_class is None: target_class = output.argmax(dim=1).item() # Backward pass self.model.zero_grad() output[0, target_class].backward() # Get gradients gradients = input_tensor.grad.data[0] return gradients def visualize( self, input_tensor: torch.Tensor, target_class: Optional[int] = None ) -> np.ndarray: """Create visualization of gradients.""" gradients = self.compute_gradients(input_tensor, target_class) # Take absolute value and sum across channels saliency = gradients.abs().sum(dim=0).cpu().numpy() # Normalize saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-8) return saliency# Example usage# model = load_pretrained_model()# explainer = VanillaGradients(model)# saliency = explainer.visualize(input_image, target_class=281) # cat
Vanilla gradients have a fundamental flaw: they only capture local sensitivity, not total attribution. A pixel might have a small gradient (locally flat) but still be critically important for the prediction. Integrated Gradients fixes this by accumulating gradients along a path from a baseline (typically a black image) to the actual input. This satisfies important theoretical properties like “completeness” — the attributions sum to the difference between the model’s output at the input and at the baseline.Accumulate gradients along the path from baseline to input:IGi(x)=(xi−xi′)×∫01∂xi∂F(x′+α(x−x′))dα
class IntegratedGradients: """Integrated Gradients attribution method.""" def __init__(self, model: nn.Module): self.model = model self.model.eval() def compute_integrated_gradients( self, input_tensor: torch.Tensor, target_class: Optional[int] = None, baseline: Optional[torch.Tensor] = None, steps: int = 50 ) -> torch.Tensor: """ Compute Integrated Gradients. Args: input_tensor: [1, C, H, W] input image target_class: Class to explain baseline: Reference input (default: zeros) steps: Number of integration steps Returns: attributions: [C, H, W] attribution map """ if baseline is None: baseline = torch.zeros_like(input_tensor) # Get target class if target_class is None: with torch.no_grad(): output = self.model(input_tensor) target_class = output.argmax(dim=1).item() # Generate interpolated inputs scaled_inputs = [] for i in range(steps + 1): alpha = i / steps scaled_input = baseline + alpha * (input_tensor - baseline) scaled_inputs.append(scaled_input) scaled_inputs = torch.cat(scaled_inputs, dim=0) scaled_inputs.requires_grad_(True) # Forward pass outputs = self.model(scaled_inputs) # Backward pass self.model.zero_grad() target_outputs = outputs[:, target_class] grads = torch.autograd.grad( outputs=target_outputs.sum(), inputs=scaled_inputs, create_graph=False )[0] # Average gradients avg_grads = grads.mean(dim=0) # Compute integrated gradients integrated_grads = (input_tensor[0] - baseline[0]) * avg_grads return integrated_grads def visualize( self, input_tensor: torch.Tensor, target_class: Optional[int] = None, steps: int = 50 ) -> np.ndarray: """Create visualization.""" attributions = self.compute_integrated_gradients( input_tensor, target_class, steps=steps ) # Sum across channels and take absolute value attr_map = attributions.abs().sum(dim=0).cpu().numpy() # Normalize attr_map = (attr_map - attr_map.min()) / (attr_map.max() - attr_map.min() + 1e-8) return attr_map
Grad-CAM is the most widely-used interpretability method in practice, and for good reason: it produces intuitive, coarse-grained heatmaps that show which regions of the image were most important for a particular class prediction. Unlike gradient-based methods that operate at the pixel level (producing noisy, hard-to-interpret saliency maps), Grad-CAM works at the feature map level, producing smooth heatmaps that align with human-understandable regions.Gradient-weighted Class Activation Mapping:
class GradCAM: """ Grad-CAM: Visual Explanations from Deep Networks. Uses gradients flowing into the final convolutional layer to produce a coarse localization map highlighting important regions. """ def __init__(self, model: nn.Module, target_layer: nn.Module): self.model = model self.target_layer = target_layer self.model.eval() self.gradients = None self.activations = None # Register hooks self._register_hooks() def _register_hooks(self): """Register forward and backward hooks.""" def forward_hook(module, input, output): self.activations = output.detach() def backward_hook(module, grad_input, grad_output): self.gradients = grad_output[0].detach() self.target_layer.register_forward_hook(forward_hook) self.target_layer.register_full_backward_hook(backward_hook) def generate_cam( self, input_tensor: torch.Tensor, target_class: Optional[int] = None ) -> np.ndarray: """ Generate Grad-CAM heatmap. Args: input_tensor: [1, C, H, W] input image target_class: Class to visualize Returns: cam: [H, W] heatmap (same size as input) """ # Forward pass output = self.model(input_tensor) # Get target class if target_class is None: target_class = output.argmax(dim=1).item() # Backward pass self.model.zero_grad() output[0, target_class].backward() # Get weights: global average pooling of gradients weights = self.gradients.mean(dim=(2, 3), keepdim=True) # [1, C, 1, 1] # Weighted combination of activation maps cam = (weights * self.activations).sum(dim=1, keepdim=True) # [1, 1, H', W'] # ReLU to keep only positive contributions cam = F.relu(cam) # Upsample to input size cam = F.interpolate( cam, size=input_tensor.shape[2:], mode='bilinear', align_corners=False ) # Normalize cam = cam[0, 0].cpu().numpy() cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) return cam def visualize( self, input_tensor: torch.Tensor, target_class: Optional[int] = None, alpha: float = 0.5 ) -> np.ndarray: """Create overlay visualization.""" cam = self.generate_cam(input_tensor, target_class) # Get original image image = input_tensor[0].permute(1, 2, 0).cpu().numpy() image = (image - image.min()) / (image.max() - image.min()) # Apply colormap to CAM import matplotlib.cm as cm heatmap = cm.jet(cam)[:, :, :3] # Overlay overlay = alpha * heatmap + (1 - alpha) * image overlay = np.clip(overlay, 0, 1) return overlay# Example usage with ResNetclass GradCAMExample: """Example of using Grad-CAM with ResNet.""" @staticmethod def get_gradcam_for_resnet(model): # For ResNet, use the last conv layer before average pooling target_layer = model.layer4[-1].conv2 return GradCAM(model, target_layer)
SHAP brings rigorous game theory to model interpretability. The idea comes from cooperative game theory: if a group of players (features) collectively achieve some payoff (the model’s prediction), how do you fairly distribute credit? Shapley values provide the unique attribution method that satisfies several desirable fairness axioms: symmetry (equal features get equal credit), efficiency (attributions sum to the total prediction), and monotonicity (a feature that always helps never gets negative credit).The catch: computing exact Shapley values requires evaluating the model on all possible subsets of features, which is exponential. DeepSHAP and GradientSHAP provide efficient approximations that leverage the network’s structure.
class DeepSHAP: """ Deep SHAP for neural network explanations. Based on Shapley values from game theory. Provides theoretically grounded feature attributions. """ def __init__(self, model: nn.Module, background_data: torch.Tensor): self.model = model self.model.eval() self.background = background_data def shap_values( self, input_tensor: torch.Tensor, target_class: Optional[int] = None, n_samples: int = 100 ) -> torch.Tensor: """ Compute SHAP values using sampling approximation. Args: input_tensor: [1, C, H, W] input to explain target_class: Class to explain n_samples: Number of coalition samples """ if target_class is None: with torch.no_grad(): output = self.model(input_tensor) target_class = output.argmax(dim=1).item() # Flatten input for feature-level attribution input_flat = input_tensor.flatten() n_features = input_flat.shape[0] # Sample coalitions and compute marginal contributions shap_values = torch.zeros_like(input_flat) for _ in range(n_samples): # Random permutation perm = torch.randperm(n_features) # Sample background bg_idx = np.random.randint(len(self.background)) background = self.background[bg_idx:bg_idx+1].flatten() # Compute marginal contributions current = background.clone() prev_output = self._evaluate(current.reshape(input_tensor.shape), target_class) for idx in perm: current[idx] = input_flat[idx] new_output = self._evaluate(current.reshape(input_tensor.shape), target_class) shap_values[idx] += (new_output - prev_output) prev_output = new_output shap_values /= n_samples return shap_values.reshape(input_tensor.shape) def _evaluate(self, x: torch.Tensor, target_class: int) -> float: with torch.no_grad(): output = self.model(x) return output[0, target_class].item()class GradientSHAP: """ GradientSHAP: Faster SHAP approximation using gradients. """ def __init__(self, model: nn.Module, background_data: torch.Tensor): self.model = model self.background = background_data def shap_values( self, input_tensor: torch.Tensor, target_class: Optional[int] = None, n_samples: int = 50 ) -> torch.Tensor: """Compute GradientSHAP values.""" if target_class is None: with torch.no_grad(): output = self.model(input_tensor) target_class = output.argmax(dim=1).item() all_grads = [] for _ in range(n_samples): # Sample background bg_idx = np.random.randint(len(self.background)) baseline = self.background[bg_idx:bg_idx+1] # Random interpolation point alpha = np.random.uniform() interpolated = baseline + alpha * (input_tensor - baseline) interpolated.requires_grad_(True) # Compute gradient output = self.model(interpolated) self.model.zero_grad() output[0, target_class].backward() grad = interpolated.grad.data all_grads.append(grad) # Average gradients avg_grad = torch.stack(all_grads).mean(dim=0) # SHAP values = gradient × (input - expected_baseline) expected_baseline = self.background.mean(dim=0) shap_values = avg_grad * (input_tensor - expected_baseline) return shap_values
While attribution methods answer “which inputs matter?”, feature visualization answers “what does this neuron detect?” By starting with random noise and iteratively modifying the image to maximize a specific neuron’s activation, we can literally see the platonic ideal of what that neuron is looking for. Early layers reveal edge and texture detectors; deeper layers reveal complex pattern detectors like dog faces, wheels, or building facades. This technique, popularized by Google’s Distill publication, is both scientifically illuminating and occasionally surreal.
class ActivationMaximization: """ Generate images that maximally activate a neuron. Useful for understanding what features a neuron detects. Regularization is critical: without it, the optimized image will be adversarial noise that maximizes activation but is visually meaningless to humans. """ def __init__(self, model: nn.Module): self.model = model self.model.eval() def visualize_filter( self, layer: nn.Module, filter_idx: int, image_size: int = 224, iterations: int = 200, lr: float = 0.1 ) -> torch.Tensor: """ Generate image that maximizes activation of a filter. Args: layer: Target layer filter_idx: Which filter to visualize image_size: Size of generated image iterations: Optimization steps lr: Learning rate """ # Start with random noise image = torch.randn(1, 3, image_size, image_size, requires_grad=True) optimizer = torch.optim.Adam([image], lr=lr) # Hook to capture activation activation = None def hook(module, input, output): nonlocal activation activation = output handle = layer.register_forward_hook(hook) for i in range(iterations): optimizer.zero_grad() # Forward pass _ = self.model(image) # Loss: negative activation (we want to maximize) loss = -activation[0, filter_idx].mean() # Add regularization loss += 0.01 * torch.norm(image) # L2 regularization loss += 0.001 * self._total_variation(image) # Smoothness loss.backward() optimizer.step() # Clip to valid range with torch.no_grad(): image.clamp_(-2, 2) handle.remove() # Normalize for visualization image = image.detach() image = (image - image.min()) / (image.max() - image.min()) return image def _total_variation(self, x: torch.Tensor) -> torch.Tensor: """Total variation loss for smoothness.""" diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]) diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]) return diff_h.mean() + diff_w.mean()class DeepDream: """ DeepDream: Maximize activations to create psychedelic images. """ def __init__(self, model: nn.Module, target_layers: List[nn.Module]): self.model = model self.target_layers = target_layers self.activations = {} def dream( self, image: torch.Tensor, iterations: int = 100, lr: float = 0.01, octave_scale: float = 1.4, n_octaves: int = 4 ) -> torch.Tensor: """ Apply DeepDream to an image. Args: image: [1, 3, H, W] input image iterations: Steps per octave lr: Learning rate octave_scale: Scale factor between octaves n_octaves: Number of octaves (scales) """ original_size = image.shape[2:] # Process at multiple scales (octaves) for octave in range(n_octaves): if octave > 0: # Upscale new_size = [int(s * octave_scale) for s in image.shape[2:]] image = F.interpolate(image, size=new_size, mode='bilinear') image = image.detach().requires_grad_(True) for _ in range(iterations): # Forward pass _ = self.model(image) # Maximize activations loss = sum( self.activations[layer].norm() for layer in self.target_layers ) loss.backward() # Gradient ascent with torch.no_grad(): image += lr * image.grad / (image.grad.norm() + 1e-8) image.grad.zero_() # Resize back image = F.interpolate(image, size=original_size, mode='bilinear') return image.detach()
TCAV moves beyond pixel-level explanations to concept-level ones, which is what humans actually care about. Instead of asking “which pixels matter?”, TCAV asks “does the concept of stripes influence the model’s prediction of zebra?” You define a concept by providing example images (images with stripes vs. random images), TCAV learns a direction in the model’s activation space that corresponds to that concept, and then measures how much pushing the representation along that direction affects the prediction. This approach speaks the language of domain experts — a doctor can ask “does the presence of calcification influence the model’s diagnosis?” without needing to understand gradients or attention maps.
class TCAV: """ Testing with Concept Activation Vectors. Explain models using human-interpretable concepts. The key advantage over saliency methods: TCAV explanations are expressed in terms that domain experts already understand. """ def __init__(self, model: nn.Module, bottleneck_layer: nn.Module): self.model = model self.bottleneck = bottleneck_layer self.activations = None # Hook to capture activations self.bottleneck.register_forward_hook(self._save_activation) def _save_activation(self, module, input, output): self.activations = output.detach() def get_activations(self, images: torch.Tensor) -> torch.Tensor: """Get bottleneck activations for images.""" with torch.no_grad(): _ = self.model(images) return self.activations def train_cav( self, concept_images: torch.Tensor, random_images: torch.Tensor ) -> torch.Tensor: """ Train Concept Activation Vector. Args: concept_images: Images with the concept random_images: Random images without the concept Returns: cav: Concept activation vector """ # Get activations concept_acts = self.get_activations(concept_images) random_acts = self.get_activations(random_images) # Flatten activations concept_acts = concept_acts.flatten(start_dim=1) random_acts = random_acts.flatten(start_dim=1) # Create labels X = torch.cat([concept_acts, random_acts], dim=0) y = torch.cat([ torch.ones(len(concept_acts)), torch.zeros(len(random_acts)) ]) # Train linear classifier from sklearn.linear_model import LogisticRegression clf = LogisticRegression(max_iter=1000) clf.fit(X.numpy(), y.numpy()) # CAV is the weight vector cav = torch.tensor(clf.coef_[0], dtype=torch.float32) cav = cav / cav.norm() return cav def compute_tcav_score( self, test_images: torch.Tensor, target_class: int, cav: torch.Tensor ) -> float: """ Compute TCAV score: fraction of test images where the concept positively influences the prediction. """ positive_count = 0 for img in test_images: img = img.unsqueeze(0).requires_grad_(True) # Forward pass output = self.model(img) # Backward pass self.model.zero_grad() output[0, target_class].backward() # Get gradient of activations act_grad = self.activations.grad if act_grad is not None: # Directional derivative along CAV act_grad_flat = act_grad.flatten() directional_deriv = torch.dot(act_grad_flat, cav) if directional_deriv > 0: positive_count += 1 tcav_score = positive_count / len(test_images) return tcav_score
def sanity_check(model, explainer, image): # 1. Randomize top layer weights # 2. Check if explanations change significantly # If they don't change, the method may not be faithful
Exercise 3: Build TCAV for Custom Concepts
Create CAVs for your own concepts:
# Collect images with/without your concept# Train CAV# Test which classes are sensitive to your concept