Skip to main content
Semantic Segmentation

Semantic Segmentation

From Classification to Segmentation

TaskOutput
ClassificationOne label per image
Object DetectionBoxes + labels
Semantic SegmentationOne label per pixel
Instance SegmentationObjects + per-pixel masks
Panoptic SegmentationAll of the above
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Optional

torch.manual_seed(42)

Fully Convolutional Networks (FCN)

The foundation of modern segmentation:
class FCN(nn.Module):
    """
    Fully Convolutional Networks (Long et al., 2015).
    
    Key insight: Replace FC layers with conv layers to preserve spatial info.
    """
    
    def __init__(self, num_classes: int, backbone: str = 'vgg16'):
        super().__init__()
        
        # VGG-style backbone
        self.features = nn.Sequential(
            # Block 1
            self._make_block(3, 64, 2),
            nn.MaxPool2d(2, 2),
            
            # Block 2
            self._make_block(64, 128, 2),
            nn.MaxPool2d(2, 2),
            
            # Block 3
            self._make_block(128, 256, 3),
            nn.MaxPool2d(2, 2),
            
            # Block 4
            self._make_block(256, 512, 3),
            nn.MaxPool2d(2, 2),
            
            # Block 5
            self._make_block(512, 512, 3),
            nn.MaxPool2d(2, 2)
        )
        
        # FC layers converted to conv
        self.fc6 = nn.Conv2d(512, 4096, 7, padding=3)
        self.fc7 = nn.Conv2d(4096, 4096, 1)
        
        # Score layer
        self.score = nn.Conv2d(4096, num_classes, 1)
        
        # Upsample
        self.upsample = nn.ConvTranspose2d(
            num_classes, num_classes, 64, stride=32, padding=16
        )
    
    def _make_block(self, in_c: int, out_c: int, num_convs: int):
        layers = []
        for i in range(num_convs):
            layers.extend([
                nn.Conv2d(in_c if i == 0 else out_c, out_c, 3, padding=1),
                nn.ReLU(inplace=True)
            ])
        return nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, 3, H, W] input image
        
        Returns:
            segmentation: [B, C, H, W] per-pixel logits
        """
        input_size = x.shape[2:]
        
        x = self.features(x)
        x = F.relu(self.fc6(x))
        x = F.relu(self.fc7(x))
        x = self.score(x)
        
        # Upsample to input size
        x = self.upsample(x)
        x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)
        
        return x


class FCN8s(nn.Module):
    """
    FCN-8s: Multi-scale fusion for better boundaries.
    
    Combines predictions from pool3, pool4, and pool5.
    """
    
    def __init__(self, num_classes: int):
        super().__init__()
        
        # Backbone with skip connections
        self.conv1 = self._make_block(3, 64, 2)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.conv2 = self._make_block(64, 128, 2)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.conv3 = self._make_block(128, 256, 3)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        self.conv4 = self._make_block(256, 512, 3)
        self.pool4 = nn.MaxPool2d(2, 2)
        
        self.conv5 = self._make_block(512, 512, 3)
        self.pool5 = nn.MaxPool2d(2, 2)
        
        # FC layers as conv
        self.fc6 = nn.Conv2d(512, 4096, 7, padding=3)
        self.fc7 = nn.Conv2d(4096, 4096, 1)
        
        # Score layers for each scale
        self.score_fr = nn.Conv2d(4096, num_classes, 1)
        self.score_pool4 = nn.Conv2d(512, num_classes, 1)
        self.score_pool3 = nn.Conv2d(256, num_classes, 1)
        
        # Upsampling
        self.upscore2 = nn.ConvTranspose2d(num_classes, num_classes, 4, stride=2, padding=1)
        self.upscore4 = nn.ConvTranspose2d(num_classes, num_classes, 4, stride=2, padding=1)
        self.upscore8 = nn.ConvTranspose2d(num_classes, num_classes, 16, stride=8, padding=4)
    
    def _make_block(self, in_c, out_c, num_convs):
        layers = []
        for i in range(num_convs):
            layers.extend([
                nn.Conv2d(in_c if i == 0 else out_c, out_c, 3, padding=1),
                nn.ReLU(inplace=True)
            ])
        return nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        input_size = x.shape[2:]
        
        # Encoder
        x = self.pool1(self.conv1(x))
        x = self.pool2(self.conv2(x))
        pool3 = self.pool3(self.conv3(x))
        pool4 = self.pool4(self.conv4(pool3))
        pool5 = self.pool5(self.conv5(pool4))
        
        x = F.relu(self.fc6(pool5))
        x = F.relu(self.fc7(x))
        
        # Multi-scale fusion
        score_fr = self.score_fr(x)
        upscore2 = self.upscore2(score_fr)
        
        score_pool4 = self.score_pool4(pool4)
        fuse_pool4 = upscore2 + score_pool4
        upscore4 = self.upscore4(fuse_pool4)
        
        score_pool3 = self.score_pool3(pool3)
        fuse_pool3 = upscore4 + score_pool3
        
        out = self.upscore8(fuse_pool3)
        out = F.interpolate(out, size=input_size, mode='bilinear', align_corners=False)
        
        return out

U-Net

The encoder-decoder architecture with skip connections:
class DoubleConv(nn.Module):
    """Double convolution block."""
    
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)


class UNet(nn.Module):
    """
    U-Net: Convolutional Networks for Biomedical Image Segmentation.
    
    Key features:
    1. Symmetric encoder-decoder structure
    2. Skip connections preserve fine details
    3. Works well with limited training data
    """
    
    def __init__(
        self,
        in_channels: int = 3,
        num_classes: int = 2,
        features: List[int] = [64, 128, 256, 512, 1024]
    ):
        super().__init__()
        
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.pool = nn.MaxPool2d(2, 2)
        
        # Encoder
        for feature in features:
            self.encoder.append(DoubleConv(in_channels, feature))
            in_channels = feature
        
        # Decoder
        for feature in reversed(features[:-1]):
            self.decoder.append(
                nn.ConvTranspose2d(feature * 2, feature, 2, stride=2)
            )
            self.decoder.append(DoubleConv(feature * 2, feature))
        
        # Final conv
        self.final_conv = nn.Conv2d(features[0], num_classes, 1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, C, H, W] input
        
        Returns:
            segmentation: [B, num_classes, H, W]
        """
        skip_connections = []
        
        # Encoder
        for encoder_block in self.encoder[:-1]:
            x = encoder_block(x)
            skip_connections.append(x)
            x = self.pool(x)
        
        x = self.encoder[-1](x)  # Bottleneck
        
        skip_connections = skip_connections[::-1]  # Reverse
        
        # Decoder
        for i in range(0, len(self.decoder), 2):
            x = self.decoder[i](x)  # Upsample
            
            skip = skip_connections[i // 2]
            
            # Handle size mismatch
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
            
            x = torch.cat([skip, x], dim=1)  # Skip connection
            x = self.decoder[i + 1](x)  # Double conv
        
        return self.final_conv(x)


class AttentionUNet(nn.Module):
    """
    Attention U-Net: Learning Where to Look.
    
    Adds attention gates to focus on relevant regions.
    """
    
    def __init__(self, in_channels: int = 3, num_classes: int = 2):
        super().__init__()
        
        features = [64, 128, 256, 512, 1024]
        
        # Encoder
        self.encoder = nn.ModuleList([
            DoubleConv(in_channels, features[0]),
            DoubleConv(features[0], features[1]),
            DoubleConv(features[1], features[2]),
            DoubleConv(features[2], features[3]),
            DoubleConv(features[3], features[4])
        ])
        
        self.pool = nn.MaxPool2d(2, 2)
        
        # Decoder with attention
        self.upconv4 = nn.ConvTranspose2d(features[4], features[3], 2, stride=2)
        self.attn4 = AttentionGate(features[3], features[3], features[3] // 2)
        self.decoder4 = DoubleConv(features[4], features[3])
        
        self.upconv3 = nn.ConvTranspose2d(features[3], features[2], 2, stride=2)
        self.attn3 = AttentionGate(features[2], features[2], features[2] // 2)
        self.decoder3 = DoubleConv(features[3], features[2])
        
        self.upconv2 = nn.ConvTranspose2d(features[2], features[1], 2, stride=2)
        self.attn2 = AttentionGate(features[1], features[1], features[1] // 2)
        self.decoder2 = DoubleConv(features[2], features[1])
        
        self.upconv1 = nn.ConvTranspose2d(features[1], features[0], 2, stride=2)
        self.attn1 = AttentionGate(features[0], features[0], features[0] // 2)
        self.decoder1 = DoubleConv(features[1], features[0])
        
        self.final = nn.Conv2d(features[0], num_classes, 1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encoder
        e1 = self.encoder[0](x)
        e2 = self.encoder[1](self.pool(e1))
        e3 = self.encoder[2](self.pool(e2))
        e4 = self.encoder[3](self.pool(e3))
        e5 = self.encoder[4](self.pool(e4))
        
        # Decoder with attention
        d4 = self.upconv4(e5)
        e4_attn = self.attn4(d4, e4)
        d4 = self.decoder4(torch.cat([e4_attn, d4], dim=1))
        
        d3 = self.upconv3(d4)
        e3_attn = self.attn3(d3, e3)
        d3 = self.decoder3(torch.cat([e3_attn, d3], dim=1))
        
        d2 = self.upconv2(d3)
        e2_attn = self.attn2(d2, e2)
        d2 = self.decoder2(torch.cat([e2_attn, d2], dim=1))
        
        d1 = self.upconv1(d2)
        e1_attn = self.attn1(d1, e1)
        d1 = self.decoder1(torch.cat([e1_attn, d1], dim=1))
        
        return self.final(d1)


class AttentionGate(nn.Module):
    """Attention gate for Attention U-Net."""
    
    def __init__(self, g_channels: int, x_channels: int, inter_channels: int):
        super().__init__()
        
        self.W_g = nn.Conv2d(g_channels, inter_channels, 1, bias=False)
        self.W_x = nn.Conv2d(x_channels, inter_channels, 1, bias=False)
        
        self.psi = nn.Sequential(
            nn.Conv2d(inter_channels, 1, 1, bias=False),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            g: Gating signal (from decoder)
            x: Skip connection (from encoder)
        """
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        
        # Resize if needed
        if g1.shape[2:] != x1.shape[2:]:
            g1 = F.interpolate(g1, size=x1.shape[2:], mode='bilinear', align_corners=False)
        
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        
        return x * psi

DeepLab Series

Atrous (Dilated) Convolutions

class AtrousConv(nn.Module):
    """
    Atrous (Dilated) Convolution.
    
    Increases receptive field without losing resolution.
    """
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        dilation: int = 2
    ):
        super().__init__()
        
        padding = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
        
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size,
            padding=padding, dilation=dilation, bias=False
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.relu(self.bn(self.conv(x)))


class ASPP(nn.Module):
    """
    Atrous Spatial Pyramid Pooling (DeepLabV3).
    
    Parallel atrous convolutions at multiple dilation rates.
    """
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int = 256,
        dilations: List[int] = [6, 12, 18]
    ):
        super().__init__()
        
        modules = []
        
        # 1x1 conv
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ))
        
        # Atrous convolutions at different rates
        for dilation in dilations:
            modules.append(nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=dilation, 
                          dilation=dilation, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ))
        
        # Global average pooling
        modules.append(nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ))
        
        self.convs = nn.ModuleList(modules)
        
        # Fusion
        self.project = nn.Sequential(
            nn.Conv2d(out_channels * (len(dilations) + 2), out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = []
        
        for conv in self.convs[:-1]:
            features.append(conv(x))
        
        # Global pooling branch
        global_feat = self.convs[-1](x)
        global_feat = F.interpolate(
            global_feat, size=x.shape[2:], mode='bilinear', align_corners=False
        )
        features.append(global_feat)
        
        x = torch.cat(features, dim=1)
        x = self.project(x)
        
        return x


class DeepLabV3Plus(nn.Module):
    """
    DeepLab V3+: Encoder-Decoder with Atrous Separable Convolution.
    
    Combines:
    - ASPP for multi-scale context
    - Decoder for sharp boundaries
    - Depthwise separable convolutions for efficiency
    """
    
    def __init__(
        self,
        backbone: nn.Module,
        num_classes: int,
        low_level_channels: int = 256,
        aspp_out_channels: int = 256
    ):
        super().__init__()
        
        self.backbone = backbone
        
        # ASPP
        self.aspp = ASPP(2048, aspp_out_channels)
        
        # Low-level feature projection
        self.low_level_conv = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(aspp_out_channels + 48, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, 1)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        input_size = x.shape[2:]
        
        # Backbone
        low_level_features, high_level_features = self.backbone(x)
        
        # ASPP
        aspp_out = self.aspp(high_level_features)
        
        # Upsample ASPP output
        aspp_out = F.interpolate(
            aspp_out, size=low_level_features.shape[2:],
            mode='bilinear', align_corners=False
        )
        
        # Process low-level features
        low_level_out = self.low_level_conv(low_level_features)
        
        # Concatenate and decode
        x = torch.cat([aspp_out, low_level_out], dim=1)
        x = self.decoder(x)
        
        # Upsample to input size
        x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)
        
        return x

Transformer-Based Segmentation

SETR (Segmentation Transformer)

class SETR(nn.Module):
    """
    SETR: Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective.
    
    Uses Vision Transformer as encoder.
    """
    
    def __init__(
        self,
        img_size: int = 512,
        patch_size: int = 16,
        num_classes: int = 21,
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12
    ):
        super().__init__()
        
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(3, embed_dim, patch_size, stride=patch_size)
        
        # Position embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        
        # Decoder (Progressive Upsampling)
        self.decoder = nn.Sequential(
            nn.Conv2d(embed_dim, 256, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, num_classes, 1)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        
        # Patch embedding
        x = self.patch_embed(x)  # [B, embed_dim, H/P, W/P]
        x = x.flatten(2).transpose(1, 2)  # [B, num_patches, embed_dim]
        
        # Add position embedding
        x = x + self.pos_embed
        
        # Transformer
        x = self.transformer(x)
        
        # Reshape to spatial
        h = w = int(self.num_patches ** 0.5)
        x = x.transpose(1, 2).view(B, -1, h, w)
        
        # Decode
        x = self.decoder(x)
        
        return x


class SegFormer(nn.Module):
    """
    SegFormer: Simple and Efficient Design for Semantic Segmentation.
    
    Key innovations:
    - Hierarchical transformer encoder
    - MLP decoder (simple but effective)
    - No positional encoding
    """
    
    def __init__(
        self,
        num_classes: int = 21,
        embed_dims: List[int] = [64, 128, 320, 512],
        num_heads: List[int] = [1, 2, 5, 8],
        depths: List[int] = [3, 6, 40, 3]
    ):
        super().__init__()
        
        # Hierarchical encoder
        self.stages = nn.ModuleList()
        in_channels = 3
        
        for i, (embed_dim, num_head, depth) in enumerate(zip(embed_dims, num_heads, depths)):
            stage = MixTransformerBlock(
                in_channels=in_channels,
                embed_dim=embed_dim,
                num_heads=num_head,
                depth=depth,
                patch_size=7 if i == 0 else 3,
                stride=4 if i == 0 else 2
            )
            self.stages.append(stage)
            in_channels = embed_dim
        
        # All-MLP decoder
        self.decode_head = MLPDecoder(
            in_channels=embed_dims,
            embed_dim=768,
            num_classes=num_classes
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = []
        
        for stage in self.stages:
            x = stage(x)
            features.append(x)
        
        out = self.decode_head(features)
        
        return out


class MixTransformerBlock(nn.Module):
    """Mix Transformer block for SegFormer."""
    
    def __init__(
        self,
        in_channels: int,
        embed_dim: int,
        num_heads: int,
        depth: int,
        patch_size: int = 3,
        stride: int = 2
    ):
        super().__init__()
        
        # Overlapping patch embedding
        self.patch_embed = nn.Conv2d(
            in_channels, embed_dim, patch_size,
            stride=stride, padding=patch_size // 2
        )
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            EfficientSelfAttentionBlock(embed_dim, num_heads)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        
        x = x.flatten(2).transpose(1, 2)  # [B, H*W, C]
        
        for block in self.blocks:
            x = block(x, H, W)
        
        x = self.norm(x)
        x = x.transpose(1, 2).view(B, C, H, W)
        
        return x


class EfficientSelfAttentionBlock(nn.Module):
    """Efficient self-attention with reduction."""
    
    def __init__(self, dim: int, num_heads: int, sr_ratio: int = 4):
        super().__init__()
        
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        
        # Spatial reduction
        self.sr = nn.Conv2d(dim, dim, sr_ratio, stride=sr_ratio)
        self.sr_norm = nn.LayerNorm(dim)
        
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
    
    def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
        # Self-attention with spatial reduction
        residual = x
        x = self.norm1(x)
        
        # Reduce K, V spatial dimensions
        B, N, C = x.shape
        kv = x.transpose(1, 2).view(B, C, H, W)
        kv = self.sr(kv).flatten(2).transpose(1, 2)
        kv = self.sr_norm(kv)
        
        x, _ = self.attn(x, kv, kv)
        x = residual + x
        
        # MLP
        x = x + self.mlp(self.norm2(x))
        
        return x


class MLPDecoder(nn.Module):
    """Simple MLP decoder for SegFormer."""
    
    def __init__(
        self,
        in_channels: List[int],
        embed_dim: int,
        num_classes: int
    ):
        super().__init__()
        
        self.linear_fuse = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_c, embed_dim, 1),
                nn.BatchNorm2d(embed_dim)
            )
            for in_c in in_channels
        ])
        
        self.linear_pred = nn.Conv2d(embed_dim * len(in_channels), num_classes, 1)
    
    def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
        target_size = features[0].shape[2:]
        
        fused = []
        for i, (feat, linear) in enumerate(zip(features, self.linear_fuse)):
            feat = linear(feat)
            if feat.shape[2:] != target_size:
                feat = F.interpolate(feat, size=target_size, mode='bilinear', align_corners=False)
            fused.append(feat)
        
        x = torch.cat(fused, dim=1)
        x = self.linear_pred(x)
        
        return x

Instance & Panoptic Segmentation

class MaskRCNN(nn.Module):
    """
    Mask R-CNN: Instance Segmentation.
    
    Extends Faster R-CNN with mask prediction head.
    """
    
    def __init__(
        self,
        backbone: nn.Module,
        num_classes: int,
        mask_dim: int = 256
    ):
        super().__init__()
        
        self.backbone = backbone
        
        # Mask head
        self.mask_head = nn.Sequential(
            nn.Conv2d(256, mask_dim, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(mask_dim, mask_dim, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(mask_dim, mask_dim, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(mask_dim, mask_dim, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(mask_dim, mask_dim, 2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(mask_dim, num_classes, 1)
        )
    
    def forward_mask(
        self,
        roi_features: torch.Tensor,
        labels: torch.Tensor
    ) -> torch.Tensor:
        """
        Predict masks for each RoI.
        
        Args:
            roi_features: [N, C, H, W] features from RoI align
            labels: [N] class labels for each RoI
        
        Returns:
            masks: [N, 1, H*2, W*2] predicted masks
        """
        masks = self.mask_head(roi_features)
        
        # Select mask for predicted class
        N = masks.size(0)
        masks = masks[torch.arange(N), labels]
        
        return masks.unsqueeze(1)


class PanopticFPN(nn.Module):
    """
    Panoptic FPN: Unified thing and stuff segmentation.
    
    Combines:
    - Instance segmentation (things: countable objects)
    - Semantic segmentation (stuff: uncountable regions)
    """
    
    def __init__(
        self,
        backbone: nn.Module,
        num_thing_classes: int,  # e.g., person, car
        num_stuff_classes: int   # e.g., sky, road
    ):
        super().__init__()
        
        self.backbone = backbone
        
        # FPN for multi-scale features
        self.fpn = nn.ModuleDict({
            'lateral': nn.ModuleList([
                nn.Conv2d(c, 256, 1) for c in [256, 512, 1024, 2048]
            ]),
            'output': nn.ModuleList([
                nn.Conv2d(256, 256, 3, padding=1) for _ in range(4)
            ])
        })
        
        # Semantic segmentation head (stuff)
        self.semantic_head = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1),
            nn.GroupNorm(32, 128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.GroupNorm(32, 128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, num_stuff_classes, 1)
        )
        
        # Instance head would be similar to Mask R-CNN
        self.num_thing_classes = num_thing_classes
    
    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Get backbone features at multiple scales
        features = self.backbone(x)  # [C1, C2, C3, C4]
        
        # FPN
        fpn_features = self._fpn_forward(features)
        
        # Semantic prediction (upsample all to same size and sum)
        semantic_features = []
        target_size = fpn_features[0].shape[2:]
        
        for feat in fpn_features:
            feat = F.interpolate(feat, size=target_size, mode='bilinear', align_corners=False)
            semantic_features.append(feat)
        
        semantic_feat = sum(semantic_features)
        semantic_out = self.semantic_head(semantic_feat)
        
        # Upsample to input resolution
        semantic_out = F.interpolate(
            semantic_out, scale_factor=4, mode='bilinear', align_corners=False
        )
        
        return {
            'semantic': semantic_out,
            'fpn_features': fpn_features  # For instance branch
        }
    
    def _fpn_forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
        """Build FPN from backbone features."""
        laterals = [lat(f) for f, lat in zip(features, self.fpn['lateral'])]
        
        # Top-down pathway
        for i in range(len(laterals) - 1, 0, -1):
            laterals[i-1] = laterals[i-1] + F.interpolate(
                laterals[i], size=laterals[i-1].shape[2:], mode='nearest'
            )
        
        outputs = [out(lat) for lat, out in zip(laterals, self.fpn['output'])]
        
        return outputs

Loss Functions

class SegmentationLosses:
    """Common loss functions for segmentation."""
    
    @staticmethod
    def cross_entropy(
        pred: torch.Tensor,
        target: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        ignore_index: int = 255
    ) -> torch.Tensor:
        """Standard cross-entropy loss."""
        return F.cross_entropy(pred, target, weight=weight, ignore_index=ignore_index)
    
    @staticmethod
    def dice_loss(
        pred: torch.Tensor,
        target: torch.Tensor,
        smooth: float = 1e-6
    ) -> torch.Tensor:
        """
        Dice Loss for imbalanced segmentation.
        
        Dice = 2|A∩B| / (|A| + |B|)
        """
        pred = F.softmax(pred, dim=1)
        num_classes = pred.size(1)
        
        # One-hot encode target
        target_one_hot = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float()
        
        intersection = (pred * target_one_hot).sum(dim=(2, 3))
        cardinality = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
        
        dice = (2 * intersection + smooth) / (cardinality + smooth)
        
        return 1 - dice.mean()
    
    @staticmethod
    def focal_loss(
        pred: torch.Tensor,
        target: torch.Tensor,
        alpha: float = 0.25,
        gamma: float = 2.0
    ) -> torch.Tensor:
        """Focal Loss for hard example mining."""
        ce_loss = F.cross_entropy(pred, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = alpha * (1 - pt) ** gamma * ce_loss
        return focal_loss.mean()
    
    @staticmethod
    def lovasz_softmax(
        pred: torch.Tensor,
        target: torch.Tensor
    ) -> torch.Tensor:
        """
        Lovász-Softmax loss.
        Directly optimizes IoU.
        """
        probas = F.softmax(pred, dim=1)
        
        # Flatten
        C = probas.size(1)
        losses = []
        
        for c in range(C):
            fg = (target == c).float()
            errors = (fg - probas[:, c]).abs()
            errors_sorted, perm = torch.sort(errors, descending=True)
            fg_sorted = fg[perm]
            losses.append(SegmentationLosses._lovasz_grad(fg_sorted) @ errors_sorted)
        
        return torch.stack(losses).mean()
    
    @staticmethod
    def _lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor:
        """Lovász gradient."""
        gts = gt_sorted.sum()
        intersection = gts - gt_sorted.cumsum(0)
        union = gts + (1 - gt_sorted).cumsum(0)
        jaccard = 1 - intersection / union
        jaccard[1:] = jaccard[1:] - jaccard[:-1]
        return jaccard

Exercises

Create a loss that focuses on boundaries:
def boundary_loss(pred, target):
    # Extract boundaries using Sobel/Laplacian
    # Weight loss higher at boundaries
    pass
Implement test-time augmentation for segmentation:
  • Inference at multiple scales
  • Flip augmentation
  • Ensemble predictions
Add Conditional Random Field refinement:
# Use pydensecrf or implement mean-field approximation
# Refine predictions based on appearance consistency

What’s Next?