Skip to main content
Diffusion Models

Diffusion Models

The Core Idea

Diffusion models work by:
  1. Forward process: Gradually add noise to data until it becomes pure noise
  2. Reverse process: Learn to denoise step by step, recovering the original data
Think of it like this:
  • Forward: Dropping ink into water (ink diffuses until water is uniformly colored)
  • Reverse: Learning to “un-diffuse” the ink back to its original drop
Diffusion Process

Mathematical Foundation

Forward Diffusion (Adding Noise)

At each step tt, we add Gaussian noise: q(xtxt1)=N(xt;1βtxt1,βtI)q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I) Where βt\beta_t is the noise schedule. We can jump directly to any step: q(xtx0)=N(xt;αˉtx0,(1αˉt)I)q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t) I) Where αt=1βt\alpha_t = 1 - \beta_t and αˉt=s=1tαs\bar{\alpha}_t = \prod_{s=1}^{t} \alpha_s.
import torch
import torch.nn as nn
import numpy as np

class DiffusionSchedule:
    """Noise schedule for diffusion process."""
    
    def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02):
        self.timesteps = timesteps
        
        # Linear schedule
        self.betas = torch.linspace(beta_start, beta_end, timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
    
    def add_noise(self, x_0, t, noise=None):
        """Add noise to x_0 at timestep t."""
        if noise is None:
            noise = torch.randn_like(x_0)
        
        sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        
        return sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise

Reverse Process (Learning to Denoise)

We train a neural network ϵθ\epsilon_\theta to predict the noise added at step tt: L=Ex0,t,ϵ[ϵϵθ(xt,t)2]\mathcal{L} = \mathbb{E}_{x_0, t, \epsilon} \left[ \| \epsilon - \epsilon_\theta(x_t, t) \|^2 \right]
class SimpleDiffusion(nn.Module):
    """Simple U-Net style denoiser."""
    
    def __init__(self, channels=1, time_emb_dim=32):
        super().__init__()
        
        # Time embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(1, time_emb_dim),
            nn.GELU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )
        
        # Encoder
        self.enc1 = nn.Conv2d(channels, 64, 3, padding=1)
        self.enc2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.enc3 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
        
        # Decoder
        self.dec3 = nn.ConvTranspose2d(256 + time_emb_dim, 128, 4, stride=2, padding=1)
        self.dec2 = nn.ConvTranspose2d(256, 64, 4, stride=2, padding=1)
        self.dec1 = nn.Conv2d(128, channels, 3, padding=1)
    
    def forward(self, x, t):
        # Time embedding
        t_emb = self.time_mlp(t.float().unsqueeze(-1) / 1000)
        
        # Encode
        e1 = torch.relu(self.enc1(x))
        e2 = torch.relu(self.enc2(e1))
        e3 = torch.relu(self.enc3(e2))
        
        # Add time embedding
        t_emb = t_emb.view(t_emb.size(0), -1, 1, 1).expand(-1, -1, e3.size(2), e3.size(3))
        e3 = torch.cat([e3, t_emb], dim=1)
        
        # Decode with skip connections
        d3 = torch.relu(self.dec3(e3))
        d2 = torch.relu(self.dec2(torch.cat([d3, e2], dim=1)))
        d1 = self.dec1(torch.cat([d2, e1], dim=1))
        
        return d1

Training Loop

def train_diffusion(model, dataloader, schedule, epochs=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    for epoch in range(epochs):
        for batch in dataloader:
            x_0 = batch[0]
            batch_size = x_0.size(0)
            
            # Random timesteps
            t = torch.randint(0, schedule.timesteps, (batch_size,))
            
            # Add noise
            noise = torch.randn_like(x_0)
            x_t = schedule.add_noise(x_0, t, noise)
            
            # Predict noise
            predicted_noise = model(x_t, t)
            
            # Loss
            loss = nn.MSELoss()(predicted_noise, noise)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

Sampling (Generation)

@torch.no_grad()
def sample(model, schedule, shape, device='cpu'):
    """Generate samples by reverse diffusion."""
    # Start from pure noise
    x = torch.randn(shape).to(device)
    
    for t in reversed(range(schedule.timesteps)):
        t_batch = torch.tensor([t] * shape[0]).to(device)
        
        # Predict noise
        predicted_noise = model(x, t_batch)
        
        # Denoise step
        alpha = schedule.alphas[t]
        alpha_cumprod = schedule.alphas_cumprod[t]
        beta = schedule.betas[t]
        
        if t > 0:
            noise = torch.randn_like(x)
        else:
            noise = 0
        
        x = (1 / torch.sqrt(alpha)) * (
            x - (beta / torch.sqrt(1 - alpha_cumprod)) * predicted_noise
        ) + torch.sqrt(beta) * noise
    
    return x

Classifier-Free Guidance

Enables controlling generation with text or class labels: ϵθ(xt,c)=ϵθ(xt,)+s(ϵθ(xt,c)ϵθ(xt,))\epsilon_\theta(x_t, c) = \epsilon_\theta(x_t, \emptyset) + s \cdot (\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \emptyset)) Where s>1s > 1 is the guidance scale (typically 7.5 for Stable Diffusion).

Connection to Stable Diffusion

Stable Diffusion operates in latent space for efficiency:
  1. VAE Encoder: Compress 512×512 image to 64×64 latent
  2. U-Net: Denoise in latent space (much cheaper)
  3. VAE Decoder: Expand latent back to image
  4. CLIP Text Encoder: Condition on text prompts

Exercises

Train a diffusion model on MNIST. Generate digit samples and visualize the denoising process.
Implement and compare linear, cosine, and quadratic noise schedules.
Add class conditioning to generate specific digits.

What’s Next

Module 15: Residual & Skip Connections

Learn how to train very deep networks with identity mappings.