Model Interpretability: Opening the Black Box
Why Interpretability Matters
Deep learning models are often “black boxes” - they work, but we don’t know why. This is problematic for:- Trust: Should we trust this diagnosis?
- Debugging: Why did the model fail?
- Fairness: Is the model using protected attributes?
- Compliance: Regulations require explanations (GDPR)
- Science: What has the model learned?
Copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional, Callable
torch.manual_seed(42)
Gradient-Based Methods
Vanilla Gradients
The simplest approach: compute gradient of output w.r.t. input.Copy
class VanillaGradients:
"""Compute input gradients for interpretability."""
def __init__(self, model: nn.Module):
self.model = model
self.model.eval()
def compute_gradients(
self,
input_tensor: torch.Tensor,
target_class: Optional[int] = None
) -> torch.Tensor:
"""
Compute gradient of target class w.r.t. input.
Args:
input_tensor: [1, C, H, W] input image
target_class: Class to explain (default: predicted class)
Returns:
gradients: [C, H, W] gradient map
"""
input_tensor = input_tensor.clone().requires_grad_(True)
# Forward pass
output = self.model(input_tensor)
# Get target class
if target_class is None:
target_class = output.argmax(dim=1).item()
# Backward pass
self.model.zero_grad()
output[0, target_class].backward()
# Get gradients
gradients = input_tensor.grad.data[0]
return gradients
def visualize(
self,
input_tensor: torch.Tensor,
target_class: Optional[int] = None
) -> np.ndarray:
"""Create visualization of gradients."""
gradients = self.compute_gradients(input_tensor, target_class)
# Take absolute value and sum across channels
saliency = gradients.abs().sum(dim=0).cpu().numpy()
# Normalize
saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-8)
return saliency
# Example usage
# model = load_pretrained_model()
# explainer = VanillaGradients(model)
# saliency = explainer.visualize(input_image, target_class=281) # cat
Gradient × Input
Multiply gradients by input for sharper attributions:Copy
class GradientTimesInput(VanillaGradients):
"""Gradient × Input attribution."""
def compute_attribution(
self,
input_tensor: torch.Tensor,
target_class: Optional[int] = None
) -> torch.Tensor:
gradients = self.compute_gradients(input_tensor, target_class)
# Multiply by input
attribution = gradients * input_tensor[0]
return attribution
Integrated Gradients
Accumulate gradients along path from baseline to input: IGi(x)=(xi−xi′)×∫01∂xi∂F(x′+α(x−x′))dαCopy
class IntegratedGradients:
"""Integrated Gradients attribution method."""
def __init__(self, model: nn.Module):
self.model = model
self.model.eval()
def compute_integrated_gradients(
self,
input_tensor: torch.Tensor,
target_class: Optional[int] = None,
baseline: Optional[torch.Tensor] = None,
steps: int = 50
) -> torch.Tensor:
"""
Compute Integrated Gradients.
Args:
input_tensor: [1, C, H, W] input image
target_class: Class to explain
baseline: Reference input (default: zeros)
steps: Number of integration steps
Returns:
attributions: [C, H, W] attribution map
"""
if baseline is None:
baseline = torch.zeros_like(input_tensor)
# Get target class
if target_class is None:
with torch.no_grad():
output = self.model(input_tensor)
target_class = output.argmax(dim=1).item()
# Generate interpolated inputs
scaled_inputs = []
for i in range(steps + 1):
alpha = i / steps
scaled_input = baseline + alpha * (input_tensor - baseline)
scaled_inputs.append(scaled_input)
scaled_inputs = torch.cat(scaled_inputs, dim=0)
scaled_inputs.requires_grad_(True)
# Forward pass
outputs = self.model(scaled_inputs)
# Backward pass
self.model.zero_grad()
target_outputs = outputs[:, target_class]
grads = torch.autograd.grad(
outputs=target_outputs.sum(),
inputs=scaled_inputs,
create_graph=False
)[0]
# Average gradients
avg_grads = grads.mean(dim=0)
# Compute integrated gradients
integrated_grads = (input_tensor[0] - baseline[0]) * avg_grads
return integrated_grads
def visualize(
self,
input_tensor: torch.Tensor,
target_class: Optional[int] = None,
steps: int = 50
) -> np.ndarray:
"""Create visualization."""
attributions = self.compute_integrated_gradients(
input_tensor, target_class, steps=steps
)
# Sum across channels and take absolute value
attr_map = attributions.abs().sum(dim=0).cpu().numpy()
# Normalize
attr_map = (attr_map - attr_map.min()) / (attr_map.max() - attr_map.min() + 1e-8)
return attr_map
SmoothGrad
Average gradients over noisy versions of input:Copy
class SmoothGrad(VanillaGradients):
"""SmoothGrad: Average gradients over noisy samples."""
def compute_smooth_gradients(
self,
input_tensor: torch.Tensor,
target_class: Optional[int] = None,
n_samples: int = 50,
noise_std: float = 0.1
) -> torch.Tensor:
"""
Args:
n_samples: Number of noisy samples
noise_std: Standard deviation of noise
"""
all_gradients = []
for _ in range(n_samples):
# Add noise
noise = torch.randn_like(input_tensor) * noise_std
noisy_input = input_tensor + noise
# Compute gradients
grads = self.compute_gradients(noisy_input, target_class)
all_gradients.append(grads)
# Average
smooth_grads = torch.stack(all_gradients).mean(dim=0)
return smooth_grads
Class Activation Mapping (CAM)
Grad-CAM
Gradient-weighted Class Activation Mapping:Copy
class GradCAM:
"""
Grad-CAM: Visual Explanations from Deep Networks.
Uses gradients flowing into the final convolutional layer to produce
a coarse localization map highlighting important regions.
"""
def __init__(self, model: nn.Module, target_layer: nn.Module):
self.model = model
self.target_layer = target_layer
self.model.eval()
self.gradients = None
self.activations = None
# Register hooks
self._register_hooks()
def _register_hooks(self):
"""Register forward and backward hooks."""
def forward_hook(module, input, output):
self.activations = output.detach()
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
self.target_layer.register_forward_hook(forward_hook)
self.target_layer.register_full_backward_hook(backward_hook)
def generate_cam(
self,
input_tensor: torch.Tensor,
target_class: Optional[int] = None
) -> np.ndarray:
"""
Generate Grad-CAM heatmap.
Args:
input_tensor: [1, C, H, W] input image
target_class: Class to visualize
Returns:
cam: [H, W] heatmap (same size as input)
"""
# Forward pass
output = self.model(input_tensor)
# Get target class
if target_class is None:
target_class = output.argmax(dim=1).item()
# Backward pass
self.model.zero_grad()
output[0, target_class].backward()
# Get weights: global average pooling of gradients
weights = self.gradients.mean(dim=(2, 3), keepdim=True) # [1, C, 1, 1]
# Weighted combination of activation maps
cam = (weights * self.activations).sum(dim=1, keepdim=True) # [1, 1, H', W']
# ReLU to keep only positive contributions
cam = F.relu(cam)
# Upsample to input size
cam = F.interpolate(
cam,
size=input_tensor.shape[2:],
mode='bilinear',
align_corners=False
)
# Normalize
cam = cam[0, 0].cpu().numpy()
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
return cam
def visualize(
self,
input_tensor: torch.Tensor,
target_class: Optional[int] = None,
alpha: float = 0.5
) -> np.ndarray:
"""Create overlay visualization."""
cam = self.generate_cam(input_tensor, target_class)
# Get original image
image = input_tensor[0].permute(1, 2, 0).cpu().numpy()
image = (image - image.min()) / (image.max() - image.min())
# Apply colormap to CAM
import matplotlib.cm as cm
heatmap = cm.jet(cam)[:, :, :3]
# Overlay
overlay = alpha * heatmap + (1 - alpha) * image
overlay = np.clip(overlay, 0, 1)
return overlay
# Example usage with ResNet
class GradCAMExample:
"""Example of using Grad-CAM with ResNet."""
@staticmethod
def get_gradcam_for_resnet(model):
# For ResNet, use the last conv layer before average pooling
target_layer = model.layer4[-1].conv2
return GradCAM(model, target_layer)
Grad-CAM++
Improved version with better localization:Copy
class GradCAMPlusPlus(GradCAM):
"""Grad-CAM++ with improved weighting."""
def generate_cam(
self,
input_tensor: torch.Tensor,
target_class: Optional[int] = None
) -> np.ndarray:
# Forward pass
output = self.model(input_tensor)
if target_class is None:
target_class = output.argmax(dim=1).item()
# Backward pass
self.model.zero_grad()
output[0, target_class].backward()
# Grad-CAM++ weighting
gradients = self.gradients[0] # [C, H, W]
activations = self.activations[0] # [C, H, W]
# Compute alpha (importance weights)
grad_2 = gradients ** 2
grad_3 = gradients ** 3
sum_activations = activations.sum(dim=(1, 2), keepdim=True)
alpha_num = grad_2
alpha_denom = 2 * grad_2 + sum_activations * grad_3 + 1e-8
alpha = alpha_num / alpha_denom
# Weighted gradients
weights = (alpha * F.relu(gradients)).sum(dim=(1, 2)) # [C]
# Weighted combination
cam = (weights.view(-1, 1, 1) * activations).sum(dim=0)
cam = F.relu(cam)
# Upsample and normalize
cam = F.interpolate(
cam.unsqueeze(0).unsqueeze(0),
size=input_tensor.shape[2:],
mode='bilinear',
align_corners=False
)[0, 0]
cam = cam.cpu().numpy()
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
return cam
Attention Visualization
Visualizing Transformer Attention
Copy
class AttentionVisualizer:
"""Visualize attention patterns in transformers."""
def __init__(self, model: nn.Module):
self.model = model
self.attention_maps = {}
# Register hooks to capture attention
self._register_hooks()
def _register_hooks(self):
"""Hook into attention layers."""
def make_hook(name):
def hook(module, input, output):
# Assuming attention returns (output, attention_weights)
if isinstance(output, tuple) and len(output) == 2:
self.attention_maps[name] = output[1].detach()
return hook
for name, module in self.model.named_modules():
if 'attention' in name.lower() or 'attn' in name.lower():
module.register_forward_hook(make_hook(name))
def get_attention_maps(
self,
input_tensor: torch.Tensor
) -> Dict[str, torch.Tensor]:
"""Forward pass and collect attention maps."""
self.attention_maps = {}
with torch.no_grad():
_ = self.model(input_tensor)
return self.attention_maps
def visualize_attention(
self,
attention: torch.Tensor,
tokens: List[str],
head: int = 0
) -> None:
"""Visualize attention as heatmap."""
# attention: [batch, heads, seq_len, seq_len]
attn = attention[0, head].cpu().numpy()
plt.figure(figsize=(10, 8))
plt.imshow(attn, cmap='viridis')
plt.xticks(range(len(tokens)), tokens, rotation=45, ha='right')
plt.yticks(range(len(tokens)), tokens)
plt.colorbar()
plt.title(f'Attention Head {head}')
plt.tight_layout()
plt.show()
@staticmethod
def aggregate_attention(
attention_maps: Dict[str, torch.Tensor],
method: str = 'mean'
) -> torch.Tensor:
"""Aggregate attention across layers and heads."""
all_attention = list(attention_maps.values())
stacked = torch.stack(all_attention, dim=0) # [layers, batch, heads, seq, seq]
if method == 'mean':
return stacked.mean(dim=(0, 2)) # [batch, seq, seq]
elif method == 'max':
return stacked.max(dim=0)[0].max(dim=1)[0]
elif method == 'rollout':
# Attention rollout
result = torch.eye(stacked.shape[-1]).unsqueeze(0)
for layer_attn in stacked:
# Average over heads
attn = layer_attn.mean(dim=1)
# Add residual connection
attn = 0.5 * attn + 0.5 * torch.eye(attn.shape[-1])
# Matrix multiply
result = torch.bmm(result, attn)
return result
else:
raise ValueError(f"Unknown method: {method}")
class ViTAttentionRollout:
"""Attention rollout for Vision Transformer."""
def __init__(self, model):
self.model = model
self.attentions = []
self._register_hooks()
def _register_hooks(self):
for block in self.model.blocks:
block.attn.register_forward_hook(self._save_attention)
def _save_attention(self, module, input, output):
self.attentions.append(output[1]) # Save attention weights
def compute_rollout(
self,
input_tensor: torch.Tensor,
discard_ratio: float = 0.9
) -> np.ndarray:
"""Compute attention rollout visualization."""
self.attentions = []
with torch.no_grad():
_ = self.model(input_tensor)
# Stack all layer attentions
attention = torch.stack(self.attentions) # [layers, batch, heads, tokens, tokens]
# Average over heads
attention = attention.mean(dim=2)
# Rollout
result = torch.eye(attention.shape[-1])
for layer_attn in attention[:, 0]: # First batch
# Discard lowest attentions
flat = layer_attn.flatten()
threshold = flat.quantile(discard_ratio)
layer_attn = layer_attn * (layer_attn > threshold)
# Renormalize
layer_attn = layer_attn / layer_attn.sum(dim=-1, keepdim=True)
# Add identity (residual)
layer_attn = 0.5 * layer_attn + 0.5 * torch.eye(layer_attn.shape[-1])
result = torch.mm(result, layer_attn)
# Get attention from CLS token to patches
cls_attention = result[0, 1:] # Exclude CLS token itself
# Reshape to 2D
num_patches = int(np.sqrt(len(cls_attention)))
attention_map = cls_attention.reshape(num_patches, num_patches).numpy()
return attention_map
SHAP (SHapley Additive exPlanations)
Copy
class DeepSHAP:
"""
Deep SHAP for neural network explanations.
Based on Shapley values from game theory.
"""
def __init__(self, model: nn.Module, background_data: torch.Tensor):
self.model = model
self.model.eval()
self.background = background_data
def shap_values(
self,
input_tensor: torch.Tensor,
target_class: Optional[int] = None,
n_samples: int = 100
) -> torch.Tensor:
"""
Compute SHAP values using sampling approximation.
Args:
input_tensor: [1, C, H, W] input to explain
target_class: Class to explain
n_samples: Number of coalition samples
"""
if target_class is None:
with torch.no_grad():
output = self.model(input_tensor)
target_class = output.argmax(dim=1).item()
# Flatten input for feature-level attribution
input_flat = input_tensor.flatten()
n_features = input_flat.shape[0]
# Sample coalitions and compute marginal contributions
shap_values = torch.zeros_like(input_flat)
for _ in range(n_samples):
# Random permutation
perm = torch.randperm(n_features)
# Sample background
bg_idx = np.random.randint(len(self.background))
background = self.background[bg_idx:bg_idx+1].flatten()
# Compute marginal contributions
current = background.clone()
prev_output = self._evaluate(current.reshape(input_tensor.shape), target_class)
for idx in perm:
current[idx] = input_flat[idx]
new_output = self._evaluate(current.reshape(input_tensor.shape), target_class)
shap_values[idx] += (new_output - prev_output)
prev_output = new_output
shap_values /= n_samples
return shap_values.reshape(input_tensor.shape)
def _evaluate(self, x: torch.Tensor, target_class: int) -> float:
with torch.no_grad():
output = self.model(x)
return output[0, target_class].item()
class GradientSHAP:
"""
GradientSHAP: Faster SHAP approximation using gradients.
"""
def __init__(self, model: nn.Module, background_data: torch.Tensor):
self.model = model
self.background = background_data
def shap_values(
self,
input_tensor: torch.Tensor,
target_class: Optional[int] = None,
n_samples: int = 50
) -> torch.Tensor:
"""Compute GradientSHAP values."""
if target_class is None:
with torch.no_grad():
output = self.model(input_tensor)
target_class = output.argmax(dim=1).item()
all_grads = []
for _ in range(n_samples):
# Sample background
bg_idx = np.random.randint(len(self.background))
baseline = self.background[bg_idx:bg_idx+1]
# Random interpolation point
alpha = np.random.uniform()
interpolated = baseline + alpha * (input_tensor - baseline)
interpolated.requires_grad_(True)
# Compute gradient
output = self.model(interpolated)
self.model.zero_grad()
output[0, target_class].backward()
grad = interpolated.grad.data
all_grads.append(grad)
# Average gradients
avg_grad = torch.stack(all_grads).mean(dim=0)
# SHAP values = gradient × (input - expected_baseline)
expected_baseline = self.background.mean(dim=0)
shap_values = avg_grad * (input_tensor - expected_baseline)
return shap_values
Feature Visualization
Activation Maximization
Copy
class ActivationMaximization:
"""
Generate images that maximally activate a neuron.
Useful for understanding what features a neuron detects.
"""
def __init__(self, model: nn.Module):
self.model = model
self.model.eval()
def visualize_filter(
self,
layer: nn.Module,
filter_idx: int,
image_size: int = 224,
iterations: int = 200,
lr: float = 0.1
) -> torch.Tensor:
"""
Generate image that maximizes activation of a filter.
Args:
layer: Target layer
filter_idx: Which filter to visualize
image_size: Size of generated image
iterations: Optimization steps
lr: Learning rate
"""
# Start with random noise
image = torch.randn(1, 3, image_size, image_size, requires_grad=True)
optimizer = torch.optim.Adam([image], lr=lr)
# Hook to capture activation
activation = None
def hook(module, input, output):
nonlocal activation
activation = output
handle = layer.register_forward_hook(hook)
for i in range(iterations):
optimizer.zero_grad()
# Forward pass
_ = self.model(image)
# Loss: negative activation (we want to maximize)
loss = -activation[0, filter_idx].mean()
# Add regularization
loss += 0.01 * torch.norm(image) # L2 regularization
loss += 0.001 * self._total_variation(image) # Smoothness
loss.backward()
optimizer.step()
# Clip to valid range
with torch.no_grad():
image.clamp_(-2, 2)
handle.remove()
# Normalize for visualization
image = image.detach()
image = (image - image.min()) / (image.max() - image.min())
return image
def _total_variation(self, x: torch.Tensor) -> torch.Tensor:
"""Total variation loss for smoothness."""
diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
return diff_h.mean() + diff_w.mean()
class DeepDream:
"""
DeepDream: Maximize activations to create psychedelic images.
"""
def __init__(self, model: nn.Module, target_layers: List[nn.Module]):
self.model = model
self.target_layers = target_layers
self.activations = {}
def dream(
self,
image: torch.Tensor,
iterations: int = 100,
lr: float = 0.01,
octave_scale: float = 1.4,
n_octaves: int = 4
) -> torch.Tensor:
"""
Apply DeepDream to an image.
Args:
image: [1, 3, H, W] input image
iterations: Steps per octave
lr: Learning rate
octave_scale: Scale factor between octaves
n_octaves: Number of octaves (scales)
"""
original_size = image.shape[2:]
# Process at multiple scales (octaves)
for octave in range(n_octaves):
if octave > 0:
# Upscale
new_size = [int(s * octave_scale) for s in image.shape[2:]]
image = F.interpolate(image, size=new_size, mode='bilinear')
image = image.detach().requires_grad_(True)
for _ in range(iterations):
# Forward pass
_ = self.model(image)
# Maximize activations
loss = sum(
self.activations[layer].norm()
for layer in self.target_layers
)
loss.backward()
# Gradient ascent
with torch.no_grad():
image += lr * image.grad / (image.grad.norm() + 1e-8)
image.grad.zero_()
# Resize back
image = F.interpolate(image, size=original_size, mode='bilinear')
return image.detach()
Concept-Based Explanations
TCAV (Testing with Concept Activation Vectors)
Copy
class TCAV:
"""
Testing with Concept Activation Vectors.
Explain models using human-interpretable concepts.
"""
def __init__(self, model: nn.Module, bottleneck_layer: nn.Module):
self.model = model
self.bottleneck = bottleneck_layer
self.activations = None
# Hook to capture activations
self.bottleneck.register_forward_hook(self._save_activation)
def _save_activation(self, module, input, output):
self.activations = output.detach()
def get_activations(self, images: torch.Tensor) -> torch.Tensor:
"""Get bottleneck activations for images."""
with torch.no_grad():
_ = self.model(images)
return self.activations
def train_cav(
self,
concept_images: torch.Tensor,
random_images: torch.Tensor
) -> torch.Tensor:
"""
Train Concept Activation Vector.
Args:
concept_images: Images with the concept
random_images: Random images without the concept
Returns:
cav: Concept activation vector
"""
# Get activations
concept_acts = self.get_activations(concept_images)
random_acts = self.get_activations(random_images)
# Flatten activations
concept_acts = concept_acts.flatten(start_dim=1)
random_acts = random_acts.flatten(start_dim=1)
# Create labels
X = torch.cat([concept_acts, random_acts], dim=0)
y = torch.cat([
torch.ones(len(concept_acts)),
torch.zeros(len(random_acts))
])
# Train linear classifier
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression(max_iter=1000)
clf.fit(X.numpy(), y.numpy())
# CAV is the weight vector
cav = torch.tensor(clf.coef_[0], dtype=torch.float32)
cav = cav / cav.norm()
return cav
def compute_tcav_score(
self,
test_images: torch.Tensor,
target_class: int,
cav: torch.Tensor
) -> float:
"""
Compute TCAV score: fraction of test images where
the concept positively influences the prediction.
"""
positive_count = 0
for img in test_images:
img = img.unsqueeze(0).requires_grad_(True)
# Forward pass
output = self.model(img)
# Backward pass
self.model.zero_grad()
output[0, target_class].backward()
# Get gradient of activations
act_grad = self.activations.grad
if act_grad is not None:
# Directional derivative along CAV
act_grad_flat = act_grad.flatten()
directional_deriv = torch.dot(act_grad_flat, cav)
if directional_deriv > 0:
positive_count += 1
tcav_score = positive_count / len(test_images)
return tcav_score
Practical Considerations
Copy
def interpretability_best_practices():
"""Best practices for model interpretability."""
tips = """
╔════════════════════════════════════════════════════════════════╗
║ INTERPRETABILITY BEST PRACTICES ║
╠════════════════════════════════════════════════════════════════╣
║ ║
║ 1. CHOOSE THE RIGHT METHOD ║
║ • Grad-CAM: Quick visual localization ║
║ • Integrated Gradients: Theoretically grounded ║
║ • SHAP: Feature importance with consistency ║
║ • TCAV: Concept-level understanding ║
║ ║
║ 2. VALIDATE EXPLANATIONS ║
║ • Sanity checks: Random model should give random explain ║
║ • Faithfulness: Removing important features hurts perf ║
║ • Human evaluation: Do explanations make sense? ║
║ ║
║ 3. BE AWARE OF LIMITATIONS ║
║ • Gradient-based methods can be noisy ║
║ • Explanations may not reflect actual reasoning ║
║ • Different methods can give different explanations ║
║ ║
║ 4. COMBINE MULTIPLE METHODS ║
║ • Cross-validate with different techniques ║
║ • Use local (per-sample) and global explanations ║
║ • Consider both input and concept-level analysis ║
║ ║
║ 5. COMMUNICATE CAREFULLY ║
║ • Explanations ≠ justifications ║
║ • Don't over-interpret visualizations ║
║ • Consider your audience's expertise ║
║ ║
╚════════════════════════════════════════════════════════════════╝
"""
print(tips)
interpretability_best_practices()
Exercises
Exercise 1: Compare Attribution Methods
Exercise 1: Compare Attribution Methods
Compare Grad-CAM, Integrated Gradients, and SHAP on the same images:
Copy
def compare_attributions(model, image, target_class):
gradcam = GradCAM(model, model.layer4[-1])
ig = IntegratedGradients(model)
shap = GradientSHAP(model, background_data)
# Compare visualizations
# Compute correlation between methods
Exercise 2: Sanity Checks
Exercise 2: Sanity Checks
Implement sanity checks for explanation methods:
Copy
def sanity_check(model, explainer, image):
# 1. Randomize top layer weights
# 2. Check if explanations change significantly
# If they don't change, the method may not be faithful
Exercise 3: Build TCAV for Custom Concepts
Exercise 3: Build TCAV for Custom Concepts
Create CAVs for your own concepts:
Copy
# Collect images with/without your concept
# Train CAV
# Test which classes are sensitive to your concept