Autoencoders & Variational Autoencoders
The Bottleneck Concept
Imagine you need to describe a complex image using only 10 numbers. You’d have to capture the essential features and discard the noise. That’s exactly what an autoencoder does! An autoencoder learns to:- Compress data into a lower-dimensional representation (encoding)
- Reconstruct the original data from this compressed form (decoding)
Copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
# Set random seeds
torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load MNIST
transform = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
Standard Autoencoder
| Component | Function | Shape (MNIST) |
|---|---|---|
| Encoder | Compress input to latent space | 784 → 128 → 64 → 32 |
| Latent Space | Compressed representation | 32 dimensions |
| Decoder | Reconstruct from latent space | 32 → 64 → 128 → 784 |
Copy
class Autoencoder(nn.Module):
"""
Standard Autoencoder with symmetric encoder-decoder.
"""
def __init__(self, input_dim=784, latent_dim=32):
super().__init__()
self.latent_dim = latent_dim
# Encoder: input -> latent
self.encoder = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, latent_dim)
)
# Decoder: latent -> output
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, input_dim),
nn.Sigmoid() # Output in [0, 1] for images
)
def encode(self, x):
"""Compress input to latent representation."""
return self.encoder(x)
def decode(self, z):
"""Reconstruct from latent representation."""
return self.decoder(z)
def forward(self, x):
z = self.encode(x)
return self.decode(z)
# Create autoencoder
autoencoder = Autoencoder(latent_dim=32).to(device)
print(f"Autoencoder parameters: {sum(p.numel() for p in autoencoder.parameters()):,}")
# Test forward pass
test_input = torch.randn(4, 784).to(device)
output = autoencoder(test_input)
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {output.shape}")
Copy
Autoencoder parameters: 380,752
Input shape: torch.Size([4, 784])
Output shape: torch.Size([4, 784])
Training the Autoencoder
The autoencoder is trained to minimize reconstruction loss - the difference between input and output. Lrecon=N1i=1∑N∥xi−x^i∥2 Where:- xi is the original input
- x^i is the reconstructed output
Copy
def train_autoencoder(model, train_loader, num_epochs=20, lr=1e-3):
"""
Train autoencoder with reconstruction loss.
"""
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss() # Mean Squared Error for reconstruction
losses = []
for epoch in range(num_epochs):
model.train()
epoch_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
# Flatten images
data = data.view(data.size(0), -1).to(device)
optimizer.zero_grad()
# Forward pass
reconstructed = model(data)
# Reconstruction loss
loss = criterion(reconstructed, data)
# Backward pass
loss.backward()
optimizer.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / len(train_loader)
losses.append(avg_loss)
if (epoch + 1) % 5 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {avg_loss:.6f}")
return losses
# Train the autoencoder
print("Training autoencoder...")
losses = train_autoencoder(autoencoder, train_loader, num_epochs=20)
# Plot loss curve
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Reconstruction Loss')
plt.title('Autoencoder Training Loss')
plt.grid(True)
plt.savefig('autoencoder_loss.png')
plt.close()
Copy
Training autoencoder...
Epoch [5/20] | Loss: 0.024512
Epoch [10/20] | Loss: 0.018234
Epoch [15/20] | Loss: 0.015891
Epoch [20/20] | Loss: 0.014523
Visualizing Reconstructions
Let’s see how well our autoencoder reconstructs images:Copy
def visualize_reconstructions(model, test_loader, n_samples=10):
"""
Compare original images with their reconstructions.
"""
model.eval()
# Get a batch of test images
data, labels = next(iter(test_loader))
data = data[:n_samples]
with torch.no_grad():
data_flat = data.view(data.size(0), -1).to(device)
reconstructed = model(data_flat)
reconstructed = reconstructed.view(-1, 1, 28, 28).cpu()
# Plot original vs reconstructed
fig, axes = plt.subplots(2, n_samples, figsize=(15, 3))
for i in range(n_samples):
# Original
axes[0, i].imshow(data[i].squeeze(), cmap='gray')
axes[0, i].axis('off')
if i == 0:
axes[0, i].set_title('Original', fontsize=12)
# Reconstructed
axes[1, i].imshow(reconstructed[i].squeeze(), cmap='gray')
axes[1, i].axis('off')
if i == 0:
axes[1, i].set_title('Reconstructed', fontsize=12)
plt.tight_layout()
plt.savefig('reconstructions.png')
plt.close()
print("Reconstructions saved to 'reconstructions.png'")
visualize_reconstructions(autoencoder, test_loader)
Latent Space Visualization
Copy
from sklearn.manifold import TSNE
def visualize_latent_space(model, test_loader, n_samples=3000):
"""
Visualize latent space using t-SNE.
"""
model.eval()
latent_vectors = []
labels_list = []
with torch.no_grad():
for data, labels in test_loader:
data = data.view(data.size(0), -1).to(device)
z = model.encode(data)
latent_vectors.append(z.cpu())
labels_list.append(labels)
if sum(len(l) for l in latent_vectors) >= n_samples:
break
# Concatenate all latent vectors
latent_vectors = torch.cat(latent_vectors, dim=0)[:n_samples].numpy()
labels_list = torch.cat(labels_list, dim=0)[:n_samples].numpy()
# Apply t-SNE
print("Running t-SNE (this may take a minute)...")
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
latent_2d = tsne.fit_transform(latent_vectors)
# Plot
plt.figure(figsize=(12, 10))
scatter = plt.scatter(latent_2d[:, 0], latent_2d[:, 1],
c=labels_list, cmap='tab10', alpha=0.6, s=5)
plt.colorbar(scatter, label='Digit')
plt.xlabel('t-SNE Dimension 1')
plt.ylabel('t-SNE Dimension 2')
plt.title('Latent Space Visualization (t-SNE)')
plt.savefig('latent_space_tsne.png', dpi=150)
plt.close()
print("Latent space visualization saved!")
visualize_latent_space(autoencoder, test_loader)
Convolutional Autoencoder
For images, convolutional autoencoders preserve spatial structure:Copy
class ConvAutoencoder(nn.Module):
"""
Convolutional Autoencoder - better for images.
Uses Conv2d for encoding and ConvTranspose2d for decoding.
"""
def __init__(self, latent_dim=64):
super().__init__()
self.latent_dim = latent_dim
# Encoder: 28x28 -> 7x7 -> latent
self.encoder = nn.Sequential(
# 28x28 -> 14x14
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
# 14x14 -> 7x7
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
# Flatten
nn.Flatten(),
nn.Linear(64 * 7 * 7, latent_dim)
)
# Decoder: latent -> 7x7 -> 28x28
self.decoder_fc = nn.Linear(latent_dim, 64 * 7 * 7)
self.decoder_conv = nn.Sequential(
# 7x7 -> 14x14
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
# 14x14 -> 28x28
nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def encode(self, x):
return self.encoder(x)
def decode(self, z):
x = self.decoder_fc(z)
x = x.view(-1, 64, 7, 7)
return self.decoder_conv(x)
def forward(self, x):
z = self.encode(x)
return self.decode(z)
# Create and test
conv_ae = ConvAutoencoder(latent_dim=64).to(device)
print(f"Conv Autoencoder parameters: {sum(p.numel() for p in conv_ae.parameters()):,}")
# Test with image batch
test_images = torch.randn(4, 1, 28, 28).to(device)
output = conv_ae(test_images)
print(f"Input shape: {test_images.shape}")
print(f"Output shape: {output.shape}")
Copy
Conv Autoencoder parameters: 285,793
Input shape: torch.Size([4, 1, 28, 28])
Output shape: torch.Size([4, 1, 28, 28])
Denoising Autoencoder
Copy
class DenoisingAutoencoder(nn.Module):
"""
Denoising Autoencoder - learns to remove noise.
"""
def __init__(self, input_dim=784, latent_dim=64, noise_factor=0.3):
super().__init__()
self.noise_factor = noise_factor
self.encoder = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, latent_dim)
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, input_dim),
nn.Sigmoid()
)
def add_noise(self, x):
"""Add Gaussian noise to input."""
noise = torch.randn_like(x) * self.noise_factor
noisy = x + noise
return torch.clamp(noisy, 0, 1) # Keep in valid range
def forward(self, x, add_noise=True):
if add_noise:
x = self.add_noise(x)
z = self.encoder(x)
return self.decoder(z)
def train_denoising_ae(model, train_loader, num_epochs=20):
"""
Train denoising autoencoder.
"""
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
for epoch in range(num_epochs):
model.train()
epoch_loss = 0
for data, _ in train_loader:
data = data.view(data.size(0), -1).to(device)
optimizer.zero_grad()
# Forward pass with noisy input
noisy_data = model.add_noise(data)
reconstructed = model(noisy_data, add_noise=False)
# Loss against CLEAN data
loss = criterion(reconstructed, data)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if (epoch + 1) % 5 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {epoch_loss/len(train_loader):.6f}")
# Train denoising autoencoder
denoising_ae = DenoisingAutoencoder(noise_factor=0.5).to(device)
print("Training Denoising Autoencoder...")
train_denoising_ae(denoising_ae, train_loader, num_epochs=20)
def visualize_denoising(model, test_loader, n_samples=10):
"""
Show: Original -> Noisy -> Denoised
"""
model.eval()
data, _ = next(iter(test_loader))
data = data[:n_samples]
data_flat = data.view(data.size(0), -1).to(device)
with torch.no_grad():
noisy = model.add_noise(data_flat)
denoised = model(noisy, add_noise=False)
fig, axes = plt.subplots(3, n_samples, figsize=(15, 4.5))
for i in range(n_samples):
axes[0, i].imshow(data[i].squeeze(), cmap='gray')
axes[0, i].axis('off')
axes[1, i].imshow(noisy[i].cpu().view(28, 28), cmap='gray')
axes[1, i].axis('off')
axes[2, i].imshow(denoised[i].cpu().view(28, 28), cmap='gray')
axes[2, i].axis('off')
axes[0, 0].set_title('Original', fontsize=12)
axes[1, 0].set_title('Noisy', fontsize=12)
axes[2, 0].set_title('Denoised', fontsize=12)
plt.tight_layout()
plt.savefig('denoising_results.png')
plt.close()
print("Denoising results saved!")
visualize_denoising(denoising_ae, test_loader)
Variational Autoencoder (VAE)
Key Differences from Standard Autoencoders
| Aspect | Standard AE | VAE |
|---|---|---|
| Latent representation | Deterministic point | Probability distribution |
| Encoder output | Single vector z | Mean μ and variance σ2 |
| Sampling | Not applicable | Sample from N(μ,σ2) |
| Generation | Poor | Can generate new samples |
The VAE Objective: ELBO
The Evidence Lower BOund (ELBO) is: LVAE=ReconstructionEq(z∣x)[logp(x∣z)]−KL DivergenceDKL(q(z∣x)∥p(z))| Term | Meaning | Purpose |
|---|---|---|
| Reconstruction term | Expected log-likelihood | Make outputs similar to inputs |
| KL Divergence | Distance from prior | Keep latent distribution close to N(0,I) |
Copy
class VAE(nn.Module):
"""
Variational Autoencoder.
Encoder outputs mu and log_var (for numerical stability).
"""
def __init__(self, input_dim=784, hidden_dim=256, latent_dim=20):
super().__init__()
self.latent_dim = latent_dim
# Encoder: x -> hidden
self.encoder_layers = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# Latent space parameters
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# Decoder: z -> x
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid()
)
def encode(self, x):
"""
Encode input to latent distribution parameters.
Returns: mu, log_var
"""
h = self.encoder_layers(x)
mu = self.fc_mu(h)
log_var = self.fc_logvar(h)
return mu, log_var
def reparameterize(self, mu, log_var):
"""
Reparameterization trick: z = mu + std * epsilon
Allows backpropagation through random sampling.
"""
std = torch.exp(0.5 * log_var)
epsilon = torch.randn_like(std)
z = mu + std * epsilon
return z
def decode(self, z):
"""Decode latent vector to reconstruction."""
return self.decoder(z)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
reconstruction = self.decode(z)
return reconstruction, mu, log_var
def vae_loss(reconstruction, original, mu, log_var, beta=1.0):
"""
VAE Loss = Reconstruction Loss + β * KL Divergence
Args:
reconstruction: Decoder output
original: Original input
mu: Mean of latent distribution
log_var: Log variance of latent distribution
beta: Weight for KL divergence (beta-VAE)
"""
# Reconstruction loss (binary cross entropy)
recon_loss = F.binary_cross_entropy(reconstruction, original, reduction='sum')
# KL divergence: -0.5 * sum(1 + log_var - mu^2 - exp(log_var))
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return recon_loss + beta * kl_loss, recon_loss, kl_loss
# Create VAE
vae = VAE(latent_dim=20).to(device)
print(f"VAE parameters: {sum(p.numel() for p in vae.parameters()):,}")
Copy
VAE parameters: 474,260
The Reparameterization Trick
Copy
def visualize_reparameterization():
"""
Demonstrate the reparameterization trick.
"""
# Parameters (learned by encoder)
mu = torch.tensor([2.0])
log_var = torch.tensor([0.5]) # variance = exp(0.5) ≈ 1.65
# Standard deviation
std = torch.exp(0.5 * log_var)
# Sample epsilon from N(0, 1)
n_samples = 1000
epsilon = torch.randn(n_samples)
# Reparameterized samples
z_samples = mu + std * epsilon
# Verify distribution
print(f"Target mean: {mu.item():.2f}, Sample mean: {z_samples.mean():.2f}")
print(f"Target std: {std.item():.2f}, Sample std: {z_samples.std():.2f}")
# Plot
plt.figure(figsize=(10, 4))
plt.hist(z_samples.numpy(), bins=50, density=True, alpha=0.7)
plt.axvline(mu.item(), color='r', linestyle='--', label=f'μ = {mu.item():.1f}')
plt.xlabel('z')
plt.ylabel('Density')
plt.title('Samples using Reparameterization Trick')
plt.legend()
plt.savefig('reparameterization_demo.png')
plt.close()
visualize_reparameterization()
Copy
Target mean: 2.00, Sample mean: 2.01
Target std: 1.28, Sample std: 1.29
Training the VAE
Copy
def train_vae(model, train_loader, num_epochs=30, lr=1e-3, beta=1.0):
"""
Train VAE with ELBO loss.
"""
optimizer = optim.Adam(model.parameters(), lr=lr)
train_losses = []
recon_losses = []
kl_losses = []
for epoch in range(num_epochs):
model.train()
epoch_loss = 0
epoch_recon = 0
epoch_kl = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.view(data.size(0), -1).to(device)
optimizer.zero_grad()
# Forward pass
reconstruction, mu, log_var = model(data)
# Calculate loss
loss, recon_loss, kl_loss = vae_loss(
reconstruction, data, mu, log_var, beta=beta
)
# Backward pass
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_recon += recon_loss.item()
epoch_kl += kl_loss.item()
# Average losses
n = len(train_loader.dataset)
train_losses.append(epoch_loss / n)
recon_losses.append(epoch_recon / n)
kl_losses.append(epoch_kl / n)
if (epoch + 1) % 5 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] | "
f"Total: {train_losses[-1]:.4f} | "
f"Recon: {recon_losses[-1]:.4f} | "
f"KL: {kl_losses[-1]:.4f}")
return train_losses, recon_losses, kl_losses
# Train VAE
print("Training VAE...")
train_losses, recon_losses, kl_losses = train_vae(vae, train_loader, num_epochs=30)
# Plot losses
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.plot(train_losses)
plt.title('Total Loss')
plt.xlabel('Epoch')
plt.subplot(1, 3, 2)
plt.plot(recon_losses)
plt.title('Reconstruction Loss')
plt.xlabel('Epoch')
plt.subplot(1, 3, 3)
plt.plot(kl_losses)
plt.title('KL Divergence')
plt.xlabel('Epoch')
plt.tight_layout()
plt.savefig('vae_training.png')
plt.close()
Copy
Training VAE...
Epoch [5/30] | Total: 163.4521 | Recon: 151.2134 | KL: 12.2387
Epoch [10/30] | Total: 141.8934 | Recon: 127.5621 | KL: 14.3313
Epoch [15/30] | Total: 135.2178 | Recon: 119.8934 | KL: 15.3244
Epoch [20/30] | Total: 131.5623 | Recon: 115.4521 | KL: 16.1102
Epoch [25/30] | Total: 129.3421 | Recon: 112.8934 | KL: 16.4487
Epoch [30/30] | Total: 127.8912 | Recon: 111.2345 | KL: 16.6567
Generating New Samples
The true power of VAEs - generating new data by sampling from the latent space!Copy
def generate_samples(model, n_samples=64):
"""
Generate new samples by sampling from the prior N(0, I).
"""
model.eval()
with torch.no_grad():
# Sample from standard normal
z = torch.randn(n_samples, model.latent_dim).to(device)
# Decode
samples = model.decode(z)
samples = samples.view(-1, 1, 28, 28).cpu()
return samples
def visualize_generated(model, n_samples=64):
"""
Display grid of generated samples.
"""
samples = generate_samples(model, n_samples)
# Create grid
n_row = int(np.sqrt(n_samples))
fig, axes = plt.subplots(n_row, n_row, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
ax.imshow(samples[i].squeeze(), cmap='gray')
ax.axis('off')
plt.suptitle('VAE Generated Samples', fontsize=16)
plt.tight_layout()
plt.savefig('vae_generated.png')
plt.close()
print(f"Generated {n_samples} new samples!")
visualize_generated(vae, n_samples=64)
Latent Space Interpolation
Copy
def interpolate_latent(model, start_img, end_img, n_steps=10):
"""
Interpolate between two images in latent space.
"""
model.eval()
with torch.no_grad():
# Encode both images
start_flat = start_img.view(1, -1).to(device)
end_flat = end_img.view(1, -1).to(device)
start_mu, start_logvar = model.encode(start_flat)
end_mu, end_logvar = model.encode(end_flat)
# Interpolate between means
interpolations = []
for alpha in np.linspace(0, 1, n_steps):
z = (1 - alpha) * start_mu + alpha * end_mu
decoded = model.decode(z)
interpolations.append(decoded.view(28, 28).cpu())
return interpolations
def visualize_interpolation(model, test_loader):
"""
Show interpolation between two random digits.
"""
# Get two different digits
data, labels = next(iter(test_loader))
# Find a 3 and a 7
idx_3 = (labels == 3).nonzero()[0].item()
idx_7 = (labels == 7).nonzero()[0].item()
img_3 = data[idx_3]
img_7 = data[idx_7]
# Interpolate
interpolations = interpolate_latent(vae, img_3, img_7, n_steps=10)
# Plot
fig, axes = plt.subplots(1, 10, figsize=(15, 1.5))
for i, ax in enumerate(axes):
ax.imshow(interpolations[i], cmap='gray')
ax.axis('off')
ax.set_title(f'{i/(len(axes)-1):.1f}')
plt.suptitle('Latent Space Interpolation: 3 → 7', fontsize=14)
plt.tight_layout()
plt.savefig('interpolation.png')
plt.close()
print("Interpolation saved!")
visualize_interpolation(vae, test_loader)
Convolutional VAE
For better image generation, use convolutional layers:Copy
class ConvVAE(nn.Module):
"""
Convolutional VAE for better image generation.
"""
def __init__(self, latent_dim=32):
super().__init__()
self.latent_dim = latent_dim
# Encoder
self.encoder_conv = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1), # 28->14
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 14->7
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 7->4
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
)
self.flatten_size = 128 * 4 * 4
self.fc_mu = nn.Linear(self.flatten_size, latent_dim)
self.fc_logvar = nn.Linear(self.flatten_size, latent_dim)
# Decoder
self.decoder_fc = nn.Linear(latent_dim, self.flatten_size)
self.decoder_conv = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=0), # 4->7
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # 7->14
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 14->28
nn.Sigmoid()
)
def encode(self, x):
h = self.encoder_conv(x)
h = h.view(h.size(0), -1)
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + std * eps
def decode(self, z):
h = self.decoder_fc(z)
h = h.view(-1, 128, 4, 4)
return self.decoder_conv(h)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
# Create Conv VAE
conv_vae = ConvVAE(latent_dim=32).to(device)
print(f"Conv VAE parameters: {sum(p.numel() for p in conv_vae.parameters()):,}")
# Test
test_batch = torch.randn(4, 1, 28, 28).to(device)
output, mu, logvar = conv_vae(test_batch)
print(f"Input: {test_batch.shape}, Output: {output.shape}")
print(f"Latent: mu {mu.shape}, logvar {logvar.shape}")
Beta-VAE: Disentangled Representations
| β Value | Effect |
|---|---|
| β = 1 | Standard VAE |
| β > 1 | More disentanglement, less reconstruction quality |
| β < 1 | Better reconstruction, more entangled |
Copy
def train_beta_vae(model, train_loader, num_epochs=30, beta=4.0):
"""
Train Beta-VAE for disentangled representations.
"""
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(num_epochs):
model.train()
total_loss = 0
for data, _ in train_loader:
if data.dim() == 4:
data = data.to(device)
else:
data = data.view(-1, 784).to(device)
optimizer.zero_grad()
if hasattr(model, 'encoder_conv'):
# Conv VAE
reconstruction, mu, log_var = model(data)
recon_loss = F.binary_cross_entropy(
reconstruction.view(-1), data.view(-1), reduction='sum'
)
else:
# Linear VAE
reconstruction, mu, log_var = model(data)
recon_loss = F.binary_cross_entropy(
reconstruction, data.view(-1, 784), reduction='sum'
)
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
# Beta-weighted loss
loss = recon_loss + beta * kl_loss
loss.backward()
optimizer.step()
total_loss += loss.item()
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] | "
f"Loss: {total_loss/len(train_loader.dataset):.4f}")
return model
# Train Beta-VAE with different beta values
print("\nTraining Beta-VAE (β=4)...")
beta_vae = VAE(latent_dim=10).to(device)
train_beta_vae(beta_vae, train_loader, num_epochs=20, beta=4.0)
Exercises
Exercise 1: Implement Sparse Autoencoder
Exercise 1: Implement Sparse Autoencoder
Add an L1 sparsity constraint to encourage sparse representations.
Copy
# TODO: Modify the autoencoder to add sparsity regularization
# Hint: Add L1 penalty on latent activations
class SparseAutoencoder(nn.Module):
def __init__(self, input_dim=784, latent_dim=64, sparsity_weight=1e-3):
# Your implementation here
pass
Exercise 2: Implement Conditional VAE
Exercise 2: Implement Conditional VAE
Create a VAE that can generate specific digits by conditioning on class labels.
Copy
# TODO: Implement CVAE that takes class label as input
# The encoder and decoder should both receive the class information
class ConditionalVAE(nn.Module):
def __init__(self, input_dim=784, latent_dim=20, num_classes=10):
# Your implementation here
pass
Exercise 3: Implement VQ-VAE
Exercise 3: Implement VQ-VAE
Vector Quantized VAE uses discrete latent codes instead of continuous.
Copy
# TODO: Implement the vector quantization layer
# Hint: Map continuous vectors to nearest codebook entries
class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings=512, embedding_dim=64):
# Your implementation here
pass
Key Takeaways
What You Learned:
- ✅ Autoencoders - Encoder-decoder architecture with bottleneck for compression
- ✅ Latent Space - Lower-dimensional representation that captures essential features
- ✅ Denoising AE - Learn to remove noise by training with corrupted inputs
- ✅ VAE Theory - Probabilistic latent space with ELBO objective
- ✅ KL Divergence - Regularizes latent space to match prior distribution
- ✅ Reparameterization - Enables backpropagation through sampling
- ✅ Generation - Sample from latent space to create new data
- ✅ Beta-VAE - Control disentanglement with β hyperparameter
Common Pitfalls
Autoencoder Mistakes to Avoid:
- Latent dim too large - No compression = just memorization
- Latent dim too small - Poor reconstructions, lost information
- Ignoring KL collapse - VAE may ignore latent space; use β-VAE or KL annealing
- Wrong reconstruction loss - Use BCE for [0,1] images, MSE otherwise
- Not normalizing inputs - Autoencoders work best with normalized data
Next: Diffusion Models
Learn about the cutting-edge generative models behind Stable Diffusion and DALL-E