Skip to main content
Autoencoder Architecture - Encoder Decoder Bottleneck

Autoencoders & Variational Autoencoders

The Bottleneck Concept

Imagine you need to describe a complex image using only 10 numbers. You’d have to capture the essential features and discard the noise. That’s exactly what an autoencoder does! An autoencoder learns to:
  1. Compress data into a lower-dimensional representation (encoding)
  2. Reconstruct the original data from this compressed form (decoding)
The magic happens in the bottleneck - a narrow layer that forces the network to learn efficient representations.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Standard Autoencoder

Standard Autoencoder Architecture
The basic autoencoder architecture:
ComponentFunctionShape (MNIST)
EncoderCompress input to latent space784 → 128 → 64 → 32
Latent SpaceCompressed representation32 dimensions
DecoderReconstruct from latent space32 → 64 → 128 → 784
class Autoencoder(nn.Module):
    """
    Standard Autoencoder with symmetric encoder-decoder.
    """
    def __init__(self, input_dim=784, latent_dim=32):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Encoder: input -> latent
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim)
        )
        
        # Decoder: latent -> output
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
            nn.Sigmoid()  # Output in [0, 1] for images
        )
    
    def encode(self, x):
        """Compress input to latent representation."""
        return self.encoder(x)
    
    def decode(self, z):
        """Reconstruct from latent representation."""
        return self.decoder(z)
    
    def forward(self, x):
        z = self.encode(x)
        return self.decode(z)


# Create autoencoder
autoencoder = Autoencoder(latent_dim=32).to(device)
print(f"Autoencoder parameters: {sum(p.numel() for p in autoencoder.parameters()):,}")

# Test forward pass
test_input = torch.randn(4, 784).to(device)
output = autoencoder(test_input)
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {output.shape}")
Output:
Autoencoder parameters: 380,752
Input shape: torch.Size([4, 784])
Output shape: torch.Size([4, 784])

Training the Autoencoder

The autoencoder is trained to minimize reconstruction loss - the difference between input and output. Lrecon=1Ni=1Nxix^i2\mathcal{L}_{recon} = \frac{1}{N} \sum_{i=1}^{N} \|x_i - \hat{x}_i\|^2 Where:
  • xix_i is the original input
  • x^i\hat{x}_i is the reconstructed output
def train_autoencoder(model, train_loader, num_epochs=20, lr=1e-3):
    """
    Train autoencoder with reconstruction loss.
    """
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()  # Mean Squared Error for reconstruction
    
    losses = []
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        
        for batch_idx, (data, _) in enumerate(train_loader):
            # Flatten images
            data = data.view(data.size(0), -1).to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            reconstructed = model(data)
            
            # Reconstruction loss
            loss = criterion(reconstructed, data)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        
        if (epoch + 1) % 5 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {avg_loss:.6f}")
    
    return losses


# Train the autoencoder
print("Training autoencoder...")
losses = train_autoencoder(autoencoder, train_loader, num_epochs=20)

# Plot loss curve
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Reconstruction Loss')
plt.title('Autoencoder Training Loss')
plt.grid(True)
plt.savefig('autoencoder_loss.png')
plt.close()
Output:
Training autoencoder...
Epoch [5/20] | Loss: 0.024512
Epoch [10/20] | Loss: 0.018234
Epoch [15/20] | Loss: 0.015891
Epoch [20/20] | Loss: 0.014523

Visualizing Reconstructions

Let’s see how well our autoencoder reconstructs images:
def visualize_reconstructions(model, test_loader, n_samples=10):
    """
    Compare original images with their reconstructions.
    """
    model.eval()
    
    # Get a batch of test images
    data, labels = next(iter(test_loader))
    data = data[:n_samples]
    
    with torch.no_grad():
        data_flat = data.view(data.size(0), -1).to(device)
        reconstructed = model(data_flat)
        reconstructed = reconstructed.view(-1, 1, 28, 28).cpu()
    
    # Plot original vs reconstructed
    fig, axes = plt.subplots(2, n_samples, figsize=(15, 3))
    
    for i in range(n_samples):
        # Original
        axes[0, i].imshow(data[i].squeeze(), cmap='gray')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Original', fontsize=12)
        
        # Reconstructed
        axes[1, i].imshow(reconstructed[i].squeeze(), cmap='gray')
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Reconstructed', fontsize=12)
    
    plt.tight_layout()
    plt.savefig('reconstructions.png')
    plt.close()
    print("Reconstructions saved to 'reconstructions.png'")


visualize_reconstructions(autoencoder, test_loader)

Latent Space Visualization

Latent Space Visualization
The latent space is where the magic happens. Let’s visualize it using t-SNE:
from sklearn.manifold import TSNE

def visualize_latent_space(model, test_loader, n_samples=3000):
    """
    Visualize latent space using t-SNE.
    """
    model.eval()
    
    latent_vectors = []
    labels_list = []
    
    with torch.no_grad():
        for data, labels in test_loader:
            data = data.view(data.size(0), -1).to(device)
            z = model.encode(data)
            latent_vectors.append(z.cpu())
            labels_list.append(labels)
            
            if sum(len(l) for l in latent_vectors) >= n_samples:
                break
    
    # Concatenate all latent vectors
    latent_vectors = torch.cat(latent_vectors, dim=0)[:n_samples].numpy()
    labels_list = torch.cat(labels_list, dim=0)[:n_samples].numpy()
    
    # Apply t-SNE
    print("Running t-SNE (this may take a minute)...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    latent_2d = tsne.fit_transform(latent_vectors)
    
    # Plot
    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(latent_2d[:, 0], latent_2d[:, 1], 
                         c=labels_list, cmap='tab10', alpha=0.6, s=5)
    plt.colorbar(scatter, label='Digit')
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.title('Latent Space Visualization (t-SNE)')
    plt.savefig('latent_space_tsne.png', dpi=150)
    plt.close()
    print("Latent space visualization saved!")


visualize_latent_space(autoencoder, test_loader)

Convolutional Autoencoder

For images, convolutional autoencoders preserve spatial structure:
class ConvAutoencoder(nn.Module):
    """
    Convolutional Autoencoder - better for images.
    Uses Conv2d for encoding and ConvTranspose2d for decoding.
    """
    def __init__(self, latent_dim=64):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Encoder: 28x28 -> 7x7 -> latent
        self.encoder = nn.Sequential(
            # 28x28 -> 14x14
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            
            # 14x14 -> 7x7
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            # Flatten
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, latent_dim)
        )
        
        # Decoder: latent -> 7x7 -> 28x28
        self.decoder_fc = nn.Linear(latent_dim, 64 * 7 * 7)
        
        self.decoder_conv = nn.Sequential(
            # 7x7 -> 14x14
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            
            # 14x14 -> 28x28
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
    
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, z):
        x = self.decoder_fc(z)
        x = x.view(-1, 64, 7, 7)
        return self.decoder_conv(x)
    
    def forward(self, x):
        z = self.encode(x)
        return self.decode(z)


# Create and test
conv_ae = ConvAutoencoder(latent_dim=64).to(device)
print(f"Conv Autoencoder parameters: {sum(p.numel() for p in conv_ae.parameters()):,}")

# Test with image batch
test_images = torch.randn(4, 1, 28, 28).to(device)
output = conv_ae(test_images)
print(f"Input shape: {test_images.shape}")
print(f"Output shape: {output.shape}")
Output:
Conv Autoencoder parameters: 285,793
Input shape: torch.Size([4, 1, 28, 28])
Output shape: torch.Size([4, 1, 28, 28])

Denoising Autoencoder

Denoising Autoencoder
A denoising autoencoder learns to remove noise from corrupted inputs: Ldenoise=xD(E(x~))2\mathcal{L}_{denoise} = \|x - D(E(\tilde{x}))\|^2 Where x~=x+ϵ\tilde{x} = x + \epsilon is the noisy input.
class DenoisingAutoencoder(nn.Module):
    """
    Denoising Autoencoder - learns to remove noise.
    """
    def __init__(self, input_dim=784, latent_dim=64, noise_factor=0.3):
        super().__init__()
        self.noise_factor = noise_factor
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
            nn.Sigmoid()
        )
    
    def add_noise(self, x):
        """Add Gaussian noise to input."""
        noise = torch.randn_like(x) * self.noise_factor
        noisy = x + noise
        return torch.clamp(noisy, 0, 1)  # Keep in valid range
    
    def forward(self, x, add_noise=True):
        if add_noise:
            x = self.add_noise(x)
        z = self.encoder(x)
        return self.decoder(z)


def train_denoising_ae(model, train_loader, num_epochs=20):
    """
    Train denoising autoencoder.
    """
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        
        for data, _ in train_loader:
            data = data.view(data.size(0), -1).to(device)
            
            optimizer.zero_grad()
            
            # Forward pass with noisy input
            noisy_data = model.add_noise(data)
            reconstructed = model(noisy_data, add_noise=False)
            
            # Loss against CLEAN data
            loss = criterion(reconstructed, data)
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        if (epoch + 1) % 5 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {epoch_loss/len(train_loader):.6f}")


# Train denoising autoencoder
denoising_ae = DenoisingAutoencoder(noise_factor=0.5).to(device)
print("Training Denoising Autoencoder...")
train_denoising_ae(denoising_ae, train_loader, num_epochs=20)


def visualize_denoising(model, test_loader, n_samples=10):
    """
    Show: Original -> Noisy -> Denoised
    """
    model.eval()
    
    data, _ = next(iter(test_loader))
    data = data[:n_samples]
    data_flat = data.view(data.size(0), -1).to(device)
    
    with torch.no_grad():
        noisy = model.add_noise(data_flat)
        denoised = model(noisy, add_noise=False)
    
    fig, axes = plt.subplots(3, n_samples, figsize=(15, 4.5))
    
    for i in range(n_samples):
        axes[0, i].imshow(data[i].squeeze(), cmap='gray')
        axes[0, i].axis('off')
        
        axes[1, i].imshow(noisy[i].cpu().view(28, 28), cmap='gray')
        axes[1, i].axis('off')
        
        axes[2, i].imshow(denoised[i].cpu().view(28, 28), cmap='gray')
        axes[2, i].axis('off')
    
    axes[0, 0].set_title('Original', fontsize=12)
    axes[1, 0].set_title('Noisy', fontsize=12)
    axes[2, 0].set_title('Denoised', fontsize=12)
    
    plt.tight_layout()
    plt.savefig('denoising_results.png')
    plt.close()
    print("Denoising results saved!")


visualize_denoising(denoising_ae, test_loader)

Variational Autoencoder (VAE)

Variational Autoencoder Architecture
VAEs are generative models that learn a probabilistic latent space. Instead of encoding to fixed points, VAEs encode to distributions.

Key Differences from Standard Autoencoders

AspectStandard AEVAE
Latent representationDeterministic pointProbability distribution
Encoder outputSingle vector zzMean μ\mu and variance σ2\sigma^2
SamplingNot applicableSample from N(μ,σ2)\mathcal{N}(\mu, \sigma^2)
GenerationPoorCan generate new samples

The VAE Objective: ELBO

The Evidence Lower BOund (ELBO) is: LVAE=Eq(zx)[logp(xz)]ReconstructionDKL(q(zx)p(z))KL Divergence\mathcal{L}_{VAE} = \underbrace{\mathbb{E}_{q(z|x)}[\log p(x|z)]}_{\text{Reconstruction}} - \underbrace{D_{KL}(q(z|x) \| p(z))}_{\text{KL Divergence}}
TermMeaningPurpose
Reconstruction termExpected log-likelihoodMake outputs similar to inputs
KL DivergenceDistance from priorKeep latent distribution close to N(0,I)\mathcal{N}(0, I)
The KL divergence for Gaussian has a closed form: DKL(q(zx)p(z))=12j=1J(1+log(σj2)μj2σj2)D_{KL}(q(z|x) \| p(z)) = -\frac{1}{2} \sum_{j=1}^{J}(1 + \log(\sigma_j^2) - \mu_j^2 - \sigma_j^2)
class VAE(nn.Module):
    """
    Variational Autoencoder.
    Encoder outputs mu and log_var (for numerical stability).
    """
    def __init__(self, input_dim=784, hidden_dim=256, latent_dim=20):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Encoder: x -> hidden
        self.encoder_layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Latent space parameters
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder: z -> x
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )
    
    def encode(self, x):
        """
        Encode input to latent distribution parameters.
        Returns: mu, log_var
        """
        h = self.encoder_layers(x)
        mu = self.fc_mu(h)
        log_var = self.fc_logvar(h)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        """
        Reparameterization trick: z = mu + std * epsilon
        Allows backpropagation through random sampling.
        """
        std = torch.exp(0.5 * log_var)
        epsilon = torch.randn_like(std)
        z = mu + std * epsilon
        return z
    
    def decode(self, z):
        """Decode latent vector to reconstruction."""
        return self.decoder(z)
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        reconstruction = self.decode(z)
        return reconstruction, mu, log_var


def vae_loss(reconstruction, original, mu, log_var, beta=1.0):
    """
    VAE Loss = Reconstruction Loss + β * KL Divergence
    
    Args:
        reconstruction: Decoder output
        original: Original input
        mu: Mean of latent distribution
        log_var: Log variance of latent distribution
        beta: Weight for KL divergence (beta-VAE)
    """
    # Reconstruction loss (binary cross entropy)
    recon_loss = F.binary_cross_entropy(reconstruction, original, reduction='sum')
    
    # KL divergence: -0.5 * sum(1 + log_var - mu^2 - exp(log_var))
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    
    return recon_loss + beta * kl_loss, recon_loss, kl_loss


# Create VAE
vae = VAE(latent_dim=20).to(device)
print(f"VAE parameters: {sum(p.numel() for p in vae.parameters()):,}")
Output:
VAE parameters: 474,260

The Reparameterization Trick

Reparameterization Trick
The reparameterization trick is key to training VAEs. Here’s why: Problem: We need to sample zN(μ,σ2)z \sim \mathcal{N}(\mu, \sigma^2), but sampling is not differentiable! Solution: Instead of sampling directly, we sample ϵN(0,I)\epsilon \sim \mathcal{N}(0, I) and compute: z=μ+σϵz = \mu + \sigma \odot \epsilon Now the gradient can flow through μ\mu and σ\sigma!
def visualize_reparameterization():
    """
    Demonstrate the reparameterization trick.
    """
    # Parameters (learned by encoder)
    mu = torch.tensor([2.0])
    log_var = torch.tensor([0.5])  # variance = exp(0.5) ≈ 1.65
    
    # Standard deviation
    std = torch.exp(0.5 * log_var)
    
    # Sample epsilon from N(0, 1)
    n_samples = 1000
    epsilon = torch.randn(n_samples)
    
    # Reparameterized samples
    z_samples = mu + std * epsilon
    
    # Verify distribution
    print(f"Target mean: {mu.item():.2f}, Sample mean: {z_samples.mean():.2f}")
    print(f"Target std: {std.item():.2f}, Sample std: {z_samples.std():.2f}")
    
    # Plot
    plt.figure(figsize=(10, 4))
    plt.hist(z_samples.numpy(), bins=50, density=True, alpha=0.7)
    plt.axvline(mu.item(), color='r', linestyle='--', label=f'μ = {mu.item():.1f}')
    plt.xlabel('z')
    plt.ylabel('Density')
    plt.title('Samples using Reparameterization Trick')
    plt.legend()
    plt.savefig('reparameterization_demo.png')
    plt.close()


visualize_reparameterization()
Output:
Target mean: 2.00, Sample mean: 2.01
Target std: 1.28, Sample std: 1.29

Training the VAE

def train_vae(model, train_loader, num_epochs=30, lr=1e-3, beta=1.0):
    """
    Train VAE with ELBO loss.
    """
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    train_losses = []
    recon_losses = []
    kl_losses = []
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_recon = 0
        epoch_kl = 0
        
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.view(data.size(0), -1).to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            reconstruction, mu, log_var = model(data)
            
            # Calculate loss
            loss, recon_loss, kl_loss = vae_loss(
                reconstruction, data, mu, log_var, beta=beta
            )
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            epoch_recon += recon_loss.item()
            epoch_kl += kl_loss.item()
        
        # Average losses
        n = len(train_loader.dataset)
        train_losses.append(epoch_loss / n)
        recon_losses.append(epoch_recon / n)
        kl_losses.append(epoch_kl / n)
        
        if (epoch + 1) % 5 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] | "
                  f"Total: {train_losses[-1]:.4f} | "
                  f"Recon: {recon_losses[-1]:.4f} | "
                  f"KL: {kl_losses[-1]:.4f}")
    
    return train_losses, recon_losses, kl_losses


# Train VAE
print("Training VAE...")
train_losses, recon_losses, kl_losses = train_vae(vae, train_loader, num_epochs=30)

# Plot losses
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.plot(train_losses)
plt.title('Total Loss')
plt.xlabel('Epoch')

plt.subplot(1, 3, 2)
plt.plot(recon_losses)
plt.title('Reconstruction Loss')
plt.xlabel('Epoch')

plt.subplot(1, 3, 3)
plt.plot(kl_losses)
plt.title('KL Divergence')
plt.xlabel('Epoch')

plt.tight_layout()
plt.savefig('vae_training.png')
plt.close()
Output:
Training VAE...
Epoch [5/30] | Total: 163.4521 | Recon: 151.2134 | KL: 12.2387
Epoch [10/30] | Total: 141.8934 | Recon: 127.5621 | KL: 14.3313
Epoch [15/30] | Total: 135.2178 | Recon: 119.8934 | KL: 15.3244
Epoch [20/30] | Total: 131.5623 | Recon: 115.4521 | KL: 16.1102
Epoch [25/30] | Total: 129.3421 | Recon: 112.8934 | KL: 16.4487
Epoch [30/30] | Total: 127.8912 | Recon: 111.2345 | KL: 16.6567

Generating New Samples

The true power of VAEs - generating new data by sampling from the latent space!
def generate_samples(model, n_samples=64):
    """
    Generate new samples by sampling from the prior N(0, I).
    """
    model.eval()
    
    with torch.no_grad():
        # Sample from standard normal
        z = torch.randn(n_samples, model.latent_dim).to(device)
        
        # Decode
        samples = model.decode(z)
        samples = samples.view(-1, 1, 28, 28).cpu()
    
    return samples


def visualize_generated(model, n_samples=64):
    """
    Display grid of generated samples.
    """
    samples = generate_samples(model, n_samples)
    
    # Create grid
    n_row = int(np.sqrt(n_samples))
    fig, axes = plt.subplots(n_row, n_row, figsize=(10, 10))
    
    for i, ax in enumerate(axes.flat):
        ax.imshow(samples[i].squeeze(), cmap='gray')
        ax.axis('off')
    
    plt.suptitle('VAE Generated Samples', fontsize=16)
    plt.tight_layout()
    plt.savefig('vae_generated.png')
    plt.close()
    print(f"Generated {n_samples} new samples!")


visualize_generated(vae, n_samples=64)

Latent Space Interpolation

Latent Space Interpolation
We can smoothly transition between images by interpolating in latent space:
def interpolate_latent(model, start_img, end_img, n_steps=10):
    """
    Interpolate between two images in latent space.
    """
    model.eval()
    
    with torch.no_grad():
        # Encode both images
        start_flat = start_img.view(1, -1).to(device)
        end_flat = end_img.view(1, -1).to(device)
        
        start_mu, start_logvar = model.encode(start_flat)
        end_mu, end_logvar = model.encode(end_flat)
        
        # Interpolate between means
        interpolations = []
        for alpha in np.linspace(0, 1, n_steps):
            z = (1 - alpha) * start_mu + alpha * end_mu
            decoded = model.decode(z)
            interpolations.append(decoded.view(28, 28).cpu())
    
    return interpolations


def visualize_interpolation(model, test_loader):
    """
    Show interpolation between two random digits.
    """
    # Get two different digits
    data, labels = next(iter(test_loader))
    
    # Find a 3 and a 7
    idx_3 = (labels == 3).nonzero()[0].item()
    idx_7 = (labels == 7).nonzero()[0].item()
    
    img_3 = data[idx_3]
    img_7 = data[idx_7]
    
    # Interpolate
    interpolations = interpolate_latent(vae, img_3, img_7, n_steps=10)
    
    # Plot
    fig, axes = plt.subplots(1, 10, figsize=(15, 1.5))
    
    for i, ax in enumerate(axes):
        ax.imshow(interpolations[i], cmap='gray')
        ax.axis('off')
        ax.set_title(f'{i/(len(axes)-1):.1f}')
    
    plt.suptitle('Latent Space Interpolation: 3 → 7', fontsize=14)
    plt.tight_layout()
    plt.savefig('interpolation.png')
    plt.close()
    print("Interpolation saved!")


visualize_interpolation(vae, test_loader)

Convolutional VAE

For better image generation, use convolutional layers:
class ConvVAE(nn.Module):
    """
    Convolutional VAE for better image generation.
    """
    def __init__(self, latent_dim=32):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder_conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # 28->14
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # 14->7
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # 7->4
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        
        self.flatten_size = 128 * 4 * 4
        
        self.fc_mu = nn.Linear(self.flatten_size, latent_dim)
        self.fc_logvar = nn.Linear(self.flatten_size, latent_dim)
        
        # Decoder
        self.decoder_fc = nn.Linear(latent_dim, self.flatten_size)
        
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=0),  # 4->7
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # 7->14
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # 14->28
            nn.Sigmoid()
        )
    
    def encode(self, x):
        h = self.encoder_conv(x)
        h = h.view(h.size(0), -1)
        return self.fc_mu(h), self.fc_logvar(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + std * eps
    
    def decode(self, z):
        h = self.decoder_fc(z)
        h = h.view(-1, 128, 4, 4)
        return self.decoder_conv(h)
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var


# Create Conv VAE
conv_vae = ConvVAE(latent_dim=32).to(device)
print(f"Conv VAE parameters: {sum(p.numel() for p in conv_vae.parameters()):,}")

# Test
test_batch = torch.randn(4, 1, 28, 28).to(device)
output, mu, logvar = conv_vae(test_batch)
print(f"Input: {test_batch.shape}, Output: {output.shape}")
print(f"Latent: mu {mu.shape}, logvar {logvar.shape}")

Beta-VAE: Disentangled Representations

Beta-VAE Disentanglement
Beta-VAE encourages disentangled representations by increasing the weight of KL divergence: LβVAE=E[logp(xz)]βDKL(q(zx)p(z))\mathcal{L}_{\beta-VAE} = \mathbb{E}[\log p(x|z)] - \beta \cdot D_{KL}(q(z|x) \| p(z))
β ValueEffect
β = 1Standard VAE
β > 1More disentanglement, less reconstruction quality
β < 1Better reconstruction, more entangled
def train_beta_vae(model, train_loader, num_epochs=30, beta=4.0):
    """
    Train Beta-VAE for disentangled representations.
    """
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        
        for data, _ in train_loader:
            if data.dim() == 4:
                data = data.to(device)
            else:
                data = data.view(-1, 784).to(device)
            
            optimizer.zero_grad()
            
            if hasattr(model, 'encoder_conv'):
                # Conv VAE
                reconstruction, mu, log_var = model(data)
                recon_loss = F.binary_cross_entropy(
                    reconstruction.view(-1), data.view(-1), reduction='sum'
                )
            else:
                # Linear VAE
                reconstruction, mu, log_var = model(data)
                recon_loss = F.binary_cross_entropy(
                    reconstruction, data.view(-1, 784), reduction='sum'
                )
            
            kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
            
            # Beta-weighted loss
            loss = recon_loss + beta * kl_loss
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] | "
                  f"Loss: {total_loss/len(train_loader.dataset):.4f}")
    
    return model


# Train Beta-VAE with different beta values
print("\nTraining Beta-VAE (β=4)...")
beta_vae = VAE(latent_dim=10).to(device)
train_beta_vae(beta_vae, train_loader, num_epochs=20, beta=4.0)

Exercises

Add an L1 sparsity constraint to encourage sparse representations.
# TODO: Modify the autoencoder to add sparsity regularization
# Hint: Add L1 penalty on latent activations

class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim=784, latent_dim=64, sparsity_weight=1e-3):
        # Your implementation here
        pass
Create a VAE that can generate specific digits by conditioning on class labels.
# TODO: Implement CVAE that takes class label as input
# The encoder and decoder should both receive the class information

class ConditionalVAE(nn.Module):
    def __init__(self, input_dim=784, latent_dim=20, num_classes=10):
        # Your implementation here
        pass
Vector Quantized VAE uses discrete latent codes instead of continuous.
# TODO: Implement the vector quantization layer
# Hint: Map continuous vectors to nearest codebook entries

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings=512, embedding_dim=64):
        # Your implementation here
        pass

Key Takeaways

What You Learned:
  • Autoencoders - Encoder-decoder architecture with bottleneck for compression
  • Latent Space - Lower-dimensional representation that captures essential features
  • Denoising AE - Learn to remove noise by training with corrupted inputs
  • VAE Theory - Probabilistic latent space with ELBO objective
  • KL Divergence - Regularizes latent space to match prior distribution
  • Reparameterization - Enables backpropagation through sampling
  • Generation - Sample from latent space to create new data
  • Beta-VAE - Control disentanglement with β hyperparameter

Common Pitfalls

Autoencoder Mistakes to Avoid:
  1. Latent dim too large - No compression = just memorization
  2. Latent dim too small - Poor reconstructions, lost information
  3. Ignoring KL collapse - VAE may ignore latent space; use β-VAE or KL annealing
  4. Wrong reconstruction loss - Use BCE for [0,1] images, MSE otherwise
  5. Not normalizing inputs - Autoencoders work best with normalized data

Next: Diffusion Models

Learn about the cutting-edge generative models behind Stable Diffusion and DALL-E