Skip to main content

Documentation Index

Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt

Use this file to discover all available pages before exploring further.

Semantic Segmentation

Semantic Segmentation

From Classification to Segmentation

Segmentation is classification taken to its extreme: instead of assigning one label to the entire image, you assign a label to every single pixel. This is the foundation for applications like self-driving cars (is this pixel road, sidewalk, or pedestrian?), medical imaging (is this pixel tumor or healthy tissue?), and satellite analysis (is this pixel forest, water, or urban?).
TaskOutputGranularity
ClassificationOne label per imageCoarsest
Object DetectionBoxes + labelsObject-level
Semantic SegmentationOne label per pixelPixel-level (no instance distinction)
Instance SegmentationObjects + per-pixel masksPixel-level (separates instances)
Panoptic SegmentationAll of the aboveThe complete picture
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Optional

torch.manual_seed(42)
Segmentation models are memory-hungry because they must maintain spatial resolution throughout the network (unlike classification models that aggressively pool). A single 512x512 input with a U-Net encoder can easily consume 8+ GB of GPU memory during training. If you hit OOM errors, reducing input resolution (e.g., from 512 to 256) has a much larger impact than reducing batch size because activation memory scales quadratically with spatial dimensions.

Fully Convolutional Networks (FCN)

The foundation of modern segmentation. Before FCN, people would apply classifiers to each pixel independently (or to sliding windows), which was absurdly slow. FCN’s breakthrough was realizing that you can replace the fully connected layers at the end of a classification network with convolutional layers, turning the entire network into a spatial-in, spatial-out function that processes the whole image in one forward pass:
class FCN(nn.Module):
    """
    Fully Convolutional Networks (Long et al., 2015).
    
    Key insight: Replace FC layers with conv layers to preserve spatial info.
    A classification network's FC layer that maps 512 features to 1000 classes
    is mathematically equivalent to a 1x1 conv with 1000 output channels.
    This means any classification network can be repurposed for segmentation
    simply by swapping FC layers for convs and adding upsampling.
    
    Pitfall: The naive FCN produces very coarse output because of all the
    pooling layers. FCN-8s (below) fixes this by fusing predictions from
    multiple scales.
    """
    
    def __init__(self, num_classes: int, backbone: str = 'vgg16'):
        super().__init__()
        
        # VGG-style backbone
        self.features = nn.Sequential(
            # Block 1
            self._make_block(3, 64, 2),
            nn.MaxPool2d(2, 2),
            
            # Block 2
            self._make_block(64, 128, 2),
            nn.MaxPool2d(2, 2),
            
            # Block 3
            self._make_block(128, 256, 3),
            nn.MaxPool2d(2, 2),
            
            # Block 4
            self._make_block(256, 512, 3),
            nn.MaxPool2d(2, 2),
            
            # Block 5
            self._make_block(512, 512, 3),
            nn.MaxPool2d(2, 2)
        )
        
        # FC layers converted to conv
        self.fc6 = nn.Conv2d(512, 4096, 7, padding=3)
        self.fc7 = nn.Conv2d(4096, 4096, 1)
        
        # Score layer
        self.score = nn.Conv2d(4096, num_classes, 1)
        
        # Upsample
        self.upsample = nn.ConvTranspose2d(
            num_classes, num_classes, 64, stride=32, padding=16
        )
    
    def _make_block(self, in_c: int, out_c: int, num_convs: int):
        layers = []
        for i in range(num_convs):
            layers.extend([
                nn.Conv2d(in_c if i == 0 else out_c, out_c, 3, padding=1),
                nn.ReLU(inplace=True)
            ])
        return nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, 3, H, W] input image
        
        Returns:
            segmentation: [B, C, H, W] per-pixel logits
        """
        input_size = x.shape[2:]
        
        x = self.features(x)
        x = F.relu(self.fc6(x))
        x = F.relu(self.fc7(x))
        x = self.score(x)
        
        # Upsample to input size
        x = self.upsample(x)
        x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)
        
        return x


class FCN8s(nn.Module):
    """
    FCN-8s: Multi-scale fusion for better boundaries.
    
    Combines predictions from pool3, pool4, and pool5.
    """
    
    def __init__(self, num_classes: int):
        super().__init__()
        
        # Backbone with skip connections
        self.conv1 = self._make_block(3, 64, 2)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.conv2 = self._make_block(64, 128, 2)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.conv3 = self._make_block(128, 256, 3)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        self.conv4 = self._make_block(256, 512, 3)
        self.pool4 = nn.MaxPool2d(2, 2)
        
        self.conv5 = self._make_block(512, 512, 3)
        self.pool5 = nn.MaxPool2d(2, 2)
        
        # FC layers as conv
        self.fc6 = nn.Conv2d(512, 4096, 7, padding=3)
        self.fc7 = nn.Conv2d(4096, 4096, 1)
        
        # Score layers for each scale
        self.score_fr = nn.Conv2d(4096, num_classes, 1)
        self.score_pool4 = nn.Conv2d(512, num_classes, 1)
        self.score_pool3 = nn.Conv2d(256, num_classes, 1)
        
        # Upsampling
        self.upscore2 = nn.ConvTranspose2d(num_classes, num_classes, 4, stride=2, padding=1)
        self.upscore4 = nn.ConvTranspose2d(num_classes, num_classes, 4, stride=2, padding=1)
        self.upscore8 = nn.ConvTranspose2d(num_classes, num_classes, 16, stride=8, padding=4)
    
    def _make_block(self, in_c, out_c, num_convs):
        layers = []
        for i in range(num_convs):
            layers.extend([
                nn.Conv2d(in_c if i == 0 else out_c, out_c, 3, padding=1),
                nn.ReLU(inplace=True)
            ])
        return nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        input_size = x.shape[2:]
        
        # Encoder
        x = self.pool1(self.conv1(x))
        x = self.pool2(self.conv2(x))
        pool3 = self.pool3(self.conv3(x))
        pool4 = self.pool4(self.conv4(pool3))
        pool5 = self.pool5(self.conv5(pool4))
        
        x = F.relu(self.fc6(pool5))
        x = F.relu(self.fc7(x))
        
        # Multi-scale fusion
        score_fr = self.score_fr(x)
        upscore2 = self.upscore2(score_fr)
        
        score_pool4 = self.score_pool4(pool4)
        fuse_pool4 = upscore2 + score_pool4
        upscore4 = self.upscore4(fuse_pool4)
        
        score_pool3 = self.score_pool3(pool3)
        fuse_pool3 = upscore4 + score_pool3
        
        out = self.upscore8(fuse_pool3)
        out = F.interpolate(out, size=input_size, mode='bilinear', align_corners=False)
        
        return out

U-Net

The most influential segmentation architecture, especially in medical imaging where labeled data is scarce. The name comes from its U-shaped architecture: the left side encodes (compresses) the input, the bottom is the bottleneck, and the right side decodes (expands) back to full resolution. The critical innovation is the skip connections that bridge corresponding encoder and decoder layers, allowing the decoder to use fine-grained spatial details from the encoder that would otherwise be lost during downsampling. Think of it like making a summary (encoding) and then expanding it back into a full essay (decoding) — without the skip connections, you would lose all the specific details. With them, the decoder can say “I know the big picture from the bottleneck features, and here are the precise boundaries from the encoder.”
class DoubleConv(nn.Module):
    """Double convolution block -- the basic unit of U-Net."""
    
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)


class UNet(nn.Module):
    """
    U-Net: Convolutional Networks for Biomedical Image Segmentation
    (Ronneberger et al., 2015).
    
    Key features:
    1. Symmetric encoder-decoder structure -- easy to reason about
    2. Skip connections preserve fine details (boundary pixels, textures)
    3. Works well with limited training data -- the original paper trained
       on just 30 images with heavy augmentation
    
    Practical tips:
    - U-Net is still competitive in 2024+ for medical and scientific imaging
    - For natural images, consider using a pretrained encoder (ResNet, etc.)
    - The feature list [64, 128, 256, 512, 1024] is a sensible default;
      halve it for smaller images or tighter memory budgets
    """
    
    def __init__(
        self,
        in_channels: int = 3,
        num_classes: int = 2,
        features: List[int] = [64, 128, 256, 512, 1024]
    ):
        super().__init__()
        
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.pool = nn.MaxPool2d(2, 2)
        
        # Encoder
        for feature in features:
            self.encoder.append(DoubleConv(in_channels, feature))
            in_channels = feature
        
        # Decoder
        for feature in reversed(features[:-1]):
            self.decoder.append(
                nn.ConvTranspose2d(feature * 2, feature, 2, stride=2)
            )
            self.decoder.append(DoubleConv(feature * 2, feature))
        
        # Final conv
        self.final_conv = nn.Conv2d(features[0], num_classes, 1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, C, H, W] input
        
        Returns:
            segmentation: [B, num_classes, H, W]
        """
        skip_connections = []
        
        # Encoder
        for encoder_block in self.encoder[:-1]:
            x = encoder_block(x)
            skip_connections.append(x)
            x = self.pool(x)
        
        x = self.encoder[-1](x)  # Bottleneck -- the deepest, most compressed representation
        
        skip_connections = skip_connections[::-1]  # Reverse to match decoder order
        
        # Decoder: mirror of encoder, but with skip connections
        # Each step: upsample -> concatenate skip connection -> double conv
        for i in range(0, len(self.decoder), 2):
            x = self.decoder[i](x)  # Upsample
            
            skip = skip_connections[i // 2]
            
            # Handle size mismatch
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
            
            x = torch.cat([skip, x], dim=1)  # Skip connection
            x = self.decoder[i + 1](x)  # Double conv
        
        return self.final_conv(x)


class AttentionUNet(nn.Module):
    """
    Attention U-Net: Learning Where to Look.
    
    Adds attention gates to focus on relevant regions.
    """
    
    def __init__(self, in_channels: int = 3, num_classes: int = 2):
        super().__init__()
        
        features = [64, 128, 256, 512, 1024]
        
        # Encoder
        self.encoder = nn.ModuleList([
            DoubleConv(in_channels, features[0]),
            DoubleConv(features[0], features[1]),
            DoubleConv(features[1], features[2]),
            DoubleConv(features[2], features[3]),
            DoubleConv(features[3], features[4])
        ])
        
        self.pool = nn.MaxPool2d(2, 2)
        
        # Decoder with attention
        self.upconv4 = nn.ConvTranspose2d(features[4], features[3], 2, stride=2)
        self.attn4 = AttentionGate(features[3], features[3], features[3] // 2)
        self.decoder4 = DoubleConv(features[4], features[3])
        
        self.upconv3 = nn.ConvTranspose2d(features[3], features[2], 2, stride=2)
        self.attn3 = AttentionGate(features[2], features[2], features[2] // 2)
        self.decoder3 = DoubleConv(features[3], features[2])
        
        self.upconv2 = nn.ConvTranspose2d(features[2], features[1], 2, stride=2)
        self.attn2 = AttentionGate(features[1], features[1], features[1] // 2)
        self.decoder2 = DoubleConv(features[2], features[1])
        
        self.upconv1 = nn.ConvTranspose2d(features[1], features[0], 2, stride=2)
        self.attn1 = AttentionGate(features[0], features[0], features[0] // 2)
        self.decoder1 = DoubleConv(features[1], features[0])
        
        self.final = nn.Conv2d(features[0], num_classes, 1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encoder
        e1 = self.encoder[0](x)
        e2 = self.encoder[1](self.pool(e1))
        e3 = self.encoder[2](self.pool(e2))
        e4 = self.encoder[3](self.pool(e3))
        e5 = self.encoder[4](self.pool(e4))
        
        # Decoder with attention
        d4 = self.upconv4(e5)
        e4_attn = self.attn4(d4, e4)
        d4 = self.decoder4(torch.cat([e4_attn, d4], dim=1))
        
        d3 = self.upconv3(d4)
        e3_attn = self.attn3(d3, e3)
        d3 = self.decoder3(torch.cat([e3_attn, d3], dim=1))
        
        d2 = self.upconv2(d3)
        e2_attn = self.attn2(d2, e2)
        d2 = self.decoder2(torch.cat([e2_attn, d2], dim=1))
        
        d1 = self.upconv1(d2)
        e1_attn = self.attn1(d1, e1)
        d1 = self.decoder1(torch.cat([e1_attn, d1], dim=1))
        
        return self.final(d1)


class AttentionGate(nn.Module):
    """
    Attention gate for Attention U-Net.
    
    The gating signal (g) from the decoder tells the skip connection (x)
    from the encoder which spatial regions to focus on. Think of it as
    the decoder saying: "I know from the global context that the object
    of interest is in the top-left -- encoder, please only send me the
    fine details from that region."
    """
    
    def __init__(self, g_channels: int, x_channels: int, inter_channels: int):
        super().__init__()
        
        self.W_g = nn.Conv2d(g_channels, inter_channels, 1, bias=False)
        self.W_x = nn.Conv2d(x_channels, inter_channels, 1, bias=False)
        
        self.psi = nn.Sequential(
            nn.Conv2d(inter_channels, 1, 1, bias=False),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            g: Gating signal (from decoder)
            x: Skip connection (from encoder)
        """
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        
        # Resize if needed
        if g1.shape[2:] != x1.shape[2:]:
            g1 = F.interpolate(g1, size=x1.shape[2:], mode='bilinear', align_corners=False)
        
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        
        return x * psi

DeepLab Series

DeepLab takes a different approach from the encoder-decoder style: instead of downsampling and then upsampling, use dilated (atrous) convolutions to maintain resolution while still capturing large-scale context. This avoids the information loss from repeated pooling.

Atrous (Dilated) Convolutions

class AtrousConv(nn.Module):
    """
    Atrous (Dilated) Convolution.
    
    Increases receptive field without losing resolution or adding parameters.
    
    Think of it like a regular convolution with "holes" (atrous = "with holes"
    in French): instead of looking at a contiguous 3x3 patch, the filter
    samples every 2nd (or 3rd, or 4th) pixel. A 3x3 kernel with dilation=2
    covers the same area as a 5x5 kernel but uses only 9 parameters.
    
    Practical tip: Use multiple dilation rates in parallel (ASPP, below)
    to capture objects at multiple scales simultaneously.
    """
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        dilation: int = 2
    ):
        super().__init__()
        
        padding = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
        
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size,
            padding=padding, dilation=dilation, bias=False
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.relu(self.bn(self.conv(x)))


class ASPP(nn.Module):
    """
    Atrous Spatial Pyramid Pooling (DeepLabV3).
    
    Parallel atrous convolutions at multiple dilation rates. The idea:
    objects appear at different scales in an image. A person close-up fills
    most of the image; a person far away is just a few pixels. ASPP captures
    context at multiple scales simultaneously by running several dilated
    convolutions in parallel (each with a different dilation rate) and
    concatenating their outputs. It also includes a global average pooling
    branch to capture image-level context ("this is an indoor scene").
    """
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int = 256,
        dilations: List[int] = [6, 12, 18]
    ):
        super().__init__()
        
        modules = []
        
        # 1x1 conv
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ))
        
        # Atrous convolutions at different rates
        for dilation in dilations:
            modules.append(nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=dilation, 
                          dilation=dilation, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ))
        
        # Global average pooling
        modules.append(nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ))
        
        self.convs = nn.ModuleList(modules)
        
        # Fusion
        self.project = nn.Sequential(
            nn.Conv2d(out_channels * (len(dilations) + 2), out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = []
        
        for conv in self.convs[:-1]:
            features.append(conv(x))
        
        # Global pooling branch
        global_feat = self.convs[-1](x)
        global_feat = F.interpolate(
            global_feat, size=x.shape[2:], mode='bilinear', align_corners=False
        )
        features.append(global_feat)
        
        x = torch.cat(features, dim=1)
        x = self.project(x)
        
        return x


class DeepLabV3Plus(nn.Module):
    """
    DeepLab V3+: Encoder-Decoder with Atrous Separable Convolution.
    
    Combines:
    - ASPP for multi-scale context
    - Decoder for sharp boundaries
    - Depthwise separable convolutions for efficiency
    """
    
    def __init__(
        self,
        backbone: nn.Module,
        num_classes: int,
        low_level_channels: int = 256,
        aspp_out_channels: int = 256
    ):
        super().__init__()
        
        self.backbone = backbone
        
        # ASPP
        self.aspp = ASPP(2048, aspp_out_channels)
        
        # Low-level feature projection
        self.low_level_conv = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(aspp_out_channels + 48, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, 1)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        input_size = x.shape[2:]
        
        # Backbone
        low_level_features, high_level_features = self.backbone(x)
        
        # ASPP
        aspp_out = self.aspp(high_level_features)
        
        # Upsample ASPP output
        aspp_out = F.interpolate(
            aspp_out, size=low_level_features.shape[2:],
            mode='bilinear', align_corners=False
        )
        
        # Process low-level features
        low_level_out = self.low_level_conv(low_level_features)
        
        # Concatenate and decode
        x = torch.cat([aspp_out, low_level_out], dim=1)
        x = self.decoder(x)
        
        # Upsample to input size
        x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)
        
        return x

Transformer-Based Segmentation

Transformers have entered segmentation with force. Their key advantage: global receptive field from layer 1. A CNN needs many layers of convolutions to “see” the full image, but a transformer’s self-attention can relate any two pixels in a single layer. This makes transformers particularly good at capturing long-range dependencies (e.g., understanding that a road continues behind a building).
Transformer-based segmentation models are significantly more data-hungry than CNN-based ones. A U-Net can achieve strong results on datasets as small as 30 images (with heavy augmentation). SegFormer or SETR typically require pretraining on ImageNet-21k and fine-tuning on thousands of labeled segmentation images. If your dataset is small (under 1,000 annotated images), a pretrained U-Net or DeepLabV3+ with a ResNet backbone is almost always the better choice.

SETR (Segmentation Transformer)

class SETR(nn.Module):
    """
    SETR: Rethinking Semantic Segmentation from a Sequence-to-Sequence
    Perspective (Zheng et al., 2021).
    
    Uses Vision Transformer as encoder. The image is split into patches,
    treated as a sequence, processed by a standard transformer, then decoded
    back to pixel-level predictions.
    
    Practical tip: SETR requires large-scale pretraining (ImageNet-21k or
    similar) to work well. Without it, the transformer encoder lacks the
    visual priors that CNNs get for free from their inductive bias.
    """
    
    def __init__(
        self,
        img_size: int = 512,
        patch_size: int = 16,
        num_classes: int = 21,
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12
    ):
        super().__init__()
        
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(3, embed_dim, patch_size, stride=patch_size)
        
        # Position embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        
        # Decoder (Progressive Upsampling)
        self.decoder = nn.Sequential(
            nn.Conv2d(embed_dim, 256, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, num_classes, 1)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        
        # Patch embedding
        x = self.patch_embed(x)  # [B, embed_dim, H/P, W/P]
        x = x.flatten(2).transpose(1, 2)  # [B, num_patches, embed_dim]
        
        # Add position embedding
        x = x + self.pos_embed
        
        # Transformer
        x = self.transformer(x)
        
        # Reshape to spatial
        h = w = int(self.num_patches ** 0.5)
        x = x.transpose(1, 2).view(B, -1, h, w)
        
        # Decode
        x = self.decoder(x)
        
        return x


class SegFormer(nn.Module):
    """
    SegFormer: Simple and Efficient Design for Semantic Segmentation
    (Xie et al., 2021).
    
    Key innovations:
    - Hierarchical transformer encoder (produces multi-scale features
      like a CNN, not flat tokens like ViT)
    - MLP decoder (shockingly simple but effective -- proves that most
      of the heavy lifting happens in the encoder)
    - No positional encoding (uses overlapping patch embeddings instead,
      which naturally encode position through the conv's local structure)
    
    Practical tip: SegFormer is arguably the best transformer-based
    segmentation model for practical use. It is simpler than SETR, faster
    than SegViT, and the MLP decoder makes it memory-efficient. The B0-B5
    variants offer a clean accuracy/speed tradeoff.
    """
    
    def __init__(
        self,
        num_classes: int = 21,
        embed_dims: List[int] = [64, 128, 320, 512],
        num_heads: List[int] = [1, 2, 5, 8],
        depths: List[int] = [3, 6, 40, 3]
    ):
        super().__init__()
        
        # Hierarchical encoder
        self.stages = nn.ModuleList()
        in_channels = 3
        
        for i, (embed_dim, num_head, depth) in enumerate(zip(embed_dims, num_heads, depths)):
            stage = MixTransformerBlock(
                in_channels=in_channels,
                embed_dim=embed_dim,
                num_heads=num_head,
                depth=depth,
                patch_size=7 if i == 0 else 3,
                stride=4 if i == 0 else 2
            )
            self.stages.append(stage)
            in_channels = embed_dim
        
        # All-MLP decoder
        self.decode_head = MLPDecoder(
            in_channels=embed_dims,
            embed_dim=768,
            num_classes=num_classes
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = []
        
        for stage in self.stages:
            x = stage(x)
            features.append(x)
        
        out = self.decode_head(features)
        
        return out


class MixTransformerBlock(nn.Module):
    """Mix Transformer block for SegFormer."""
    
    def __init__(
        self,
        in_channels: int,
        embed_dim: int,
        num_heads: int,
        depth: int,
        patch_size: int = 3,
        stride: int = 2
    ):
        super().__init__()
        
        # Overlapping patch embedding
        self.patch_embed = nn.Conv2d(
            in_channels, embed_dim, patch_size,
            stride=stride, padding=patch_size // 2
        )
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            EfficientSelfAttentionBlock(embed_dim, num_heads)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        
        x = x.flatten(2).transpose(1, 2)  # [B, H*W, C]
        
        for block in self.blocks:
            x = block(x, H, W)
        
        x = self.norm(x)
        x = x.transpose(1, 2).view(B, C, H, W)
        
        return x


class EfficientSelfAttentionBlock(nn.Module):
    """Efficient self-attention with reduction."""
    
    def __init__(self, dim: int, num_heads: int, sr_ratio: int = 4):
        super().__init__()
        
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        
        # Spatial reduction
        self.sr = nn.Conv2d(dim, dim, sr_ratio, stride=sr_ratio)
        self.sr_norm = nn.LayerNorm(dim)
        
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
    
    def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
        # Self-attention with spatial reduction
        residual = x
        x = self.norm1(x)
        
        # Reduce K, V spatial dimensions
        B, N, C = x.shape
        kv = x.transpose(1, 2).view(B, C, H, W)
        kv = self.sr(kv).flatten(2).transpose(1, 2)
        kv = self.sr_norm(kv)
        
        x, _ = self.attn(x, kv, kv)
        x = residual + x
        
        # MLP
        x = x + self.mlp(self.norm2(x))
        
        return x


class MLPDecoder(nn.Module):
    """Simple MLP decoder for SegFormer."""
    
    def __init__(
        self,
        in_channels: List[int],
        embed_dim: int,
        num_classes: int
    ):
        super().__init__()
        
        self.linear_fuse = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_c, embed_dim, 1),
                nn.BatchNorm2d(embed_dim)
            )
            for in_c in in_channels
        ])
        
        self.linear_pred = nn.Conv2d(embed_dim * len(in_channels), num_classes, 1)
    
    def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
        target_size = features[0].shape[2:]
        
        fused = []
        for i, (feat, linear) in enumerate(zip(features, self.linear_fuse)):
            feat = linear(feat)
            if feat.shape[2:] != target_size:
                feat = F.interpolate(feat, size=target_size, mode='bilinear', align_corners=False)
            fused.append(feat)
        
        x = torch.cat(fused, dim=1)
        x = self.linear_pred(x)
        
        return x

Instance & Panoptic Segmentation

class MaskRCNN(nn.Module):
    """
    Mask R-CNN: Instance Segmentation (He et al., 2017).
    
    Extends Faster R-CNN with a mask prediction head. The key difference
    from semantic segmentation: Mask R-CNN produces a separate binary mask
    for each detected object instance. Two people standing side by side
    get two distinct masks, not one "person" blob.
    
    Architecture: Faster R-CNN (detects boxes) + parallel mask branch
    that predicts a per-class binary mask for each detected region.
    
    Practical tip: The mask head operates on RoI-aligned features and
    predicts one mask per class (not per instance). The final mask is
    selected based on the classification head's prediction.
    """
    
    def __init__(
        self,
        backbone: nn.Module,
        num_classes: int,
        mask_dim: int = 256
    ):
        super().__init__()
        
        self.backbone = backbone
        
        # Mask head
        self.mask_head = nn.Sequential(
            nn.Conv2d(256, mask_dim, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(mask_dim, mask_dim, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(mask_dim, mask_dim, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(mask_dim, mask_dim, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(mask_dim, mask_dim, 2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(mask_dim, num_classes, 1)
        )
    
    def forward_mask(
        self,
        roi_features: torch.Tensor,
        labels: torch.Tensor
    ) -> torch.Tensor:
        """
        Predict masks for each RoI.
        
        Args:
            roi_features: [N, C, H, W] features from RoI align
            labels: [N] class labels for each RoI
        
        Returns:
            masks: [N, 1, H*2, W*2] predicted masks
        """
        masks = self.mask_head(roi_features)
        
        # Select mask for predicted class
        N = masks.size(0)
        masks = masks[torch.arange(N), labels]
        
        return masks.unsqueeze(1)


class PanopticFPN(nn.Module):
    """
    Panoptic FPN: Unified thing and stuff segmentation (Kirillov et al., 2019).
    
    Panoptic segmentation is the grand unification of segmentation tasks.
    It answers: "for every pixel, what class is it, and if it is a 'thing'
    (countable object like a car or person), which specific instance?"
    
    Combines:
    - Instance segmentation for "things" (countable: car, person, dog)
    - Semantic segmentation for "stuff" (uncountable: sky, road, grass)
    
    Practical tip: Panoptic segmentation metrics (PQ, SQ, RQ) are more
    informative than mIoU alone because they evaluate both recognition
    quality and segmentation quality separately.
    """
    
    def __init__(
        self,
        backbone: nn.Module,
        num_thing_classes: int,  # e.g., person, car
        num_stuff_classes: int   # e.g., sky, road
    ):
        super().__init__()
        
        self.backbone = backbone
        
        # FPN for multi-scale features
        self.fpn = nn.ModuleDict({
            'lateral': nn.ModuleList([
                nn.Conv2d(c, 256, 1) for c in [256, 512, 1024, 2048]
            ]),
            'output': nn.ModuleList([
                nn.Conv2d(256, 256, 3, padding=1) for _ in range(4)
            ])
        })
        
        # Semantic segmentation head (stuff)
        self.semantic_head = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1),
            nn.GroupNorm(32, 128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.GroupNorm(32, 128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, num_stuff_classes, 1)
        )
        
        # Instance head would be similar to Mask R-CNN
        self.num_thing_classes = num_thing_classes
    
    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Get backbone features at multiple scales
        features = self.backbone(x)  # [C1, C2, C3, C4]
        
        # FPN
        fpn_features = self._fpn_forward(features)
        
        # Semantic prediction (upsample all to same size and sum)
        semantic_features = []
        target_size = fpn_features[0].shape[2:]
        
        for feat in fpn_features:
            feat = F.interpolate(feat, size=target_size, mode='bilinear', align_corners=False)
            semantic_features.append(feat)
        
        semantic_feat = sum(semantic_features)
        semantic_out = self.semantic_head(semantic_feat)
        
        # Upsample to input resolution
        semantic_out = F.interpolate(
            semantic_out, scale_factor=4, mode='bilinear', align_corners=False
        )
        
        return {
            'semantic': semantic_out,
            'fpn_features': fpn_features  # For instance branch
        }
    
    def _fpn_forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
        """Build FPN from backbone features."""
        laterals = [lat(f) for f, lat in zip(features, self.fpn['lateral'])]
        
        # Top-down pathway
        for i in range(len(laterals) - 1, 0, -1):
            laterals[i-1] = laterals[i-1] + F.interpolate(
                laterals[i], size=laterals[i-1].shape[2:], mode='nearest'
            )
        
        outputs = [out(lat) for lat, out in zip(laterals, self.fpn['output'])]
        
        return outputs

Loss Functions

class SegmentationLosses:
    """Common loss functions for segmentation."""
    
    @staticmethod
    def cross_entropy(
        pred: torch.Tensor,
        target: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        ignore_index: int = 255
    ) -> torch.Tensor:
        """Standard cross-entropy loss."""
        return F.cross_entropy(pred, target, weight=weight, ignore_index=ignore_index)
    
    @staticmethod
    def dice_loss(
        pred: torch.Tensor,
        target: torch.Tensor,
        smooth: float = 1e-6
    ) -> torch.Tensor:
        """
        Dice Loss for imbalanced segmentation.
        
        Dice = 2|A∩B| / (|A| + |B|)
        """
        pred = F.softmax(pred, dim=1)
        num_classes = pred.size(1)
        
        # One-hot encode target
        target_one_hot = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float()
        
        intersection = (pred * target_one_hot).sum(dim=(2, 3))
        cardinality = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
        
        dice = (2 * intersection + smooth) / (cardinality + smooth)
        
        return 1 - dice.mean()
    
    @staticmethod
    def focal_loss(
        pred: torch.Tensor,
        target: torch.Tensor,
        alpha: float = 0.25,
        gamma: float = 2.0
    ) -> torch.Tensor:
        """Focal Loss for hard example mining."""
        ce_loss = F.cross_entropy(pred, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = alpha * (1 - pt) ** gamma * ce_loss
        return focal_loss.mean()
    
    @staticmethod
    def lovasz_softmax(
        pred: torch.Tensor,
        target: torch.Tensor
    ) -> torch.Tensor:
        """
        Lovász-Softmax loss.
        Directly optimizes IoU.
        """
        probas = F.softmax(pred, dim=1)
        
        # Flatten
        C = probas.size(1)
        losses = []
        
        for c in range(C):
            fg = (target == c).float()
            errors = (fg - probas[:, c]).abs()
            errors_sorted, perm = torch.sort(errors, descending=True)
            fg_sorted = fg[perm]
            losses.append(SegmentationLosses._lovasz_grad(fg_sorted) @ errors_sorted)
        
        return torch.stack(losses).mean()
    
    @staticmethod
    def _lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor:
        """Lovász gradient."""
        gts = gt_sorted.sum()
        intersection = gts - gt_sorted.cumsum(0)
        union = gts + (1 - gt_sorted).cumsum(0)
        jaccard = 1 - intersection / union
        jaccard[1:] = jaccard[1:] - jaccard[:-1]
        return jaccard

Exercises

Create a loss that focuses on boundaries:
def boundary_loss(pred, target):
    # Extract boundaries using Sobel/Laplacian
    # Weight loss higher at boundaries
    pass
Implement test-time augmentation for segmentation:
  • Inference at multiple scales
  • Flip augmentation
  • Ensemble predictions
Add Conditional Random Field refinement:
# Use pydensecrf or implement mean-field approximation
# Refine predictions based on appearance consistency

What’s Next?

Hyperparameter Tuning

Systematic optimization with Optuna and Ray Tune

Reproducibility

Experiment tracking and reproducible research