Documentation Index
Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt
Use this file to discover all available pages before exploring further.
Semantic Segmentation
From Classification to Segmentation
Segmentation is classification taken to its extreme: instead of assigning one label to the entire image, you assign a label to every single pixel. This is the foundation for applications like self-driving cars (is this pixel road, sidewalk, or pedestrian?), medical imaging (is this pixel tumor or healthy tissue?), and satellite analysis (is this pixel forest, water, or urban?).| Task | Output | Granularity |
|---|---|---|
| Classification | One label per image | Coarsest |
| Object Detection | Boxes + labels | Object-level |
| Semantic Segmentation | One label per pixel | Pixel-level (no instance distinction) |
| Instance Segmentation | Objects + per-pixel masks | Pixel-level (separates instances) |
| Panoptic Segmentation | All of the above | The complete picture |
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Optional
torch.manual_seed(42)
Segmentation models are memory-hungry because they must maintain spatial resolution throughout the network (unlike classification models that aggressively pool). A single 512x512 input with a U-Net encoder can easily consume 8+ GB of GPU memory during training. If you hit OOM errors, reducing input resolution (e.g., from 512 to 256) has a much larger impact than reducing batch size because activation memory scales quadratically with spatial dimensions.
Fully Convolutional Networks (FCN)
The foundation of modern segmentation. Before FCN, people would apply classifiers to each pixel independently (or to sliding windows), which was absurdly slow. FCN’s breakthrough was realizing that you can replace the fully connected layers at the end of a classification network with convolutional layers, turning the entire network into a spatial-in, spatial-out function that processes the whole image in one forward pass:class FCN(nn.Module):
"""
Fully Convolutional Networks (Long et al., 2015).
Key insight: Replace FC layers with conv layers to preserve spatial info.
A classification network's FC layer that maps 512 features to 1000 classes
is mathematically equivalent to a 1x1 conv with 1000 output channels.
This means any classification network can be repurposed for segmentation
simply by swapping FC layers for convs and adding upsampling.
Pitfall: The naive FCN produces very coarse output because of all the
pooling layers. FCN-8s (below) fixes this by fusing predictions from
multiple scales.
"""
def __init__(self, num_classes: int, backbone: str = 'vgg16'):
super().__init__()
# VGG-style backbone
self.features = nn.Sequential(
# Block 1
self._make_block(3, 64, 2),
nn.MaxPool2d(2, 2),
# Block 2
self._make_block(64, 128, 2),
nn.MaxPool2d(2, 2),
# Block 3
self._make_block(128, 256, 3),
nn.MaxPool2d(2, 2),
# Block 4
self._make_block(256, 512, 3),
nn.MaxPool2d(2, 2),
# Block 5
self._make_block(512, 512, 3),
nn.MaxPool2d(2, 2)
)
# FC layers converted to conv
self.fc6 = nn.Conv2d(512, 4096, 7, padding=3)
self.fc7 = nn.Conv2d(4096, 4096, 1)
# Score layer
self.score = nn.Conv2d(4096, num_classes, 1)
# Upsample
self.upsample = nn.ConvTranspose2d(
num_classes, num_classes, 64, stride=32, padding=16
)
def _make_block(self, in_c: int, out_c: int, num_convs: int):
layers = []
for i in range(num_convs):
layers.extend([
nn.Conv2d(in_c if i == 0 else out_c, out_c, 3, padding=1),
nn.ReLU(inplace=True)
])
return nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, 3, H, W] input image
Returns:
segmentation: [B, C, H, W] per-pixel logits
"""
input_size = x.shape[2:]
x = self.features(x)
x = F.relu(self.fc6(x))
x = F.relu(self.fc7(x))
x = self.score(x)
# Upsample to input size
x = self.upsample(x)
x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)
return x
class FCN8s(nn.Module):
"""
FCN-8s: Multi-scale fusion for better boundaries.
Combines predictions from pool3, pool4, and pool5.
"""
def __init__(self, num_classes: int):
super().__init__()
# Backbone with skip connections
self.conv1 = self._make_block(3, 64, 2)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = self._make_block(64, 128, 2)
self.pool2 = nn.MaxPool2d(2, 2)
self.conv3 = self._make_block(128, 256, 3)
self.pool3 = nn.MaxPool2d(2, 2)
self.conv4 = self._make_block(256, 512, 3)
self.pool4 = nn.MaxPool2d(2, 2)
self.conv5 = self._make_block(512, 512, 3)
self.pool5 = nn.MaxPool2d(2, 2)
# FC layers as conv
self.fc6 = nn.Conv2d(512, 4096, 7, padding=3)
self.fc7 = nn.Conv2d(4096, 4096, 1)
# Score layers for each scale
self.score_fr = nn.Conv2d(4096, num_classes, 1)
self.score_pool4 = nn.Conv2d(512, num_classes, 1)
self.score_pool3 = nn.Conv2d(256, num_classes, 1)
# Upsampling
self.upscore2 = nn.ConvTranspose2d(num_classes, num_classes, 4, stride=2, padding=1)
self.upscore4 = nn.ConvTranspose2d(num_classes, num_classes, 4, stride=2, padding=1)
self.upscore8 = nn.ConvTranspose2d(num_classes, num_classes, 16, stride=8, padding=4)
def _make_block(self, in_c, out_c, num_convs):
layers = []
for i in range(num_convs):
layers.extend([
nn.Conv2d(in_c if i == 0 else out_c, out_c, 3, padding=1),
nn.ReLU(inplace=True)
])
return nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
input_size = x.shape[2:]
# Encoder
x = self.pool1(self.conv1(x))
x = self.pool2(self.conv2(x))
pool3 = self.pool3(self.conv3(x))
pool4 = self.pool4(self.conv4(pool3))
pool5 = self.pool5(self.conv5(pool4))
x = F.relu(self.fc6(pool5))
x = F.relu(self.fc7(x))
# Multi-scale fusion
score_fr = self.score_fr(x)
upscore2 = self.upscore2(score_fr)
score_pool4 = self.score_pool4(pool4)
fuse_pool4 = upscore2 + score_pool4
upscore4 = self.upscore4(fuse_pool4)
score_pool3 = self.score_pool3(pool3)
fuse_pool3 = upscore4 + score_pool3
out = self.upscore8(fuse_pool3)
out = F.interpolate(out, size=input_size, mode='bilinear', align_corners=False)
return out
U-Net
The most influential segmentation architecture, especially in medical imaging where labeled data is scarce. The name comes from its U-shaped architecture: the left side encodes (compresses) the input, the bottom is the bottleneck, and the right side decodes (expands) back to full resolution. The critical innovation is the skip connections that bridge corresponding encoder and decoder layers, allowing the decoder to use fine-grained spatial details from the encoder that would otherwise be lost during downsampling. Think of it like making a summary (encoding) and then expanding it back into a full essay (decoding) — without the skip connections, you would lose all the specific details. With them, the decoder can say “I know the big picture from the bottleneck features, and here are the precise boundaries from the encoder.”class DoubleConv(nn.Module):
"""Double convolution block -- the basic unit of U-Net."""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class UNet(nn.Module):
"""
U-Net: Convolutional Networks for Biomedical Image Segmentation
(Ronneberger et al., 2015).
Key features:
1. Symmetric encoder-decoder structure -- easy to reason about
2. Skip connections preserve fine details (boundary pixels, textures)
3. Works well with limited training data -- the original paper trained
on just 30 images with heavy augmentation
Practical tips:
- U-Net is still competitive in 2024+ for medical and scientific imaging
- For natural images, consider using a pretrained encoder (ResNet, etc.)
- The feature list [64, 128, 256, 512, 1024] is a sensible default;
halve it for smaller images or tighter memory budgets
"""
def __init__(
self,
in_channels: int = 3,
num_classes: int = 2,
features: List[int] = [64, 128, 256, 512, 1024]
):
super().__init__()
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
self.pool = nn.MaxPool2d(2, 2)
# Encoder
for feature in features:
self.encoder.append(DoubleConv(in_channels, feature))
in_channels = feature
# Decoder
for feature in reversed(features[:-1]):
self.decoder.append(
nn.ConvTranspose2d(feature * 2, feature, 2, stride=2)
)
self.decoder.append(DoubleConv(feature * 2, feature))
# Final conv
self.final_conv = nn.Conv2d(features[0], num_classes, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, C, H, W] input
Returns:
segmentation: [B, num_classes, H, W]
"""
skip_connections = []
# Encoder
for encoder_block in self.encoder[:-1]:
x = encoder_block(x)
skip_connections.append(x)
x = self.pool(x)
x = self.encoder[-1](x) # Bottleneck -- the deepest, most compressed representation
skip_connections = skip_connections[::-1] # Reverse to match decoder order
# Decoder: mirror of encoder, but with skip connections
# Each step: upsample -> concatenate skip connection -> double conv
for i in range(0, len(self.decoder), 2):
x = self.decoder[i](x) # Upsample
skip = skip_connections[i // 2]
# Handle size mismatch
if x.shape != skip.shape:
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([skip, x], dim=1) # Skip connection
x = self.decoder[i + 1](x) # Double conv
return self.final_conv(x)
class AttentionUNet(nn.Module):
"""
Attention U-Net: Learning Where to Look.
Adds attention gates to focus on relevant regions.
"""
def __init__(self, in_channels: int = 3, num_classes: int = 2):
super().__init__()
features = [64, 128, 256, 512, 1024]
# Encoder
self.encoder = nn.ModuleList([
DoubleConv(in_channels, features[0]),
DoubleConv(features[0], features[1]),
DoubleConv(features[1], features[2]),
DoubleConv(features[2], features[3]),
DoubleConv(features[3], features[4])
])
self.pool = nn.MaxPool2d(2, 2)
# Decoder with attention
self.upconv4 = nn.ConvTranspose2d(features[4], features[3], 2, stride=2)
self.attn4 = AttentionGate(features[3], features[3], features[3] // 2)
self.decoder4 = DoubleConv(features[4], features[3])
self.upconv3 = nn.ConvTranspose2d(features[3], features[2], 2, stride=2)
self.attn3 = AttentionGate(features[2], features[2], features[2] // 2)
self.decoder3 = DoubleConv(features[3], features[2])
self.upconv2 = nn.ConvTranspose2d(features[2], features[1], 2, stride=2)
self.attn2 = AttentionGate(features[1], features[1], features[1] // 2)
self.decoder2 = DoubleConv(features[2], features[1])
self.upconv1 = nn.ConvTranspose2d(features[1], features[0], 2, stride=2)
self.attn1 = AttentionGate(features[0], features[0], features[0] // 2)
self.decoder1 = DoubleConv(features[1], features[0])
self.final = nn.Conv2d(features[0], num_classes, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Encoder
e1 = self.encoder[0](x)
e2 = self.encoder[1](self.pool(e1))
e3 = self.encoder[2](self.pool(e2))
e4 = self.encoder[3](self.pool(e3))
e5 = self.encoder[4](self.pool(e4))
# Decoder with attention
d4 = self.upconv4(e5)
e4_attn = self.attn4(d4, e4)
d4 = self.decoder4(torch.cat([e4_attn, d4], dim=1))
d3 = self.upconv3(d4)
e3_attn = self.attn3(d3, e3)
d3 = self.decoder3(torch.cat([e3_attn, d3], dim=1))
d2 = self.upconv2(d3)
e2_attn = self.attn2(d2, e2)
d2 = self.decoder2(torch.cat([e2_attn, d2], dim=1))
d1 = self.upconv1(d2)
e1_attn = self.attn1(d1, e1)
d1 = self.decoder1(torch.cat([e1_attn, d1], dim=1))
return self.final(d1)
class AttentionGate(nn.Module):
"""
Attention gate for Attention U-Net.
The gating signal (g) from the decoder tells the skip connection (x)
from the encoder which spatial regions to focus on. Think of it as
the decoder saying: "I know from the global context that the object
of interest is in the top-left -- encoder, please only send me the
fine details from that region."
"""
def __init__(self, g_channels: int, x_channels: int, inter_channels: int):
super().__init__()
self.W_g = nn.Conv2d(g_channels, inter_channels, 1, bias=False)
self.W_x = nn.Conv2d(x_channels, inter_channels, 1, bias=False)
self.psi = nn.Sequential(
nn.Conv2d(inter_channels, 1, 1, bias=False),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Args:
g: Gating signal (from decoder)
x: Skip connection (from encoder)
"""
g1 = self.W_g(g)
x1 = self.W_x(x)
# Resize if needed
if g1.shape[2:] != x1.shape[2:]:
g1 = F.interpolate(g1, size=x1.shape[2:], mode='bilinear', align_corners=False)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
DeepLab Series
DeepLab takes a different approach from the encoder-decoder style: instead of downsampling and then upsampling, use dilated (atrous) convolutions to maintain resolution while still capturing large-scale context. This avoids the information loss from repeated pooling.Atrous (Dilated) Convolutions
class AtrousConv(nn.Module):
"""
Atrous (Dilated) Convolution.
Increases receptive field without losing resolution or adding parameters.
Think of it like a regular convolution with "holes" (atrous = "with holes"
in French): instead of looking at a contiguous 3x3 patch, the filter
samples every 2nd (or 3rd, or 4th) pixel. A 3x3 kernel with dilation=2
covers the same area as a 5x5 kernel but uses only 9 parameters.
Practical tip: Use multiple dilation rates in parallel (ASPP, below)
to capture objects at multiple scales simultaneously.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
dilation: int = 2
):
super().__init__()
padding = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=padding, dilation=dilation, bias=False
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.relu(self.bn(self.conv(x)))
class ASPP(nn.Module):
"""
Atrous Spatial Pyramid Pooling (DeepLabV3).
Parallel atrous convolutions at multiple dilation rates. The idea:
objects appear at different scales in an image. A person close-up fills
most of the image; a person far away is just a few pixels. ASPP captures
context at multiple scales simultaneously by running several dilated
convolutions in parallel (each with a different dilation rate) and
concatenating their outputs. It also includes a global average pooling
branch to capture image-level context ("this is an indoor scene").
"""
def __init__(
self,
in_channels: int,
out_channels: int = 256,
dilations: List[int] = [6, 12, 18]
):
super().__init__()
modules = []
# 1x1 conv
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
# Atrous convolutions at different rates
for dilation in dilations:
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=dilation,
dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
# Global average pooling
modules.append(nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
self.convs = nn.ModuleList(modules)
# Fusion
self.project = nn.Sequential(
nn.Conv2d(out_channels * (len(dilations) + 2), out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.5)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = []
for conv in self.convs[:-1]:
features.append(conv(x))
# Global pooling branch
global_feat = self.convs[-1](x)
global_feat = F.interpolate(
global_feat, size=x.shape[2:], mode='bilinear', align_corners=False
)
features.append(global_feat)
x = torch.cat(features, dim=1)
x = self.project(x)
return x
class DeepLabV3Plus(nn.Module):
"""
DeepLab V3+: Encoder-Decoder with Atrous Separable Convolution.
Combines:
- ASPP for multi-scale context
- Decoder for sharp boundaries
- Depthwise separable convolutions for efficiency
"""
def __init__(
self,
backbone: nn.Module,
num_classes: int,
low_level_channels: int = 256,
aspp_out_channels: int = 256
):
super().__init__()
self.backbone = backbone
# ASPP
self.aspp = ASPP(2048, aspp_out_channels)
# Low-level feature projection
self.low_level_conv = nn.Sequential(
nn.Conv2d(low_level_channels, 48, 1, bias=False),
nn.BatchNorm2d(48),
nn.ReLU(inplace=True)
)
# Decoder
self.decoder = nn.Sequential(
nn.Conv2d(aspp_out_channels + 48, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, num_classes, 1)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
input_size = x.shape[2:]
# Backbone
low_level_features, high_level_features = self.backbone(x)
# ASPP
aspp_out = self.aspp(high_level_features)
# Upsample ASPP output
aspp_out = F.interpolate(
aspp_out, size=low_level_features.shape[2:],
mode='bilinear', align_corners=False
)
# Process low-level features
low_level_out = self.low_level_conv(low_level_features)
# Concatenate and decode
x = torch.cat([aspp_out, low_level_out], dim=1)
x = self.decoder(x)
# Upsample to input size
x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)
return x
Transformer-Based Segmentation
Transformers have entered segmentation with force. Their key advantage: global receptive field from layer 1. A CNN needs many layers of convolutions to “see” the full image, but a transformer’s self-attention can relate any two pixels in a single layer. This makes transformers particularly good at capturing long-range dependencies (e.g., understanding that a road continues behind a building).Transformer-based segmentation models are significantly more data-hungry than CNN-based ones. A U-Net can achieve strong results on datasets as small as 30 images (with heavy augmentation). SegFormer or SETR typically require pretraining on ImageNet-21k and fine-tuning on thousands of labeled segmentation images. If your dataset is small (under 1,000 annotated images), a pretrained U-Net or DeepLabV3+ with a ResNet backbone is almost always the better choice.
SETR (Segmentation Transformer)
class SETR(nn.Module):
"""
SETR: Rethinking Semantic Segmentation from a Sequence-to-Sequence
Perspective (Zheng et al., 2021).
Uses Vision Transformer as encoder. The image is split into patches,
treated as a sequence, processed by a standard transformer, then decoded
back to pixel-level predictions.
Practical tip: SETR requires large-scale pretraining (ImageNet-21k or
similar) to work well. Without it, the transformer encoder lacks the
visual priors that CNNs get for free from their inductive bias.
"""
def __init__(
self,
img_size: int = 512,
patch_size: int = 16,
num_classes: int = 21,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12
):
super().__init__()
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# Patch embedding
self.patch_embed = nn.Conv2d(3, embed_dim, patch_size, stride=patch_size)
# Position embedding
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=embed_dim * 4,
dropout=0.1,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
# Decoder (Progressive Upsampling)
self.decoder = nn.Sequential(
nn.Conv2d(embed_dim, 256, 1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(256, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(256, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(256, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(256, num_classes, 1)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
# Patch embedding
x = self.patch_embed(x) # [B, embed_dim, H/P, W/P]
x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
# Add position embedding
x = x + self.pos_embed
# Transformer
x = self.transformer(x)
# Reshape to spatial
h = w = int(self.num_patches ** 0.5)
x = x.transpose(1, 2).view(B, -1, h, w)
# Decode
x = self.decoder(x)
return x
class SegFormer(nn.Module):
"""
SegFormer: Simple and Efficient Design for Semantic Segmentation
(Xie et al., 2021).
Key innovations:
- Hierarchical transformer encoder (produces multi-scale features
like a CNN, not flat tokens like ViT)
- MLP decoder (shockingly simple but effective -- proves that most
of the heavy lifting happens in the encoder)
- No positional encoding (uses overlapping patch embeddings instead,
which naturally encode position through the conv's local structure)
Practical tip: SegFormer is arguably the best transformer-based
segmentation model for practical use. It is simpler than SETR, faster
than SegViT, and the MLP decoder makes it memory-efficient. The B0-B5
variants offer a clean accuracy/speed tradeoff.
"""
def __init__(
self,
num_classes: int = 21,
embed_dims: List[int] = [64, 128, 320, 512],
num_heads: List[int] = [1, 2, 5, 8],
depths: List[int] = [3, 6, 40, 3]
):
super().__init__()
# Hierarchical encoder
self.stages = nn.ModuleList()
in_channels = 3
for i, (embed_dim, num_head, depth) in enumerate(zip(embed_dims, num_heads, depths)):
stage = MixTransformerBlock(
in_channels=in_channels,
embed_dim=embed_dim,
num_heads=num_head,
depth=depth,
patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2
)
self.stages.append(stage)
in_channels = embed_dim
# All-MLP decoder
self.decode_head = MLPDecoder(
in_channels=embed_dims,
embed_dim=768,
num_classes=num_classes
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = []
for stage in self.stages:
x = stage(x)
features.append(x)
out = self.decode_head(features)
return out
class MixTransformerBlock(nn.Module):
"""Mix Transformer block for SegFormer."""
def __init__(
self,
in_channels: int,
embed_dim: int,
num_heads: int,
depth: int,
patch_size: int = 3,
stride: int = 2
):
super().__init__()
# Overlapping patch embedding
self.patch_embed = nn.Conv2d(
in_channels, embed_dim, patch_size,
stride=stride, padding=patch_size // 2
)
# Transformer blocks
self.blocks = nn.ModuleList([
EfficientSelfAttentionBlock(embed_dim, num_heads)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
B, C, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # [B, H*W, C]
for block in self.blocks:
x = block(x, H, W)
x = self.norm(x)
x = x.transpose(1, 2).view(B, C, H, W)
return x
class EfficientSelfAttentionBlock(nn.Module):
"""Efficient self-attention with reduction."""
def __init__(self, dim: int, num_heads: int, sr_ratio: int = 4):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
# Spatial reduction
self.sr = nn.Conv2d(dim, dim, sr_ratio, stride=sr_ratio)
self.sr_norm = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
# Self-attention with spatial reduction
residual = x
x = self.norm1(x)
# Reduce K, V spatial dimensions
B, N, C = x.shape
kv = x.transpose(1, 2).view(B, C, H, W)
kv = self.sr(kv).flatten(2).transpose(1, 2)
kv = self.sr_norm(kv)
x, _ = self.attn(x, kv, kv)
x = residual + x
# MLP
x = x + self.mlp(self.norm2(x))
return x
class MLPDecoder(nn.Module):
"""Simple MLP decoder for SegFormer."""
def __init__(
self,
in_channels: List[int],
embed_dim: int,
num_classes: int
):
super().__init__()
self.linear_fuse = nn.ModuleList([
nn.Sequential(
nn.Conv2d(in_c, embed_dim, 1),
nn.BatchNorm2d(embed_dim)
)
for in_c in in_channels
])
self.linear_pred = nn.Conv2d(embed_dim * len(in_channels), num_classes, 1)
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
target_size = features[0].shape[2:]
fused = []
for i, (feat, linear) in enumerate(zip(features, self.linear_fuse)):
feat = linear(feat)
if feat.shape[2:] != target_size:
feat = F.interpolate(feat, size=target_size, mode='bilinear', align_corners=False)
fused.append(feat)
x = torch.cat(fused, dim=1)
x = self.linear_pred(x)
return x
Instance & Panoptic Segmentation
class MaskRCNN(nn.Module):
"""
Mask R-CNN: Instance Segmentation (He et al., 2017).
Extends Faster R-CNN with a mask prediction head. The key difference
from semantic segmentation: Mask R-CNN produces a separate binary mask
for each detected object instance. Two people standing side by side
get two distinct masks, not one "person" blob.
Architecture: Faster R-CNN (detects boxes) + parallel mask branch
that predicts a per-class binary mask for each detected region.
Practical tip: The mask head operates on RoI-aligned features and
predicts one mask per class (not per instance). The final mask is
selected based on the classification head's prediction.
"""
def __init__(
self,
backbone: nn.Module,
num_classes: int,
mask_dim: int = 256
):
super().__init__()
self.backbone = backbone
# Mask head
self.mask_head = nn.Sequential(
nn.Conv2d(256, mask_dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(mask_dim, mask_dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(mask_dim, mask_dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(mask_dim, mask_dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(mask_dim, mask_dim, 2, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(mask_dim, num_classes, 1)
)
def forward_mask(
self,
roi_features: torch.Tensor,
labels: torch.Tensor
) -> torch.Tensor:
"""
Predict masks for each RoI.
Args:
roi_features: [N, C, H, W] features from RoI align
labels: [N] class labels for each RoI
Returns:
masks: [N, 1, H*2, W*2] predicted masks
"""
masks = self.mask_head(roi_features)
# Select mask for predicted class
N = masks.size(0)
masks = masks[torch.arange(N), labels]
return masks.unsqueeze(1)
class PanopticFPN(nn.Module):
"""
Panoptic FPN: Unified thing and stuff segmentation (Kirillov et al., 2019).
Panoptic segmentation is the grand unification of segmentation tasks.
It answers: "for every pixel, what class is it, and if it is a 'thing'
(countable object like a car or person), which specific instance?"
Combines:
- Instance segmentation for "things" (countable: car, person, dog)
- Semantic segmentation for "stuff" (uncountable: sky, road, grass)
Practical tip: Panoptic segmentation metrics (PQ, SQ, RQ) are more
informative than mIoU alone because they evaluate both recognition
quality and segmentation quality separately.
"""
def __init__(
self,
backbone: nn.Module,
num_thing_classes: int, # e.g., person, car
num_stuff_classes: int # e.g., sky, road
):
super().__init__()
self.backbone = backbone
# FPN for multi-scale features
self.fpn = nn.ModuleDict({
'lateral': nn.ModuleList([
nn.Conv2d(c, 256, 1) for c in [256, 512, 1024, 2048]
]),
'output': nn.ModuleList([
nn.Conv2d(256, 256, 3, padding=1) for _ in range(4)
])
})
# Semantic segmentation head (stuff)
self.semantic_head = nn.Sequential(
nn.Conv2d(256, 128, 3, padding=1),
nn.GroupNorm(32, 128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, padding=1),
nn.GroupNorm(32, 128),
nn.ReLU(inplace=True),
nn.Conv2d(128, num_stuff_classes, 1)
)
# Instance head would be similar to Mask R-CNN
self.num_thing_classes = num_thing_classes
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
# Get backbone features at multiple scales
features = self.backbone(x) # [C1, C2, C3, C4]
# FPN
fpn_features = self._fpn_forward(features)
# Semantic prediction (upsample all to same size and sum)
semantic_features = []
target_size = fpn_features[0].shape[2:]
for feat in fpn_features:
feat = F.interpolate(feat, size=target_size, mode='bilinear', align_corners=False)
semantic_features.append(feat)
semantic_feat = sum(semantic_features)
semantic_out = self.semantic_head(semantic_feat)
# Upsample to input resolution
semantic_out = F.interpolate(
semantic_out, scale_factor=4, mode='bilinear', align_corners=False
)
return {
'semantic': semantic_out,
'fpn_features': fpn_features # For instance branch
}
def _fpn_forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
"""Build FPN from backbone features."""
laterals = [lat(f) for f, lat in zip(features, self.fpn['lateral'])]
# Top-down pathway
for i in range(len(laterals) - 1, 0, -1):
laterals[i-1] = laterals[i-1] + F.interpolate(
laterals[i], size=laterals[i-1].shape[2:], mode='nearest'
)
outputs = [out(lat) for lat, out in zip(laterals, self.fpn['output'])]
return outputs
Loss Functions
class SegmentationLosses:
"""Common loss functions for segmentation."""
@staticmethod
def cross_entropy(
pred: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.Tensor] = None,
ignore_index: int = 255
) -> torch.Tensor:
"""Standard cross-entropy loss."""
return F.cross_entropy(pred, target, weight=weight, ignore_index=ignore_index)
@staticmethod
def dice_loss(
pred: torch.Tensor,
target: torch.Tensor,
smooth: float = 1e-6
) -> torch.Tensor:
"""
Dice Loss for imbalanced segmentation.
Dice = 2|A∩B| / (|A| + |B|)
"""
pred = F.softmax(pred, dim=1)
num_classes = pred.size(1)
# One-hot encode target
target_one_hot = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float()
intersection = (pred * target_one_hot).sum(dim=(2, 3))
cardinality = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
dice = (2 * intersection + smooth) / (cardinality + smooth)
return 1 - dice.mean()
@staticmethod
def focal_loss(
pred: torch.Tensor,
target: torch.Tensor,
alpha: float = 0.25,
gamma: float = 2.0
) -> torch.Tensor:
"""Focal Loss for hard example mining."""
ce_loss = F.cross_entropy(pred, target, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = alpha * (1 - pt) ** gamma * ce_loss
return focal_loss.mean()
@staticmethod
def lovasz_softmax(
pred: torch.Tensor,
target: torch.Tensor
) -> torch.Tensor:
"""
Lovász-Softmax loss.
Directly optimizes IoU.
"""
probas = F.softmax(pred, dim=1)
# Flatten
C = probas.size(1)
losses = []
for c in range(C):
fg = (target == c).float()
errors = (fg - probas[:, c]).abs()
errors_sorted, perm = torch.sort(errors, descending=True)
fg_sorted = fg[perm]
losses.append(SegmentationLosses._lovasz_grad(fg_sorted) @ errors_sorted)
return torch.stack(losses).mean()
@staticmethod
def _lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor:
"""Lovász gradient."""
gts = gt_sorted.sum()
intersection = gts - gt_sorted.cumsum(0)
union = gts + (1 - gt_sorted).cumsum(0)
jaccard = 1 - intersection / union
jaccard[1:] = jaccard[1:] - jaccard[:-1]
return jaccard
Exercises
Exercise 1: Implement Boundary Loss
Exercise 1: Implement Boundary Loss
Create a loss that focuses on boundaries:
def boundary_loss(pred, target):
# Extract boundaries using Sobel/Laplacian
# Weight loss higher at boundaries
pass
Exercise 2: Multi-Scale Inference
Exercise 2: Multi-Scale Inference
Implement test-time augmentation for segmentation:
- Inference at multiple scales
- Flip augmentation
- Ensemble predictions
Exercise 3: CRF Post-Processing
Exercise 3: CRF Post-Processing
Add Conditional Random Field refinement:
# Use pydensecrf or implement mean-field approximation
# Refine predictions based on appearance consistency
What’s Next?
Hyperparameter Tuning
Systematic optimization with Optuna and Ray Tune
Reproducibility
Experiment tracking and reproducible research