Skip to main content
Residual Networks

Residual & Skip Connections

The Depth Problem

In 2015, researchers tried training networks with 100+ layers. They expected deeper = better. What actually happened: Deeper networks performed WORSE than shallower ones! This wasn’t overfitting — even training error was higher. Deep networks simply couldn’t learn.
The Degradation Problem: Adding more layers can hurt performance even when the network has more capacity. This happens because gradients vanish or explode through many layers.

The Residual Insight

The solution is beautifully simple: skip connections. Instead of learning H(x)H(x), learn the residual F(x)=H(x)xF(x) = H(x) - x: Output=F(x)+x\text{Output} = F(x) + x
Input (x)

    ├───────────────────────────────┐
    │                               │
    ▼                               │
┌────────┐                          │
│ Conv   │                          │
└────────┘                          │
    │                               │
    ▼                               │
┌────────┐                          │
│ Conv   │                          │  Skip Connection
└────────┘                          │
    │                               │
    ▼                               │
   (+) ◄────────────────────────────┘


Output (F(x) + x)

Why This Works

  1. Identity is easy: If optimal is identity, weights just go to zero
  2. Gradient highway: Gradients flow directly through skip connections
  3. Ensemble effect: Each block can refine or skip
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    """Basic residual block."""
    
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        identity = x
        
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        out += identity  # Skip connection!
        out = self.relu(out)
        
        return out


class BottleneckBlock(nn.Module):
    """Bottleneck block for deeper networks (ResNet-50+)."""
    
    expansion = 4
    
    def __init__(self, in_channels, bottleneck_channels):
        super().__init__()
        out_channels = bottleneck_channels * self.expansion
        
        # 1x1 reduce
        self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, 1)
        self.bn1 = nn.BatchNorm2d(bottleneck_channels)
        
        # 3x3 convolve
        self.conv2 = nn.Conv2d(bottleneck_channels, bottleneck_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(bottleneck_channels)
        
        # 1x1 expand
        self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, 1)
        self.bn3 = nn.BatchNorm2d(out_channels)
        
        self.relu = nn.ReLU(inplace=True)
        
        # Projection shortcut if dimensions change
        self.shortcut = nn.Identity()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        identity = self.shortcut(x)
        
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        
        out += identity
        out = self.relu(out)
        
        return out

ResNet Architecture

class ResNet18(nn.Module):
    """Simplified ResNet-18."""
    
    def __init__(self, num_classes=1000):
        super().__init__()
        
        # Initial convolution
        self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
        
        # Residual blocks
        self.layer1 = self._make_layer(64, 64, 2)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)
        
        # Classifier
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, num_classes)
    
    def _make_layer(self, in_channels, out_channels, num_blocks, stride=1):
        layers = []
        
        # First block may downsample
        layers.append(self._make_block(in_channels, out_channels, stride))
        
        # Remaining blocks
        for _ in range(1, num_blocks):
            layers.append(self._make_block(out_channels, out_channels))
        
        return nn.Sequential(*layers)
    
    def _make_block(self, in_channels, out_channels, stride=1):
        return ResidualBlockWithDownsample(in_channels, out_channels, stride)
    
    def forward(self, x):
        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.flatten(1)
        x = self.fc(x)
        
        return x

DenseNet: Dense Connections

Instead of adding, DenseNet concatenates all previous features:
class DenseBlock(nn.Module):
    """Dense block with growth rate k."""
    
    def __init__(self, in_channels, growth_rate, num_layers):
        super().__init__()
        self.layers = nn.ModuleList()
        
        for i in range(num_layers):
            self.layers.append(nn.Sequential(
                nn.BatchNorm2d(in_channels + i * growth_rate),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels + i * growth_rate, growth_rate, 3, padding=1)
            ))
    
    def forward(self, x):
        features = [x]
        
        for layer in self.layers:
            out = layer(torch.cat(features, dim=1))
            features.append(out)
        
        return torch.cat(features, dim=1)

U-Net: Skip Connections for Segmentation

U-Net combines encoder-decoder with skip connections for pixel-level predictions:
class UNet(nn.Module):
    """U-Net for image segmentation."""
    
    def __init__(self, in_channels=3, out_channels=1):
        super().__init__()
        
        # Encoder
        self.enc1 = self._block(in_channels, 64)
        self.enc2 = self._block(64, 128)
        self.enc3 = self._block(128, 256)
        self.enc4 = self._block(256, 512)
        
        self.pool = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = self._block(512, 1024)
        
        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = self._block(1024, 512)  # 512 + 512 from skip
        
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = self._block(512, 256)
        
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = self._block(256, 128)
        
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = self._block(128, 64)
        
        self.out = nn.Conv2d(64, out_channels, 1)
    
    def _block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # Encoder path (save for skip connections)
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        
        # Bottleneck
        b = self.bottleneck(self.pool(e4))
        
        # Decoder path (concatenate skip connections)
        d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        
        return self.out(d1)
U-Net Architecture

Comparison

ArchitectureConnection TypeBest For
ResNetAddImage classification
DenseNetConcatenateFeature reuse, fewer params
U-NetSkip + ConcatSegmentation
HighwayGated addSequence modeling

Exercises

Compare gradient magnitudes at early layers for a 50-layer network with and without skip connections.
Implement ResNet-50 and ResNet-101 using bottleneck blocks.
Train U-Net on a simple segmentation task (e.g., cell segmentation).

What’s Next