> ## Documentation Index
> Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt
> Use this file to discover all available pages before exploring further.

# Pooling, Stride & CNN Design

> Master modern CNN architectures - VGG, ResNet, EfficientNet - and learn the design principles behind effective convolutional networks

<Frame>
  <img src="https://mintlify.s3.us-west-1.amazonaws.com/devweeekends/images/courses/deep-learning-mastery/cnn-design-concept.svg" alt="CNN Design Principles" />
</Frame>

# Pooling, Stride & CNN Design

## From Building Blocks to Architectures

In the previous chapter, we learned about convolutions - the fundamental operation that makes CNNs powerful. Now we'll explore how to **combine these building blocks** into effective architectures.

Think of it like LEGO: knowing what a brick is doesn't make you an architect. Understanding **how to stack bricks** into stable, beautiful structures does. The same convolution operation, arranged differently, gives you VGG (brute force depth), ResNet (skip connections), or EfficientNet (smart scaling). The building block is the same -- the architecture is everything.

<Note>
  **The Evolution of CNNs**: From LeNet's 5 layers in 1998 to modern architectures with 1000+ layers, CNN design has undergone a remarkable evolution. Each breakthrough taught us new principles about what makes networks learn better.
</Note>

***

## Deep Dive: Pooling Operations

### Why Pooling Matters

Pooling serves three critical purposes:

1. **Dimensionality Reduction**: Reduce computational cost and memory
2. **Translation Invariance**: Small shifts in input don't change output much
3. **Feature Summarization**: Capture the "presence" of a feature, not its exact location

<Frame>
  <img src="https://mintlify.s3.us-west-1.amazonaws.com/devweeekends/images/courses/deep-learning-mastery/pooling-purposes.svg" alt="Pooling Purposes" />
</Frame>

### Max Pooling: The Dominant Choice

Max pooling takes the maximum value in each window -- like asking "Is this feature present anywhere in this region?" If a strong edge exists in any of the four pixels of a 2x2 window, the max pool preserves it. The exact position is lost, but the presence is retained. This is why max pooling provides *translation invariance* -- a feature can shift by a pixel or two without changing the pooled output.

Think of max pooling like scanning a crowd for your friend's red hat. You divide the crowd into sections and for each section you just note: "Is there a red hat here? Yes or no, and how bright?" You don't record the exact seat number -- just the strongest signal per section. That is exactly what max pooling does to feature maps.

$$
\text{MaxPool}_{i,j} = \max_{m,n \in \mathcal{R}_{i,j}} X_{m,n}
$$

```python theme={null}
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# Manual max pooling implementation
def max_pool_2d(x, kernel_size=2, stride=2):
    """
    Manual 2D max pooling implementation.
    
    Why implement from scratch? Understanding pooling mechanics helps
    debug spatial dimension mismatches -- one of the most common CNN errors.
    
    Args:
        x: Input tensor (H, W) or (C, H, W)
        kernel_size: Size of pooling window
        stride: Step between windows (usually == kernel_size to avoid overlap)
    
    Returns:
        Pooled output tensor
    """
    if x.ndim == 2:
        x = x[np.newaxis, ...]  # Add channel dimension
    
    C, H, W = x.shape
    out_H = (H - kernel_size) // stride + 1
    out_W = (W - kernel_size) // stride + 1
    
    output = np.zeros((C, out_H, out_W))
    
    for c in range(C):
        for i in range(out_H):
            for j in range(out_W):
                h_start = i * stride
                w_start = j * stride
                window = x[c, h_start:h_start+kernel_size, w_start:w_start+kernel_size]
                output[c, i, j] = np.max(window)
    
    return output.squeeze() if output.shape[0] == 1 else output


# Example: Visualize max pooling effect
np.random.seed(42)
feature_map = np.random.rand(8, 8)

pooled = max_pool_2d(feature_map, kernel_size=2, stride=2)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))

axes[0].imshow(feature_map, cmap='viridis')
axes[0].set_title(f'Original: {feature_map.shape}')
axes[0].set_xlabel('Width')
axes[0].set_ylabel('Height')

axes[1].imshow(pooled, cmap='viridis')
axes[1].set_title(f'After 2×2 Max Pool: {pooled.shape}')
axes[1].set_xlabel('Width')
axes[1].set_ylabel('Height')

plt.tight_layout()
plt.show()

print(f"Reduction: {feature_map.size} → {pooled.size} = {feature_map.size/pooled.size}× fewer values")
```

### Average Pooling: Smooth Aggregation

Average pooling computes the mean - useful when you care about the overall "intensity" of activations:

$$
\text{AvgPool}_{i,j} = \frac{1}{|\mathcal{R}_{i,j}|} \sum_{m,n \in \mathcal{R}_{i,j}} X_{m,n}
$$

```python theme={null}
def avg_pool_2d(x, kernel_size=2, stride=2):
    """2D average pooling implementation."""
    if x.ndim == 2:
        x = x[np.newaxis, ...]
    
    C, H, W = x.shape
    out_H = (H - kernel_size) // stride + 1
    out_W = (W - kernel_size) // stride + 1
    
    output = np.zeros((C, out_H, out_W))
    
    for c in range(C):
        for i in range(out_H):
            for j in range(out_W):
                h_start = i * stride
                w_start = j * stride
                window = x[c, h_start:h_start+kernel_size, w_start:w_start+kernel_size]
                output[c, i, j] = np.mean(window)  # Mean instead of max
    
    return output.squeeze() if output.shape[0] == 1 else output


# Compare max vs average pooling
avg_pooled = avg_pool_2d(feature_map, kernel_size=2, stride=2)
max_pooled = max_pool_2d(feature_map, kernel_size=2, stride=2)

print(f"Max pooled values: {max_pooled.flatten()[:4]}")
print(f"Avg pooled values: {avg_pooled.flatten()[:4]}")
print(f"Difference: Max values are always >= Average values")
```

### Global Average Pooling (GAP): The Modern Approach

Global Average Pooling reduces each feature map to a single value - replacing fully connected layers at the end of CNNs:

$$
\text{GAP}(X_c) = \frac{1}{H \times W} \sum_{i=1}^{H}\sum_{j=1}^{W} X_{c,i,j}
$$

```python theme={null}
class GlobalAveragePooling(nn.Module):
    """Global Average Pooling layer."""
    
    def forward(self, x):
        # x: (batch, channels, height, width)
        return x.mean(dim=[2, 3])  # Average over spatial dimensions


# PyTorch's built-in alternative
gap = nn.AdaptiveAvgPool2d(1)

x = torch.randn(16, 512, 7, 7)  # Batch of 16, 512 channels, 7×7 spatial
output = gap(x).squeeze(-1).squeeze(-1)

print(f"Input: {x.shape}")
print(f"After GAP: {output.shape}")
print(f"Parameters: 0 (GAP is parameter-free!)")
```

<Warning>
  **GAP vs Fully Connected**: Traditional CNNs used large FC layers at the end (VGG's 4096×4096 FC layers have 16 million parameters!). GAP reduces this to zero parameters while often improving generalization.
</Warning>

### Pooling Comparison Table

| Pooling Type        | Formula                | Best For                       | Downsides                                        |
| ------------------- | ---------------------- | ------------------------------ | ------------------------------------------------ |
| **Max Pooling**     | $\max(x)$              | Detecting presence of features | Loses spatial info, gradient to one element only |
| **Average Pooling** | $\text{mean}(x)$       | Smooth downsampling            | May dilute strong activations                    |
| **Global Average**  | $\text{mean over all}$ | Classification heads           | Loses all spatial info                           |
| **Strided Conv**    | Learned                | Modern architectures           | More parameters                                  |

***

## Understanding Stride

### Stride as Downsampling

Stride controls how much the convolution kernel "jumps" between positions:

```python theme={null}
def visualize_stride_effect():
    """Visualize how stride affects output dimensions."""
    
    x = torch.randn(1, 3, 32, 32)
    
    conv_s1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
    conv_s2 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1)
    conv_s4 = nn.Conv2d(3, 64, kernel_size=3, stride=4, padding=1)
    
    print("Input:", x.shape)
    print("Stride 1:", conv_s1(x).shape)  # Same spatial size
    print("Stride 2:", conv_s2(x).shape)  # Half the size
    print("Stride 4:", conv_s4(x).shape)  # Quarter the size
    
visualize_stride_effect()
```

### The Output Size Formula

For any convolutional or pooling layer:

$$
\text{Output Size} = \left\lfloor \frac{\text{Input Size} - \text{Kernel Size} + 2 \times \text{Padding}}{\text{Stride}} \right\rfloor + 1
$$

```python theme={null}
def compute_output_size(input_size, kernel_size, stride, padding):
    """
    Compute output spatial dimensions for conv/pool layers.
    
    Args:
        input_size: Input height or width
        kernel_size: Size of kernel/filter
        stride: Step size
        padding: Zero padding
    
    Returns:
        Output size
    """
    return (input_size - kernel_size + 2 * padding) // stride + 1


# Examples
print("32x32 input, 3x3 kernel, stride 1, padding 1:", 
      compute_output_size(32, 3, 1, 1))  # 32

print("32x32 input, 3x3 kernel, stride 2, padding 1:", 
      compute_output_size(32, 3, 2, 1))  # 16

print("224x224 input, 7x7 kernel, stride 2, padding 3:", 
      compute_output_size(224, 7, 2, 3))  # 112 (ResNet first layer)


def compute_network_sizes(input_size, layers):
    """
    Compute sizes through a sequence of layers.
    
    Args:
        input_size: Initial spatial size (H=W assumed)
        layers: List of (kernel, stride, padding) tuples
    
    Returns:
        List of output sizes at each layer
    """
    sizes = [input_size]
    current = input_size
    
    for kernel, stride, padding in layers:
        current = compute_output_size(current, kernel, stride, padding)
        sizes.append(current)
    
    return sizes


# Trace sizes through a simple network
layers = [
    (3, 1, 1),   # Conv: 32 → 32
    (2, 2, 0),   # Pool: 32 → 16
    (3, 1, 1),   # Conv: 16 → 16
    (2, 2, 0),   # Pool: 16 → 8
    (3, 1, 1),   # Conv: 8 → 8
    (2, 2, 0),   # Pool: 8 → 4
]

sizes = compute_network_sizes(32, layers)
print("\nSize progression:", " → ".join(map(str, sizes)))
```

### Stride vs Pooling for Downsampling

Modern architectures often debate: **strided convolution or pooling?**

```python theme={null}
class DownsampleComparison(nn.Module):
    """Compare different downsampling strategies."""
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        # Option 1: Conv + Max Pool
        self.conv_pool = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        # Option 2: Strided Convolution
        self.strided_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
            nn.ReLU()
        )
        
        # Option 3: Conv + Average Pool
        self.conv_avgpool = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(),
            nn.AvgPool2d(2, 2)
        )
    
    def forward(self, x, method='strided'):
        if method == 'maxpool':
            return self.conv_pool(x)
        elif method == 'strided':
            return self.strided_conv(x)
        elif method == 'avgpool':
            return self.conv_avgpool(x)


# Compare parameter counts
model = DownsampleComparison(64, 128)

def count_params(module):
    return sum(p.numel() for p in module.parameters())

print(f"Conv + MaxPool params: {count_params(model.conv_pool):,}")
print(f"Strided Conv params:   {count_params(model.strided_conv):,}")
print(f"Conv + AvgPool params: {count_params(model.conv_avgpool):,}")
```

<Tip>
  **Modern Trend**: ResNet and many modern architectures use strided convolutions for downsampling because they're learnable - the network can learn *what* information to preserve during downsampling, not just *where* the maximum is.
</Tip>

***

## Classic CNN Architectures

### VGGNet: The Power of Depth

VGG's key insight: **Use many small (3x3) filters instead of few large ones.**

Two 3x3 convolutions have the same receptive field as one 5x5, but with:

* Fewer parameters: $2 \times (3^2 \times C^2) $ vs $5^2 \times C^2$ -- roughly 28% fewer parameters for the same field of view
* More non-linearity: Two ReLU activations instead of one, giving the network more expressive power per receptive field size
* Better gradient flow: Shallower individual operations mean less multiplicative shrinkage per step

This is a general principle that shows up repeatedly in CNN design: **decompose one large operation into multiple smaller ones**. You trade depth for width of the kernel, and you almost always come out ahead.

```python theme={null}
class VGGBlock(nn.Module):
    """A VGG-style convolutional block."""
    
    def __init__(self, in_channels, out_channels, num_convs):
        super().__init__()
        
        layers = []
        for i in range(num_convs):
            layers.append(
                nn.Conv2d(
                    in_channels if i == 0 else out_channels,
                    out_channels,
                    kernel_size=3,
                    padding=1
                )
            )
            layers.append(nn.ReLU(inplace=True))
        
        layers.append(nn.MaxPool2d(2, 2))
        
        self.block = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.block(x)


class VGG16(nn.Module):
    """
    VGG-16 architecture.
    
    Configuration: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 
                    512, 512, 512, 'M', 512, 512, 512, 'M']
    
    Total: 13 conv layers + 3 FC layers = 16 weight layers
    """
    
    def __init__(self, num_classes=1000):
        super().__init__()
        
        self.features = nn.Sequential(
            # Block 1: 224 → 112
            VGGBlock(3, 64, num_convs=2),
            
            # Block 2: 112 → 56
            VGGBlock(64, 128, num_convs=2),
            
            # Block 3: 56 → 28
            VGGBlock(128, 256, num_convs=3),
            
            # Block 4: 28 → 14
            VGGBlock(256, 512, num_convs=3),
            
            # Block 5: 14 → 7
            VGGBlock(512, 512, num_convs=3),
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, num_classes),
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


# Analyze VGG16
model = VGG16()
total_params = sum(p.numel() for p in model.parameters())
print(f"VGG16 Total Parameters: {total_params:,}")

# Parameter breakdown
features_params = sum(p.numel() for p in model.features.parameters())
classifier_params = sum(p.numel() for p in model.classifier.parameters())
print(f"  Features (Conv layers): {features_params:,} ({100*features_params/total_params:.1f}%)")
print(f"  Classifier (FC layers): {classifier_params:,} ({100*classifier_params/total_params:.1f}%)")
```

<Frame>
  <img src="https://mintlify.s3.us-west-1.amazonaws.com/devweeekends/images/courses/deep-learning-mastery/vgg-architecture.svg" alt="VGG Architecture" />
</Frame>

***

## ResNet: The Skip Connection Revolution

### The Degradation Problem

As networks get deeper, they should perform at least as well as shallower ones (the extra layers could just learn identity). But in practice, deeper networks performed **worse**:

```python theme={null}
def demonstrate_degradation_problem():
    """
    Illustrate the degradation problem that ResNet solved.
    
    Deeper networks were harder to train, even with batch normalization.
    """
    training_curves = {
        '20-layer': {'train_error': 8.5, 'test_error': 8.8},
        '32-layer': {'train_error': 7.9, 'test_error': 8.2},
        '56-layer': {'train_error': 9.0, 'test_error': 9.5},  # Worse!
    }
    
    print("Plain Network Performance (Before ResNet):")
    print("-" * 50)
    for depth, errors in training_curves.items():
        print(f"{depth}: Train Error = {errors['train_error']}%, "
              f"Test Error = {errors['test_error']}%")
    
    print("\nKey Insight: 56-layer performs WORSE than 20-layer!")
    print("This isn't overfitting - training error is also higher.")
    print("The network simply couldn't learn the identity function.")

demonstrate_degradation_problem()
```

### The Residual Connection

ResNet's solution: Instead of learning $H(x)$, learn the **residual** $F(x) = H(x) - x$:

$$
\mathbf{y} = F(\mathbf{x}, \{W_i\}) + \mathbf{x}
$$

**Why this is mathematically profound**: Consider the gradient flow. For a plain network, the gradient must pass through every layer's transformation. For a residual block, the gradient of the output $y$ with respect to the input $x$ is:

$$
\frac{\partial y}{\partial x} = \frac{\partial F}{\partial x} + I
$$

That $+ I$ (identity matrix) is everything. It means the gradient always has a component of magnitude 1 flowing directly through, regardless of what $F$ learns. In a 100-layer plain network, gradients are products of 100 matrices -- they shrink exponentially. In a 100-layer ResNet, the identity shortcut gives gradients a "express lane" that bypasses the multiplicative chain entirely.

Think of it like building a highway system. A plain network forces every car (gradient) to drive through every small town (layer) on the route. A ResNet builds an expressway alongside the local roads -- traffic can take either path. Even if the local roads are congested (saturated activations), the expressway keeps things flowing.

<Frame>
  <img src="https://mintlify.s3.us-west-1.amazonaws.com/devweeekends/images/courses/deep-learning-mastery/residual-block.svg" alt="Residual Block" />
</Frame>

```python theme={null}
class ResidualBlock(nn.Module):
    """
    Basic Residual Block for ResNet-18/34.
    
    The key insight: If identity mapping is optimal,
    it's easier for the network to push F(x) -> 0
    than to learn F(x) -> x in a plain network.
    
    Analogy: Imagine editing a document. A plain network has to
    rewrite the entire document from scratch at each layer.
    A residual network only writes the *changes* (the diff).
    Learning "change nothing" (output zeros) is trivial;
    learning "copy everything perfectly" is surprisingly hard.
    """
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, 
            kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(
            out_channels, out_channels,
            kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            # Need to match dimensions
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 
                         kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = torch.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += self.shortcut(identity)  # The skip connection!
        out = torch.relu(out)
        
        return out


# Visualize residual learning
def visualize_residual_learning():
    block = ResidualBlock(64, 64)
    
    x = torch.randn(1, 64, 32, 32)
    
    with torch.no_grad():
        # Track what the residual branch learns
        identity = x
        residual = block.bn2(block.conv2(
            torch.relu(block.bn1(block.conv1(x)))
        ))
        output = torch.relu(residual + identity)
    
    print(f"Input norm:    {x.norm():.2f}")
    print(f"Residual norm: {residual.norm():.2f}")
    print(f"Output norm:   {output.norm():.2f}")
    print(f"\nResidual/Input ratio: {residual.norm()/x.norm():.4f}")
    print("(Lower ratio → block learning closer to identity)")

visualize_residual_learning()
```

### Bottleneck Block for Deeper Networks

For ResNet-50/101/152, bottleneck blocks reduce computation:

```python theme={null}
class BottleneckBlock(nn.Module):
    """
    Bottleneck Residual Block for ResNet-50/101/152.
    
    Uses 1×1 convolutions to reduce/restore dimensions,
    making the 3×3 convolution cheaper.
    
    Structure: 1×1 (reduce) → 3×3 (process) → 1×1 (expand)
    """
    
    expansion = 4  # Output channels = 4 × internal channels
    
    def __init__(self, in_channels, internal_channels, stride=1):
        super().__init__()
        
        out_channels = internal_channels * self.expansion
        
        # 1×1 reduce
        self.conv1 = nn.Conv2d(
            in_channels, internal_channels,
            kernel_size=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(internal_channels)
        
        # 3×3 process
        self.conv2 = nn.Conv2d(
            internal_channels, internal_channels,
            kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(internal_channels)
        
        # 1×1 expand
        self.conv3 = nn.Conv2d(
            internal_channels, out_channels,
            kernel_size=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(out_channels)
        
        # Shortcut
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                         kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        identity = x
        
        out = torch.relu(self.bn1(self.conv1(x)))
        out = torch.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        
        out += self.shortcut(identity)
        out = torch.relu(out)
        
        return out


# Compare parameters
basic = ResidualBlock(256, 256)
bottleneck = BottleneckBlock(256, 64)  # 64 internal, 256 output

basic_params = sum(p.numel() for p in basic.parameters())
bottleneck_params = sum(p.numel() for p in bottleneck.parameters())

print(f"Basic Block (256→256):      {basic_params:,} parameters")
print(f"Bottleneck Block (256→256): {bottleneck_params:,} parameters")
print(f"Reduction: {100*(1 - bottleneck_params/basic_params):.1f}%")
```

### Complete ResNet Implementation

```python theme={null}
class ResNet(nn.Module):
    """
    ResNet architecture for ImageNet.
    
    Variants:
    - ResNet-18:  [2, 2, 2, 2] with BasicBlock
    - ResNet-34:  [3, 4, 6, 3] with BasicBlock
    - ResNet-50:  [3, 4, 6, 3] with Bottleneck
    - ResNet-101: [3, 4, 23, 3] with Bottleneck
    - ResNet-152: [3, 8, 36, 3] with Bottleneck
    """
    
    def __init__(self, block, num_blocks, num_classes=1000):
        super().__init__()
        
        self.in_channels = 64
        
        # Initial layers (stem)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, 
                               stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # Residual layers
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
        # Classification head
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Output channels depend on block type
        if hasattr(block, 'expansion'):
            final_channels = 512 * block.expansion
        else:
            final_channels = 512
            
        self.fc = nn.Linear(final_channels, num_classes)
        
        # Initialize weights
        self._initialize_weights()
    
    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        
        for stride in strides:
            if hasattr(block, 'expansion'):
                layers.append(block(self.in_channels, out_channels, stride))
                self.in_channels = out_channels * block.expansion
            else:
                layers.append(block(self.in_channels, out_channels, stride))
                self.in_channels = out_channels
        
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', 
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Stem
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.maxpool(x)
        
        # Residual blocks
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # Classification
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x


def ResNet18(num_classes=1000):
    return ResNet(ResidualBlock, [2, 2, 2, 2], num_classes)

def ResNet50(num_classes=1000):
    return ResNet(BottleneckBlock, [3, 4, 6, 3], num_classes)


# Compare architectures
resnet18 = ResNet18()
resnet50 = ResNet50()

print(f"ResNet-18: {sum(p.numel() for p in resnet18.parameters()):,} parameters")
print(f"ResNet-50: {sum(p.numel() for p in resnet50.parameters()):,} parameters")

# Test forward pass
x = torch.randn(2, 3, 224, 224)
print(f"\nInput: {x.shape}")
print(f"ResNet-18 output: {resnet18(x).shape}")
print(f"ResNet-50 output: {resnet50(x).shape}")
```

***

## EfficientNet: Compound Scaling

### The Scaling Problem

How do you scale a network when you have more compute? Previous approaches:

* **Width scaling**: More channels (WideResNet)
* **Depth scaling**: More layers (ResNet-152)
* **Resolution scaling**: Higher input resolution

EfficientNet's insight: **Scale all three together!**

$$
\text{depth} = \alpha^\phi, \quad \text{width} = \beta^\phi, \quad \text{resolution} = \gamma^\phi
$$

Subject to: $\alpha \cdot \beta^2 \cdot \gamma^2 \approx 2$ (to double compute when $\phi$ increases by 1)

**Mathematical intuition behind the constraint**: FLOPS scale linearly with depth ($\alpha$), quadratically with width ($\beta^2$, because both input and output channels grow), and quadratically with resolution ($\gamma^2$, because spatial dimensions are 2D). So the total compute scales as $\alpha \cdot \beta^2 \cdot \gamma^2$. Setting this product to 2 means each increment of $\phi$ exactly doubles compute -- giving you a clean scaling knob. The specific values $\alpha=1.2, \beta=1.1, \gamma=1.15$ were found by grid search on a small baseline model, then applied uniformly to scale up.

This is the equivalent of a budget allocation problem: given twice the money, should you hire more workers (width), add more assembly steps (depth), or buy higher-resolution materials (resolution)? EfficientNet's answer is "a bit of each, in a fixed ratio," which consistently outperforms dumping all the budget into one dimension.

```python theme={null}
class MBConvBlock(nn.Module):
    """
    Mobile Inverted Bottleneck Conv Block (MBConv).
    
    The building block of EfficientNet, originally from MobileNetV2.
    Uses inverted residuals: narrow → wide → narrow
    """
    
    def __init__(self, in_channels, out_channels, expand_ratio, 
                 kernel_size, stride, se_ratio=0.25):
        super().__init__()
        
        self.stride = stride
        self.use_residual = (stride == 1 and in_channels == out_channels)
        
        expanded = in_channels * expand_ratio
        
        layers = []
        
        # Expansion (if expand_ratio > 1)
        if expand_ratio != 1:
            layers.extend([
                nn.Conv2d(in_channels, expanded, 1, bias=False),
                nn.BatchNorm2d(expanded),
                nn.SiLU(inplace=True),  # Swish activation
            ])
        
        # Depthwise convolution
        layers.extend([
            nn.Conv2d(expanded, expanded, kernel_size,
                     stride=stride, padding=kernel_size//2,
                     groups=expanded, bias=False),
            nn.BatchNorm2d(expanded),
            nn.SiLU(inplace=True),
        ])
        
        # Squeeze-and-Excitation
        squeezed = max(1, int(in_channels * se_ratio))
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(expanded, squeezed, 1),
            nn.SiLU(inplace=True),
            nn.Conv2d(squeezed, expanded, 1),
            nn.Sigmoid(),
        )
        
        # Projection
        layers.extend([
            nn.Conv2d(expanded, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
        ])
        
        self.conv = nn.Sequential(*layers[:3*int(expand_ratio>1) + 3])
        self.project = nn.Sequential(*layers[-2:])
        self.expanded = expanded
    
    def forward(self, x):
        identity = x
        
        out = self.conv(x) if hasattr(self, 'conv') else x
        
        # SE attention
        out = out * self.se(out)
        
        # Project
        out = self.project(out)
        
        # Residual connection
        if self.use_residual:
            out = out + identity
        
        return out


# Simplified EfficientNet-B0 base configuration
efficientnet_b0_config = [
    # (expand_ratio, channels, num_blocks, kernel, stride)
    (1, 16, 1, 3, 1),
    (6, 24, 2, 3, 2),
    (6, 40, 2, 5, 2),
    (6, 80, 3, 3, 2),
    (6, 112, 3, 5, 1),
    (6, 192, 4, 5, 2),
    (6, 320, 1, 3, 1),
]

print("EfficientNet-B0 Configuration:")
print("-" * 60)
for i, (expand, channels, blocks, kernel, stride) in enumerate(efficientnet_b0_config):
    print(f"Stage {i+1}: {blocks} × MBConv{expand} {kernel}×{kernel}, {channels} channels, stride {stride}")
```

### Scaling Coefficients

```python theme={null}
# EfficientNet compound scaling
efficientnet_scales = {
    'B0': {'width': 1.0, 'depth': 1.0, 'resolution': 224, 'dropout': 0.2},
    'B1': {'width': 1.0, 'depth': 1.1, 'resolution': 240, 'dropout': 0.2},
    'B2': {'width': 1.1, 'depth': 1.2, 'resolution': 260, 'dropout': 0.3},
    'B3': {'width': 1.2, 'depth': 1.4, 'resolution': 300, 'dropout': 0.3},
    'B4': {'width': 1.4, 'depth': 1.8, 'resolution': 380, 'dropout': 0.4},
    'B5': {'width': 1.6, 'depth': 2.2, 'resolution': 456, 'dropout': 0.4},
    'B6': {'width': 1.8, 'depth': 2.6, 'resolution': 528, 'dropout': 0.5},
    'B7': {'width': 2.0, 'depth': 3.1, 'resolution': 600, 'dropout': 0.5},
}

print("EfficientNet Scaling:")
print("-" * 65)
print(f"{'Model':<8} {'Width':<8} {'Depth':<8} {'Resolution':<12} {'Params (M)':<12}")
print("-" * 65)

# Approximate parameter counts (in millions)
param_counts = {'B0': 5.3, 'B1': 7.8, 'B2': 9.2, 'B3': 12, 
                'B4': 19, 'B5': 30, 'B6': 43, 'B7': 66}

for model, config in efficientnet_scales.items():
    print(f"{model:<8} {config['width']:<8.1f} {config['depth']:<8.1f} "
          f"{config['resolution']:<12} {param_counts[model]:<12}")
```

***

## Modern CNN Design Principles

### Principle 1: Use Small Filters

```python theme={null}
def compare_filter_sizes():
    """
    Demonstrate why 3×3 filters are preferred over larger ones.
    """
    # Two 3×3 convs = same receptive field as one 5×5
    # But with more non-linearity and fewer params
    
    in_ch, out_ch = 64, 64
    
    # Single 5×5
    conv_5x5 = nn.Conv2d(in_ch, out_ch, 5, padding=2)
    params_5x5 = sum(p.numel() for p in conv_5x5.parameters())
    
    # Two 3×3
    conv_3x3_stack = nn.Sequential(
        nn.Conv2d(in_ch, out_ch, 3, padding=1),
        nn.ReLU(),
        nn.Conv2d(out_ch, out_ch, 3, padding=1),
    )
    params_3x3 = sum(p.numel() for p in conv_3x3_stack.parameters())
    
    print("Comparison: 5×5 vs Two 3×3")
    print(f"  5×5 params: {params_5x5:,}")
    print(f"  2×(3×3) params: {params_3x3:,}")
    print(f"  Savings: {100*(1-params_3x3/params_5x5):.1f}%")
    print(f"  Plus: 2 ReLUs instead of 1 = more non-linearity")

compare_filter_sizes()
```

### Principle 2: Batch Normalization Everywhere

```python theme={null}
class ConvBNReLU(nn.Module):
    """The standard Conv → BatchNorm → ReLU block."""
    
    def __init__(self, in_channels, out_channels, kernel_size=3, 
                 stride=1, padding=1, groups=1):
        super().__init__()
        
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size,
            stride=stride, padding=padding, groups=groups, bias=False
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))
```

### Principle 3: Skip Connections for Deep Networks

```python theme={null}
class ModernBlock(nn.Module):
    """
    A modern residual block combining best practices:
    - Pre-activation (BN-ReLU before Conv)
    - Bottleneck design
    - SE attention
    """
    
    def __init__(self, in_channels, out_channels, stride=1, se_ratio=0.25):
        super().__init__()
        
        mid_channels = out_channels // 4
        
        # Pre-activation
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, bias=False)
        
        self.bn2 = nn.BatchNorm2d(mid_channels)
        self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 
                               stride=stride, padding=1, bias=False)
        
        self.bn3 = nn.BatchNorm2d(mid_channels)
        self.conv3 = nn.Conv2d(mid_channels, out_channels, 1, bias=False)
        
        # SE block
        squeezed = max(1, int(out_channels * se_ratio))
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(out_channels, squeezed, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(squeezed, out_channels, 1),
            nn.Sigmoid()
        )
        
        # Shortcut
        self.shortcut = nn.Identity()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Conv2d(
                in_channels, out_channels, 1, stride=stride, bias=False
            )
    
    def forward(self, x):
        identity = self.shortcut(x)
        
        out = torch.relu(self.bn1(x))
        out = self.conv1(out)
        
        out = torch.relu(self.bn2(out))
        out = self.conv2(out)
        
        out = torch.relu(self.bn3(out))
        out = self.conv3(out)
        
        # SE attention
        out = out * self.se(out)
        
        return out + identity
```

### Principle 4: Global Average Pooling for Classification

```python theme={null}
class ModernCNN(nn.Module):
    """
    Modern CNN following current best practices.
    """
    
    def __init__(self, num_classes=1000):
        super().__init__()
        
        # Stem: aggressive downsampling
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1),
        )
        
        # Main blocks
        self.stage1 = self._make_stage(64, 256, 3, stride=1)
        self.stage2 = self._make_stage(256, 512, 4, stride=2)
        self.stage3 = self._make_stage(512, 1024, 6, stride=2)
        self.stage4 = self._make_stage(1024, 2048, 3, stride=2)
        
        # Head: GAP instead of FC layers
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, num_classes)
    
    def _make_stage(self, in_ch, out_ch, num_blocks, stride):
        blocks = [ModernBlock(in_ch, out_ch, stride)]
        for _ in range(1, num_blocks):
            blocks.append(ModernBlock(out_ch, out_ch))
        return nn.Sequential(*blocks)
    
    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.avgpool(x)
        x = x.flatten(1)
        x = self.fc(x)
        return x
```

***

## Architecture Comparison

```python theme={null}
import torchvision.models as models

def compare_architectures():
    """Compare popular CNN architectures."""
    
    architectures = {
        'VGG-16': models.vgg16(),
        'ResNet-50': models.resnet50(),
        'DenseNet-121': models.densenet121(),
        'EfficientNet-B0': models.efficientnet_b0(),
        'MobileNetV3-Large': models.mobilenet_v3_large(),
    }
    
    print("Architecture Comparison:")
    print("-" * 75)
    print(f"{'Model':<20} {'Params (M)':<15} {'MACs (G)':<15} {'Top-1 Acc':<15}")
    print("-" * 75)
    
    # Approximate values
    stats = {
        'VGG-16': (138, 15.5, 71.6),
        'ResNet-50': (25.6, 4.1, 76.1),
        'DenseNet-121': (8.0, 2.9, 74.4),
        'EfficientNet-B0': (5.3, 0.39, 77.1),
        'MobileNetV3-Large': (5.4, 0.22, 75.2),
    }
    
    for name in architectures:
        params, macs, acc = stats[name]
        print(f"{name:<20} {params:<15.1f} {macs:<15.2f} {acc:<15.1f}")

compare_architectures()
```

### When to Use Which Architecture

| Architecture     | Best For                           | Strengths                   | Weaknesses             |
| ---------------- | ---------------------------------- | --------------------------- | ---------------------- |
| **VGG**          | Teaching, feature extraction       | Simple, uniform structure   | Very large, slow       |
| **ResNet**       | General purpose, transfer learning | Robust, well-studied        | Large memory footprint |
| **DenseNet**     | Limited data, feature reuse        | Parameter efficient         | Memory intensive       |
| **EfficientNet** | Best accuracy/compute              | State-of-the-art efficiency | Complex to implement   |
| **MobileNet**    | Mobile/edge deployment             | Extremely fast              | Lower accuracy         |

***

## Practical Tips for CNN Design

<Tip>
  **Design Checklist for New CNNs:**

  1. Start with a proven architecture (ResNet, EfficientNet)
  2. Use 3×3 convolutions (or 1×1 for channel mixing)
  3. Add BatchNorm after every conv layer
  4. Use skip connections if depth > 10 layers
  5. Use GAP instead of large FC layers
  6. Consider SE attention for accuracy boost
  7. Profile memory and compute before deploying
</Tip>

<Warning>
  **Common Mistakes:**

  * Too many FC layers at the end (use GAP instead)
  * Not using skip connections in deep networks
  * Forgetting BatchNorm bias should be False
  * Using large kernels (7x7) everywhere
  * Not considering inference speed vs accuracy tradeoff
</Warning>

<Tip>
  **Training Pitfalls and Debugging Hints for CNN Design:**

  **Spatial dimension mismatch**: The single most common CNN bug. If your forward pass crashes with a shape error, trace the spatial dimensions layer by layer using the formula: `output = (input - kernel + 2*padding) // stride + 1`. Write a helper that prints shapes at each layer during the first forward pass.

  **BatchNorm + bias=False**: When using `nn.BatchNorm2d` after `nn.Conv2d`, always set `bias=False` in the conv layer. BatchNorm already has a learnable bias (beta), so the conv bias is redundant -- it wastes parameters and can slightly slow convergence.

  **ResNet identity shortcut dimension mismatch**: If your residual block crashes, check that the shortcut projection handles both channel count changes AND spatial downsampling (stride). A common error is matching channels but forgetting that stride=2 also halves the spatial dimensions.

  **Checkerboard artifacts from strided transposed convolutions**: If you use `nn.ConvTranspose2d` for upsampling (common in generators and decoders), watch for checkerboard patterns in the output. This happens when `kernel_size` is not divisible by `stride`. Prefer `nn.Upsample` followed by a regular conv, or use `kernel_size=4, stride=2, padding=1`.

  **Frozen BatchNorm during fine-tuning**: When fine-tuning a pretrained CNN with a very small batch size (less than 8), the batch statistics become noisy and destabilize training. Either freeze BatchNorm layers with `model.eval()` for those modules, or switch to GroupNorm which is batch-size independent.
</Tip>

***

## Exercises

<AccordionGroup>
  <Accordion title="Exercise 1: Implement Custom Pooling">
    Implement these pooling variants and compare them on CIFAR-10:

    1. **Mixed Pooling**: $\lambda \cdot \text{max} + (1-\lambda) \cdot \text{avg}$
    2. **Stochastic Pooling**: Sample based on activation magnitudes
    3. **LP Pooling**: $\left(\frac{1}{N}\sum_i x_i^p\right)^{1/p}$

    ```python theme={null}
    class MixedPooling(nn.Module):
        def __init__(self, kernel_size, stride=None, alpha=0.5):
            # Implement mixed max/avg pooling
            pass
    ```
  </Accordion>

  <Accordion title="Exercise 2: Build ResNet Variants">
    Implement and compare these ResNet variations:

    1. **Pre-activation ResNet**: BN-ReLU-Conv instead of Conv-BN-ReLU
    2. **Wide ResNet**: Increase width instead of depth
    3. **ResNeXt**: Grouped convolutions in bottleneck

    Train each on CIFAR-10 and compare accuracy vs parameters.
  </Accordion>

  <Accordion title="Exercise 3: Analyze Receptive Fields">
    Write a function to compute the receptive field of any CNN:

    ```python theme={null}
    def compute_receptive_field(model, input_size=224):
        """
        Compute the receptive field of the final feature map.
        
        Returns:
            rf_size: Size of receptive field in input space
            rf_centers: Locations of receptive field centers
        """
        pass
    ```

    Use it to analyze VGG-16, ResNet-50, and EfficientNet-B0.
  </Accordion>

  <Accordion title="Exercise 4: Architecture Search">
    Implement a simple neural architecture search:

    1. Define a search space (number of layers, channels, kernel sizes)
    2. Use random search to sample architectures
    3. Train each for 10 epochs on CIFAR-10
    4. Find the Pareto frontier of accuracy vs parameters

    ```python theme={null}
    def sample_architecture():
        """Sample a random CNN architecture."""
        pass

    def evaluate_architecture(model, train_loader, val_loader):
        """Train and evaluate an architecture."""
        pass
    ```
  </Accordion>

  <Accordion title="Exercise 5: Transfer Learning Deep Dive">
    Using a pretrained ResNet-50:

    1. Fine-tune only the last layer on a new dataset
    2. Fine-tune all layers with different learning rates
    3. Extract intermediate features and train a separate classifier
    4. Compare accuracy, training time, and generalization

    Use the Oxford Pets or Stanford Cars dataset.
  </Accordion>
</AccordionGroup>

***

## Key Takeaways

| Concept           | Key Insight                                                   |
| ----------------- | ------------------------------------------------------------- |
| **Max Pooling**   | Detects presence of features, provides translation invariance |
| **Stride**        | Learnable downsampling, modern replacement for pooling        |
| **GAP**           | Eliminates FC parameters, improves generalization             |
| **VGGNet**        | Small filters + depth = large receptive field                 |
| **ResNet**        | Skip connections enable training very deep networks           |
| **EfficientNet**  | Scale depth, width, and resolution together                   |
| **Modern Design** | BN everywhere, skip connections, attention, GAP head          |

***

## What's Next

<CardGroup cols={1}>
  <Card title="Module 8: Recurrent Neural Networks" icon="rotate-right" href="/courses/deep-learning-mastery/08-rnns">
    Move from images to sequences — how neural networks process time-series data, text, and other sequential inputs.
  </Card>
</CardGroup>

***

## Interview Deep-Dive

<AccordionGroup>
  <Accordion title="Explain the degradation problem that motivated ResNet. Why is it surprising, and how do skip connections solve it?">
    **Strong Answer:**

    * The degradation problem is counterintuitive: a 56-layer plain network has higher TRAINING error than a 20-layer one. This is not overfitting (where test error is high but training error is low). The deeper network has strictly more capacity -- it could, in theory, learn the 20-layer solution by making the extra 36 layers identity mappings. But standard training cannot find this solution.
    * The fundamental issue is optimization, not representation. Gradient-based training cannot effectively learn identity mappings through a chain of nonlinear layers. Learning $H(x) = x$ when $H$ involves two conv layers, two BN layers, and two ReLU activations is surprisingly difficult because the optimization landscape has poor conditioning.
    * Skip connections reframe the learning task: instead of learning $H(x) = x$ (identity), the network learns $F(x) = H(x) - x = 0$ (the residual). Learning to output zero is trivially easy -- just push all weights toward zero. The output becomes $F(x) + x = 0 + x = x$, recovering the identity.
    * The gradient flow benefit: $\partial(F(x) + x)/\partial x = \partial F/\partial x + 1$. The additive 1 means gradients always have a direct path to earlier layers, regardless of what $F$ does. This additive gradient highway prevents vanishing gradients across arbitrarily many layers. It is the add gate pattern from backpropagation: addition distributes gradients without attenuation.

    **Follow-up: Pre-activation ResNet (He et al., 2016) places BN and ReLU before the conv layers rather than after. Why does this improve training?**

    In the original "post-activation" ResNet, the ReLU after the addition $\text{ReLU}(F(x) + x)$ means the skip connection output passes through a non-linearity, which can break the clean identity path. Pre-activation places BN-ReLU-Conv-BN-ReLU-Conv inside the residual branch, leaving the skip connection completely clean: $y = F(x) + x$ with no non-linearity applied to the sum. This creates a truly unimpeded gradient highway for the identity path. Empirically, pre-activation ResNets train more easily for very deep networks (1000+ layers) and achieve slightly better accuracy on CIFAR benchmarks.
  </Accordion>

  <Accordion title="Max pooling versus strided convolution for downsampling: what are the trade-offs, and which would you choose?">
    **Strong Answer:**

    * **Max pooling** (parameter-free): takes the maximum value in each window. Provides translation invariance (small shifts do not change the max), acts as a feature detector ("is this feature present anywhere in this region?"), and adds zero parameters. However, during backpropagation, gradients only flow to the single maximum element -- all other positions in the window receive zero gradient, discarding information about sub-maximum activations.
    * **Strided convolution** (learnable): a regular convolution with stride > 1. The downsampling operation itself is learned, allowing the network to optimize how information is aggregated. It adds parameters ($K^2 \times C_{in} \times C_{out}$) and is more expressive. However, strided convolutions can introduce aliasing artifacts (high-frequency patterns can alias into low-frequency ones when downsampling without proper anti-aliasing).
    * **Modern trend**: strided convolutions are preferred in most recent architectures (ResNet, EfficientNet, ConvNeXt). The "All Convolutional Net" paper (Springenberg et al., 2015) showed that replacing all pooling with strided convolutions matches or exceeds performance. Global Average Pooling (GAP) remains standard for the final spatial reduction before the classification head.
    * **When to use max pooling**: when you want strict translation invariance (small shifts should produce identical features), when you want to minimize parameters (embedded/mobile), or in U-Net-style architectures where the corresponding upsampling path needs max-unpool indices for precise localization.

    **Follow-up: What is the anti-aliasing problem with strided convolutions, and how was it addressed?**

    Zhang (2019, "Making Convolutional Networks Shift-Invariant Again") showed that strided convolutions violate shift invariance: shifting the input by one pixel can completely change which elements survive the stride, producing drastically different outputs. The fix is to apply a low-pass (blur) filter before the stride, similar to the Nyquist-Shannon sampling theorem. The "BlurPool" approach inserts a fixed Gaussian blur kernel between the convolution and the stride, suppressing high-frequency components that would alias. This improves both shift consistency and classification accuracy by 1-2% on ImageNet.
  </Accordion>

  <Accordion title="You need to design a CNN for a custom image classification task with 50 classes and 10,000 training images. Walk through your architecture decisions.">
    **Strong Answer:**

    * With 10,000 images across 50 classes (200 per class), this is firmly in the transfer learning regime. Training from scratch would overfit catastrophically on a modern architecture.
    * **Architecture choice**: start with a pretrained ResNet-50 or EfficientNet-B0. Both are well-validated, widely available, and have strong pretrained ImageNet features. For 200 images per class, feature extraction (frozen backbone + trained classifier) is likely sufficient. If performance is inadequate, gradually unfreeze later layers.
    * **Classifier head**: replace the final FC layer with a lightweight head: `nn.Sequential(nn.Dropout(0.3), nn.Linear(2048, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, 50))`. Dropout is critical at this data scale.
    * **Data augmentation**: aggressive augmentation is essential. Use RandomResizedCrop, HorizontalFlip, ColorJitter, RandAugment, and potentially MixUp/CutMix. This effectively multiplies the dataset by 10-50x.
    * **Training recipe**: AdamW optimizer, learning rate 1e-3 for the classifier head, cosine annealing schedule, weight decay 0.01. If fine-tuning pretrained layers, use discriminative learning rates (10-100x smaller for earlier layers).
    * **Regularization stack**: weight decay + dropout + data augmentation + label smoothing (epsilon=0.1) + early stopping based on validation loss.
    * **What I would NOT do**: train from scratch, use a very deep or very wide custom architecture, skip data augmentation, or use a single learning rate for all layers during fine-tuning.

    **Follow-up: How would your approach change if you had 1 million training images instead?**

    With 1M images (20,000 per class), full fine-tuning becomes viable and likely outperforms feature extraction. I would still start from a pretrained backbone (transfer learning always helps even with large datasets), but I would unfreeze all layers from the start with discriminative learning rates. I would also consider training a larger model (EfficientNet-B4 or ViT-B) since the data can support more parameters. Data augmentation remains important but can be less aggressive. The training schedule would be longer (100-200 epochs with cosine decay) and I would use larger batch sizes with mixed-precision training. At this scale, I would also run proper ablation studies to validate each architectural choice, since the compute investment justifies rigorous experimentation.
  </Accordion>
</AccordionGroup>
