Skip to main content
Vision Transformers

Vision Transformers

From CNNs to Transformers

CNNs dominated vision for a decade with their inductive biases:
  • Locality (convolutions)
  • Translation equivariance
  • Hierarchical features
Vision Transformers (ViT) showed that pure attention can match or exceed CNNs, especially at scale.

ViT Architecture

ViT Architecture

Core Idea

Split image into patches → treat each patch as a token → apply transformer.

Patch Embedding

import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    """Split image into patches and embed them."""
    
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Conv layer = linear projection of flattened patches
        self.proj = nn.Conv2d(in_channels, embed_dim, 
                              kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        # x: (B, C, H, W) -> (B, embed_dim, H/P, W/P)
        x = self.proj(x)
        # Flatten spatial dimensions: (B, embed_dim, N) -> (B, N, embed_dim)
        x = x.flatten(2).transpose(1, 2)
        return x

Full ViT Implementation

class ViT(nn.Module):
    """Vision Transformer."""
    
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        dropout=0.1,
    ):
        super().__init__()
        
        # Patch embedding
        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches
        
        # Learnable [CLS] token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Positional embedding (learnable)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # (B, N, D)
        
        # Prepend [CLS] token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, N+1, D)
        
        # Add positional embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Classification on [CLS] token
        x = self.norm(x)
        return self.head(x[:, 0])


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(dropout),
        )
    
    def forward(self, x):
        # Pre-norm architecture
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

ViT Variants

DeiT (Data-efficient Image Transformer)

Training improvements for ViT without massive data:
# DeiT adds distillation token
class DeiT(ViT):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.head_dist = nn.Linear(self.embed_dim, self.num_classes)
    
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        dist_tokens = self.dist_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, dist_tokens, x], dim=1)
        
        x = x + self.pos_embed
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        return self.head(x[:, 0]), self.head_dist(x[:, 1])

Swin Transformer

Hierarchical with shifted windows for efficiency:
class WindowAttention(nn.Module):
    """Attention within local windows."""
    
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.window_size = window_size
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        
        # Relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) ** 2, num_heads)
        )
    
    def forward(self, x, mask=None):
        B_, N, C = x.shape  # N = window_size^2
        
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        # Add relative position bias
        attn = attn + self._get_relative_position_bias()
        
        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf'))
        
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        return self.proj(x)

Using Pretrained ViTs

import timm

# Load pretrained ViT
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)

# Common variants
models = {
    'vit_small': 'vit_small_patch16_224',
    'vit_base': 'vit_base_patch16_224',
    'vit_large': 'vit_large_patch16_224',
    'deit_small': 'deit_small_patch16_224',
    'swin_tiny': 'swin_tiny_patch4_window7_224',
    'swin_base': 'swin_base_patch4_window7_224',
}

# List all available ViT models
vit_models = timm.list_models('*vit*', pretrained=True)
print(f"Available ViT models: {len(vit_models)}")

ViT vs CNN Comparison

AspectCNNViT
Inductive biasStrong (locality)Weak (learns from data)
Data efficiencyBetter with small dataNeeds large data or distillation
ComputeEfficientO(N²) attention
ScalabilitySaturatesScales well
InterpretabilityFilter visualizationAttention maps

Visualizing Attention

def visualize_attention(model, image, patch_size=16):
    """Visualize attention from [CLS] token to patches."""
    model.eval()
    
    # Get attention weights from last layer
    attentions = []
    def get_attention(module, input, output):
        attentions.append(output[1])  # Attention weights
    
    # Register hook on last attention layer
    model.blocks[-1].attn.register_forward_hook(get_attention)
    
    with torch.no_grad():
        _ = model(image.unsqueeze(0))
    
    # Shape: (1, num_heads, N+1, N+1)
    attn = attentions[0]
    
    # Get attention from [CLS] to patches (average over heads)
    cls_attn = attn[0, :, 0, 1:].mean(0)  # (N,)
    
    # Reshape to image grid
    num_patches = int(cls_attn.shape[0] ** 0.5)
    attn_map = cls_attn.reshape(num_patches, num_patches)
    
    # Upsample to image size
    attn_map = F.interpolate(
        attn_map.unsqueeze(0).unsqueeze(0),
        size=(224, 224),
        mode='bilinear'
    )[0, 0]
    
    return attn_map.numpy()

Exercises

Implement a complete ViT and train it on CIFAR-10 with proper augmentation.
Visualize attention maps for different images. What does the model attend to?
Compare ViT and ResNet on the same dataset. Analyze accuracy vs compute tradeoffs.

What’s Next

Module 24: Multimodal Models

CLIP, vision-language models, and connecting modalities.