Semantic Segmentation
From Classification to Segmentation
| Task | Output |
|---|---|
| Classification | One label per image |
| Object Detection | Boxes + labels |
| Semantic Segmentation | One label per pixel |
| Instance Segmentation | Objects + per-pixel masks |
| Panoptic Segmentation | All of the above |
Copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Optional
torch.manual_seed(42)
Fully Convolutional Networks (FCN)
The foundation of modern segmentation:Copy
class FCN(nn.Module):
"""
Fully Convolutional Networks (Long et al., 2015).
Key insight: Replace FC layers with conv layers to preserve spatial info.
"""
def __init__(self, num_classes: int, backbone: str = 'vgg16'):
super().__init__()
# VGG-style backbone
self.features = nn.Sequential(
# Block 1
self._make_block(3, 64, 2),
nn.MaxPool2d(2, 2),
# Block 2
self._make_block(64, 128, 2),
nn.MaxPool2d(2, 2),
# Block 3
self._make_block(128, 256, 3),
nn.MaxPool2d(2, 2),
# Block 4
self._make_block(256, 512, 3),
nn.MaxPool2d(2, 2),
# Block 5
self._make_block(512, 512, 3),
nn.MaxPool2d(2, 2)
)
# FC layers converted to conv
self.fc6 = nn.Conv2d(512, 4096, 7, padding=3)
self.fc7 = nn.Conv2d(4096, 4096, 1)
# Score layer
self.score = nn.Conv2d(4096, num_classes, 1)
# Upsample
self.upsample = nn.ConvTranspose2d(
num_classes, num_classes, 64, stride=32, padding=16
)
def _make_block(self, in_c: int, out_c: int, num_convs: int):
layers = []
for i in range(num_convs):
layers.extend([
nn.Conv2d(in_c if i == 0 else out_c, out_c, 3, padding=1),
nn.ReLU(inplace=True)
])
return nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, 3, H, W] input image
Returns:
segmentation: [B, C, H, W] per-pixel logits
"""
input_size = x.shape[2:]
x = self.features(x)
x = F.relu(self.fc6(x))
x = F.relu(self.fc7(x))
x = self.score(x)
# Upsample to input size
x = self.upsample(x)
x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)
return x
class FCN8s(nn.Module):
"""
FCN-8s: Multi-scale fusion for better boundaries.
Combines predictions from pool3, pool4, and pool5.
"""
def __init__(self, num_classes: int):
super().__init__()
# Backbone with skip connections
self.conv1 = self._make_block(3, 64, 2)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = self._make_block(64, 128, 2)
self.pool2 = nn.MaxPool2d(2, 2)
self.conv3 = self._make_block(128, 256, 3)
self.pool3 = nn.MaxPool2d(2, 2)
self.conv4 = self._make_block(256, 512, 3)
self.pool4 = nn.MaxPool2d(2, 2)
self.conv5 = self._make_block(512, 512, 3)
self.pool5 = nn.MaxPool2d(2, 2)
# FC layers as conv
self.fc6 = nn.Conv2d(512, 4096, 7, padding=3)
self.fc7 = nn.Conv2d(4096, 4096, 1)
# Score layers for each scale
self.score_fr = nn.Conv2d(4096, num_classes, 1)
self.score_pool4 = nn.Conv2d(512, num_classes, 1)
self.score_pool3 = nn.Conv2d(256, num_classes, 1)
# Upsampling
self.upscore2 = nn.ConvTranspose2d(num_classes, num_classes, 4, stride=2, padding=1)
self.upscore4 = nn.ConvTranspose2d(num_classes, num_classes, 4, stride=2, padding=1)
self.upscore8 = nn.ConvTranspose2d(num_classes, num_classes, 16, stride=8, padding=4)
def _make_block(self, in_c, out_c, num_convs):
layers = []
for i in range(num_convs):
layers.extend([
nn.Conv2d(in_c if i == 0 else out_c, out_c, 3, padding=1),
nn.ReLU(inplace=True)
])
return nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
input_size = x.shape[2:]
# Encoder
x = self.pool1(self.conv1(x))
x = self.pool2(self.conv2(x))
pool3 = self.pool3(self.conv3(x))
pool4 = self.pool4(self.conv4(pool3))
pool5 = self.pool5(self.conv5(pool4))
x = F.relu(self.fc6(pool5))
x = F.relu(self.fc7(x))
# Multi-scale fusion
score_fr = self.score_fr(x)
upscore2 = self.upscore2(score_fr)
score_pool4 = self.score_pool4(pool4)
fuse_pool4 = upscore2 + score_pool4
upscore4 = self.upscore4(fuse_pool4)
score_pool3 = self.score_pool3(pool3)
fuse_pool3 = upscore4 + score_pool3
out = self.upscore8(fuse_pool3)
out = F.interpolate(out, size=input_size, mode='bilinear', align_corners=False)
return out
U-Net
The encoder-decoder architecture with skip connections:Copy
class DoubleConv(nn.Module):
"""Double convolution block."""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class UNet(nn.Module):
"""
U-Net: Convolutional Networks for Biomedical Image Segmentation.
Key features:
1. Symmetric encoder-decoder structure
2. Skip connections preserve fine details
3. Works well with limited training data
"""
def __init__(
self,
in_channels: int = 3,
num_classes: int = 2,
features: List[int] = [64, 128, 256, 512, 1024]
):
super().__init__()
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
self.pool = nn.MaxPool2d(2, 2)
# Encoder
for feature in features:
self.encoder.append(DoubleConv(in_channels, feature))
in_channels = feature
# Decoder
for feature in reversed(features[:-1]):
self.decoder.append(
nn.ConvTranspose2d(feature * 2, feature, 2, stride=2)
)
self.decoder.append(DoubleConv(feature * 2, feature))
# Final conv
self.final_conv = nn.Conv2d(features[0], num_classes, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, C, H, W] input
Returns:
segmentation: [B, num_classes, H, W]
"""
skip_connections = []
# Encoder
for encoder_block in self.encoder[:-1]:
x = encoder_block(x)
skip_connections.append(x)
x = self.pool(x)
x = self.encoder[-1](x) # Bottleneck
skip_connections = skip_connections[::-1] # Reverse
# Decoder
for i in range(0, len(self.decoder), 2):
x = self.decoder[i](x) # Upsample
skip = skip_connections[i // 2]
# Handle size mismatch
if x.shape != skip.shape:
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([skip, x], dim=1) # Skip connection
x = self.decoder[i + 1](x) # Double conv
return self.final_conv(x)
class AttentionUNet(nn.Module):
"""
Attention U-Net: Learning Where to Look.
Adds attention gates to focus on relevant regions.
"""
def __init__(self, in_channels: int = 3, num_classes: int = 2):
super().__init__()
features = [64, 128, 256, 512, 1024]
# Encoder
self.encoder = nn.ModuleList([
DoubleConv(in_channels, features[0]),
DoubleConv(features[0], features[1]),
DoubleConv(features[1], features[2]),
DoubleConv(features[2], features[3]),
DoubleConv(features[3], features[4])
])
self.pool = nn.MaxPool2d(2, 2)
# Decoder with attention
self.upconv4 = nn.ConvTranspose2d(features[4], features[3], 2, stride=2)
self.attn4 = AttentionGate(features[3], features[3], features[3] // 2)
self.decoder4 = DoubleConv(features[4], features[3])
self.upconv3 = nn.ConvTranspose2d(features[3], features[2], 2, stride=2)
self.attn3 = AttentionGate(features[2], features[2], features[2] // 2)
self.decoder3 = DoubleConv(features[3], features[2])
self.upconv2 = nn.ConvTranspose2d(features[2], features[1], 2, stride=2)
self.attn2 = AttentionGate(features[1], features[1], features[1] // 2)
self.decoder2 = DoubleConv(features[2], features[1])
self.upconv1 = nn.ConvTranspose2d(features[1], features[0], 2, stride=2)
self.attn1 = AttentionGate(features[0], features[0], features[0] // 2)
self.decoder1 = DoubleConv(features[1], features[0])
self.final = nn.Conv2d(features[0], num_classes, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Encoder
e1 = self.encoder[0](x)
e2 = self.encoder[1](self.pool(e1))
e3 = self.encoder[2](self.pool(e2))
e4 = self.encoder[3](self.pool(e3))
e5 = self.encoder[4](self.pool(e4))
# Decoder with attention
d4 = self.upconv4(e5)
e4_attn = self.attn4(d4, e4)
d4 = self.decoder4(torch.cat([e4_attn, d4], dim=1))
d3 = self.upconv3(d4)
e3_attn = self.attn3(d3, e3)
d3 = self.decoder3(torch.cat([e3_attn, d3], dim=1))
d2 = self.upconv2(d3)
e2_attn = self.attn2(d2, e2)
d2 = self.decoder2(torch.cat([e2_attn, d2], dim=1))
d1 = self.upconv1(d2)
e1_attn = self.attn1(d1, e1)
d1 = self.decoder1(torch.cat([e1_attn, d1], dim=1))
return self.final(d1)
class AttentionGate(nn.Module):
"""Attention gate for Attention U-Net."""
def __init__(self, g_channels: int, x_channels: int, inter_channels: int):
super().__init__()
self.W_g = nn.Conv2d(g_channels, inter_channels, 1, bias=False)
self.W_x = nn.Conv2d(x_channels, inter_channels, 1, bias=False)
self.psi = nn.Sequential(
nn.Conv2d(inter_channels, 1, 1, bias=False),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Args:
g: Gating signal (from decoder)
x: Skip connection (from encoder)
"""
g1 = self.W_g(g)
x1 = self.W_x(x)
# Resize if needed
if g1.shape[2:] != x1.shape[2:]:
g1 = F.interpolate(g1, size=x1.shape[2:], mode='bilinear', align_corners=False)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
DeepLab Series
Atrous (Dilated) Convolutions
Copy
class AtrousConv(nn.Module):
"""
Atrous (Dilated) Convolution.
Increases receptive field without losing resolution.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
dilation: int = 2
):
super().__init__()
padding = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=padding, dilation=dilation, bias=False
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.relu(self.bn(self.conv(x)))
class ASPP(nn.Module):
"""
Atrous Spatial Pyramid Pooling (DeepLabV3).
Parallel atrous convolutions at multiple dilation rates.
"""
def __init__(
self,
in_channels: int,
out_channels: int = 256,
dilations: List[int] = [6, 12, 18]
):
super().__init__()
modules = []
# 1x1 conv
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
# Atrous convolutions at different rates
for dilation in dilations:
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=dilation,
dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
# Global average pooling
modules.append(nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
self.convs = nn.ModuleList(modules)
# Fusion
self.project = nn.Sequential(
nn.Conv2d(out_channels * (len(dilations) + 2), out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.5)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = []
for conv in self.convs[:-1]:
features.append(conv(x))
# Global pooling branch
global_feat = self.convs[-1](x)
global_feat = F.interpolate(
global_feat, size=x.shape[2:], mode='bilinear', align_corners=False
)
features.append(global_feat)
x = torch.cat(features, dim=1)
x = self.project(x)
return x
class DeepLabV3Plus(nn.Module):
"""
DeepLab V3+: Encoder-Decoder with Atrous Separable Convolution.
Combines:
- ASPP for multi-scale context
- Decoder for sharp boundaries
- Depthwise separable convolutions for efficiency
"""
def __init__(
self,
backbone: nn.Module,
num_classes: int,
low_level_channels: int = 256,
aspp_out_channels: int = 256
):
super().__init__()
self.backbone = backbone
# ASPP
self.aspp = ASPP(2048, aspp_out_channels)
# Low-level feature projection
self.low_level_conv = nn.Sequential(
nn.Conv2d(low_level_channels, 48, 1, bias=False),
nn.BatchNorm2d(48),
nn.ReLU(inplace=True)
)
# Decoder
self.decoder = nn.Sequential(
nn.Conv2d(aspp_out_channels + 48, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, num_classes, 1)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
input_size = x.shape[2:]
# Backbone
low_level_features, high_level_features = self.backbone(x)
# ASPP
aspp_out = self.aspp(high_level_features)
# Upsample ASPP output
aspp_out = F.interpolate(
aspp_out, size=low_level_features.shape[2:],
mode='bilinear', align_corners=False
)
# Process low-level features
low_level_out = self.low_level_conv(low_level_features)
# Concatenate and decode
x = torch.cat([aspp_out, low_level_out], dim=1)
x = self.decoder(x)
# Upsample to input size
x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)
return x
Transformer-Based Segmentation
SETR (Segmentation Transformer)
Copy
class SETR(nn.Module):
"""
SETR: Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective.
Uses Vision Transformer as encoder.
"""
def __init__(
self,
img_size: int = 512,
patch_size: int = 16,
num_classes: int = 21,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12
):
super().__init__()
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# Patch embedding
self.patch_embed = nn.Conv2d(3, embed_dim, patch_size, stride=patch_size)
# Position embedding
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=embed_dim * 4,
dropout=0.1,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
# Decoder (Progressive Upsampling)
self.decoder = nn.Sequential(
nn.Conv2d(embed_dim, 256, 1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(256, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(256, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(256, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(256, num_classes, 1)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
# Patch embedding
x = self.patch_embed(x) # [B, embed_dim, H/P, W/P]
x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
# Add position embedding
x = x + self.pos_embed
# Transformer
x = self.transformer(x)
# Reshape to spatial
h = w = int(self.num_patches ** 0.5)
x = x.transpose(1, 2).view(B, -1, h, w)
# Decode
x = self.decoder(x)
return x
class SegFormer(nn.Module):
"""
SegFormer: Simple and Efficient Design for Semantic Segmentation.
Key innovations:
- Hierarchical transformer encoder
- MLP decoder (simple but effective)
- No positional encoding
"""
def __init__(
self,
num_classes: int = 21,
embed_dims: List[int] = [64, 128, 320, 512],
num_heads: List[int] = [1, 2, 5, 8],
depths: List[int] = [3, 6, 40, 3]
):
super().__init__()
# Hierarchical encoder
self.stages = nn.ModuleList()
in_channels = 3
for i, (embed_dim, num_head, depth) in enumerate(zip(embed_dims, num_heads, depths)):
stage = MixTransformerBlock(
in_channels=in_channels,
embed_dim=embed_dim,
num_heads=num_head,
depth=depth,
patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2
)
self.stages.append(stage)
in_channels = embed_dim
# All-MLP decoder
self.decode_head = MLPDecoder(
in_channels=embed_dims,
embed_dim=768,
num_classes=num_classes
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = []
for stage in self.stages:
x = stage(x)
features.append(x)
out = self.decode_head(features)
return out
class MixTransformerBlock(nn.Module):
"""Mix Transformer block for SegFormer."""
def __init__(
self,
in_channels: int,
embed_dim: int,
num_heads: int,
depth: int,
patch_size: int = 3,
stride: int = 2
):
super().__init__()
# Overlapping patch embedding
self.patch_embed = nn.Conv2d(
in_channels, embed_dim, patch_size,
stride=stride, padding=patch_size // 2
)
# Transformer blocks
self.blocks = nn.ModuleList([
EfficientSelfAttentionBlock(embed_dim, num_heads)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
B, C, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # [B, H*W, C]
for block in self.blocks:
x = block(x, H, W)
x = self.norm(x)
x = x.transpose(1, 2).view(B, C, H, W)
return x
class EfficientSelfAttentionBlock(nn.Module):
"""Efficient self-attention with reduction."""
def __init__(self, dim: int, num_heads: int, sr_ratio: int = 4):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
# Spatial reduction
self.sr = nn.Conv2d(dim, dim, sr_ratio, stride=sr_ratio)
self.sr_norm = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
# Self-attention with spatial reduction
residual = x
x = self.norm1(x)
# Reduce K, V spatial dimensions
B, N, C = x.shape
kv = x.transpose(1, 2).view(B, C, H, W)
kv = self.sr(kv).flatten(2).transpose(1, 2)
kv = self.sr_norm(kv)
x, _ = self.attn(x, kv, kv)
x = residual + x
# MLP
x = x + self.mlp(self.norm2(x))
return x
class MLPDecoder(nn.Module):
"""Simple MLP decoder for SegFormer."""
def __init__(
self,
in_channels: List[int],
embed_dim: int,
num_classes: int
):
super().__init__()
self.linear_fuse = nn.ModuleList([
nn.Sequential(
nn.Conv2d(in_c, embed_dim, 1),
nn.BatchNorm2d(embed_dim)
)
for in_c in in_channels
])
self.linear_pred = nn.Conv2d(embed_dim * len(in_channels), num_classes, 1)
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
target_size = features[0].shape[2:]
fused = []
for i, (feat, linear) in enumerate(zip(features, self.linear_fuse)):
feat = linear(feat)
if feat.shape[2:] != target_size:
feat = F.interpolate(feat, size=target_size, mode='bilinear', align_corners=False)
fused.append(feat)
x = torch.cat(fused, dim=1)
x = self.linear_pred(x)
return x
Instance & Panoptic Segmentation
Copy
class MaskRCNN(nn.Module):
"""
Mask R-CNN: Instance Segmentation.
Extends Faster R-CNN with mask prediction head.
"""
def __init__(
self,
backbone: nn.Module,
num_classes: int,
mask_dim: int = 256
):
super().__init__()
self.backbone = backbone
# Mask head
self.mask_head = nn.Sequential(
nn.Conv2d(256, mask_dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(mask_dim, mask_dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(mask_dim, mask_dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(mask_dim, mask_dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(mask_dim, mask_dim, 2, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(mask_dim, num_classes, 1)
)
def forward_mask(
self,
roi_features: torch.Tensor,
labels: torch.Tensor
) -> torch.Tensor:
"""
Predict masks for each RoI.
Args:
roi_features: [N, C, H, W] features from RoI align
labels: [N] class labels for each RoI
Returns:
masks: [N, 1, H*2, W*2] predicted masks
"""
masks = self.mask_head(roi_features)
# Select mask for predicted class
N = masks.size(0)
masks = masks[torch.arange(N), labels]
return masks.unsqueeze(1)
class PanopticFPN(nn.Module):
"""
Panoptic FPN: Unified thing and stuff segmentation.
Combines:
- Instance segmentation (things: countable objects)
- Semantic segmentation (stuff: uncountable regions)
"""
def __init__(
self,
backbone: nn.Module,
num_thing_classes: int, # e.g., person, car
num_stuff_classes: int # e.g., sky, road
):
super().__init__()
self.backbone = backbone
# FPN for multi-scale features
self.fpn = nn.ModuleDict({
'lateral': nn.ModuleList([
nn.Conv2d(c, 256, 1) for c in [256, 512, 1024, 2048]
]),
'output': nn.ModuleList([
nn.Conv2d(256, 256, 3, padding=1) for _ in range(4)
])
})
# Semantic segmentation head (stuff)
self.semantic_head = nn.Sequential(
nn.Conv2d(256, 128, 3, padding=1),
nn.GroupNorm(32, 128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, padding=1),
nn.GroupNorm(32, 128),
nn.ReLU(inplace=True),
nn.Conv2d(128, num_stuff_classes, 1)
)
# Instance head would be similar to Mask R-CNN
self.num_thing_classes = num_thing_classes
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
# Get backbone features at multiple scales
features = self.backbone(x) # [C1, C2, C3, C4]
# FPN
fpn_features = self._fpn_forward(features)
# Semantic prediction (upsample all to same size and sum)
semantic_features = []
target_size = fpn_features[0].shape[2:]
for feat in fpn_features:
feat = F.interpolate(feat, size=target_size, mode='bilinear', align_corners=False)
semantic_features.append(feat)
semantic_feat = sum(semantic_features)
semantic_out = self.semantic_head(semantic_feat)
# Upsample to input resolution
semantic_out = F.interpolate(
semantic_out, scale_factor=4, mode='bilinear', align_corners=False
)
return {
'semantic': semantic_out,
'fpn_features': fpn_features # For instance branch
}
def _fpn_forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
"""Build FPN from backbone features."""
laterals = [lat(f) for f, lat in zip(features, self.fpn['lateral'])]
# Top-down pathway
for i in range(len(laterals) - 1, 0, -1):
laterals[i-1] = laterals[i-1] + F.interpolate(
laterals[i], size=laterals[i-1].shape[2:], mode='nearest'
)
outputs = [out(lat) for lat, out in zip(laterals, self.fpn['output'])]
return outputs
Loss Functions
Copy
class SegmentationLosses:
"""Common loss functions for segmentation."""
@staticmethod
def cross_entropy(
pred: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.Tensor] = None,
ignore_index: int = 255
) -> torch.Tensor:
"""Standard cross-entropy loss."""
return F.cross_entropy(pred, target, weight=weight, ignore_index=ignore_index)
@staticmethod
def dice_loss(
pred: torch.Tensor,
target: torch.Tensor,
smooth: float = 1e-6
) -> torch.Tensor:
"""
Dice Loss for imbalanced segmentation.
Dice = 2|A∩B| / (|A| + |B|)
"""
pred = F.softmax(pred, dim=1)
num_classes = pred.size(1)
# One-hot encode target
target_one_hot = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float()
intersection = (pred * target_one_hot).sum(dim=(2, 3))
cardinality = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
dice = (2 * intersection + smooth) / (cardinality + smooth)
return 1 - dice.mean()
@staticmethod
def focal_loss(
pred: torch.Tensor,
target: torch.Tensor,
alpha: float = 0.25,
gamma: float = 2.0
) -> torch.Tensor:
"""Focal Loss for hard example mining."""
ce_loss = F.cross_entropy(pred, target, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = alpha * (1 - pt) ** gamma * ce_loss
return focal_loss.mean()
@staticmethod
def lovasz_softmax(
pred: torch.Tensor,
target: torch.Tensor
) -> torch.Tensor:
"""
Lovász-Softmax loss.
Directly optimizes IoU.
"""
probas = F.softmax(pred, dim=1)
# Flatten
C = probas.size(1)
losses = []
for c in range(C):
fg = (target == c).float()
errors = (fg - probas[:, c]).abs()
errors_sorted, perm = torch.sort(errors, descending=True)
fg_sorted = fg[perm]
losses.append(SegmentationLosses._lovasz_grad(fg_sorted) @ errors_sorted)
return torch.stack(losses).mean()
@staticmethod
def _lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor:
"""Lovász gradient."""
gts = gt_sorted.sum()
intersection = gts - gt_sorted.cumsum(0)
union = gts + (1 - gt_sorted).cumsum(0)
jaccard = 1 - intersection / union
jaccard[1:] = jaccard[1:] - jaccard[:-1]
return jaccard
Exercises
Exercise 1: Implement Boundary Loss
Exercise 1: Implement Boundary Loss
Create a loss that focuses on boundaries:
Copy
def boundary_loss(pred, target):
# Extract boundaries using Sobel/Laplacian
# Weight loss higher at boundaries
pass
Exercise 2: Multi-Scale Inference
Exercise 2: Multi-Scale Inference
Implement test-time augmentation for segmentation:
- Inference at multiple scales
- Flip augmentation
- Ensemble predictions
Exercise 3: CRF Post-Processing
Exercise 3: CRF Post-Processing
Add Conditional Random Field refinement:
Copy
# Use pydensecrf or implement mean-field approximation
# Refine predictions based on appearance consistency