Skip to main content

Documentation Index

Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt

Use this file to discover all available pages before exploring further.

Vision Transformers

Vision Transformers

From CNNs to Transformers

CNNs dominated vision for a decade with their inductive biases:
  • Locality (convolutions): each filter only looks at a small spatial neighborhood, just as a doctor examines a tissue sample under a microscope rather than trying to see the whole body at once
  • Translation equivariance: a feature detector that recognizes a cat ear in the top-left corner will recognize it in the bottom-right corner — no need to re-learn the same pattern at every position
  • Hierarchical features: early layers detect edges, middle layers combine edges into textures and parts, and deep layers recognize objects
Vision Transformers (ViT) showed that pure attention can match or exceed CNNs, especially at scale. The key insight: CNNs assume locality is important and hardcode it into the architecture. ViTs make no such assumption — they let the model learn which spatial relationships matter from data. This is more flexible, but requires much more data to compensate for the lack of built-in inductive bias.
The critical trade-off: CNNs encode strong priors about images (locality, translation equivariance) that help with small datasets but limit the model’s ceiling. ViTs encode almost no image-specific priors, so they need 10-100x more data to learn what CNNs get for free, but they have a higher ceiling because they can discover spatial patterns that convolutions cannot express. This is why ViTs dominate at ImageNet-21k scale (14M images) but struggle at CIFAR-10 scale (50K images) without heavy augmentation or distillation.

ViT Architecture

ViT Architecture

Core Idea

Split image into patches, treat each patch as a token, apply a standard Transformer. The analogy: Imagine reading a book by cutting each page into a grid of sticky notes, then feeding all the sticky notes into a reading comprehension model. Each sticky note (patch) is a “word” in the model’s vocabulary. The self-attention mechanism lets every patch attend to every other patch, so a patch showing a wheel can directly attend to a patch showing a road — even if they’re on opposite sides of the image. A CNN would need many stacked layers to propagate information across that distance. The math: For a 224x224 image with 16x16 patches, you get (224/16)2=196(224/16)^2 = 196 tokens, each of dimension 16×16×3=76816 \times 16 \times 3 = 768 (flattened patch pixels). The self-attention cost is O(N2d)O(N^2 \cdot d) where N=196N = 196 — manageable. But try 4x4 patches on 224x224 and you get 3,136 tokens — attention becomes the bottleneck.

Patch Embedding

import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    """Split image into patches and embed them."""
    
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2  # 224/16 = 14, 14*14 = 196 patches
        
        # Clever trick: a Conv2d with kernel_size=stride=patch_size is mathematically
        # identical to extracting non-overlapping patches and projecting each through
        # a linear layer. But it's much faster because it leverages optimized CUDA kernels.
        self.proj = nn.Conv2d(in_channels, embed_dim, 
                              kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        # x: (B, C, H, W) -> (B, embed_dim, H/P, W/P) = (B, 768, 14, 14)
        x = self.proj(x)
        # Flatten spatial grid into a sequence: (B, 768, 196) -> (B, 196, 768)
        # Now each of the 196 patches is a 768-dim token, just like word embeddings in NLP
        x = x.flatten(2).transpose(1, 2)
        return x

Full ViT Implementation

class ViT(nn.Module):
    """Vision Transformer."""
    
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        dropout=0.1,
    ):
        super().__init__()
        
        # Patch embedding
        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches
        
        # Learnable [CLS] token -- a "summary" token that aggregates information
        # from all patches via attention. Classification is done on this token's
        # final representation, similar to BERT's [CLS] token in NLP.
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Positional embedding (learnable): without this, the model has no idea
        # about spatial arrangement -- it would treat a shuffled image identically
        # to the original. The +1 accounts for the [CLS] token.
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize with small values -- truncated normal prevents outliers
        # that could destabilize early training. 0.02 std is the ViT default.
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # (B, N, D)
        
        # Prepend [CLS] token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, N+1, D)
        
        # Add positional embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Classification on [CLS] token
        x = self.norm(x)
        return self.head(x[:, 0])


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(dropout),
        )
    
    def forward(self, x):
        # Pre-norm architecture: normalize BEFORE attention/MLP, not after.
        # This provides more stable gradients than post-norm (used in original Transformer).
        # The residual connection (x + ...) ensures gradient flow even through deep networks.
        normed = self.norm1(x)
        x = x + self.attn(normed, normed, normed)[0]  # Self-attention: Q=K=V
        x = x + self.mlp(self.norm2(x))                # Feedforward: expand then compress
        return x

ViT Variants

DeiT (Data-efficient Image Transformer)

ViT’s original paper required 300M images (JFT-300M) for strong results. DeiT, from Facebook AI, showed you can train ViTs competitively on ImageNet alone (1.2M images) using better training recipes and knowledge distillation. The key innovation is a distillation token that learns to mimic a CNN teacher’s predictions:
# DeiT adds a distillation token alongside the [CLS] token.
# The [CLS] token is trained with the ground-truth label (hard loss),
# while the distillation token is trained to match the CNN teacher (soft loss).
# At inference, both predictions are averaged for the final output.
class DeiT(ViT):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.head_dist = nn.Linear(self.embed_dim, self.num_classes)  # Separate head for distillation
    
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        dist_tokens = self.dist_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, dist_tokens, x], dim=1)
        
        x = x + self.pos_embed
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        return self.head(x[:, 0]), self.head_dist(x[:, 1])

Swin Transformer

The problem with standard ViT: self-attention over all 196 patches is O(1962)38KO(196^2) \approx 38K operations. Scale to 1024x1024 images with 16x16 patches and you get 4,096 tokens — attention costs O(40962)17MO(4096^2) \approx 17M per layer. That’s impractical for dense prediction tasks (segmentation, detection). Swin’s solution: Restrict attention to local windows (e.g., 7x7 patches), reducing complexity from O(N2)O(N^2) to O(NW2)O(N \cdot W^2) where W=49W = 49 is the window size. To allow cross-window communication, alternate layers shift the window grid by half a window — so patterns that span a window boundary in one layer fall inside a window in the next. This is hierarchical (like a CNN) but uses attention (like a ViT):
class WindowAttention(nn.Module):
    """Attention within local windows -- the core building block of Swin."""
    
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.window_size = window_size
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5  # 1/sqrt(d_k) scaling for stable attention
        
        self.qkv = nn.Linear(dim, dim * 3)  # Project to Q, K, V in one shot for efficiency
        self.proj = nn.Linear(dim, dim)
        
        # Relative position bias: unlike ViT's absolute positional embeddings,
        # Swin uses relative position bias -- the attention score between patches
        # depends on their relative position, not absolute location. This makes
        # the model naturally generalizable to different image sizes.
        # Table size: (2*W-1)^2 covers all possible relative positions in a WxW window
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) ** 2, num_heads)
        )
    
    def forward(self, x, mask=None):
        B_, N, C = x.shape  # N = window_size^2
        
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        # Add relative position bias
        attn = attn + self._get_relative_position_bias()
        
        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf'))
        
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        return self.proj(x)

Using Pretrained ViTs

import timm

# Load pretrained ViT
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)

# Common variants
models = {
    'vit_small': 'vit_small_patch16_224',
    'vit_base': 'vit_base_patch16_224',
    'vit_large': 'vit_large_patch16_224',
    'deit_small': 'deit_small_patch16_224',
    'swin_tiny': 'swin_tiny_patch4_window7_224',
    'swin_base': 'swin_base_patch4_window7_224',
}

# List all available ViT models
vit_models = timm.list_models('*vit*', pretrained=True)
print(f"Available ViT models: {len(vit_models)}")

ViT vs CNN Comparison

AspectCNNViT
Inductive biasStrong (locality, translation equivariance)Weak (learns spatial relationships from data)
Data efficiencyBetter with small data (10K-100K images)Needs large data (1M+) or distillation from a CNN teacher
ComputeEfficient; FLOPs scale linearly with image sizeO(N^2) attention; quadratic in token count
ScalabilityAccuracy saturates beyond ~300M paramsContinues improving with more params and data
InterpretabilityFilter visualization, GradCAMAttention maps show which patches attend to which
Resolution flexibilityWorks at any resolution nativelyRequires interpolating positional embeddings for new resolutions
Dense predictionNatural hierarchical features for detection/segmentationRequires modifications (Swin, ViTDet) for multi-scale features
A senior engineer’s rule of thumb: if your training set is under 1 million images and you can’t use pretrained weights, use a CNN (ResNet, EfficientNet). If you can fine-tune from ImageNet-21k or larger pretrained checkpoints, ViTs will likely outperform. For production systems, the choice often comes down to latency: CNNs are easier to optimize for edge deployment (TensorRT, CoreML), while ViTs shine on server-side inference where batch processing amortizes the attention cost.

Visualizing Attention

def visualize_attention(model, image, patch_size=16):
    """
    Visualize attention from [CLS] token to patches.
    This reveals what the model "looks at" when making a classification decision.
    High-attention patches typically correspond to the most discriminative image regions.
    """
    model.eval()
    
    # Get attention weights from last layer
    attentions = []
    def get_attention(module, input, output):
        attentions.append(output[1])  # Attention weights
    
    # Register hook on last attention layer
    model.blocks[-1].attn.register_forward_hook(get_attention)
    
    with torch.no_grad():
        _ = model(image.unsqueeze(0))
    
    # Shape: (1, num_heads, N+1, N+1)
    attn = attentions[0]
    
    # Get attention from [CLS] to all patches, averaged over attention heads.
    # Index 0 = [CLS] token, indices 1: = patch tokens.
    # Averaging over heads gives a single attention map; you can also inspect
    # individual heads to see different "attention specializations".
    cls_attn = attn[0, :, 0, 1:].mean(0)  # (N,)
    
    # Reshape to image grid
    num_patches = int(cls_attn.shape[0] ** 0.5)
    attn_map = cls_attn.reshape(num_patches, num_patches)
    
    # Upsample to image size
    attn_map = F.interpolate(
        attn_map.unsqueeze(0).unsqueeze(0),
        size=(224, 224),
        mode='bilinear'
    )[0, 0]
    
    return attn_map.numpy()

Exercises

Implement a complete ViT and train it on CIFAR-10 with proper augmentation.
Visualize attention maps for different images. What does the model attend to?
Compare ViT and ResNet on the same dataset. Analyze accuracy vs compute tradeoffs.

Training Tips

Practical ViT training advice from practitioners:
  • Data augmentation is non-negotiable. Without strong augmentation (RandAugment, Mixup, CutMix, random erasing), ViTs overfit severely on datasets under 1M images. DeiT’s success came largely from aggressive augmentation, not just distillation.
  • Learning rate warmup matters more than for CNNs. ViTs are sensitive to the initial learning rate. Use linear warmup for 5-10 epochs, then cosine decay. The peak learning rate for ViT-Base is typically 1e-3 with AdamW.
  • Weight decay is essential. Use AdamW with weight decay 0.05-0.3. ViTs have more parameters than equivalent CNNs and regularize less well from architecture alone.
  • Positional embedding interpolation. When fine-tuning at a different resolution than pretraining (e.g., pretrained at 224, fine-tuning at 384), bicubically interpolate the positional embeddings. This is imperfect but works well in practice — the model quickly adapts during fine-tuning.
  • Patch size trade-off. Smaller patches (8x8 instead of 16x16) give better accuracy but quadruple the token count, increasing attention cost 16x. For many tasks, 16x16 patches with a larger model are more efficient than 8x8 patches with a smaller model.

Interview Deep-Dive

Strong Answer:
  • ViTs split an image into fixed-size, non-overlapping patches (typically 16x16 pixels), flatten each patch into a vector, and linearly project it into an embedding space. These patch embeddings become the “tokens” for a standard Transformer encoder — identical to word tokens in NLP.
  • A learnable [CLS] token is prepended to the sequence. After passing through LL Transformer layers (self-attention + MLP), the [CLS] token’s representation is used for classification. The intuition: through self-attention, the [CLS] token aggregates information from all patches, producing a global image summary.
  • Why positional embeddings are necessary: self-attention is permutation-equivariant — swapping two tokens in the input swaps the corresponding outputs, but doesn’t change the attention computation. Without positional embeddings, the model cannot distinguish between an original image and a randomly shuffled version (same patches, different arrangement). Positional embeddings inject spatial order by adding a unique learned vector to each token based on its position.
  • Without positional embeddings: the model can still learn to classify some images (since patch content alone carries information), but accuracy drops by 3-5% on ImageNet. Interestingly, the learned positional embeddings exhibit a 2D spatial structure when visualized — nearby patches have similar positional embeddings, and the model effectively recovers a notion of spatial locality from data.
  • A senior engineer would note: the choice between learned absolute positional embeddings (ViT), sinusoidal (original Transformer), relative position bias (Swin), and rotary embeddings (RoPE, used in some recent vision models) significantly affects the model’s ability to generalize to new resolutions. Relative and rotary embeddings generalize better because they encode distances rather than absolute positions.
Follow-up: How do you fine-tune a ViT pretrained at 224x224 on 384x384 images?The patch embeddings work at any resolution (they’re just convolutions), but the positional embeddings have a fixed length (196 for 224/16). At 384x384, you get (384/16)2=576(384/16)^2 = 576 tokens. The standard approach is to bicubically interpolate the 14x14 positional embedding grid to 24x24, treating it as a 2D image. Then fine-tune for a few epochs — the model quickly adapts. This works because the positional embeddings are smooth in 2D space, making interpolation reasonable.
Strong Answer:
  • ViT (Dosovitskiy et al., 2020): the original vision transformer. Pure self-attention over all patches. Demonstrated that transformers can match CNNs on vision, but required pretraining on JFT-300M (300 million images). Without massive pretraining, ViT underperforms ResNets on ImageNet. Choose when: you have access to large-scale pretraining data or pretrained checkpoints, and your task is image classification.
  • DeiT (Touvron et al., 2021): same architecture as ViT, but with a better training recipe — strong data augmentation (RandAugment, Mixup, CutMix), knowledge distillation from a RegNet CNN teacher, and stochastic depth. Achieves 83.1% top-1 on ImageNet with only ImageNet-1K training data (1.2M images). The key insight: ViT’s poor performance on smaller datasets was a training problem, not an architecture problem. Choose when: you want ViT-level performance without JFT-scale pretraining data. Use the deit_* variants from timm as your go-to starting point.
  • Swin Transformer (Liu et al., 2021): introduces hierarchical feature maps (like a CNN) and local window attention with shifted windows. Attention is O(N)O(N) instead of O(N2)O(N^2) in image size, and the hierarchical structure produces multi-scale feature maps naturally. Choose when: you need a vision backbone for dense prediction tasks (object detection, semantic segmentation, instance segmentation) where multi-scale features are essential. Swin has largely replaced CNNs as the backbone in modern detection/segmentation frameworks (Mask R-CNN, Cascade R-CNN).
  • Decision framework in practice: for classification tasks with pretrained checkpoints, DeiT or plain ViT fine-tuning is hard to beat. For detection/segmentation, Swin or ViTDet (ViT adapted for detection with simple feature pyramids). For edge deployment where latency matters, EfficientNet or MobileNet may still win because ViTs are harder to quantize and optimize for mobile hardware.
Strong Answer:
  • Approach 1: Window attention (Swin). Restrict attention to local windows of size W×WW \times W. Complexity drops from O(N2)O(N^2) to O(NW2)O(N \cdot W^2), which is linear in image size. Shifted windows across alternating layers allow cross-window information flow. Sacrifice: global attention is only achieved after multiple layers of shifted windows; in early layers, distant patches cannot directly communicate. This is fine for most vision tasks but can hurt tasks requiring long-range pixel-level dependencies.
  • Approach 2: Linear attention / kernel approximation. Replace the softmax attention kernel exp(QKT/d)\exp(QK^T/\sqrt{d}) with a factored form ϕ(Q)ϕ(K)T\phi(Q)\phi(K)^T where ϕ\phi is a feature map (e.g., random Fourier features, ELU+1). This allows computing attention in O(Nd2)O(N \cdot d^2) instead of O(N2d)O(N^2 \cdot d), which is cheaper when dNd \ll N. Used in Performer, Linear Transformer. Sacrifice: the approximation of softmax attention is imperfect — sharp, peaked attention patterns (common in vision) are poorly approximated, leading to 1-3% accuracy drops on ImageNet.
  • Approach 3: Token reduction / token merging. Progressively reduce the number of tokens through the network. Methods include: average pooling of neighboring tokens (PoolFormer), learned token merging based on similarity (ToMe — Token Merging), or keeping only the top-K most informative tokens (DynamicViT). Sacrifice: spatial resolution is lost, which is acceptable for classification but problematic for dense prediction tasks that need per-pixel outputs.
  • A senior engineer would note: in practice, the O(N2)O(N^2) cost of ViT at 224x224 resolution (196 tokens) is already fast enough for most use cases — it’s comparable to a ResNet-50 in FLOPs. The efficiency question becomes critical at high resolutions (1024x1024+) or in video (where tokens multiply by frame count). FlashAttention (Dao et al., 2022) is often the most practical solution: it doesn’t change the asymptotic complexity but reduces memory usage from O(N2)O(N^2) to O(N)O(N) and achieves 2-4x wall-clock speedup through IO-aware tiling. It’s a systems optimization, not an algorithmic one, and it preserves exact attention semantics.
Follow-up: What about hybrid approaches like CoAtNet?CoAtNet (Dai et al., 2021) uses depthwise convolutions in early stages (where spatial resolution is high and locality is most useful) and self-attention in later stages (where the token count is reduced by pooling and global context matters more). This gives the best of both worlds: CNN-like efficiency and locality bias in early layers, Transformer-like global modeling in later layers. It set state-of-the-art on ImageNet at its release. The general principle — use the right inductive bias at the right scale — is increasingly how production vision architectures are designed.
Strong Answer:
  • Step 1: Model selection. Start with the smallest ViT variant that meets your accuracy requirement. DeiT-Tiny (5.7M params) or DeiT-Small (22M params) are good baselines. If latency is extremely tight, consider MobileViT, EfficientFormer, or FastViT — these are hybrid architectures specifically designed for mobile deployment.
  • Step 2: Knowledge distillation. Train a smaller student ViT using a larger pretrained ViT or CNN ensemble as the teacher. This typically recovers 60-80% of the accuracy gap between the small and large models. The DeiT distillation approach (using a distillation token) is particularly effective.
  • Step 3: Quantization. ViTs are harder to quantize than CNNs because attention logits can have high dynamic range. Use post-training quantization (PTQ) with careful calibration — quantize to INT8 for weights and activations. The LayerNorm and Softmax operations may need to stay in FP32 (mixed-precision quantization). Expect 1.5-2x speedup with less than 1% accuracy drop for INT8.
  • Step 4: Token reduction at inference. Apply Token Merging (ToMe) to reduce the number of tokens by 30-50% after the first few layers. Visually similar neighboring patches are merged (averaged), dramatically reducing the attention computation for deeper layers. This is a free lunch — 2x speedup with less than 0.5% accuracy loss on ImageNet.
  • Step 5: Export and runtime. Export via ONNX, then optimize with TensorRT (NVIDIA), CoreML (Apple), or TFLite (Android). These runtimes fuse operations (LayerNorm + Linear, attention kernel fusion) and use hardware-specific optimizations. On Apple devices with the Neural Engine, CoreML can run DeiT-Small at roughly 5ms per image on an iPhone 15.
  • Key metric: measure end-to-end latency on the target device, not just FLOPs. ViTs have different bottlenecks than CNNs — attention is memory-bandwidth bound (not compute bound), so reducing FLOPs doesn’t always translate linearly to latency gains.

What’s Next

Module 24: Multimodal Models

CLIP, vision-language models, and connecting modalities.