Generative Adversarial Networks
The Counterfeiter vs Detective Game
Imagine a world with two players locked in an eternal battle:- The Counterfeiter (Generator): Creates fake money, trying to make it indistinguishable from real currency
- The Detective (Discriminator): Examines bills and tries to identify which are real and which are fake
Copy
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
GAN Architecture Overview
| Component | Role | Input | Output |
|---|---|---|---|
| Generator (G) | Creates fake data | Random noise z | Fake sample G(z) |
| Discriminator (D) | Classifies real/fake | Sample x | Probability D(x) |
Copy
class SimpleGenerator(nn.Module):
"""
Generator: Maps random noise to data space.
Think of it as learning to "draw" realistic images from scratch.
"""
def __init__(self, latent_dim=100, output_dim=784):
super().__init__()
self.latent_dim = latent_dim
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
nn.Linear(1024, output_dim),
nn.Tanh() # Output in range [-1, 1]
)
def forward(self, z):
return self.model(z)
class SimpleDiscriminator(nn.Module):
"""
Discriminator: Binary classifier - real or fake?
Outputs probability that input is real.
"""
def __init__(self, input_dim=784):
super().__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid() # Output probability [0, 1]
)
def forward(self, x):
return self.model(x)
# Initialize networks
latent_dim = 100
generator = SimpleGenerator(latent_dim=latent_dim).to(device)
discriminator = SimpleDiscriminator().to(device)
print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
Copy
Generator parameters: 1,076,244
Discriminator parameters: 407,297
The Minimax Loss Function
The GAN training objective is a minimax game: GminDmaxV(D,G)=Ex∼pdata[logD(x)]+Ez∼pz[log(1−D(G(z)))] Let’s break this down:| Term | Meaning | What It Does |
|---|---|---|
| Ex∼pdata[logD(x)] | Expected log probability for real data | D wants this HIGH (correctly identify real) |
| Ez∼pz[log(1−D(G(z)))] | Expected log probability for fake data | D wants this HIGH, G wants this LOW |
- Correctly classify real images → D(x)≈1
- Correctly classify fake images → D(G(z))≈0
- Fool the discriminator → D(G(z))≈1
Copy
def discriminator_loss(real_output, fake_output):
"""
Discriminator wants to:
- Output 1 for real images (so log(D(x)) is high)
- Output 0 for fake images (so log(1 - D(G(z))) is high)
"""
real_loss = nn.BCELoss()(real_output, torch.ones_like(real_output))
fake_loss = nn.BCELoss()(fake_output, torch.zeros_like(fake_output))
return real_loss + fake_loss
def generator_loss(fake_output):
"""
Generator wants discriminator to output 1 for fake images.
We use the "non-saturating" loss for better gradients.
"""
return nn.BCELoss()(fake_output, torch.ones_like(fake_output))
# Alternative: Using BCEWithLogitsLoss for numerical stability
class GANLoss:
def __init__(self):
self.criterion = nn.BCEWithLogitsLoss()
def discriminator_loss(self, real_logits, fake_logits):
real_labels = torch.ones_like(real_logits)
fake_labels = torch.zeros_like(fake_logits)
real_loss = self.criterion(real_logits, real_labels)
fake_loss = self.criterion(fake_logits, fake_labels)
return real_loss + fake_loss
def generator_loss(self, fake_logits):
real_labels = torch.ones_like(fake_logits)
return self.criterion(fake_logits, real_labels)
Complete GAN Training Loop
Let’s train a GAN on MNIST digits:Copy
# Data preparation
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
])
train_dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)
batch_size = 128
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Hyperparameters
lr = 0.0002
beta1 = 0.5
num_epochs = 50
# Optimizers - Adam with specific betas for GANs
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
# Loss function
criterion = nn.BCELoss()
def train_gan(generator, discriminator, dataloader, num_epochs):
"""
Complete GAN training loop with monitoring.
"""
g_losses = []
d_losses = []
# Fixed noise for visualization
fixed_noise = torch.randn(64, latent_dim, device=device)
for epoch in range(num_epochs):
g_loss_epoch = 0
d_loss_epoch = 0
for batch_idx, (real_images, _) in enumerate(dataloader):
batch_size = real_images.size(0)
real_images = real_images.view(batch_size, -1).to(device)
# Labels
real_labels = torch.ones(batch_size, 1, device=device)
fake_labels = torch.zeros(batch_size, 1, device=device)
# ==================
# Train Discriminator
# ==================
d_optimizer.zero_grad()
# Real images
real_output = discriminator(real_images)
d_loss_real = criterion(real_output, real_labels)
# Fake images
noise = torch.randn(batch_size, latent_dim, device=device)
fake_images = generator(noise)
fake_output = discriminator(fake_images.detach()) # Detach to not update G
d_loss_fake = criterion(fake_output, fake_labels)
# Total discriminator loss
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()
# ===============
# Train Generator
# ===============
g_optimizer.zero_grad()
# Generate new fake images
noise = torch.randn(batch_size, latent_dim, device=device)
fake_images = generator(noise)
fake_output = discriminator(fake_images)
# Generator wants discriminator to think fakes are real
g_loss = criterion(fake_output, real_labels)
g_loss.backward()
g_optimizer.step()
g_loss_epoch += g_loss.item()
d_loss_epoch += d_loss.item()
# Average losses
g_loss_epoch /= len(dataloader)
d_loss_epoch /= len(dataloader)
g_losses.append(g_loss_epoch)
d_losses.append(d_loss_epoch)
# Print progress
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] | "
f"D Loss: {d_loss_epoch:.4f} | G Loss: {g_loss_epoch:.4f}")
# Generate sample images
with torch.no_grad():
sample_images = generator(fixed_noise).view(-1, 1, 28, 28)
visualize_samples(sample_images[:16], epoch + 1)
return g_losses, d_losses
def visualize_samples(images, epoch):
"""Display generated images in a grid."""
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
img = images[i].cpu().squeeze().numpy()
img = (img + 1) / 2 # Denormalize from [-1, 1] to [0, 1]
ax.imshow(img, cmap='gray')
ax.axis('off')
plt.suptitle(f'Generated Samples - Epoch {epoch}')
plt.tight_layout()
plt.savefig(f'gan_samples_epoch_{epoch}.png')
plt.close()
# Train the GAN
print("Starting GAN training...")
g_losses, d_losses = train_gan(generator, discriminator, dataloader, num_epochs)
Copy
Starting GAN training...
Epoch [10/50] | D Loss: 0.8234 | G Loss: 1.4521
Epoch [20/50] | D Loss: 0.6912 | G Loss: 1.2847
Epoch [30/50] | D Loss: 0.5823 | G Loss: 1.5234
Epoch [40/50] | D Loss: 0.5412 | G Loss: 1.6891
Epoch [50/50] | D Loss: 0.5234 | G Loss: 1.7234
Mode Collapse: The GAN’s Achilles Heel
Why Does Mode Collapse Happen?
| Cause | Explanation |
|---|---|
| Easy shortcut | Generator finds one mode that consistently fools D |
| Discriminator forgetting | D gets fooled by one mode, G exploits it |
| Training imbalance | G or D becomes too strong too quickly |
Copy
def detect_mode_collapse(generator, num_samples=1000, threshold=0.1):
"""
Detect mode collapse by checking diversity of generated samples.
"""
generator.eval()
with torch.no_grad():
noise = torch.randn(num_samples, latent_dim, device=device)
samples = generator(noise).cpu().numpy()
# Calculate pairwise distances
from scipy.spatial.distance import pdist
distances = pdist(samples, metric='euclidean')
mean_distance = np.mean(distances)
std_distance = np.std(distances)
# Low mean distance indicates mode collapse
diversity_score = mean_distance / (std_distance + 1e-8)
print(f"Mean pairwise distance: {mean_distance:.4f}")
print(f"Std pairwise distance: {std_distance:.4f}")
print(f"Diversity score: {diversity_score:.4f}")
if mean_distance < threshold:
print("⚠️ Warning: Possible mode collapse detected!")
return True
return False
# Solutions to mode collapse
class AntiModeCollapseTraining:
"""
Techniques to prevent mode collapse.
"""
@staticmethod
def minibatch_discrimination(features, num_kernels=5, kernel_dim=3):
"""
Compute features based on relationships with other samples.
Helps discriminator catch mode collapse.
"""
batch_size = features.size(0)
T = nn.Parameter(torch.randn(features.size(1), num_kernels * kernel_dim))
M = features @ T
M = M.view(batch_size, num_kernels, kernel_dim)
# Compute L1 distance between all pairs
M_expanded = M.unsqueeze(0) # [1, B, K, D]
M_T = M.unsqueeze(1) # [B, 1, K, D]
L1_distance = torch.abs(M_expanded - M_T).sum(dim=3) # [B, B, K]
# Compute similarity
similarity = torch.exp(-L1_distance)
# Exclude self-similarity
mask = 1 - torch.eye(batch_size, device=features.device)
similarity = similarity * mask.unsqueeze(2)
o = similarity.sum(dim=1) # [B, K]
return o
@staticmethod
def feature_matching_loss(real_features, fake_features):
"""
Match statistics of discriminator features instead of output.
"""
return torch.mean((real_features.mean(0) - fake_features.mean(0)) ** 2)
@staticmethod
def historical_averaging(current_params, historical_params, beta=0.99):
"""
Penalize large deviations from historical parameter averages.
"""
loss = 0
for curr, hist in zip(current_params, historical_params):
loss += torch.mean((curr - hist) ** 2)
return loss
DCGAN: Deep Convolutional GAN
| Guideline | Generator | Discriminator |
|---|---|---|
| Pooling | Use transposed conv | Use strided conv |
| Normalization | BatchNorm (except output) | BatchNorm (except input) |
| Activation | ReLU (except output: Tanh) | LeakyReLU |
| Architecture | Fully convolutional | Fully convolutional |
Copy
class DCGANGenerator(nn.Module):
"""
DCGAN Generator using transposed convolutions.
Maps noise vector to 64x64 RGB image.
"""
def __init__(self, latent_dim=100, feature_maps=64, channels=3):
super().__init__()
self.latent_dim = latent_dim
self.model = nn.Sequential(
# Input: latent_dim x 1 x 1
nn.ConvTranspose2d(latent_dim, feature_maps * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(feature_maps * 8),
nn.ReLU(True),
# State: (feature_maps*8) x 4 x 4
nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 4),
nn.ReLU(True),
# State: (feature_maps*4) x 8 x 8
nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 2),
nn.ReLU(True),
# State: (feature_maps*2) x 16 x 16
nn.ConvTranspose2d(feature_maps * 2, feature_maps, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps),
nn.ReLU(True),
# State: feature_maps x 32 x 32
nn.ConvTranspose2d(feature_maps, channels, 4, 2, 1, bias=False),
nn.Tanh()
# Output: channels x 64 x 64
)
self._init_weights()
def _init_weights(self):
"""Initialize weights as per DCGAN paper."""
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.normal_(m.weight, 0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
def forward(self, z):
# Reshape to [batch, latent_dim, 1, 1]
z = z.view(-1, self.latent_dim, 1, 1)
return self.model(z)
class DCGANDiscriminator(nn.Module):
"""
DCGAN Discriminator using strided convolutions.
Classifies 64x64 RGB images as real or fake.
"""
def __init__(self, feature_maps=64, channels=3):
super().__init__()
self.model = nn.Sequential(
# Input: channels x 64 x 64
nn.Conv2d(channels, feature_maps, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# State: feature_maps x 32 x 32
nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 2),
nn.LeakyReLU(0.2, inplace=True),
# State: (feature_maps*2) x 16 x 16
nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 4),
nn.LeakyReLU(0.2, inplace=True),
# State: (feature_maps*4) x 8 x 8
nn.Conv2d(feature_maps * 4, feature_maps * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 8),
nn.LeakyReLU(0.2, inplace=True),
# State: (feature_maps*8) x 4 x 4
nn.Conv2d(feature_maps * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
# Output: 1 x 1 x 1
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.normal_(m.weight, 0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
def forward(self, x):
return self.model(x).view(-1, 1)
# Instantiate DCGAN
dc_generator = DCGANGenerator(latent_dim=100).to(device)
dc_discriminator = DCGANDiscriminator().to(device)
# Test forward pass
test_noise = torch.randn(4, 100, device=device)
test_images = dc_generator(test_noise)
print(f"Generator output shape: {test_images.shape}") # [4, 3, 64, 64]
test_pred = dc_discriminator(test_images)
print(f"Discriminator output shape: {test_pred.shape}") # [4, 1]
Copy
Generator output shape: torch.Size([4, 3, 64, 64])
Discriminator output shape: torch.Size([4, 1])
Wasserstein GAN (WGAN)
Why Wasserstein Distance?
W(Pr,Pg)=γ∈Π(Pr,Pg)infE(x,y)∼γ[∥x−y∥]| Advantage | Explanation |
|---|---|
| Continuous gradients | Provides gradients even when distributions don’t overlap |
| Meaningful loss | Loss correlates with image quality |
| Stable training | No need for careful balance between G and D |
Copy
class WGANCritic(nn.Module):
"""
WGAN uses a 'Critic' instead of 'Discriminator'.
No sigmoid - outputs unbounded real numbers.
"""
def __init__(self, feature_maps=64, channels=3):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(channels, feature_maps, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(feature_maps * 4, feature_maps * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(feature_maps * 8, 1, 4, 1, 0, bias=False),
# No sigmoid! Unbounded output
)
def forward(self, x):
return self.model(x).view(-1)
def wgan_critic_loss(real_output, fake_output):
"""
WGAN Critic Loss: Maximize E[C(real)] - E[C(fake)]
Equivalent to minimizing: E[C(fake)] - E[C(real)]
"""
return fake_output.mean() - real_output.mean()
def wgan_generator_loss(fake_output):
"""
WGAN Generator Loss: Minimize -E[C(fake)]
Equivalent to maximizing E[C(fake)]
"""
return -fake_output.mean()
def gradient_penalty(critic, real_data, fake_data, device):
"""
WGAN-GP: Gradient penalty for 1-Lipschitz constraint.
Penalizes gradients that deviate from 1.
"""
batch_size = real_data.size(0)
# Random interpolation coefficient
alpha = torch.rand(batch_size, 1, 1, 1, device=device)
# Interpolate between real and fake
interpolated = alpha * real_data + (1 - alpha) * fake_data
interpolated.requires_grad_(True)
# Get critic score for interpolated
critic_interpolated = critic(interpolated)
# Compute gradients
gradients = torch.autograd.grad(
outputs=critic_interpolated,
inputs=interpolated,
grad_outputs=torch.ones_like(critic_interpolated),
create_graph=True,
retain_graph=True
)[0]
# Gradient norm
gradients = gradients.view(batch_size, -1)
gradient_norm = gradients.norm(2, dim=1)
# Penalty: (||grad|| - 1)^2
penalty = ((gradient_norm - 1) ** 2).mean()
return penalty
def train_wgan_gp(generator, critic, dataloader, num_epochs,
n_critic=5, lambda_gp=10):
"""
Training loop for WGAN with gradient penalty.
"""
g_optimizer = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.0, 0.9))
c_optimizer = optim.Adam(critic.parameters(), lr=1e-4, betas=(0.0, 0.9))
for epoch in range(num_epochs):
for batch_idx, (real_images, _) in enumerate(dataloader):
real_images = real_images.to(device)
batch_size = real_images.size(0)
# Train Critic (more often than generator)
for _ in range(n_critic):
c_optimizer.zero_grad()
# Generate fake images
noise = torch.randn(batch_size, latent_dim, device=device)
fake_images = generator(noise)
# Critic scores
real_validity = critic(real_images)
fake_validity = critic(fake_images.detach())
# Gradient penalty
gp = gradient_penalty(critic, real_images, fake_images, device)
# Critic loss
c_loss = wgan_critic_loss(real_validity, fake_validity) + lambda_gp * gp
c_loss.backward()
c_optimizer.step()
# Train Generator
g_optimizer.zero_grad()
noise = torch.randn(batch_size, latent_dim, device=device)
fake_images = generator(noise)
fake_validity = critic(fake_images)
g_loss = wgan_generator_loss(fake_validity)
g_loss.backward()
g_optimizer.step()
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] | "
f"C Loss: {c_loss.item():.4f} | G Loss: {g_loss.item():.4f}")
Conditional GANs (cGAN)
Copy
class ConditionalGenerator(nn.Module):
"""
Generator conditioned on class labels.
Generate specific digit: "Generate a 7"
"""
def __init__(self, latent_dim=100, num_classes=10, embed_dim=50, img_size=28):
super().__init__()
self.latent_dim = latent_dim
self.img_size = img_size
# Embedding for class labels
self.label_embedding = nn.Embedding(num_classes, embed_dim)
# Input: noise + embedded label
self.model = nn.Sequential(
nn.Linear(latent_dim + embed_dim, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
nn.Linear(1024, img_size * img_size),
nn.Tanh()
)
def forward(self, z, labels):
# Embed labels
label_embedding = self.label_embedding(labels)
# Concatenate noise and label embedding
gen_input = torch.cat([z, label_embedding], dim=1)
return self.model(gen_input)
class ConditionalDiscriminator(nn.Module):
"""
Discriminator that also receives class labels.
"""
def __init__(self, num_classes=10, embed_dim=50, img_size=28):
super().__init__()
self.img_size = img_size
# Embedding for class labels
self.label_embedding = nn.Embedding(num_classes, embed_dim)
# Input: flattened image + embedded label
self.model = nn.Sequential(
nn.Linear(img_size * img_size + embed_dim, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img, labels):
# Flatten image
img_flat = img.view(img.size(0), -1)
# Embed labels
label_embedding = self.label_embedding(labels)
# Concatenate image and label embedding
d_input = torch.cat([img_flat, label_embedding], dim=1)
return self.model(d_input)
# Example: Generate specific digits
def generate_digit(generator, digit, num_samples=16):
"""Generate samples of a specific digit."""
generator.eval()
with torch.no_grad():
noise = torch.randn(num_samples, latent_dim, device=device)
labels = torch.full((num_samples,), digit, dtype=torch.long, device=device)
generated = generator(noise, labels)
return generated.view(-1, 1, 28, 28)
# Initialize conditional GAN
cond_generator = ConditionalGenerator().to(device)
cond_discriminator = ConditionalDiscriminator().to(device)
# Generate all digits 0-9
print("Conditional GAN allows generating specific classes!")
for digit in range(10):
samples = generate_digit(cond_generator, digit, num_samples=8)
print(f"Generated {samples.shape[0]} samples of digit {digit}")
GAN Evaluation Metrics
Evaluating GANs is Hard! Unlike supervised learning, there’s no clear “correct answer” to compare against.
| Metric | Measures | Formula/Approach | |
|---|---|---|---|
| Inception Score (IS) | Quality + Diversity | $\exp(\mathbb[D_(p(y | x) | p(y))])$ |
| FID (Fréchet Inception Distance) | Feature similarity | Compare Gaussian fits in feature space | |
| LPIPS | Perceptual similarity | Distance in learned feature space |
Copy
def calculate_fid_simple(real_features, fake_features):
"""
Simplified FID calculation.
Lower FID = Better quality and diversity.
Full FID uses InceptionV3 features.
"""
# Calculate statistics
mu_real = np.mean(real_features, axis=0)
sigma_real = np.cov(real_features, rowvar=False)
mu_fake = np.mean(fake_features, axis=0)
sigma_fake = np.cov(fake_features, rowvar=False)
# Calculate FID
diff = mu_real - mu_fake
# Matrix square root
from scipy.linalg import sqrtm
covmean = sqrtm(sigma_real @ sigma_fake)
# Handle numerical issues
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)
return fid
# Practical evaluation
def evaluate_gan(generator, dataloader, num_samples=1000):
"""
Evaluate GAN using multiple metrics.
"""
generator.eval()
# Collect real and fake samples
real_samples = []
fake_samples = []
with torch.no_grad():
for real_images, _ in dataloader:
real_samples.append(real_images.view(real_images.size(0), -1))
if len(real_samples) * real_images.size(0) >= num_samples:
break
noise = torch.randn(num_samples, latent_dim, device=device)
fake_images = generator(noise)
fake_samples.append(fake_images.cpu())
real_samples = torch.cat(real_samples, dim=0)[:num_samples].numpy()
fake_samples = torch.cat(fake_samples, dim=0).numpy()
# Calculate metrics
fid = calculate_fid_simple(real_samples, fake_samples)
# Diversity (average pairwise distance)
from scipy.spatial.distance import pdist
diversity = np.mean(pdist(fake_samples[:100]))
print(f"FID Score: {fid:.2f} (lower is better)")
print(f"Diversity Score: {diversity:.4f} (higher is better)")
return {"fid": fid, "diversity": diversity}
Exercises
Exercise 1: Implement Label Smoothing
Exercise 1: Implement Label Smoothing
Label smoothing can stabilize GAN training. Instead of using hard labels (0 and 1), use soft labels (e.g., 0.1 and 0.9).
Copy
# TODO: Modify the training loop to use label smoothing
# Real labels: 0.9 instead of 1.0
# Fake labels: 0.1 instead of 0.0
def train_with_label_smoothing(generator, discriminator, dataloader, num_epochs):
# Your implementation here
pass
Exercise 2: Build a Progressive Training Schedule
Exercise 2: Build a Progressive Training Schedule
Implement a training schedule that gradually increases image resolution.
Copy
# TODO: Start with 4x4 images and progressively grow to 64x64
# Hint: Add layers during training
class ProgressiveGenerator(nn.Module):
def __init__(self):
# Your implementation here
pass
Exercise 3: Implement Spectral Normalization
Exercise 3: Implement Spectral Normalization
Spectral normalization constrains the Lipschitz constant of discriminator layers.
Copy
# TODO: Implement spectral normalization for discriminator weights
# Hint: Normalize weights by their largest singular value
class SpectralNormDiscriminator(nn.Module):
# Your implementation here
pass
Key Takeaways
What You Learned:
- ✅ GAN Fundamentals - Generator creates, Discriminator classifies in a minimax game
- ✅ Minimax Loss - Adversarial training objective and its components
- ✅ Mode Collapse - Common failure mode and solutions (minibatch discrimination, feature matching)
- ✅ DCGAN - Convolutional architecture guidelines for stable training
- ✅ WGAN - Wasserstein distance for improved training stability
- ✅ Conditional GANs - Control generation with class labels or other conditioning
- ✅ Evaluation - FID, Inception Score, and diversity metrics
Common Pitfalls
GAN Training Mistakes to Avoid:
- Imbalanced training - Don’t let D or G become too strong too quickly
- Ignoring mode collapse - Monitor sample diversity throughout training
- Wrong learning rates - GANs are sensitive; start with proven hyperparameters
- Batch size too small - BatchNorm needs sufficient batch size (≥16)
- Not using proper initialization - DCGAN weight init matters significantly
Next: Autoencoders & VAEs
Learn about variational autoencoders and latent space representations