Reproducibility
The Reproducibility Crisis
Deep learning experiments are notoriously hard to reproduce:| Challenge | Impact |
|---|---|
| Random seeds | Different initializations, shuffling |
| Hardware differences | GPU floating-point variations |
| Software versions | Framework, CUDA, cuDNN changes |
| Data pipelines | Preprocessing, augmentation order |
| Hyperparameters | Undocumented settings |
Copy
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os
import json
import hashlib
from datetime import datetime
from typing import Dict, Any, Optional, Callable
from dataclasses import dataclass, asdict, field
from pathlib import Path
import sys
torch.manual_seed(42)
Seed Management
Comprehensive Seeding
Copy
class SeedManager:
"""
Manage random seeds across all sources of randomness.
"""
def __init__(self, seed: int = 42):
self.seed = seed
self.original_states = {}
def set_all_seeds(self, seed: Optional[int] = None):
"""
Set seeds for all random sources.
Note: Even with all seeds set, perfect reproducibility
requires additional measures (see below).
"""
seed = seed or self.seed
# Python's random
random.seed(seed)
# NumPy
np.random.seed(seed)
# PyTorch
torch.manual_seed(seed)
# CUDA (if available)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Make cuDNN deterministic
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Environment variable for some operations
os.environ['PYTHONHASHSEED'] = str(seed)
print(f"All seeds set to {seed}")
def save_states(self):
"""Save random states for later restoration."""
self.original_states = {
'python': random.getstate(),
'numpy': np.random.get_state(),
'torch': torch.get_rng_state(),
}
if torch.cuda.is_available():
self.original_states['cuda'] = torch.cuda.get_rng_state_all()
return self.original_states
def restore_states(self, states: Optional[Dict] = None):
"""Restore random states."""
states = states or self.original_states
random.setstate(states['python'])
np.random.set_state(states['numpy'])
torch.set_rng_state(states['torch'])
if torch.cuda.is_available() and 'cuda' in states:
torch.cuda.set_rng_state_all(states['cuda'])
def worker_init_fn(self, worker_id: int):
"""
For DataLoader workers - each needs unique but reproducible seed.
Usage:
DataLoader(..., worker_init_fn=seed_manager.worker_init_fn)
"""
worker_seed = self.seed + worker_id
np.random.seed(worker_seed)
random.seed(worker_seed)
# Usage
seed_manager = SeedManager(42)
seed_manager.set_all_seeds()
Deterministic Operations
Copy
def enable_deterministic_mode():
"""
Enable fully deterministic operations in PyTorch.
Warning: May reduce performance significantly!
"""
# PyTorch 1.8+
if hasattr(torch, 'use_deterministic_algorithms'):
torch.use_deterministic_algorithms(True)
# Some operations don't have deterministic implementations
# Set environment variable to allow them (with warning)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
# Disable tf32 for full precision
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
class DeterministicModule(nn.Module):
"""
Base class that enforces deterministic operations.
"""
def __init__(self):
super().__init__()
# Warn about non-deterministic layers
self._warned_layers = set()
def forward(self, x):
raise NotImplementedError
def check_determinism(self, x: torch.Tensor, n_runs: int = 5) -> bool:
"""Test if forward pass is deterministic."""
outputs = []
for _ in range(n_runs):
# Reset seeds
torch.manual_seed(0)
with torch.no_grad():
out = self.forward(x.clone())
outputs.append(out.cpu())
# Check if all outputs are identical
reference = outputs[0]
for out in outputs[1:]:
if not torch.allclose(reference, out, rtol=0, atol=0):
return False
return True
Experiment Configuration
Configuration Management
Copy
@dataclass
class ExperimentConfig:
"""
Complete experiment configuration.
All hyperparameters and settings in one place.
"""
# Identifiers
experiment_name: str = "experiment"
run_id: str = field(default_factory=lambda: datetime.now().strftime("%Y%m%d_%H%M%S"))
# Seeds
seed: int = 42
# Data
data_path: str = "./data"
train_split: float = 0.8
val_split: float = 0.1
# Model
model_type: str = "resnet18"
hidden_size: int = 256
num_layers: int = 3
dropout: float = 0.1
# Training
batch_size: int = 32
learning_rate: float = 1e-3
weight_decay: float = 1e-4
epochs: int = 100
# Optimizer
optimizer: str = "adamw"
momentum: float = 0.9
# Scheduler
scheduler: str = "cosine"
warmup_epochs: int = 5
# Regularization
label_smoothing: float = 0.0
mixup_alpha: float = 0.0
# Hardware
device: str = "cuda"
num_workers: int = 4
mixed_precision: bool = True
# Logging
log_dir: str = "./logs"
checkpoint_dir: str = "./checkpoints"
log_interval: int = 100
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
def save(self, path: str):
"""Save config to JSON."""
with open(path, 'w') as f:
json.dump(self.to_dict(), f, indent=2)
@classmethod
def load(cls, path: str) -> 'ExperimentConfig':
"""Load config from JSON."""
with open(path, 'r') as f:
data = json.load(f)
return cls(**data)
def get_hash(self) -> str:
"""Get unique hash of configuration."""
config_str = json.dumps(self.to_dict(), sort_keys=True)
return hashlib.sha256(config_str.encode()).hexdigest()[:8]
def diff(self, other: 'ExperimentConfig') -> Dict[str, tuple]:
"""Show differences between two configs."""
self_dict = self.to_dict()
other_dict = other.to_dict()
diffs = {}
all_keys = set(self_dict.keys()) | set(other_dict.keys())
for key in all_keys:
v1 = self_dict.get(key)
v2 = other_dict.get(key)
if v1 != v2:
diffs[key] = (v1, v2)
return diffs
# CLI configuration loading
class ConfigLoader:
"""Load configs from multiple sources with overrides."""
@staticmethod
def from_cli(default_config: ExperimentConfig) -> ExperimentConfig:
"""
Load config with CLI overrides.
Usage:
python train.py --learning_rate 0.001 --batch_size 64
"""
import argparse
parser = argparse.ArgumentParser()
# Add all config fields as arguments
for key, value in asdict(default_config).items():
arg_type = type(value) if value is not None else str
parser.add_argument(f'--{key}', type=arg_type, default=value)
args = parser.parse_args()
return ExperimentConfig(**vars(args))
YAML Configuration
Copy
class YAMLConfig:
"""
YAML-based configuration with inheritance.
Example YAML:
# base.yaml
model:
type: resnet18
hidden_size: 256
training:
batch_size: 32
learning_rate: 0.001
# experiment.yaml
defaults:
- base
training:
"""
@staticmethod
def load(path: str) -> Dict:
"""Load YAML with inheritance."""
try:
import yaml
with open(path, 'r') as f:
config = yaml.safe_load(f)
# Handle defaults/inheritance
if 'defaults' in config:
base_config = {}
for default in config['defaults']:
base_path = Path(path).parent / f"{default}.yaml"
base_config = YAMLConfig._deep_merge(
base_config,
YAMLConfig.load(str(base_path))
)
# Remove defaults key and merge
del config['defaults']
config = YAMLConfig._deep_merge(base_config, config)
return config
except ImportError:
print("Install PyYAML: pip install pyyaml")
return {}
@staticmethod
def _deep_merge(base: Dict, override: Dict) -> Dict:
"""Deep merge two dictionaries."""
result = base.copy()
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = YAMLConfig._deep_merge(result[key], value)
else:
result[key] = value
return result
Environment Tracking
Environment Snapshot
Copy
class EnvironmentSnapshot:
"""
Capture complete environment information for reproducibility.
"""
@staticmethod
def capture() -> Dict[str, Any]:
"""Capture current environment state."""
snapshot = {
'timestamp': datetime.now().isoformat(),
'python': {
'version': sys.version,
'platform': sys.platform,
},
'hardware': EnvironmentSnapshot._get_hardware_info(),
'packages': EnvironmentSnapshot._get_packages(),
'git': EnvironmentSnapshot._get_git_info(),
'cuda': EnvironmentSnapshot._get_cuda_info(),
}
return snapshot
@staticmethod
def _get_hardware_info() -> Dict:
"""Get hardware information."""
import platform
info = {
'cpu': platform.processor(),
'machine': platform.machine(),
'node': platform.node(),
}
# GPU info
if torch.cuda.is_available():
info['gpus'] = []
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
info['gpus'].append({
'name': props.name,
'memory_gb': props.total_memory / 1e9,
'capability': f"{props.major}.{props.minor}"
})
return info
@staticmethod
def _get_packages() -> Dict[str, str]:
"""Get installed package versions."""
try:
import pkg_resources
return {
pkg.key: pkg.version
for pkg in pkg_resources.working_set
}
except Exception:
return {}
@staticmethod
def _get_git_info() -> Dict:
"""Get git repository information."""
try:
import subprocess
return {
'commit': subprocess.check_output(
['git', 'rev-parse', 'HEAD']
).decode().strip(),
'branch': subprocess.check_output(
['git', 'rev-parse', '--abbrev-ref', 'HEAD']
).decode().strip(),
'dirty': bool(subprocess.check_output(
['git', 'status', '--porcelain']
).decode().strip()),
}
except Exception:
return {'commit': 'unknown', 'branch': 'unknown', 'dirty': False}
@staticmethod
def _get_cuda_info() -> Dict:
"""Get CUDA information."""
if not torch.cuda.is_available():
return {'available': False}
return {
'available': True,
'version': torch.version.cuda,
'cudnn_version': torch.backends.cudnn.version(),
'device_count': torch.cuda.device_count(),
}
@staticmethod
def save(path: str):
"""Save snapshot to file."""
snapshot = EnvironmentSnapshot.capture()
with open(path, 'w') as f:
json.dump(snapshot, f, indent=2)
# Requirements generation
def generate_requirements() -> str:
"""Generate requirements.txt content."""
try:
import pkg_resources
lines = []
for pkg in pkg_resources.working_set:
lines.append(f"{pkg.key}=={pkg.version}")
return '\n'.join(sorted(lines))
except Exception:
return ""
Experiment Tracking
Simple Logger
Copy
class ExperimentLogger:
"""
Simple experiment logger for local tracking.
"""
def __init__(
self,
experiment_name: str,
log_dir: str = "./experiments",
config: Optional[ExperimentConfig] = None
):
self.experiment_name = experiment_name
self.run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create experiment directory
self.exp_dir = Path(log_dir) / experiment_name / self.run_id
self.exp_dir.mkdir(parents=True, exist_ok=True)
# Initialize logs
self.metrics = {}
self.step = 0
# Save initial info
if config:
config.save(self.exp_dir / "config.json")
EnvironmentSnapshot.save(str(self.exp_dir / "environment.json"))
print(f"Experiment logged to: {self.exp_dir}")
def log_metric(self, name: str, value: float, step: Optional[int] = None):
"""Log a metric value."""
step = step or self.step
if name not in self.metrics:
self.metrics[name] = []
self.metrics[name].append({
'step': step,
'value': value,
'timestamp': datetime.now().isoformat()
})
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""Log multiple metrics."""
for name, value in metrics.items():
self.log_metric(name, value, step)
def increment_step(self):
"""Increment global step."""
self.step += 1
def save_checkpoint(
self,
model: nn.Module,
optimizer: optim.Optimizer,
epoch: int,
metrics: Dict[str, float]
):
"""Save model checkpoint."""
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'metrics': metrics,
'step': self.step,
}
path = self.exp_dir / f"checkpoint_epoch_{epoch}.pt"
torch.save(checkpoint, path)
# Also save as latest
torch.save(checkpoint, self.exp_dir / "checkpoint_latest.pt")
def load_checkpoint(self, path: str) -> Dict:
"""Load a checkpoint."""
return torch.load(path)
def save_logs(self):
"""Save all metrics to disk."""
with open(self.exp_dir / "metrics.json", 'w') as f:
json.dump(self.metrics, f, indent=2)
def finish(self):
"""Finalize experiment."""
self.save_logs()
# Create summary
summary = {
'experiment_name': self.experiment_name,
'run_id': self.run_id,
'total_steps': self.step,
'final_metrics': {
name: values[-1]['value'] if values else None
for name, values in self.metrics.items()
}
}
with open(self.exp_dir / "summary.json", 'w') as f:
json.dump(summary, f, indent=2)
print(f"Experiment finished. Logs saved to {self.exp_dir}")
Weights & Biases Integration
Copy
class WandbLogger:
"""
Weights & Biases logger wrapper.
Features:
- Automatic metric tracking
- Hyperparameter logging
- Artifact versioning
- Team collaboration
"""
def __init__(
self,
project: str,
experiment_name: str,
config: Optional[Dict] = None,
tags: Optional[list] = None
):
self.project = project
self.experiment_name = experiment_name
try:
import wandb
self.wandb = wandb
# Initialize run
self.run = wandb.init(
project=project,
name=experiment_name,
config=config,
tags=tags,
reinit=True
)
self.enabled = True
except ImportError:
print("Install wandb: pip install wandb")
self.enabled = False
def log(self, metrics: Dict[str, float], step: Optional[int] = None):
"""Log metrics."""
if self.enabled:
self.wandb.log(metrics, step=step)
def log_artifact(self, path: str, name: str, artifact_type: str = "model"):
"""Log an artifact (model, data, etc.)."""
if self.enabled:
artifact = self.wandb.Artifact(name, type=artifact_type)
if os.path.isfile(path):
artifact.add_file(path)
else:
artifact.add_dir(path)
self.run.log_artifact(artifact)
def watch_model(self, model: nn.Module, log_freq: int = 100):
"""Watch model gradients and parameters."""
if self.enabled:
self.wandb.watch(model, log="all", log_freq=log_freq)
def finish(self):
"""Finish the run."""
if self.enabled:
self.wandb.finish()
class MLflowLogger:
"""
MLflow logger wrapper.
Features:
- Experiment tracking
- Model registry
- Parameter logging
- Artifact storage
"""
def __init__(
self,
experiment_name: str,
tracking_uri: Optional[str] = None
):
self.experiment_name = experiment_name
try:
import mlflow
self.mlflow = mlflow
if tracking_uri:
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(experiment_name)
self.run = mlflow.start_run()
self.enabled = True
except ImportError:
print("Install mlflow: pip install mlflow")
self.enabled = False
def log_params(self, params: Dict[str, Any]):
"""Log parameters."""
if self.enabled:
self.mlflow.log_params(params)
def log_metric(self, name: str, value: float, step: Optional[int] = None):
"""Log a metric."""
if self.enabled:
self.mlflow.log_metric(name, value, step=step)
def log_model(self, model: nn.Module, artifact_path: str = "model"):
"""Log PyTorch model."""
if self.enabled:
self.mlflow.pytorch.log_model(model, artifact_path)
def finish(self):
"""End the run."""
if self.enabled:
self.mlflow.end_run()
Data Versioning
Copy
class DataVersioning:
"""
Track data versions for reproducibility.
"""
@staticmethod
def compute_hash(path: str, algorithm: str = 'md5') -> str:
"""
Compute hash of file or directory.
Use this to detect data changes.
"""
hash_fn = hashlib.new(algorithm)
path = Path(path)
if path.is_file():
with open(path, 'rb') as f:
for chunk in iter(lambda: f.read(8192), b''):
hash_fn.update(chunk)
else:
# Directory: hash all files sorted by name
for file_path in sorted(path.rglob('*')):
if file_path.is_file():
hash_fn.update(str(file_path.relative_to(path)).encode())
with open(file_path, 'rb') as f:
for chunk in iter(lambda: f.read(8192), b''):
hash_fn.update(chunk)
return hash_fn.hexdigest()
@staticmethod
def create_manifest(
data_path: str,
output_path: str = "data_manifest.json"
):
"""
Create manifest of all data files with hashes.
"""
manifest = {
'created': datetime.now().isoformat(),
'root': data_path,
'files': []
}
for file_path in sorted(Path(data_path).rglob('*')):
if file_path.is_file():
manifest['files'].append({
'path': str(file_path.relative_to(data_path)),
'size': file_path.stat().st_size,
'hash': DataVersioning.compute_hash(str(file_path))
})
manifest['total_hash'] = DataVersioning.compute_hash(data_path)
with open(output_path, 'w') as f:
json.dump(manifest, f, indent=2)
return manifest
@staticmethod
def verify_manifest(manifest_path: str) -> bool:
"""
Verify data matches manifest.
"""
with open(manifest_path, 'r') as f:
manifest = json.load(f)
data_path = manifest['root']
for file_info in manifest['files']:
file_path = Path(data_path) / file_info['path']
if not file_path.exists():
print(f"Missing: {file_info['path']}")
return False
if DataVersioning.compute_hash(str(file_path)) != file_info['hash']:
print(f"Changed: {file_info['path']}")
return False
return True
Reproducible Data Loading
Copy
class ReproducibleDataLoader:
"""
Data loader with reproducible ordering.
"""
def __init__(
self,
dataset,
batch_size: int = 32,
shuffle: bool = True,
seed: int = 42,
num_workers: int = 4
):
self.seed = seed
# Create generator for reproducible shuffling
self.generator = torch.Generator()
self.generator.manual_seed(seed)
# Worker init function for reproducible augmentations
def worker_init_fn(worker_id):
worker_seed = seed + worker_id
np.random.seed(worker_seed)
random.seed(worker_seed)
self.loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
worker_init_fn=worker_init_fn,
generator=self.generator if shuffle else None,
# For full reproducibility:
persistent_workers=False,
drop_last=True
)
def __iter__(self):
return iter(self.loader)
def __len__(self):
return len(self.loader)
def reset(self):
"""Reset generator for reproducible iteration."""
self.generator.manual_seed(self.seed)
# Reproducible train/val/test splits
def create_reproducible_splits(
dataset,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
seed: int = 42
) -> tuple:
"""
Create reproducible dataset splits.
"""
n = len(dataset)
indices = list(range(n))
# Reproducible shuffle
rng = np.random.RandomState(seed)
rng.shuffle(indices)
train_end = int(n * train_ratio)
val_end = int(n * (train_ratio + val_ratio))
train_indices = indices[:train_end]
val_indices = indices[train_end:val_end]
test_indices = indices[val_end:]
train_set = torch.utils.data.Subset(dataset, train_indices)
val_set = torch.utils.data.Subset(dataset, val_indices)
test_set = torch.utils.data.Subset(dataset, test_indices)
return train_set, val_set, test_set
Complete Reproducibility Template
Copy
class ReproducibleExperiment:
"""
Complete template for reproducible experiments.
"""
def __init__(self, config: ExperimentConfig):
self.config = config
# Set up reproducibility
self.seed_manager = SeedManager(config.seed)
self.seed_manager.set_all_seeds()
# Enable deterministic mode if needed
if hasattr(config, 'deterministic') and config.deterministic:
enable_deterministic_mode()
# Set up logging
self.logger = ExperimentLogger(
experiment_name=config.experiment_name,
log_dir=config.log_dir,
config=config
)
# Save code snapshot
self._save_code_snapshot()
def _save_code_snapshot(self):
"""Save current code for reproducibility."""
code_dir = self.logger.exp_dir / "code"
code_dir.mkdir(exist_ok=True)
# Copy relevant Python files
for py_file in Path('.').glob('**/*.py'):
if 'experiments' not in str(py_file):
dest = code_dir / py_file
dest.parent.mkdir(parents=True, exist_ok=True)
with open(py_file, 'r') as f:
content = f.read()
with open(dest, 'w') as f:
f.write(content)
def run(self, train_fn: Callable):
"""
Run experiment with full tracking.
Args:
train_fn: Function that takes (config, logger) and trains model
"""
try:
result = train_fn(self.config, self.logger)
self.logger.log_metrics({
'final_' + k: v for k, v in result.items()
})
except Exception as e:
# Log error
with open(self.logger.exp_dir / "error.txt", 'w') as f:
import traceback
f.write(traceback.format_exc())
raise
finally:
self.logger.finish()
return result
@classmethod
def reproduce(cls, exp_dir: str) -> 'ReproducibleExperiment':
"""
Reproduce an experiment from saved logs.
"""
exp_dir = Path(exp_dir)
# Load config
config = ExperimentConfig.load(exp_dir / "config.json")
# Verify environment
with open(exp_dir / "environment.json", 'r') as f:
saved_env = json.load(f)
current_env = EnvironmentSnapshot.capture()
# Warn about differences
if saved_env['cuda']['version'] != current_env['cuda']['version']:
print(f"Warning: CUDA version mismatch "
f"({saved_env['cuda']['version']} vs {current_env['cuda']['version']})")
return cls(config)
# Usage example
def main():
"""Example reproducible training script."""
# Define configuration
config = ExperimentConfig(
experiment_name="resnet_cifar",
seed=42,
learning_rate=0.001,
batch_size=128,
epochs=100
)
# Create experiment
experiment = ReproducibleExperiment(config)
# Define training function
def train(config, logger):
# Your training code here
# Use config for all hyperparameters
# Use logger.log_metrics() for tracking
return {'accuracy': 0.95} # Example result
# Run
result = experiment.run(train)
print(f"Final result: {result}")
# Checklist for reproducibility
def reproducibility_checklist():
"""Print reproducibility checklist."""
checklist = """
╔════════════════════════════════════════════════════════════════╗
║ REPRODUCIBILITY CHECKLIST ║
╠════════════════════════════════════════════════════════════════╣
║ ║
║ SEEDS & RANDOMNESS ║
║ □ Set all seeds (Python, NumPy, PyTorch, CUDA) ║
║ □ Use worker_init_fn for DataLoader workers ║
║ □ Enable deterministic mode if needed ║
║ □ Document any non-deterministic operations ║
║ ║
║ CONFIGURATION ║
║ □ All hyperparameters in config file ║
║ □ No hardcoded values in code ║
║ □ Version control config files ║
║ ║
║ ENVIRONMENT ║
║ □ Pin package versions (requirements.txt) ║
║ □ Document CUDA/cuDNN versions ║
║ □ Use Docker/container if possible ║
║ ║
║ DATA ║
║ □ Version your data (hash, DVC, etc.) ║
║ □ Document preprocessing steps ║
║ □ Fixed train/val/test splits ║
║ ║
║ CODE ║
║ □ Git commit hash for each experiment ║
║ □ No uncommitted changes during experiments ║
║ □ Save code snapshot with results ║
║ ║
║ LOGGING ║
║ □ Log all metrics, not just final ║
║ □ Save checkpoints periodically ║
║ □ Track wall-clock time ║
║ ║
╚════════════════════════════════════════════════════════════════╝
"""
print(checklist)
reproducibility_checklist()
Exercises
Exercise 1: Cross-Platform Reproducibility
Exercise 1: Cross-Platform Reproducibility
Test your experiment reproduces across:
- Different GPUs (NVIDIA vs AMD)
- Different OS (Linux vs Windows)
- Different PyTorch versions
Exercise 2: Build Experiment Dashboard
Exercise 2: Build Experiment Dashboard
Create a simple dashboard to:
- Compare multiple runs
- Visualize metrics over time
- Filter by hyperparameters
Exercise 3: Docker for ML
Exercise 3: Docker for ML
Create a Dockerfile that:
- Pins all dependencies
- Includes data download script
- Runs experiments reproducibly