Skip to main content

Documentation Index

Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt

Use this file to discover all available pages before exploring further.

Generative Adversarial Networks - Generator vs Discriminator

Generative Adversarial Networks

The Counterfeiter vs Detective Game

Imagine a world with two players locked in an eternal battle:
  • The Counterfeiter (Generator): Creates fake money, trying to make it indistinguishable from real currency
  • The Detective (Discriminator): Examines bills and tries to identify which are real and which are fake
At first, the counterfeiter is terrible — their fake bills look obviously fake. But with each rejection, they learn and improve. Meanwhile, the detective gets better at spotting fakes, forcing the counterfeiter to up their game. This is exactly how GANs work. The Generator never sees real data directly — it only receives gradient signals from the Discriminator telling it “you’re getting warmer” or “you’re getting colder.” Think of it like learning to paint while blindfolded, where your only feedback is a critic’s score. Over time, the Generator learns to produce outputs so realistic that the Discriminator can’t do better than flipping a coin.
The key mathematical insight: at the Nash equilibrium of this game, the Generator’s output distribution pgp_g exactly matches the real data distribution pdatap_{data}, and the Discriminator outputs D(x)=0.5D(x) = 0.5 for all inputs. In practice, we rarely reach this equilibrium perfectly, which is why GAN training is notoriously finicky.
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

GAN Architecture Overview

GAN Architecture - Generator and Discriminator
A GAN consists of two neural networks trained simultaneously:
ComponentRoleInputOutput
Generator (G)Creates fake dataRandom noise zzFake sample G(z)G(z)
Discriminator (D)Classifies real/fakeSample xxProbability D(x)D(x)
The networks compete in a minimax game - the generator tries to fool the discriminator, while the discriminator tries to catch the generator.
class SimpleGenerator(nn.Module):
    """
    Generator: Maps random noise to data space.
    Think of it as learning to "draw" realistic images from scratch.
    The noise vector z is like a set of "knobs" -- each dimension controls
    some learned aspect of the output (pose, lighting, style, etc.).
    """
    def __init__(self, latent_dim=100, output_dim=784):
        super().__init__()
        self.latent_dim = latent_dim
        
        self.model = nn.Sequential(
            # Progressively expand from low-dim noise to high-dim image space.
            # Each layer adds capacity for the generator to learn richer mappings.
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),       # LeakyReLU prevents "dead neuron" problem
            nn.BatchNorm1d(256),      # Stabilizes training by normalizing activations
            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            
            nn.Linear(1024, output_dim),
            nn.Tanh()  # Output in [-1, 1] to match normalized image range
        )
    
    def forward(self, z):
        return self.model(z)


class SimpleDiscriminator(nn.Module):
    """
    Discriminator: Binary classifier - real or fake?
    Outputs probability that input is real.
    Unlike the Generator, the Discriminator does NOT use BatchNorm --
    BN in D can cause training instability by correlating samples within a batch.
    """
    def __init__(self, input_dim=784):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2),       # LeakyReLU in D is critical -- ReLU causes dead neurons
            nn.Dropout(0.3),          # Dropout regularizes D to avoid overpowering G
            
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(256, 1),
            nn.Sigmoid()  # Output probability [0, 1]: 1 = "I think this is real"
        )
    
    def forward(self, x):
        return self.model(x)


# Initialize networks
latent_dim = 100
generator = SimpleGenerator(latent_dim=latent_dim).to(device)
discriminator = SimpleDiscriminator().to(device)

print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
Output:
Generator parameters: 1,076,244
Discriminator parameters: 407,297

The Minimax Loss Function

The GAN training objective is a minimax game: minGmaxDV(D,G)=Expdata[logD(x)]+Ezpz[log(1D(G(z)))]\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] Intuition behind the math: The Discriminator is playing a classification game — it wants to output 1 for real samples and 0 for fakes. The logarithm amplifies mistakes: log(D(x))\log(D(x)) punishes D heavily when it assigns low probability to a real sample (e.g., log(0.01)=4.6\log(0.01) = -4.6), but barely rewards it for being right (log(0.99)=0.01\log(0.99) = -0.01). This asymmetric penalty is what drives both networks to improve rapidly.
TermMeaningWhat It Does
Expdata[logD(x)]\mathbb{E}_{x \sim p_{data}}[\log D(x)]Expected log probability for real dataD wants this HIGH (correctly identify real)
Ezpz[log(1D(G(z)))]\mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]Expected log probability for fake dataD wants this HIGH, G wants this LOW
GAN Minimax Game Visualization
The Discriminator wants to maximize V(D,G)V(D,G):
  • Correctly classify real images → D(x)1D(x) \approx 1
  • Correctly classify fake images → D(G(z))0D(G(z)) \approx 0
The Generator wants to minimize V(D,G)V(D,G):
  • Fool the discriminator → D(G(z))1D(G(z)) \approx 1
def discriminator_loss(real_output, fake_output):
    """
    Discriminator wants to:
    - Output 1 for real images (so log(D(x)) is high)
    - Output 0 for fake images (so log(1 - D(G(z))) is high)
    """
    real_loss = nn.BCELoss()(real_output, torch.ones_like(real_output))
    fake_loss = nn.BCELoss()(fake_output, torch.zeros_like(fake_output))
    return real_loss + fake_loss


def generator_loss(fake_output):
    """
    Generator wants discriminator to output 1 for fake images.
    We use the "non-saturating" loss: instead of minimizing log(1-D(G(z))),
    we maximize log(D(G(z))). Why? Early in training D easily rejects G's
    outputs, so (1-D(G(z))) is close to 1 and log(1-D(G(z))) is near 0 --
    the gradient vanishes. Flipping the objective gives much stronger gradients.
    """
    return nn.BCELoss()(fake_output, torch.ones_like(fake_output))


# Alternative: Using BCEWithLogitsLoss for numerical stability
class GANLoss:
    def __init__(self):
        self.criterion = nn.BCEWithLogitsLoss()
    
    def discriminator_loss(self, real_logits, fake_logits):
        real_labels = torch.ones_like(real_logits)
        fake_labels = torch.zeros_like(fake_logits)
        
        real_loss = self.criterion(real_logits, real_labels)
        fake_loss = self.criterion(fake_logits, fake_labels)
        return real_loss + fake_loss
    
    def generator_loss(self, fake_logits):
        real_labels = torch.ones_like(fake_logits)
        return self.criterion(fake_logits, real_labels)

Complete GAN Training Loop

Let’s train a GAN on MNIST digits:
# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

train_dataset = datasets.MNIST(
    root='./data', 
    train=True, 
    download=True, 
    transform=transform
)

batch_size = 128
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Hyperparameters -- these are the DCGAN defaults that work surprisingly well
lr = 0.0002        # Lower than typical Adam lr (1e-3) -- GANs need gentle updates
beta1 = 0.5        # Lower momentum than default (0.9) -- prevents oscillation in adversarial dynamics
num_epochs = 50

# Optimizers - Adam with specific betas for GANs
# The (0.5, 0.999) betas were found empirically by the DCGAN authors and remain
# the go-to starting point. Higher beta1 causes training to oscillate.
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# Loss function
criterion = nn.BCELoss()


def train_gan(generator, discriminator, dataloader, num_epochs):
    """
    Complete GAN training loop with monitoring.
    """
    g_losses = []
    d_losses = []
    
    # Fixed noise for visualization
    fixed_noise = torch.randn(64, latent_dim, device=device)
    
    for epoch in range(num_epochs):
        g_loss_epoch = 0
        d_loss_epoch = 0
        
        for batch_idx, (real_images, _) in enumerate(dataloader):
            batch_size = real_images.size(0)
            real_images = real_images.view(batch_size, -1).to(device)
            
            # Labels
            real_labels = torch.ones(batch_size, 1, device=device)
            fake_labels = torch.zeros(batch_size, 1, device=device)
            
            # ==================
            # Train Discriminator
            # ==================
            d_optimizer.zero_grad()
            
            # Real images
            real_output = discriminator(real_images)
            d_loss_real = criterion(real_output, real_labels)
            
            # Fake images
            noise = torch.randn(batch_size, latent_dim, device=device)
            fake_images = generator(noise)
            fake_output = discriminator(fake_images.detach())  # detach() is critical: we train D without backpropagating into G
            d_loss_fake = criterion(fake_output, fake_labels)
            
            # Total discriminator loss
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            d_optimizer.step()
            
            # ===============
            # Train Generator
            # ===============
            g_optimizer.zero_grad()
            
            # Generate new fake images
            noise = torch.randn(batch_size, latent_dim, device=device)
            fake_images = generator(noise)
            fake_output = discriminator(fake_images)
            
            # Generator wants discriminator to think fakes are real
            g_loss = criterion(fake_output, real_labels)
            g_loss.backward()
            g_optimizer.step()
            
            g_loss_epoch += g_loss.item()
            d_loss_epoch += d_loss.item()
        
        # Average losses
        g_loss_epoch /= len(dataloader)
        d_loss_epoch /= len(dataloader)
        g_losses.append(g_loss_epoch)
        d_losses.append(d_loss_epoch)
        
        # Print progress
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] | "
                  f"D Loss: {d_loss_epoch:.4f} | G Loss: {g_loss_epoch:.4f}")
            
            # Generate sample images
            with torch.no_grad():
                sample_images = generator(fixed_noise).view(-1, 1, 28, 28)
                visualize_samples(sample_images[:16], epoch + 1)
    
    return g_losses, d_losses


def visualize_samples(images, epoch):
    """Display generated images in a grid."""
    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
    for i, ax in enumerate(axes.flat):
        img = images[i].cpu().squeeze().numpy()
        img = (img + 1) / 2  # Denormalize from [-1, 1] to [0, 1]
        ax.imshow(img, cmap='gray')
        ax.axis('off')
    plt.suptitle(f'Generated Samples - Epoch {epoch}')
    plt.tight_layout()
    plt.savefig(f'gan_samples_epoch_{epoch}.png')
    plt.close()


# Train the GAN
print("Starting GAN training...")
g_losses, d_losses = train_gan(generator, discriminator, dataloader, num_epochs)
Output:
Starting GAN training...
Epoch [10/50] | D Loss: 0.8234 | G Loss: 1.4521
Epoch [20/50] | D Loss: 0.6912 | G Loss: 1.2847
Epoch [30/50] | D Loss: 0.5823 | G Loss: 1.5234
Epoch [40/50] | D Loss: 0.5412 | G Loss: 1.6891
Epoch [50/50] | D Loss: 0.5234 | G Loss: 1.7234

Mode Collapse: The GAN’s Achilles Heel

Mode Collapse in GANs
Mode collapse occurs when the generator produces only a limited variety of outputs, essentially “memorizing” a few samples that fool the discriminator. Think of it this way: if you’re a student trying to pass an exam, and you discover the teacher always accepts the same essay, you’d stop writing anything else. The Generator does the same thing — it finds one “safe” output that consistently gets a high score from the Discriminator, then maps every noise vector to that output (or a small cluster of outputs). The result? A generator that can only produce three types of faces, or always generates the digit “1” regardless of input noise.

Why Does Mode Collapse Happen?

CauseExplanation
Easy shortcutGenerator finds one mode that consistently fools D
Discriminator forgettingD gets fooled by one mode, G exploits it
Training imbalanceG or D becomes too strong too quickly
def detect_mode_collapse(generator, num_samples=1000, threshold=0.1):
    """
    Detect mode collapse by checking diversity of generated samples.
    """
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_samples, latent_dim, device=device)
        samples = generator(noise).cpu().numpy()
    
    # Calculate pairwise distances
    from scipy.spatial.distance import pdist
    distances = pdist(samples, metric='euclidean')
    
    mean_distance = np.mean(distances)
    std_distance = np.std(distances)
    
    # Low mean distance indicates mode collapse
    diversity_score = mean_distance / (std_distance + 1e-8)
    
    print(f"Mean pairwise distance: {mean_distance:.4f}")
    print(f"Std pairwise distance: {std_distance:.4f}")
    print(f"Diversity score: {diversity_score:.4f}")
    
    if mean_distance < threshold:
        print("⚠️ Warning: Possible mode collapse detected!")
        return True
    return False


# Solutions to mode collapse
class AntiModeCollapseTraining:
    """
    Techniques to prevent mode collapse.
    """
    
    @staticmethod
    def minibatch_discrimination(features, num_kernels=5, kernel_dim=3):
        """
        Compute features based on relationships with other samples.
        Helps discriminator catch mode collapse.
        """
        batch_size = features.size(0)
        T = nn.Parameter(torch.randn(features.size(1), num_kernels * kernel_dim))
        
        M = features @ T
        M = M.view(batch_size, num_kernels, kernel_dim)
        
        # Compute L1 distance between all pairs
        M_expanded = M.unsqueeze(0)  # [1, B, K, D]
        M_T = M.unsqueeze(1)  # [B, 1, K, D]
        
        L1_distance = torch.abs(M_expanded - M_T).sum(dim=3)  # [B, B, K]
        
        # Compute similarity
        similarity = torch.exp(-L1_distance)
        # Exclude self-similarity
        mask = 1 - torch.eye(batch_size, device=features.device)
        similarity = similarity * mask.unsqueeze(2)
        
        o = similarity.sum(dim=1)  # [B, K]
        return o
    
    @staticmethod
    def feature_matching_loss(real_features, fake_features):
        """
        Match statistics of discriminator features instead of output.
        """
        return torch.mean((real_features.mean(0) - fake_features.mean(0)) ** 2)
    
    @staticmethod
    def historical_averaging(current_params, historical_params, beta=0.99):
        """
        Penalize large deviations from historical parameter averages.
        """
        loss = 0
        for curr, hist in zip(current_params, historical_params):
            loss += torch.mean((curr - hist) ** 2)
        return loss

DCGAN: Deep Convolutional GAN

DCGAN Architecture with Convolutions
DCGANs use convolutional layers for better image generation. Key architectural guidelines:
GuidelineGeneratorDiscriminator
PoolingUse transposed convUse strided conv
NormalizationBatchNorm (except output)BatchNorm (except input)
ActivationReLU (except output: Tanh)LeakyReLU
ArchitectureFully convolutionalFully convolutional
class DCGANGenerator(nn.Module):
    """
    DCGAN Generator using transposed convolutions.
    Maps noise vector to 64x64 RGB image.
    """
    def __init__(self, latent_dim=100, feature_maps=64, channels=3):
        super().__init__()
        self.latent_dim = latent_dim
        
        self.model = nn.Sequential(
            # Input: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, feature_maps * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.ReLU(True),
            # State: (feature_maps*8) x 4 x 4
            
            nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.ReLU(True),
            # State: (feature_maps*4) x 8 x 8
            
            nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.ReLU(True),
            # State: (feature_maps*2) x 16 x 16
            
            nn.ConvTranspose2d(feature_maps * 2, feature_maps, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps),
            nn.ReLU(True),
            # State: feature_maps x 32 x 32
            
            nn.ConvTranspose2d(feature_maps, channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output: channels x 64 x 64
        )
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights as per DCGAN paper."""
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1.0, 0.02)
                nn.init.zeros_(m.bias)
    
    def forward(self, z):
        # Reshape to [batch, latent_dim, 1, 1]
        z = z.view(-1, self.latent_dim, 1, 1)
        return self.model(z)


class DCGANDiscriminator(nn.Module):
    """
    DCGAN Discriminator using strided convolutions.
    Classifies 64x64 RGB images as real or fake.
    """
    def __init__(self, feature_maps=64, channels=3):
        super().__init__()
        
        self.model = nn.Sequential(
            # Input: channels x 64 x 64
            nn.Conv2d(channels, feature_maps, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State: feature_maps x 32 x 32
            
            nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (feature_maps*2) x 16 x 16
            
            nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (feature_maps*4) x 8 x 8
            
            nn.Conv2d(feature_maps * 4, feature_maps * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (feature_maps*8) x 4 x 4
            
            nn.Conv2d(feature_maps * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # Output: 1 x 1 x 1
        )
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1.0, 0.02)
                nn.init.zeros_(m.bias)
    
    def forward(self, x):
        return self.model(x).view(-1, 1)


# Instantiate DCGAN
dc_generator = DCGANGenerator(latent_dim=100).to(device)
dc_discriminator = DCGANDiscriminator().to(device)

# Test forward pass
test_noise = torch.randn(4, 100, device=device)
test_images = dc_generator(test_noise)
print(f"Generator output shape: {test_images.shape}")  # [4, 3, 64, 64]

test_pred = dc_discriminator(test_images)
print(f"Discriminator output shape: {test_pred.shape}")  # [4, 1]
Output:
Generator output shape: torch.Size([4, 3, 64, 64])
Discriminator output shape: torch.Size([4, 1])

Wasserstein GAN (WGAN)

Wasserstein Distance vs JS Divergence
WGAN addresses training instability by using the Wasserstein distance (Earth Mover’s Distance) instead of JS divergence.

Why Wasserstein Distance?

The analogy: Imagine you have two piles of sand (the real and generated distributions) and you want to measure how different they are. The Wasserstein distance measures the minimum amount of “work” (mass times distance) needed to reshape one pile into the other. Unlike JS divergence, which jumps between 0 and log2\log 2 when distributions don’t overlap, the Wasserstein distance changes smoothly — giving the generator useful gradient signal even when the discriminator can perfectly separate real from fake. W(Pr,Pg)=infγΠ(Pr,Pg)E(x,y)γ[xy]W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x, y) \sim \gamma}[\|x - y\|]
Training tip: The Wasserstein distance is computed via the Kantorovich-Rubinstein duality, which requires the critic to be 1-Lipschitz (its gradients bounded by 1 everywhere). The original WGAN enforced this with weight clipping, but WGAN-GP (gradient penalty) is strictly better — it doesn’t suffer from capacity underuse or exploding/vanishing weights.
AdvantageExplanation
Continuous gradientsProvides gradients even when distributions don’t overlap
Meaningful lossLoss correlates with image quality
Stable trainingNo need for careful balance between G and D
class WGANCritic(nn.Module):
    """
    WGAN uses a 'Critic' instead of 'Discriminator'.
    No sigmoid -- outputs unbounded real numbers (Wasserstein distance is unconstrained).
    The name changes from "discriminator" to "critic" because it no longer classifies
    real/fake; instead it scores how "real" something looks on a continuous scale.
    """
    def __init__(self, feature_maps=64, channels=3):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Conv2d(channels, feature_maps, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(feature_maps * 4, feature_maps * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(feature_maps * 8, 1, 4, 1, 0, bias=False),
            # No sigmoid! Unbounded output
        )
    
    def forward(self, x):
        return self.model(x).view(-1)


def wgan_critic_loss(real_output, fake_output):
    """
    WGAN Critic Loss: Maximize E[C(real)] - E[C(fake)]
    Equivalent to minimizing: E[C(fake)] - E[C(real)]
    """
    return fake_output.mean() - real_output.mean()


def wgan_generator_loss(fake_output):
    """
    WGAN Generator Loss: Minimize -E[C(fake)]
    Equivalent to maximizing E[C(fake)]
    """
    return -fake_output.mean()


def gradient_penalty(critic, real_data, fake_data, device):
    """
    WGAN-GP: Gradient penalty for 1-Lipschitz constraint.
    Instead of clipping weights (which underuses network capacity),
    we penalize the gradient norm along random interpolations between
    real and fake data. The penalty enforces ||grad(C(x))|| = 1,
    which is a sufficient condition for the critic to be 1-Lipschitz.
    """
    batch_size = real_data.size(0)
    
    # Random interpolation coefficient -- sample along the line between real and fake
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    
    # Interpolate between real and fake
    interpolated = alpha * real_data + (1 - alpha) * fake_data
    interpolated.requires_grad_(True)
    
    # Get critic score for interpolated
    critic_interpolated = critic(interpolated)
    
    # Compute gradients
    gradients = torch.autograd.grad(
        outputs=critic_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(critic_interpolated),
        create_graph=True,
        retain_graph=True
    )[0]
    
    # Gradient norm
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    
    # Penalty: (||grad|| - 1)^2
    penalty = ((gradient_norm - 1) ** 2).mean()
    
    return penalty


def train_wgan_gp(generator, critic, dataloader, num_epochs, 
                   n_critic=5, lambda_gp=10):
    """
    Training loop for WGAN with gradient penalty.
    """
    # WGAN uses lower lr and different betas than DCGAN: beta1=0.0 disables momentum
    # entirely, which prevents the critic from "remembering" outdated gradient directions
    g_optimizer = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.0, 0.9))
    c_optimizer = optim.Adam(critic.parameters(), lr=1e-4, betas=(0.0, 0.9))
    
    for epoch in range(num_epochs):
        for batch_idx, (real_images, _) in enumerate(dataloader):
            real_images = real_images.to(device)
            batch_size = real_images.size(0)
            
            # Train Critic n_critic times per generator step.
            # Why? The Wasserstein estimate is only accurate when the critic is close
            # to optimal. Training D more often ensures the gradient signal to G is meaningful.
            for _ in range(n_critic):
                c_optimizer.zero_grad()
                
                # Generate fake images
                noise = torch.randn(batch_size, latent_dim, device=device)
                fake_images = generator(noise)
                
                # Critic scores
                real_validity = critic(real_images)
                fake_validity = critic(fake_images.detach())
                
                # Gradient penalty
                gp = gradient_penalty(critic, real_images, fake_images, device)
                
                # Critic loss
                c_loss = wgan_critic_loss(real_validity, fake_validity) + lambda_gp * gp
                c_loss.backward()
                c_optimizer.step()
            
            # Train Generator
            g_optimizer.zero_grad()
            
            noise = torch.randn(batch_size, latent_dim, device=device)
            fake_images = generator(noise)
            fake_validity = critic(fake_images)
            
            g_loss = wgan_generator_loss(fake_validity)
            g_loss.backward()
            g_optimizer.step()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] | "
                  f"C Loss: {c_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

Conditional GANs (cGAN)

Conditional GAN Architecture
Conditional GANs allow us to control what we generate by conditioning on additional information (class labels, text, images). minGmaxDV(D,G)=Expdata[logD(xy)]+Ezpz[log(1D(G(zy)y))]\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}}[\log D(x|y)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z|y)|y))]
class ConditionalGenerator(nn.Module):
    """
    Generator conditioned on class labels.
    Generate specific digit: "Generate a 7"
    """
    def __init__(self, latent_dim=100, num_classes=10, embed_dim=50, img_size=28):
        super().__init__()
        self.latent_dim = latent_dim
        self.img_size = img_size
        
        # Embedding for class labels
        self.label_embedding = nn.Embedding(num_classes, embed_dim)
        
        # Input: noise + embedded label
        self.model = nn.Sequential(
            nn.Linear(latent_dim + embed_dim, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            
            nn.Linear(1024, img_size * img_size),
            nn.Tanh()
        )
    
    def forward(self, z, labels):
        # Embed labels
        label_embedding = self.label_embedding(labels)
        
        # Concatenate noise and label embedding
        gen_input = torch.cat([z, label_embedding], dim=1)
        
        return self.model(gen_input)


class ConditionalDiscriminator(nn.Module):
    """
    Discriminator that also receives class labels.
    """
    def __init__(self, num_classes=10, embed_dim=50, img_size=28):
        super().__init__()
        self.img_size = img_size
        
        # Embedding for class labels
        self.label_embedding = nn.Embedding(num_classes, embed_dim)
        
        # Input: flattened image + embedded label
        self.model = nn.Sequential(
            nn.Linear(img_size * img_size + embed_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img, labels):
        # Flatten image
        img_flat = img.view(img.size(0), -1)
        
        # Embed labels
        label_embedding = self.label_embedding(labels)
        
        # Concatenate image and label embedding
        d_input = torch.cat([img_flat, label_embedding], dim=1)
        
        return self.model(d_input)


# Example: Generate specific digits
def generate_digit(generator, digit, num_samples=16):
    """Generate samples of a specific digit."""
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_samples, latent_dim, device=device)
        labels = torch.full((num_samples,), digit, dtype=torch.long, device=device)
        generated = generator(noise, labels)
    return generated.view(-1, 1, 28, 28)


# Initialize conditional GAN
cond_generator = ConditionalGenerator().to(device)
cond_discriminator = ConditionalDiscriminator().to(device)

# Generate all digits 0-9
print("Conditional GAN allows generating specific classes!")
for digit in range(10):
    samples = generate_digit(cond_generator, digit, num_samples=8)
    print(f"Generated {samples.shape[0]} samples of digit {digit}")

GAN Evaluation Metrics

Evaluating GANs is Hard! Unlike supervised learning, there’s no clear “correct answer” to compare against. The generator loss alone is almost meaningless — it can decrease while image quality gets worse, or increase while images improve. Always use sample-based metrics alongside visual inspection.
MetricMeasuresFormula/Approach
Inception Score (IS)Quality + Diversity$\exp(\mathbb[D_(p(yx) | p(y))])$
FID (Fréchet Inception Distance)Feature similarityCompare Gaussian fits in feature space
LPIPSPerceptual similarityDistance in learned feature space
def calculate_fid_simple(real_features, fake_features):
    """
    Simplified FID calculation.
    Lower FID = Better quality and diversity.
    
    Full FID uses InceptionV3 features.
    """
    # Calculate statistics
    mu_real = np.mean(real_features, axis=0)
    sigma_real = np.cov(real_features, rowvar=False)
    
    mu_fake = np.mean(fake_features, axis=0)
    sigma_fake = np.cov(fake_features, rowvar=False)
    
    # Calculate FID
    diff = mu_real - mu_fake
    
    # Matrix square root
    from scipy.linalg import sqrtm
    covmean = sqrtm(sigma_real @ sigma_fake)
    
    # Handle numerical issues
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)
    
    return fid


# Practical evaluation
def evaluate_gan(generator, dataloader, num_samples=1000):
    """
    Evaluate GAN using multiple metrics.
    """
    generator.eval()
    
    # Collect real and fake samples
    real_samples = []
    fake_samples = []
    
    with torch.no_grad():
        for real_images, _ in dataloader:
            real_samples.append(real_images.view(real_images.size(0), -1))
            if len(real_samples) * real_images.size(0) >= num_samples:
                break
        
        noise = torch.randn(num_samples, latent_dim, device=device)
        fake_images = generator(noise)
        fake_samples.append(fake_images.cpu())
    
    real_samples = torch.cat(real_samples, dim=0)[:num_samples].numpy()
    fake_samples = torch.cat(fake_samples, dim=0).numpy()
    
    # Calculate metrics
    fid = calculate_fid_simple(real_samples, fake_samples)
    
    # Diversity (average pairwise distance)
    from scipy.spatial.distance import pdist
    diversity = np.mean(pdist(fake_samples[:100]))
    
    print(f"FID Score: {fid:.2f} (lower is better)")
    print(f"Diversity Score: {diversity:.4f} (higher is better)")
    
    return {"fid": fid, "diversity": diversity}

Exercises

Label smoothing can stabilize GAN training. Instead of using hard labels (0 and 1), use soft labels (e.g., 0.1 and 0.9).
# TODO: Modify the training loop to use label smoothing
# Real labels: 0.9 instead of 1.0
# Fake labels: 0.1 instead of 0.0

def train_with_label_smoothing(generator, discriminator, dataloader, num_epochs):
    # Your implementation here
    pass
Implement a training schedule that gradually increases image resolution.
# TODO: Start with 4x4 images and progressively grow to 64x64
# Hint: Add layers during training

class ProgressiveGenerator(nn.Module):
    def __init__(self):
        # Your implementation here
        pass
Spectral normalization constrains the Lipschitz constant of discriminator layers.
# TODO: Implement spectral normalization for discriminator weights
# Hint: Normalize weights by their largest singular value

class SpectralNormDiscriminator(nn.Module):
    # Your implementation here
    pass

Key Takeaways

What You Learned:
  • GAN Fundamentals - Generator creates, Discriminator classifies in a minimax game
  • Minimax Loss - Adversarial training objective and its components
  • Mode Collapse - Common failure mode and solutions (minibatch discrimination, feature matching)
  • DCGAN - Convolutional architecture guidelines for stable training
  • WGAN - Wasserstein distance for improved training stability
  • Conditional GANs - Control generation with class labels or other conditioning
  • Evaluation - FID, Inception Score, and diversity metrics

Training Tips from the Trenches

GAN Training Mistakes to Avoid:
  1. Imbalanced training — Don’t let D or G become too strong too quickly. Monitor the ratio of D accuracy on real vs fake: if D accuracy hits 100% early, G will get vanishing gradients. A healthy D accuracy hovers around 60-80%.
  2. Ignoring mode collapse — Monitor sample diversity throughout training. Generate a grid of 64+ images every N epochs and visually inspect. If all outputs look nearly identical, you have mode collapse.
  3. Wrong learning rates — GANs are sensitive; start with the DCGAN defaults (lr=0.0002, betas=(0.5, 0.999)) and only deviate with reason.
  4. Batch size too small — BatchNorm needs sufficient batch size (at least 16, ideally 64+). Small batches cause noisy BN statistics that destabilize training.
  5. Not using proper initialization — DCGAN weight init (Normal(0, 0.02)) matters significantly. Xavier or He init, which work great for classifiers, can cause GAN training to diverge.
Practical training checklist a senior engineer would follow:
  • Start with a known-working architecture (DCGAN or StyleGAN2) before experimenting
  • Log both G and D losses AND generated samples — losses alone are misleading
  • Use torch.no_grad() when generating samples for visualization to avoid memory leaks
  • If training with mixed precision (fp16), keep the discriminator in fp32 — the sigmoid output is numerically sensitive
  • Save checkpoints frequently: GAN training is non-monotonic, and the best checkpoint is often not the last one
  • Use FID on a held-out set as your primary quality metric, not visual inspection alone

Interview Deep-Dive

Strong Answer:
  • The minimax objective is minGmaxDE[logD(x)]+E[log(1D(G(z)))]\min_G \max_D \mathbb{E}[\log D(x)] + \mathbb{E}[\log(1 - D(G(z)))]. The Discriminator maximizes this expression (correctly classifying real and fake), while the Generator minimizes it (fooling D).
  • Vanishing gradient problem: Early in training, the Generator produces obviously fake outputs. The Discriminator quickly learns to reject them with D(G(z))0D(G(z)) \approx 0. The Generator’s gradient comes from log(1D(G(z)))\log(1 - D(G(z))), which is log(10)=log(1)=0\log(1 - 0) = \log(1) = 0 — the gradient is essentially zero. The Generator receives almost no learning signal precisely when it needs the most guidance.
  • Non-saturating fix: Instead of minimizing log(1D(G(z)))\log(1 - D(G(z))), the Generator maximizes log(D(G(z)))\log(D(G(z))). When D(G(z))0D(G(z)) \approx 0, this gives log(0)\log(0) \rightarrow -\infty, producing a very large gradient. The Generator now receives strong signal to improve even when the Discriminator easily rejects its outputs.
  • A senior engineer would note: the non-saturating loss changes the optimization landscape but doesn’t change the theoretical equilibrium point. Both formulations converge to pg=pdatap_g = p_{data} at the Nash equilibrium. The practical difference is entirely about gradient magnitude during early training.
Follow-up: Does the non-saturating loss introduce any new problems?Yes. It can cause training instability through mode-seeking behavior. The original loss is mode-covering (G tries to spread mass over all of pdatap_{data}), while the non-saturating loss is mode-seeking (G concentrates on modes that currently fool D the most). This is one mechanism behind mode collapse. WGAN’s Wasserstein distance addresses both problems simultaneously.
Strong Answer:
  • Mode collapse is when the Generator maps many different noise vectors to a small set of outputs, producing limited diversity. In the extreme case (“complete collapse”), every input produces the same image. Partial collapse is more common: a GAN trained on MNIST might only generate digits 1, 7, and 9, ignoring the other seven classes.
  • Technique 1: Minibatch Discrimination. The Discriminator receives additional features computed across the batch, allowing it to detect when all generated samples look too similar. Trade-off: adds computational overhead and couples predictions within a batch, which complicates distributed training.
  • Technique 2: Wasserstein loss (WGAN-GP). By replacing the JS divergence with the Wasserstein distance, the critic provides meaningful gradients even when distributions don’t overlap, reducing the “shortcut incentive” that causes collapse. Trade-off: requires training the critic for multiple steps per generator step (typically 5), increasing wall-clock time by 3-5x. Also requires removing BatchNorm from the critic when using gradient penalty.
  • Technique 3: Unrolled GANs. The Generator anticipates future Discriminator updates by unrolling K steps of D’s optimization. This prevents G from over-exploiting D’s current weaknesses. Trade-off: memory-intensive (must store K computation graphs) and adds significant complexity. Rarely used in production — more of a research technique.
  • A senior engineer would add: in practice, monitoring for mode collapse matters more than any single technique. Track the diversity of generated samples using FID, the number of distinct modes in generated class distributions, or simply visual inspection grids. If collapse is detected, the most pragmatic fix is often reducing the learning rate of G relative to D, or switching to a progressive training schedule.
Strong Answer:
  • Original GAN loss (BCE): Simple to implement, works well with DCGAN architecture and careful hyperparameter tuning. Choose this when you want a quick prototype and your dataset is well-behaved (balanced, sufficient data). The main risk is training instability and mode collapse.
  • WGAN-GP: Uses the Wasserstein distance approximated via a gradient penalty on the critic. The critic loss directly correlates with sample quality (unlike BCE loss), making it a useful training diagnostic. Choose WGAN-GP when training stability is paramount or when the original GAN loss diverges. The cost is 3-5x slower training due to multiple critic updates per generator step, plus the gradient penalty computation doubles backward-pass cost.
  • Spectral Normalization (SN-GAN): Constrains the Lipschitz constant of the Discriminator by normalizing each weight matrix by its spectral norm (largest singular value). Unlike WGAN-GP, it requires only one D update per G update and adds negligible computational overhead (one power iteration step per forward pass). Choose SN when you want WGAN-level stability without the training cost. In practice, SN-GAN has become the default for many architectures including BigGAN and StyleGAN.
  • Decision framework: for production image generation, start with SN-GAN. If quality plateaus, try WGAN-GP. Only fall back to vanilla BCE loss for simple datasets (MNIST, CIFAR) where you know the hyperparameters work. For large-scale generation (ImageNet, faces), modern architectures like StyleGAN2 use a combination of SN plus R1 gradient penalty, which is another flavor of the same Lipschitz-constraint idea.
Strong Answer:
  • Inception Score (IS): measures two things: (1) quality — high-quality images should have confident, peaked class predictions p(yx)p(y|x), and (2) diversity — the marginal distribution p(y)=p(yx)pg(x)dxp(y) = \int p(y|x)p_g(x)dx should be uniform across all classes. IS = exp(E[DKL(p(yx)p(y))])\exp(\mathbb{E}[D_{KL}(p(y|x) \| p(y))]). Higher is better.
  • IS limitations: it uses InceptionV3 trained on ImageNet, so it’s biased toward ImageNet-like images. A GAN generating perfect medical images would get a low IS. It also can’t detect intra-class mode collapse (generating only one type of dog still gets high IS if different breeds are represented). And critically, IS doesn’t compare against real data at all — a GAN could generate images from a completely different distribution and still get high IS.
  • Frechet Inception Distance (FID): computes the Frechet distance between two multivariate Gaussians fitted to InceptionV3 features of real and generated images: FID=μrμg2+Tr(Σr+Σg2(ΣrΣg)1/2)FID = \|\mu_r - \mu_g\|^2 + \text{Tr}(\Sigma_r + \Sigma_g - 2(\Sigma_r\Sigma_g)^{1/2}). Lower is better. FID captures both quality and diversity and compares against the actual data distribution.
  • FID limitations: assumes Gaussian feature distributions (which is a rough approximation), sensitive to sample size (need at least 10,000 samples for reliable estimates, ideally 50,000), and still depends on InceptionV3 features. For domains far from natural images, consider domain-specific metrics or Kernel Inception Distance (KID), which has an unbiased estimator and is less sensitive to sample size.
  • A senior engineer would add: never rely on a single metric. Use FID as the primary quantitative measure, but always supplement with visual inspection (grid plots), precision/recall curves (to separate quality from diversity failures), and domain-specific metrics where applicable.
Strong Answer:
  • Architecture choice: StyleGAN2 or StyleGAN3 as the backbone. These architectures produce the highest-quality images for structured objects and offer fine-grained control via the style-based generator. The mapping network transforms the latent code zz into an intermediate space ww, which controls different aspects of the image at different resolutions (coarse features like shape at low resolutions, fine details like texture at high resolutions).
  • Data pipeline: collect at least 50,000 product images per category (shoes, bags, electronics). Clean the dataset rigorously — remove duplicates, watermarks, and low-quality images. Apply standardized backgrounds (white or transparent). Resize to a consistent resolution (512x512 or 1024x1024). Augment with horizontal flips only (not rotations — product orientation matters).
  • Training strategy: progressive growing is unnecessary for StyleGAN2+ (the architecture handles it internally). Train with R1 gradient penalty (γ=10\gamma = 10), non-saturating logistic loss, and path length regularization every 16 minibatches. Use 4-8 GPUs with a total batch size of 32-64. Train for 25M+ images seen (not epochs) and track FID against a held-out validation set. Early stopping when FID plateaus.
  • Quality assurance pipeline: (1) automated FID/KID checks against the real dataset, (2) LPIPS-based diversity check (reject batches with mean pairwise LPIPS below a threshold), (3) human evaluation panel rating realism on a 1-5 scale for a random sample of 200 images, (4) A/B testing on the platform — do generated images lead to similar click-through and conversion rates as real product photos?
  • Production considerations: serve the generator with ONNX Runtime or TensorRT for 10-50ms inference latency. Cache generated images rather than generating on-the-fly. Implement a moderation pipeline to catch any artifacts or inappropriate content before serving. Version the model and track FID over time to detect quality regression.

Next: Autoencoders & VAEs

Learn about variational autoencoders and latent space representations