Gradient Flow: The Lifeblood of Deep Learning
Understanding Gradient Dynamics
Gradients are how neural networks learn. They flow backward through the network, telling each parameter how to change. When this flow is disrupted, learning stops. ∂W(1)∂L=∂y^∂L⋅∂a(L)∂y^⋅∂a(L−1)∂a(L)⋯∂W(1)∂a(2) This chain of multiplications is the crux of all gradient problems.Copy
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from collections import defaultdict
torch.manual_seed(42)
The Vanishing Gradient Problem
Mathematical Analysis
For a sigmoid activation σ(x)=1+e−x1: σ′(x)=σ(x)(1−σ(x))≤0.25 Through L layers: ∂W(1)∂L≤0.25L⋅∂W(L)∂LCopy
def analyze_vanishing_gradients():
"""Demonstrate vanishing gradients mathematically."""
# Sigmoid gradient analysis
x = np.linspace(-10, 10, 1000)
sigmoid = 1 / (1 + np.exp(-x))
sigmoid_grad = sigmoid * (1 - sigmoid)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# Sigmoid function
axes[0].plot(x, sigmoid, 'b-', linewidth=2)
axes[0].set_title('Sigmoid Activation')
axes[0].set_xlabel('x')
axes[0].set_ylabel('σ(x)')
axes[0].grid(True, alpha=0.3)
axes[0].axhline(y=0.5, color='r', linestyle='--', alpha=0.5)
# Sigmoid gradient
axes[1].plot(x, sigmoid_grad, 'g-', linewidth=2)
axes[1].set_title("Sigmoid Gradient (max=0.25)")
axes[1].set_xlabel('x')
axes[1].set_ylabel("σ'(x)")
axes[1].grid(True, alpha=0.3)
axes[1].axhline(y=0.25, color='r', linestyle='--', alpha=0.5, label='max=0.25')
axes[1].legend()
# Gradient decay through layers
layers = np.arange(1, 51)
max_grad_factor = 0.25 ** layers
axes[2].semilogy(layers, max_grad_factor, 'r-', linewidth=2)
axes[2].set_title('Gradient Decay Through Layers')
axes[2].set_xlabel('Number of Layers')
axes[2].set_ylabel('Max Gradient Factor')
axes[2].grid(True, alpha=0.3)
# Annotate specific points
for l in [10, 20, 30, 40, 50]:
axes[2].annotate(f'{0.25**l:.2e}', (l, 0.25**l),
textcoords="offset points", xytext=(0,10), ha='center')
plt.tight_layout()
plt.show()
print("Gradient Decay Analysis")
print("="*50)
for depth in [10, 20, 50, 100]:
factor = 0.25 ** depth
print(f"After {depth} layers: gradient ≤ {factor:.2e}")
analyze_vanishing_gradients()
Live Demonstration
Copy
def vanishing_gradient_demo():
"""Watch gradients vanish in a deep sigmoid network."""
class DeepSigmoid(nn.Module):
def __init__(self, depth, width=100):
super().__init__()
layers = []
for _ in range(depth):
layers.append(nn.Linear(width, width))
layers.append(nn.Sigmoid())
layers.append(nn.Linear(width, 1))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
depths = [5, 10, 20, 50]
print("Vanishing Gradient Demonstration")
print("="*60)
results = {}
for depth in depths:
model = DeepSigmoid(depth)
x = torch.randn(32, 100)
y = torch.randn(32, 1)
# Forward and backward
output = model(x)
loss = nn.MSELoss()(output, y)
loss.backward()
# Collect gradient norms per layer
grad_norms = []
for name, param in model.named_parameters():
if 'weight' in name and param.grad is not None:
grad_norms.append(param.grad.norm().item())
results[depth] = grad_norms
print(f"\nDepth {depth}:")
print(f" First layer grad norm: {grad_norms[0]:.2e}")
print(f" Last layer grad norm: {grad_norms[-1]:.2e}")
print(f" Ratio (first/last): {grad_norms[0]/grad_norms[-1]:.2e}")
# Plot
plt.figure(figsize=(12, 5))
for depth, grads in results.items():
plt.semilogy(range(len(grads)), grads, 'o-', label=f'Depth={depth}')
plt.xlabel('Layer Index')
plt.ylabel('Gradient Norm (log scale)')
plt.title('Gradient Norms Through Network Depth')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
vanishing_gradient_demo()
The Exploding Gradient Problem
When Gradients Explode
If weight matrices have eigenvalues ∣λ∣>1: i=1∏LW(i)≈vλL Where v is an eigenvector and gradients grow exponentially.Copy
def exploding_gradient_demo():
"""Demonstrate exploding gradients."""
class DeepLinear(nn.Module):
def __init__(self, depth, width=100):
super().__init__()
layers = []
for _ in range(depth):
layer = nn.Linear(width, width, bias=False)
# Initialize with slightly too large weights
nn.init.normal_(layer.weight, std=1.5 / np.sqrt(width))
layers.append(layer)
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
print("Exploding Gradient Demonstration")
print("="*60)
for depth in [10, 20, 30]:
model = DeepLinear(depth)
x = torch.randn(32, 100)
try:
# Forward pass
with torch.no_grad():
activations = [x]
current = x
for layer in model.net:
current = layer(current)
activations.append(current)
# Check for explosion
output = model(x)
print(f"\nDepth {depth}:")
print(f" Input norm: {x.norm().item():.2e}")
print(f" Output norm: {output.norm().item():.2e}")
# Track activation growth
norms = [a.norm().item() for a in activations]
growth_rate = norms[-1] / norms[0]
print(f" Growth rate: {growth_rate:.2e}")
if np.isnan(output.norm().item()) or np.isinf(output.norm().item()):
print(" ⚠ EXPLODED to NaN/Inf!")
elif growth_rate > 1e6:
print(" ⚠ Severe explosion detected!")
except Exception as e:
print(f"\nDepth {depth}: Failed - {str(e)[:50]}")
exploding_gradient_demo()
Gradient Clipping Solutions
Copy
class GradientClipper:
"""Various gradient clipping strategies."""
@staticmethod
def clip_by_norm(parameters, max_norm):
"""Clip gradients by global norm (most common)."""
total_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm)
return total_norm
@staticmethod
def clip_by_value(parameters, clip_value):
"""Clip each gradient element to [-clip_value, clip_value]."""
torch.nn.utils.clip_grad_value_(parameters, clip_value)
@staticmethod
def clip_by_global_norm_manual(parameters, max_norm):
"""Manual implementation of global norm clipping."""
parameters = list(filter(lambda p: p.grad is not None, parameters))
# Compute global norm
total_norm = 0.0
for p in parameters:
total_norm += p.grad.data.norm(2).item() ** 2
total_norm = np.sqrt(total_norm)
# Compute clipping coefficient
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
return total_norm, clip_coef
@staticmethod
def adaptive_clipping(parameters, percentile=10):
"""Clip based on gradient distribution (AdaClip style)."""
all_grads = []
for p in parameters:
if p.grad is not None:
all_grads.append(p.grad.data.abs().flatten())
all_grads = torch.cat(all_grads)
threshold = torch.quantile(all_grads, 1 - percentile/100)
for p in parameters:
if p.grad is not None:
p.grad.data.clamp_(-threshold, threshold)
return threshold.item()
# Example usage
def gradient_clipping_example():
"""Demonstrate gradient clipping strategies."""
model = nn.Linear(100, 100)
x = torch.randn(32, 100)
y = torch.randn(32, 100)
# Simulate large gradients
loss = nn.MSELoss()(model(x), y) * 1000
loss.backward()
original_norm = model.weight.grad.norm().item()
print(f"Original gradient norm: {original_norm:.2f}")
# Clip by norm
loss.backward()
clipped_norm = GradientClipper.clip_by_norm(model.parameters(), max_norm=1.0)
print(f"After clip_by_norm(1.0): {model.weight.grad.norm().item():.2f}")
# Clip by value
loss.backward()
GradientClipper.clip_by_value(model.parameters(), clip_value=0.1)
print(f"After clip_by_value(0.1): {model.weight.grad.norm().item():.2f}")
gradient_clipping_example()
Gradient Flow Visualization
Building a Gradient Monitor
Copy
class GradientMonitor:
"""Comprehensive gradient monitoring toolkit."""
def __init__(self, model):
self.model = model
self.gradient_history = defaultdict(list)
self.activation_history = defaultdict(list)
self.hooks = []
def register_hooks(self):
"""Register forward and backward hooks."""
def forward_hook(name):
def hook(module, input, output):
if isinstance(output, torch.Tensor):
self.activation_history[name].append({
'mean': output.mean().item(),
'std': output.std().item(),
'min': output.min().item(),
'max': output.max().item(),
'dead_fraction': (output == 0).float().mean().item()
})
return hook
def backward_hook(name):
def hook(module, grad_input, grad_output):
if grad_output[0] is not None:
grad = grad_output[0]
self.gradient_history[name].append({
'mean': grad.mean().item(),
'std': grad.std().item(),
'norm': grad.norm().item(),
'max_abs': grad.abs().max().item()
})
return hook
for name, module in self.model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
self.hooks.append(
module.register_forward_hook(forward_hook(name))
)
self.hooks.append(
module.register_full_backward_hook(backward_hook(name))
)
def remove_hooks(self):
"""Clean up hooks."""
for hook in self.hooks:
hook.remove()
self.hooks = []
def plot_gradients(self, title="Gradient Flow"):
"""Visualize gradient statistics."""
if not self.gradient_history:
print("No gradients recorded. Did you run a backward pass?")
return
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
layer_names = list(self.gradient_history.keys())
# Get latest statistics
means = [self.gradient_history[n][-1]['mean'] for n in layer_names]
stds = [self.gradient_history[n][-1]['std'] for n in layer_names]
norms = [self.gradient_history[n][-1]['norm'] for n in layer_names]
max_abs = [self.gradient_history[n][-1]['max_abs'] for n in layer_names]
x = range(len(layer_names))
# Gradient means
axes[0,0].bar(x, means, color='blue', alpha=0.7)
axes[0,0].set_xlabel('Layer')
axes[0,0].set_ylabel('Gradient Mean')
axes[0,0].set_title('Gradient Means')
axes[0,0].axhline(y=0, color='r', linestyle='--', alpha=0.5)
# Gradient stds
axes[0,1].bar(x, stds, color='green', alpha=0.7)
axes[0,1].set_xlabel('Layer')
axes[0,1].set_ylabel('Gradient Std')
axes[0,1].set_title('Gradient Standard Deviations')
axes[0,1].set_yscale('log')
# Gradient norms
axes[1,0].bar(x, norms, color='orange', alpha=0.7)
axes[1,0].set_xlabel('Layer')
axes[1,0].set_ylabel('Gradient Norm')
axes[1,0].set_title('Gradient Norms per Layer')
axes[1,0].set_yscale('log')
# Max absolute gradient
axes[1,1].bar(x, max_abs, color='red', alpha=0.7)
axes[1,1].set_xlabel('Layer')
axes[1,1].set_ylabel('Max |Gradient|')
axes[1,1].set_title('Maximum Absolute Gradient')
axes[1,1].set_yscale('log')
plt.suptitle(title, fontsize=14)
plt.tight_layout()
plt.show()
def plot_gradient_evolution(self, layer_name=None):
"""Plot how gradients evolve over training."""
if layer_name is None:
layer_name = list(self.gradient_history.keys())[0]
history = self.gradient_history[layer_name]
steps = range(len(history))
norms = [h['norm'] for h in history]
stds = [h['std'] for h in history]
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(steps, norms, 'b-', linewidth=2)
axes[0].set_xlabel('Training Step')
axes[0].set_ylabel('Gradient Norm')
axes[0].set_title(f'Gradient Norm Evolution - {layer_name}')
axes[0].grid(True, alpha=0.3)
axes[1].plot(steps, stds, 'g-', linewidth=2)
axes[1].set_xlabel('Training Step')
axes[1].set_ylabel('Gradient Std')
axes[1].set_title(f'Gradient Std Evolution - {layer_name}')
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Example usage
def gradient_monitoring_example():
"""Demonstrate gradient monitoring."""
# Create a model with potential gradient issues
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10)
)
monitor = GradientMonitor(model)
monitor.register_hooks()
# Simulate training
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
for step in range(50):
x = torch.randn(32, 784)
y = torch.randint(0, 10, (32,))
optimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
# Visualize
monitor.plot_gradients("Gradient Flow After 50 Steps")
monitor.remove_hooks()
gradient_monitoring_example()
Gradient Flow in Different Architectures
Residual Connections
Copy
def residual_gradient_flow():
"""Compare gradient flow with and without residual connections."""
class PlainBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.fc1 = nn.Linear(dim, dim)
self.fc2 = nn.Linear(dim, dim)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.fc2(self.relu(self.fc1(x))))
class ResidualBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.fc1 = nn.Linear(dim, dim)
self.fc2 = nn.Linear(dim, dim)
self.relu = nn.ReLU()
def forward(self, x):
residual = x
out = self.relu(self.fc1(x))
out = self.fc2(out)
return self.relu(out + residual) # Skip connection!
def build_network(block_class, num_blocks, dim):
layers = [block_class(dim) for _ in range(num_blocks)]
return nn.Sequential(*layers)
depth = 50
dim = 128
plain_net = build_network(PlainBlock, depth, dim)
res_net = build_network(ResidualBlock, depth, dim)
print("Gradient Flow: Plain vs Residual Networks")
print("="*60)
for name, net in [("Plain", plain_net), ("Residual", res_net)]:
x = torch.randn(32, dim)
y = torch.randn(32, dim)
output = net(x)
loss = nn.MSELoss()(output, y)
loss.backward()
# Collect gradients from each block
grad_norms = []
for module in net:
if hasattr(module, 'fc1'):
grad_norms.append(module.fc1.weight.grad.norm().item())
print(f"\n{name} Network ({depth} blocks):")
print(f" First block gradient: {grad_norms[0]:.6f}")
print(f" Last block gradient: {grad_norms[-1]:.6f}")
print(f" Ratio (first/last): {grad_norms[0]/grad_norms[-1]:.4f}")
# Plot
plt.semilogy(grad_norms, label=name, marker='o')
plt.xlabel('Block Index')
plt.ylabel('Gradient Norm (log scale)')
plt.title('Gradient Flow: Plain vs Residual')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
residual_gradient_flow()
Dense Connections (DenseNet style)
Copy
class DenseBlock(nn.Module):
"""DenseNet-style block with dense connections."""
def __init__(self, dim, growth_rate=32):
super().__init__()
self.fc = nn.Linear(dim, growth_rate)
self.relu = nn.ReLU()
def forward(self, features):
# features is a list of all previous feature maps
concat = torch.cat(features, dim=1)
out = self.relu(self.fc(concat))
return out
def dense_gradient_flow():
"""Demonstrate gradient flow in DenseNet architecture."""
class DenseNetwork(nn.Module):
def __init__(self, input_dim, num_blocks, growth_rate=32):
super().__init__()
self.initial = nn.Linear(input_dim, growth_rate)
# Dense blocks
self.blocks = nn.ModuleList()
in_features = growth_rate
for _ in range(num_blocks):
self.blocks.append(nn.Linear(in_features, growth_rate))
in_features += growth_rate
self.final = nn.Linear(in_features, 10)
def forward(self, x):
features = [self.initial(x)]
for block in self.blocks:
concat = torch.cat(features, dim=1)
new_features = torch.relu(block(concat))
features.append(new_features)
concat = torch.cat(features, dim=1)
return self.final(concat)
model = DenseNetwork(input_dim=784, num_blocks=20, growth_rate=32)
x = torch.randn(32, 784)
y = torch.randint(0, 10, (32,))
output = model(x)
loss = nn.CrossEntropyLoss()(output, y)
loss.backward()
# Analyze gradients
print("Dense Network Gradient Analysis")
print("="*50)
grad_norms = []
for i, block in enumerate(model.blocks):
grad_norms.append(block.weight.grad.norm().item())
print(f"Block {i}: gradient norm = {grad_norms[-1]:.6f}")
print(f"\nGradient variation: std = {np.std(grad_norms):.6f}")
print(f"Ratio (first/last): {grad_norms[0]/grad_norms[-1]:.4f}")
dense_gradient_flow()
Advanced Analysis Techniques
Gradient Covariance Analysis
Copy
def gradient_covariance_analysis(model, data_loader, num_batches=10):
"""Analyze gradient covariance structure."""
print("Gradient Covariance Analysis")
print("="*50)
# Collect gradients over multiple batches
all_gradients = defaultdict(list)
for batch_idx, (x, y) in enumerate(data_loader):
if batch_idx >= num_batches:
break
model.zero_grad()
output = model(x)
loss = nn.CrossEntropyLoss()(output, y)
loss.backward()
for name, param in model.named_parameters():
if param.grad is not None:
all_gradients[name].append(param.grad.flatten().detach().cpu().numpy())
# Analyze each layer
for name, grads in all_gradients.items():
grads = np.array(grads) # [num_batches, num_params]
# Compute covariance
mean_grad = grads.mean(axis=0)
centered = grads - mean_grad
# Gradient variance
variance = np.var(grads, axis=0).mean()
# Gradient correlation (sample a subset for large layers)
if grads.shape[1] > 1000:
idx = np.random.choice(grads.shape[1], 1000, replace=False)
grads_sample = grads[:, idx]
else:
grads_sample = grads
corr = np.corrcoef(grads_sample.T)
avg_correlation = (corr.sum() - np.trace(corr)) / (corr.size - corr.shape[0])
print(f"\n{name}:")
print(f" Mean gradient magnitude: {np.abs(mean_grad).mean():.6f}")
print(f" Gradient variance: {variance:.6f}")
print(f" Avg correlation between params: {avg_correlation:.4f}")
# Create a simple example
def run_covariance_analysis():
model = nn.Sequential(
nn.Linear(100, 50),
nn.ReLU(),
nn.Linear(50, 10)
)
# Simple dataset
dataset = torch.utils.data.TensorDataset(
torch.randn(1000, 100),
torch.randint(0, 10, (1000,))
)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
gradient_covariance_analysis(model, loader)
run_covariance_analysis()
Hessian Analysis
Copy
def hessian_analysis():
"""Analyze Hessian eigenspectrum for understanding loss landscape."""
# Simple model for tractable Hessian computation
model = nn.Linear(10, 5)
# Sample data
x = torch.randn(50, 10)
y = torch.randint(0, 5, (50,))
# Compute Hessian using autograd
def compute_hessian_eigenvalues(model, x, y, top_k=10):
"""Compute top eigenvalues of the Hessian."""
from torch.autograd.functional import hessian
# Flatten parameters
params = torch.cat([p.flatten() for p in model.parameters()])
n_params = len(params)
print(f"Computing Hessian for {n_params} parameters...")
def loss_fn(flat_params):
# Unflatten and apply
idx = 0
for p in model.parameters():
numel = p.numel()
p.data = flat_params[idx:idx+numel].view(p.shape)
idx += numel
output = model(x)
return nn.CrossEntropyLoss()(output, y)
# Compute Hessian
H = hessian(loss_fn, params)
# Get eigenvalues
eigenvalues = torch.linalg.eigvalsh(H)
return eigenvalues
eigenvalues = compute_hessian_eigenvalues(model, x, y)
print("\nHessian Eigenvalue Analysis")
print("="*50)
print(f"Max eigenvalue: {eigenvalues.max().item():.4f}")
print(f"Min eigenvalue: {eigenvalues.min().item():.4f}")
print(f"Condition number: {eigenvalues.max().item() / (eigenvalues.min().item() + 1e-8):.2f}")
# Plot eigenvalue distribution
plt.figure(figsize=(10, 4))
plt.hist(eigenvalues.numpy(), bins=50, edgecolor='black', alpha=0.7)
plt.xlabel('Eigenvalue')
plt.ylabel('Count')
plt.title('Hessian Eigenvalue Distribution')
plt.axvline(x=0, color='r', linestyle='--', alpha=0.5)
plt.show()
# Negative eigenvalues indicate saddle points
n_negative = (eigenvalues < 0).sum().item()
print(f"\nNegative eigenvalues: {n_negative} ({100*n_negative/len(eigenvalues):.1f}%)")
if n_negative > 0:
print("→ You might be at a saddle point!")
hessian_analysis()
Fixing Gradient Flow Issues
Comprehensive Diagnostic and Fix Toolkit
Copy
class GradientDoctor:
"""Diagnose and fix gradient flow issues."""
def __init__(self, model):
self.model = model
self.issues = []
def diagnose(self, sample_input, sample_target):
"""Run comprehensive gradient diagnostics."""
print("╔══════════════════════════════════════════════════════════╗")
print("║ GRADIENT FLOW DIAGNOSIS ║")
print("╚══════════════════════════════════════════════════════════╝")
self.issues = []
# Forward pass
output = self.model(sample_input)
# Check for NaN in output
if torch.isnan(output).any():
self.issues.append(("CRITICAL", "NaN in forward pass output"))
print("⚠ CRITICAL: NaN detected in output!")
return self.issues
# Backward pass
if output.dim() == 2 and output.size(1) > 1:
loss = nn.CrossEntropyLoss()(output, sample_target)
else:
loss = nn.MSELoss()(output.flatten(), sample_target.float().flatten())
loss.backward()
# Check each layer
print("\nLayer-by-layer analysis:")
print("-" * 60)
layer_idx = 0
for name, param in self.model.named_parameters():
if param.grad is None:
self.issues.append(("WARNING", f"{name}: No gradient computed"))
print(f"⚠ {name}: No gradient")
continue
grad = param.grad
grad_norm = grad.norm().item()
grad_mean = grad.mean().item()
grad_std = grad.std().item()
# Check for issues
status = "✓"
if torch.isnan(grad).any():
self.issues.append(("CRITICAL", f"{name}: NaN gradient"))
status = "⚠ NaN"
elif grad_norm < 1e-7:
self.issues.append(("WARNING", f"{name}: Vanishing gradient (norm={grad_norm:.2e})"))
status = "⚠ Vanishing"
elif grad_norm > 1e4:
self.issues.append(("WARNING", f"{name}: Exploding gradient (norm={grad_norm:.2e})"))
status = "⚠ Exploding"
print(f"{name:<40} norm={grad_norm:<10.2e} std={grad_std:<10.2e} {status}")
layer_idx += 1
print("\n" + "="*60)
if self.issues:
print(f"Found {len(self.issues)} issues:")
for severity, msg in self.issues:
print(f" [{severity}] {msg}")
else:
print("✓ Gradient flow looks healthy!")
return self.issues
def suggest_fixes(self):
"""Suggest fixes based on diagnosed issues."""
print("\n╔══════════════════════════════════════════════════════════╗")
print("║ SUGGESTED FIXES ║")
print("╚══════════════════════════════════════════════════════════╝")
if not self.issues:
print("No issues to fix!")
return
fixes = []
for severity, msg in self.issues:
if "Vanishing" in msg:
fixes.extend([
"• Use He/Kaiming initialization for ReLU layers",
"• Add residual connections (skip connections)",
"• Replace sigmoid/tanh with ReLU/GELU",
"• Add BatchNorm or LayerNorm",
"• Use LSTM/GRU instead of vanilla RNN"
])
elif "Exploding" in msg:
fixes.extend([
"• Apply gradient clipping: torch.nn.utils.clip_grad_norm_(..., max_norm=1.0)",
"• Reduce learning rate",
"• Initialize weights with smaller variance",
"• Add weight decay regularization"
])
elif "NaN" in msg:
fixes.extend([
"• Check for numerical instability (log of 0, division by 0)",
"• Reduce learning rate significantly",
"• Use gradient clipping",
"• Check for correct loss function usage",
"• Verify data preprocessing (no NaN in inputs)"
])
elif "No gradient" in msg:
fixes.extend([
"• Ensure requires_grad=True for trainable parameters",
"• Check if the layer is actually used in forward pass",
"• Verify no torch.no_grad() context is active"
])
# Remove duplicates
fixes = list(dict.fromkeys(fixes))
for fix in fixes:
print(fix)
def apply_quick_fixes(self):
"""Apply automatic quick fixes."""
print("\nApplying quick fixes...")
for name, module in self.model.named_modules():
# Re-initialize layers that might have bad weights
if isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
print("✓ Re-initialized all Linear and Conv2d layers with He initialization")
# Example usage
def gradient_doctor_demo():
# Create a problematic model
model = nn.Sequential(
nn.Linear(100, 256),
nn.Sigmoid(), # Problematic for deep networks
nn.Linear(256, 256),
nn.Sigmoid(),
nn.Linear(256, 256),
nn.Sigmoid(),
nn.Linear(256, 10)
)
# Small weight initialization (will cause vanishing)
for m in model.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.01)
doctor = GradientDoctor(model)
x = torch.randn(32, 100)
y = torch.randint(0, 10, (32,))
doctor.diagnose(x, y)
doctor.suggest_fixes()
gradient_doctor_demo()
Exercises
Exercise 1: Gradient Flow Experiment
Exercise 1: Gradient Flow Experiment
Compare gradient flow through different activation functions:
Copy
activations = [nn.Sigmoid(), nn.Tanh(), nn.ReLU(), nn.GELU(), nn.SiLU()]
for act in activations:
model = build_deep_network(depth=30, activation=act)
grad_norms = measure_gradient_flow(model)
plot_gradient_profile(grad_norms, label=act.__class__.__name__)
Exercise 2: Implement Gradient Noise Scale
Exercise 2: Implement Gradient Noise Scale
The gradient noise scale (grad_norm / batch_size) indicates if you’re in:
- Small batch regime: high noise, needs smaller LR
- Large batch regime: low noise, can use larger LR
Copy
def compute_gradient_noise_scale(model, data_loader):
# Compute gradient with full batch
# Compute gradients with mini-batches
# Compare noise levels
pass
Exercise 3: Build a Gradient Dashboard
Exercise 3: Build a Gradient Dashboard
Create a real-time gradient monitoring dashboard using matplotlib:
Copy
class GradientDashboard:
def __init__(self, model):
self.fig, self.axes = plt.subplots(2, 2)
plt.ion() # Interactive mode
def update(self, step):
# Update gradient histograms
# Update gradient norm curves
# Update layer-wise statistics
plt.draw()
plt.pause(0.01)