Skip to main content
Vision Transformers

Vision Transformers

From CNNs to Transformers

CNNs dominated vision for a decade with their inductive biases:
  • Locality (convolutions)
  • Translation equivariance
  • Hierarchical features
Vision Transformers (ViT) showed that pure attention can match or exceed CNNs, especially at scale.

ViT Architecture

ViT Architecture

Core Idea

Split image into patches → treat each patch as a token → apply transformer.

Patch Embedding

import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    """Split image into patches and embed them."""
    
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Conv layer = linear projection of flattened patches
        self.proj = nn.Conv2d(in_channels, embed_dim, 
                              kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        # x: (B, C, H, W) -> (B, embed_dim, H/P, W/P)
        x = self.proj(x)
        # Flatten spatial dimensions: (B, embed_dim, N) -> (B, N, embed_dim)
        x = x.flatten(2).transpose(1, 2)
        return x

Full ViT Implementation

class ViT(nn.Module):
    """Vision Transformer."""
    
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        dropout=0.1,
    ):
        super().__init__()
        
        # Patch embedding
        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches
        
        # Learnable [CLS] token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Positional embedding (learnable)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # (B, N, D)
        
        # Prepend [CLS] token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, N+1, D)
        
        # Add positional embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Classification on [CLS] token
        x = self.norm(x)
        return self.head(x[:, 0])


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(dropout),
        )
    
    def forward(self, x):
        # Pre-norm architecture
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

ViT Variants

DeiT (Data-efficient Image Transformer)

Training improvements for ViT without massive data:
# DeiT adds distillation token
class DeiT(ViT):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.head_dist = nn.Linear(self.embed_dim, self.num_classes)
    
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        dist_tokens = self.dist_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, dist_tokens, x], dim=1)
        
        x = x + self.pos_embed
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        return self.head(x[:, 0]), self.head_dist(x[:, 1])

Swin Transformer

Hierarchical with shifted windows for efficiency:
class WindowAttention(nn.Module):
    """Attention within local windows."""
    
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.window_size = window_size
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        
        # Relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) ** 2, num_heads)
        )
    
    def forward(self, x, mask=None):
        B_, N, C = x.shape  # N = window_size^2
        
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        # Add relative position bias
        attn = attn + self._get_relative_position_bias()
        
        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf'))
        
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        return self.proj(x)

Using Pretrained ViTs

import timm

# Load pretrained ViT
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)

# Common variants
models = {
    'vit_small': 'vit_small_patch16_224',
    'vit_base': 'vit_base_patch16_224',
    'vit_large': 'vit_large_patch16_224',
    'deit_small': 'deit_small_patch16_224',
    'swin_tiny': 'swin_tiny_patch4_window7_224',
    'swin_base': 'swin_base_patch4_window7_224',
}

# List all available ViT models
vit_models = timm.list_models('*vit*', pretrained=True)
print(f"Available ViT models: {len(vit_models)}")

ViT vs CNN Comparison

AspectCNNViT
Inductive biasStrong (locality)Weak (learns from data)
Data efficiencyBetter with small dataNeeds large data or distillation
ComputeEfficientO(N²) attention
ScalabilitySaturatesScales well
InterpretabilityFilter visualizationAttention maps

Visualizing Attention

def visualize_attention(model, image, patch_size=16):
    """Visualize attention from [CLS] token to patches."""
    model.eval()
    
    # Get attention weights from last layer
    attentions = []
    def get_attention(module, input, output):
        attentions.append(output[1])  # Attention weights
    
    # Register hook on last attention layer
    model.blocks[-1].attn.register_forward_hook(get_attention)
    
    with torch.no_grad():
        _ = model(image.unsqueeze(0))
    
    # Shape: (1, num_heads, N+1, N+1)
    attn = attentions[0]
    
    # Get attention from [CLS] to patches (average over heads)
    cls_attn = attn[0, :, 0, 1:].mean(0)  # (N,)
    
    # Reshape to image grid
    num_patches = int(cls_attn.shape[0] ** 0.5)
    attn_map = cls_attn.reshape(num_patches, num_patches)
    
    # Upsample to image size
    attn_map = F.interpolate(
        attn_map.unsqueeze(0).unsqueeze(0),
        size=(224, 224),
        mode='bilinear'
    )[0, 0]
    
    return attn_map.numpy()

Exercises

Implement a complete ViT and train it on CIFAR-10 with proper augmentation.
Visualize attention maps for different images. What does the model attend to?
Compare ViT and ResNet on the same dataset. Analyze accuracy vs compute tradeoffs.

What’s Next