Skip to main content
Generative Adversarial Networks - Generator vs Discriminator

Generative Adversarial Networks

The Counterfeiter vs Detective Game

Imagine a world with two players locked in an eternal battle:
  • The Counterfeiter (Generator): Creates fake money, trying to make it indistinguishable from real currency
  • The Detective (Discriminator): Examines bills and tries to identify which are real and which are fake
At first, the counterfeiter is terrible - their fake bills look obviously fake. But with each rejection, they learn and improve. Meanwhile, the detective gets better at spotting fakes, forcing the counterfeiter to up their game. This is exactly how GANs work!
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

GAN Architecture Overview

GAN Architecture - Generator and Discriminator
A GAN consists of two neural networks trained simultaneously:
ComponentRoleInputOutput
Generator (G)Creates fake dataRandom noise zzFake sample G(z)G(z)
Discriminator (D)Classifies real/fakeSample xxProbability D(x)D(x)
The networks compete in a minimax game - the generator tries to fool the discriminator, while the discriminator tries to catch the generator.
class SimpleGenerator(nn.Module):
    """
    Generator: Maps random noise to data space.
    Think of it as learning to "draw" realistic images from scratch.
    """
    def __init__(self, latent_dim=100, output_dim=784):
        super().__init__()
        self.latent_dim = latent_dim
        
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            
            nn.Linear(1024, output_dim),
            nn.Tanh()  # Output in range [-1, 1]
        )
    
    def forward(self, z):
        return self.model(z)


class SimpleDiscriminator(nn.Module):
    """
    Discriminator: Binary classifier - real or fake?
    Outputs probability that input is real.
    """
    def __init__(self, input_dim=784):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(256, 1),
            nn.Sigmoid()  # Output probability [0, 1]
        )
    
    def forward(self, x):
        return self.model(x)


# Initialize networks
latent_dim = 100
generator = SimpleGenerator(latent_dim=latent_dim).to(device)
discriminator = SimpleDiscriminator().to(device)

print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
Output:
Generator parameters: 1,076,244
Discriminator parameters: 407,297

The Minimax Loss Function

The GAN training objective is a minimax game: minGmaxDV(D,G)=Expdata[logD(x)]+Ezpz[log(1D(G(z)))]\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] Let’s break this down:
TermMeaningWhat It Does
Expdata[logD(x)]\mathbb{E}_{x \sim p_{data}}[\log D(x)]Expected log probability for real dataD wants this HIGH (correctly identify real)
Ezpz[log(1D(G(z)))]\mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]Expected log probability for fake dataD wants this HIGH, G wants this LOW
GAN Minimax Game Visualization
The Discriminator wants to maximize V(D,G)V(D,G):
  • Correctly classify real images → D(x)1D(x) \approx 1
  • Correctly classify fake images → D(G(z))0D(G(z)) \approx 0
The Generator wants to minimize V(D,G)V(D,G):
  • Fool the discriminator → D(G(z))1D(G(z)) \approx 1
def discriminator_loss(real_output, fake_output):
    """
    Discriminator wants to:
    - Output 1 for real images (so log(D(x)) is high)
    - Output 0 for fake images (so log(1 - D(G(z))) is high)
    """
    real_loss = nn.BCELoss()(real_output, torch.ones_like(real_output))
    fake_loss = nn.BCELoss()(fake_output, torch.zeros_like(fake_output))
    return real_loss + fake_loss


def generator_loss(fake_output):
    """
    Generator wants discriminator to output 1 for fake images.
    We use the "non-saturating" loss for better gradients.
    """
    return nn.BCELoss()(fake_output, torch.ones_like(fake_output))


# Alternative: Using BCEWithLogitsLoss for numerical stability
class GANLoss:
    def __init__(self):
        self.criterion = nn.BCEWithLogitsLoss()
    
    def discriminator_loss(self, real_logits, fake_logits):
        real_labels = torch.ones_like(real_logits)
        fake_labels = torch.zeros_like(fake_logits)
        
        real_loss = self.criterion(real_logits, real_labels)
        fake_loss = self.criterion(fake_logits, fake_labels)
        return real_loss + fake_loss
    
    def generator_loss(self, fake_logits):
        real_labels = torch.ones_like(fake_logits)
        return self.criterion(fake_logits, real_labels)

Complete GAN Training Loop

Let’s train a GAN on MNIST digits:
# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

train_dataset = datasets.MNIST(
    root='./data', 
    train=True, 
    download=True, 
    transform=transform
)

batch_size = 128
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Hyperparameters
lr = 0.0002
beta1 = 0.5
num_epochs = 50

# Optimizers - Adam with specific betas for GANs
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# Loss function
criterion = nn.BCELoss()


def train_gan(generator, discriminator, dataloader, num_epochs):
    """
    Complete GAN training loop with monitoring.
    """
    g_losses = []
    d_losses = []
    
    # Fixed noise for visualization
    fixed_noise = torch.randn(64, latent_dim, device=device)
    
    for epoch in range(num_epochs):
        g_loss_epoch = 0
        d_loss_epoch = 0
        
        for batch_idx, (real_images, _) in enumerate(dataloader):
            batch_size = real_images.size(0)
            real_images = real_images.view(batch_size, -1).to(device)
            
            # Labels
            real_labels = torch.ones(batch_size, 1, device=device)
            fake_labels = torch.zeros(batch_size, 1, device=device)
            
            # ==================
            # Train Discriminator
            # ==================
            d_optimizer.zero_grad()
            
            # Real images
            real_output = discriminator(real_images)
            d_loss_real = criterion(real_output, real_labels)
            
            # Fake images
            noise = torch.randn(batch_size, latent_dim, device=device)
            fake_images = generator(noise)
            fake_output = discriminator(fake_images.detach())  # Detach to not update G
            d_loss_fake = criterion(fake_output, fake_labels)
            
            # Total discriminator loss
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            d_optimizer.step()
            
            # ===============
            # Train Generator
            # ===============
            g_optimizer.zero_grad()
            
            # Generate new fake images
            noise = torch.randn(batch_size, latent_dim, device=device)
            fake_images = generator(noise)
            fake_output = discriminator(fake_images)
            
            # Generator wants discriminator to think fakes are real
            g_loss = criterion(fake_output, real_labels)
            g_loss.backward()
            g_optimizer.step()
            
            g_loss_epoch += g_loss.item()
            d_loss_epoch += d_loss.item()
        
        # Average losses
        g_loss_epoch /= len(dataloader)
        d_loss_epoch /= len(dataloader)
        g_losses.append(g_loss_epoch)
        d_losses.append(d_loss_epoch)
        
        # Print progress
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] | "
                  f"D Loss: {d_loss_epoch:.4f} | G Loss: {g_loss_epoch:.4f}")
            
            # Generate sample images
            with torch.no_grad():
                sample_images = generator(fixed_noise).view(-1, 1, 28, 28)
                visualize_samples(sample_images[:16], epoch + 1)
    
    return g_losses, d_losses


def visualize_samples(images, epoch):
    """Display generated images in a grid."""
    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
    for i, ax in enumerate(axes.flat):
        img = images[i].cpu().squeeze().numpy()
        img = (img + 1) / 2  # Denormalize from [-1, 1] to [0, 1]
        ax.imshow(img, cmap='gray')
        ax.axis('off')
    plt.suptitle(f'Generated Samples - Epoch {epoch}')
    plt.tight_layout()
    plt.savefig(f'gan_samples_epoch_{epoch}.png')
    plt.close()


# Train the GAN
print("Starting GAN training...")
g_losses, d_losses = train_gan(generator, discriminator, dataloader, num_epochs)
Output:
Starting GAN training...
Epoch [10/50] | D Loss: 0.8234 | G Loss: 1.4521
Epoch [20/50] | D Loss: 0.6912 | G Loss: 1.2847
Epoch [30/50] | D Loss: 0.5823 | G Loss: 1.5234
Epoch [40/50] | D Loss: 0.5412 | G Loss: 1.6891
Epoch [50/50] | D Loss: 0.5234 | G Loss: 1.7234

Mode Collapse: The GAN’s Achilles Heel

Mode Collapse in GANs
Mode collapse occurs when the generator produces only a limited variety of outputs, essentially “memorizing” a few samples that fool the discriminator.

Why Does Mode Collapse Happen?

CauseExplanation
Easy shortcutGenerator finds one mode that consistently fools D
Discriminator forgettingD gets fooled by one mode, G exploits it
Training imbalanceG or D becomes too strong too quickly
def detect_mode_collapse(generator, num_samples=1000, threshold=0.1):
    """
    Detect mode collapse by checking diversity of generated samples.
    """
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_samples, latent_dim, device=device)
        samples = generator(noise).cpu().numpy()
    
    # Calculate pairwise distances
    from scipy.spatial.distance import pdist
    distances = pdist(samples, metric='euclidean')
    
    mean_distance = np.mean(distances)
    std_distance = np.std(distances)
    
    # Low mean distance indicates mode collapse
    diversity_score = mean_distance / (std_distance + 1e-8)
    
    print(f"Mean pairwise distance: {mean_distance:.4f}")
    print(f"Std pairwise distance: {std_distance:.4f}")
    print(f"Diversity score: {diversity_score:.4f}")
    
    if mean_distance < threshold:
        print("⚠️ Warning: Possible mode collapse detected!")
        return True
    return False


# Solutions to mode collapse
class AntiModeCollapseTraining:
    """
    Techniques to prevent mode collapse.
    """
    
    @staticmethod
    def minibatch_discrimination(features, num_kernels=5, kernel_dim=3):
        """
        Compute features based on relationships with other samples.
        Helps discriminator catch mode collapse.
        """
        batch_size = features.size(0)
        T = nn.Parameter(torch.randn(features.size(1), num_kernels * kernel_dim))
        
        M = features @ T
        M = M.view(batch_size, num_kernels, kernel_dim)
        
        # Compute L1 distance between all pairs
        M_expanded = M.unsqueeze(0)  # [1, B, K, D]
        M_T = M.unsqueeze(1)  # [B, 1, K, D]
        
        L1_distance = torch.abs(M_expanded - M_T).sum(dim=3)  # [B, B, K]
        
        # Compute similarity
        similarity = torch.exp(-L1_distance)
        # Exclude self-similarity
        mask = 1 - torch.eye(batch_size, device=features.device)
        similarity = similarity * mask.unsqueeze(2)
        
        o = similarity.sum(dim=1)  # [B, K]
        return o
    
    @staticmethod
    def feature_matching_loss(real_features, fake_features):
        """
        Match statistics of discriminator features instead of output.
        """
        return torch.mean((real_features.mean(0) - fake_features.mean(0)) ** 2)
    
    @staticmethod
    def historical_averaging(current_params, historical_params, beta=0.99):
        """
        Penalize large deviations from historical parameter averages.
        """
        loss = 0
        for curr, hist in zip(current_params, historical_params):
            loss += torch.mean((curr - hist) ** 2)
        return loss

DCGAN: Deep Convolutional GAN

DCGAN Architecture with Convolutions
DCGANs use convolutional layers for better image generation. Key architectural guidelines:
GuidelineGeneratorDiscriminator
PoolingUse transposed convUse strided conv
NormalizationBatchNorm (except output)BatchNorm (except input)
ActivationReLU (except output: Tanh)LeakyReLU
ArchitectureFully convolutionalFully convolutional
class DCGANGenerator(nn.Module):
    """
    DCGAN Generator using transposed convolutions.
    Maps noise vector to 64x64 RGB image.
    """
    def __init__(self, latent_dim=100, feature_maps=64, channels=3):
        super().__init__()
        self.latent_dim = latent_dim
        
        self.model = nn.Sequential(
            # Input: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, feature_maps * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.ReLU(True),
            # State: (feature_maps*8) x 4 x 4
            
            nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.ReLU(True),
            # State: (feature_maps*4) x 8 x 8
            
            nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.ReLU(True),
            # State: (feature_maps*2) x 16 x 16
            
            nn.ConvTranspose2d(feature_maps * 2, feature_maps, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps),
            nn.ReLU(True),
            # State: feature_maps x 32 x 32
            
            nn.ConvTranspose2d(feature_maps, channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output: channels x 64 x 64
        )
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights as per DCGAN paper."""
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1.0, 0.02)
                nn.init.zeros_(m.bias)
    
    def forward(self, z):
        # Reshape to [batch, latent_dim, 1, 1]
        z = z.view(-1, self.latent_dim, 1, 1)
        return self.model(z)


class DCGANDiscriminator(nn.Module):
    """
    DCGAN Discriminator using strided convolutions.
    Classifies 64x64 RGB images as real or fake.
    """
    def __init__(self, feature_maps=64, channels=3):
        super().__init__()
        
        self.model = nn.Sequential(
            # Input: channels x 64 x 64
            nn.Conv2d(channels, feature_maps, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State: feature_maps x 32 x 32
            
            nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (feature_maps*2) x 16 x 16
            
            nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (feature_maps*4) x 8 x 8
            
            nn.Conv2d(feature_maps * 4, feature_maps * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (feature_maps*8) x 4 x 4
            
            nn.Conv2d(feature_maps * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # Output: 1 x 1 x 1
        )
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1.0, 0.02)
                nn.init.zeros_(m.bias)
    
    def forward(self, x):
        return self.model(x).view(-1, 1)


# Instantiate DCGAN
dc_generator = DCGANGenerator(latent_dim=100).to(device)
dc_discriminator = DCGANDiscriminator().to(device)

# Test forward pass
test_noise = torch.randn(4, 100, device=device)
test_images = dc_generator(test_noise)
print(f"Generator output shape: {test_images.shape}")  # [4, 3, 64, 64]

test_pred = dc_discriminator(test_images)
print(f"Discriminator output shape: {test_pred.shape}")  # [4, 1]
Output:
Generator output shape: torch.Size([4, 3, 64, 64])
Discriminator output shape: torch.Size([4, 1])

Wasserstein GAN (WGAN)

Wasserstein Distance vs JS Divergence
WGAN addresses training instability by using the Wasserstein distance (Earth Mover’s Distance) instead of JS divergence.

Why Wasserstein Distance?

W(Pr,Pg)=infγΠ(Pr,Pg)E(x,y)γ[xy]W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x, y) \sim \gamma}[\|x - y\|]
AdvantageExplanation
Continuous gradientsProvides gradients even when distributions don’t overlap
Meaningful lossLoss correlates with image quality
Stable trainingNo need for careful balance between G and D
class WGANCritic(nn.Module):
    """
    WGAN uses a 'Critic' instead of 'Discriminator'.
    No sigmoid - outputs unbounded real numbers.
    """
    def __init__(self, feature_maps=64, channels=3):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Conv2d(channels, feature_maps, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(feature_maps * 4, feature_maps * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(feature_maps * 8, 1, 4, 1, 0, bias=False),
            # No sigmoid! Unbounded output
        )
    
    def forward(self, x):
        return self.model(x).view(-1)


def wgan_critic_loss(real_output, fake_output):
    """
    WGAN Critic Loss: Maximize E[C(real)] - E[C(fake)]
    Equivalent to minimizing: E[C(fake)] - E[C(real)]
    """
    return fake_output.mean() - real_output.mean()


def wgan_generator_loss(fake_output):
    """
    WGAN Generator Loss: Minimize -E[C(fake)]
    Equivalent to maximizing E[C(fake)]
    """
    return -fake_output.mean()


def gradient_penalty(critic, real_data, fake_data, device):
    """
    WGAN-GP: Gradient penalty for 1-Lipschitz constraint.
    Penalizes gradients that deviate from 1.
    """
    batch_size = real_data.size(0)
    
    # Random interpolation coefficient
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    
    # Interpolate between real and fake
    interpolated = alpha * real_data + (1 - alpha) * fake_data
    interpolated.requires_grad_(True)
    
    # Get critic score for interpolated
    critic_interpolated = critic(interpolated)
    
    # Compute gradients
    gradients = torch.autograd.grad(
        outputs=critic_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(critic_interpolated),
        create_graph=True,
        retain_graph=True
    )[0]
    
    # Gradient norm
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    
    # Penalty: (||grad|| - 1)^2
    penalty = ((gradient_norm - 1) ** 2).mean()
    
    return penalty


def train_wgan_gp(generator, critic, dataloader, num_epochs, 
                   n_critic=5, lambda_gp=10):
    """
    Training loop for WGAN with gradient penalty.
    """
    g_optimizer = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.0, 0.9))
    c_optimizer = optim.Adam(critic.parameters(), lr=1e-4, betas=(0.0, 0.9))
    
    for epoch in range(num_epochs):
        for batch_idx, (real_images, _) in enumerate(dataloader):
            real_images = real_images.to(device)
            batch_size = real_images.size(0)
            
            # Train Critic (more often than generator)
            for _ in range(n_critic):
                c_optimizer.zero_grad()
                
                # Generate fake images
                noise = torch.randn(batch_size, latent_dim, device=device)
                fake_images = generator(noise)
                
                # Critic scores
                real_validity = critic(real_images)
                fake_validity = critic(fake_images.detach())
                
                # Gradient penalty
                gp = gradient_penalty(critic, real_images, fake_images, device)
                
                # Critic loss
                c_loss = wgan_critic_loss(real_validity, fake_validity) + lambda_gp * gp
                c_loss.backward()
                c_optimizer.step()
            
            # Train Generator
            g_optimizer.zero_grad()
            
            noise = torch.randn(batch_size, latent_dim, device=device)
            fake_images = generator(noise)
            fake_validity = critic(fake_images)
            
            g_loss = wgan_generator_loss(fake_validity)
            g_loss.backward()
            g_optimizer.step()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] | "
                  f"C Loss: {c_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

Conditional GANs (cGAN)

Conditional GAN Architecture
Conditional GANs allow us to control what we generate by conditioning on additional information (class labels, text, images). minGmaxDV(D,G)=Expdata[logD(xy)]+Ezpz[log(1D(G(zy)y))]\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}}[\log D(x|y)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z|y)|y))]
class ConditionalGenerator(nn.Module):
    """
    Generator conditioned on class labels.
    Generate specific digit: "Generate a 7"
    """
    def __init__(self, latent_dim=100, num_classes=10, embed_dim=50, img_size=28):
        super().__init__()
        self.latent_dim = latent_dim
        self.img_size = img_size
        
        # Embedding for class labels
        self.label_embedding = nn.Embedding(num_classes, embed_dim)
        
        # Input: noise + embedded label
        self.model = nn.Sequential(
            nn.Linear(latent_dim + embed_dim, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            
            nn.Linear(1024, img_size * img_size),
            nn.Tanh()
        )
    
    def forward(self, z, labels):
        # Embed labels
        label_embedding = self.label_embedding(labels)
        
        # Concatenate noise and label embedding
        gen_input = torch.cat([z, label_embedding], dim=1)
        
        return self.model(gen_input)


class ConditionalDiscriminator(nn.Module):
    """
    Discriminator that also receives class labels.
    """
    def __init__(self, num_classes=10, embed_dim=50, img_size=28):
        super().__init__()
        self.img_size = img_size
        
        # Embedding for class labels
        self.label_embedding = nn.Embedding(num_classes, embed_dim)
        
        # Input: flattened image + embedded label
        self.model = nn.Sequential(
            nn.Linear(img_size * img_size + embed_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img, labels):
        # Flatten image
        img_flat = img.view(img.size(0), -1)
        
        # Embed labels
        label_embedding = self.label_embedding(labels)
        
        # Concatenate image and label embedding
        d_input = torch.cat([img_flat, label_embedding], dim=1)
        
        return self.model(d_input)


# Example: Generate specific digits
def generate_digit(generator, digit, num_samples=16):
    """Generate samples of a specific digit."""
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_samples, latent_dim, device=device)
        labels = torch.full((num_samples,), digit, dtype=torch.long, device=device)
        generated = generator(noise, labels)
    return generated.view(-1, 1, 28, 28)


# Initialize conditional GAN
cond_generator = ConditionalGenerator().to(device)
cond_discriminator = ConditionalDiscriminator().to(device)

# Generate all digits 0-9
print("Conditional GAN allows generating specific classes!")
for digit in range(10):
    samples = generate_digit(cond_generator, digit, num_samples=8)
    print(f"Generated {samples.shape[0]} samples of digit {digit}")

GAN Evaluation Metrics

Evaluating GANs is Hard! Unlike supervised learning, there’s no clear “correct answer” to compare against.
MetricMeasuresFormula/Approach
Inception Score (IS)Quality + Diversity$\exp(\mathbb[D_(p(yx) | p(y))])$
FID (Fréchet Inception Distance)Feature similarityCompare Gaussian fits in feature space
LPIPSPerceptual similarityDistance in learned feature space
def calculate_fid_simple(real_features, fake_features):
    """
    Simplified FID calculation.
    Lower FID = Better quality and diversity.
    
    Full FID uses InceptionV3 features.
    """
    # Calculate statistics
    mu_real = np.mean(real_features, axis=0)
    sigma_real = np.cov(real_features, rowvar=False)
    
    mu_fake = np.mean(fake_features, axis=0)
    sigma_fake = np.cov(fake_features, rowvar=False)
    
    # Calculate FID
    diff = mu_real - mu_fake
    
    # Matrix square root
    from scipy.linalg import sqrtm
    covmean = sqrtm(sigma_real @ sigma_fake)
    
    # Handle numerical issues
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)
    
    return fid


# Practical evaluation
def evaluate_gan(generator, dataloader, num_samples=1000):
    """
    Evaluate GAN using multiple metrics.
    """
    generator.eval()
    
    # Collect real and fake samples
    real_samples = []
    fake_samples = []
    
    with torch.no_grad():
        for real_images, _ in dataloader:
            real_samples.append(real_images.view(real_images.size(0), -1))
            if len(real_samples) * real_images.size(0) >= num_samples:
                break
        
        noise = torch.randn(num_samples, latent_dim, device=device)
        fake_images = generator(noise)
        fake_samples.append(fake_images.cpu())
    
    real_samples = torch.cat(real_samples, dim=0)[:num_samples].numpy()
    fake_samples = torch.cat(fake_samples, dim=0).numpy()
    
    # Calculate metrics
    fid = calculate_fid_simple(real_samples, fake_samples)
    
    # Diversity (average pairwise distance)
    from scipy.spatial.distance import pdist
    diversity = np.mean(pdist(fake_samples[:100]))
    
    print(f"FID Score: {fid:.2f} (lower is better)")
    print(f"Diversity Score: {diversity:.4f} (higher is better)")
    
    return {"fid": fid, "diversity": diversity}

Exercises

Label smoothing can stabilize GAN training. Instead of using hard labels (0 and 1), use soft labels (e.g., 0.1 and 0.9).
# TODO: Modify the training loop to use label smoothing
# Real labels: 0.9 instead of 1.0
# Fake labels: 0.1 instead of 0.0

def train_with_label_smoothing(generator, discriminator, dataloader, num_epochs):
    # Your implementation here
    pass
Implement a training schedule that gradually increases image resolution.
# TODO: Start with 4x4 images and progressively grow to 64x64
# Hint: Add layers during training

class ProgressiveGenerator(nn.Module):
    def __init__(self):
        # Your implementation here
        pass
Spectral normalization constrains the Lipschitz constant of discriminator layers.
# TODO: Implement spectral normalization for discriminator weights
# Hint: Normalize weights by their largest singular value

class SpectralNormDiscriminator(nn.Module):
    # Your implementation here
    pass

Key Takeaways

What You Learned:
  • GAN Fundamentals - Generator creates, Discriminator classifies in a minimax game
  • Minimax Loss - Adversarial training objective and its components
  • Mode Collapse - Common failure mode and solutions (minibatch discrimination, feature matching)
  • DCGAN - Convolutional architecture guidelines for stable training
  • WGAN - Wasserstein distance for improved training stability
  • Conditional GANs - Control generation with class labels or other conditioning
  • Evaluation - FID, Inception Score, and diversity metrics

Common Pitfalls

GAN Training Mistakes to Avoid:
  1. Imbalanced training - Don’t let D or G become too strong too quickly
  2. Ignoring mode collapse - Monitor sample diversity throughout training
  3. Wrong learning rates - GANs are sensitive; start with proven hyperparameters
  4. Batch size too small - BatchNorm needs sufficient batch size (≥16)
  5. Not using proper initialization - DCGAN weight init matters significantly

Next: Autoencoders & VAEs

Learn about variational autoencoders and latent space representations