Skip to main content
Neural Architecture Search

Neural Architecture Search: Automating Design

Why Automate Architecture Design?

Designing neural networks is:
  • Time-consuming: Weeks of human experimentation
  • Expertise-dependent: Requires deep intuition
  • Suboptimal: Humans explore limited search space
NAS algorithms can:
  • Explore vast architecture spaces systematically
  • Find novel, non-intuitive designs
  • Optimize for specific hardware constraints
  • Discover architectures that outperform human designs
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from typing import List, Dict, Tuple, Optional
from collections import namedtuple

torch.manual_seed(42)

NAS Components

1. Search Space

Define what architectures are possible:
# Operation choices for each edge
OPERATIONS = {
    'none': lambda C, stride: Zero(stride),
    'skip_connect': lambda C, stride: Identity() if stride == 1 else FactorizedReduce(C, C),
    'sep_conv_3x3': lambda C, stride: SepConv(C, C, 3, stride, 1),
    'sep_conv_5x5': lambda C, stride: SepConv(C, C, 5, stride, 2),
    'dil_conv_3x3': lambda C, stride: DilConv(C, C, 3, stride, 2, 2),
    'dil_conv_5x5': lambda C, stride: DilConv(C, C, 5, stride, 4, 2),
    'avg_pool_3x3': lambda C, stride: nn.AvgPool2d(3, stride=stride, padding=1),
    'max_pool_3x3': lambda C, stride: nn.MaxPool2d(3, stride=stride, padding=1),
}


class SepConv(nn.Module):
    """Separable convolution."""
    
    def __init__(self, C_in, C_out, kernel_size, stride, padding):
        super().__init__()
        self.op = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv2d(C_in, C_in, kernel_size, stride, padding, groups=C_in, bias=False),
            nn.Conv2d(C_in, C_out, 1, bias=False),
            nn.BatchNorm2d(C_out),
            nn.ReLU(inplace=False),
            nn.Conv2d(C_out, C_out, kernel_size, 1, padding, groups=C_out, bias=False),
            nn.Conv2d(C_out, C_out, 1, bias=False),
            nn.BatchNorm2d(C_out)
        )
    
    def forward(self, x):
        return self.op(x)


class DilConv(nn.Module):
    """Dilated convolution."""
    
    def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
        super().__init__()
        self.op = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, bias=False),
            nn.Conv2d(C_in, C_out, 1, bias=False),
            nn.BatchNorm2d(C_out)
        )
    
    def forward(self, x):
        return self.op(x)


class Identity(nn.Module):
    def forward(self, x):
        return x


class Zero(nn.Module):
    def __init__(self, stride):
        super().__init__()
        self.stride = stride
    
    def forward(self, x):
        if self.stride == 1:
            return x * 0
        return x[:, :, ::self.stride, ::self.stride] * 0


class FactorizedReduce(nn.Module):
    """Reduce spatial size while maintaining channels."""
    
    def __init__(self, C_in, C_out):
        super().__init__()
        self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, bias=False)
        self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, bias=False)
        self.bn = nn.BatchNorm2d(C_out)
    
    def forward(self, x):
        out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
        return self.bn(out)

2. Search Strategy

How to explore the search space:
class SearchStrategy:
    """Base class for NAS search strategies."""
    
    def __init__(self, search_space: Dict):
        self.search_space = search_space
    
    def sample_architecture(self) -> Dict:
        """Sample a random architecture from the search space."""
        raise NotImplementedError
    
    def update(self, architecture: Dict, performance: float):
        """Update search strategy based on evaluation."""
        raise NotImplementedError

3. Performance Estimation

How to evaluate candidate architectures efficiently:
class PerformanceEstimator:
    """Base class for performance estimation strategies."""
    
    def estimate(self, architecture: Dict, dataset) -> float:
        """Estimate performance of an architecture."""
        raise NotImplementedError


class FullTrainingEstimator(PerformanceEstimator):
    """Train model fully - slow but accurate."""
    
    def estimate(self, architecture, dataset, epochs=100):
        model = build_model(architecture)
        train(model, dataset, epochs)
        return evaluate(model, dataset)


class EarlyStoppingEstimator(PerformanceEstimator):
    """Train for few epochs - faster but less accurate."""
    
    def estimate(self, architecture, dataset, epochs=10):
        model = build_model(architecture)
        train(model, dataset, epochs)
        return evaluate(model, dataset)

Make architecture search differentiable by using continuous relaxation.

The DARTS Cell

class MixedOp(nn.Module):
    """Mixed operation with learnable architecture weights."""
    
    def __init__(self, C: int, stride: int):
        super().__init__()
        self._ops = nn.ModuleList()
        
        for name, op_fn in OPERATIONS.items():
            op = op_fn(C, stride)
            self._ops.append(op)
    
    def forward(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
        """
        Weighted sum of all operations.
        weights: softmax of architecture parameters
        """
        return sum(w * op(x) for w, op in zip(weights, self._ops))


class DARTSCell(nn.Module):
    """
    DARTS cell with learnable architecture.
    
    A cell is a DAG of nodes where:
    - Node 0, 1: input nodes (from previous cells)
    - Node 2, 3, ...: intermediate nodes
    - Output: concatenation of all intermediate nodes
    """
    
    def __init__(self, num_nodes: int, C: int, reduction: bool):
        super().__init__()
        
        self.num_nodes = num_nodes
        self.reduction = reduction
        
        self._ops = nn.ModuleList()
        
        # Create mixed ops for each edge
        for i in range(num_nodes):
            for j in range(i + 2):  # Connect to all previous nodes
                stride = 2 if reduction and j < 2 else 1
                self._ops.append(MixedOp(C, stride))
        
        # Preprocessing for input nodes
        self.preprocess0 = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(C, C, 1, bias=False),
            nn.BatchNorm2d(C)
        )
        self.preprocess1 = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(C, C, 1, bias=False),
            nn.BatchNorm2d(C)
        )
    
    def forward(
        self, 
        s0: torch.Tensor, 
        s1: torch.Tensor, 
        weights: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            s0: Output from cell k-2
            s1: Output from cell k-1
            weights: Architecture weights [num_edges, num_ops]
        """
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)
        
        states = [s0, s1]
        
        offset = 0
        for i in range(self.num_nodes):
            # Sum weighted outputs from all previous nodes
            s = sum(
                self._ops[offset + j](states[j], weights[offset + j])
                for j in range(i + 2)
            )
            offset += i + 2
            states.append(s)
        
        # Concatenate all intermediate nodes
        return torch.cat(states[2:], dim=1)


class DARTSNetwork(nn.Module):
    """Complete DARTS searchable network."""
    
    def __init__(
        self,
        C: int = 16,
        num_classes: int = 10,
        num_layers: int = 8,
        num_nodes: int = 4,
        num_ops: int = 8
    ):
        super().__init__()
        
        self.num_layers = num_layers
        self.num_nodes = num_nodes
        
        # Initial stem
        self.stem = nn.Sequential(
            nn.Conv2d(3, C, 3, padding=1, bias=False),
            nn.BatchNorm2d(C)
        )
        
        # Build cells
        self.cells = nn.ModuleList()
        
        C_curr = C
        reduction_layers = [num_layers // 3, 2 * num_layers // 3]
        
        for i in range(num_layers):
            reduction = i in reduction_layers
            cell = DARTSCell(num_nodes, C_curr, reduction)
            self.cells.append(cell)
            
            if reduction:
                C_curr *= 2
        
        # Classifier
        self.classifier = nn.Linear(C_curr * num_nodes, num_classes)
        
        # Architecture parameters
        num_edges = sum(i + 2 for i in range(num_nodes))
        self.arch_params_normal = nn.Parameter(torch.randn(num_edges, num_ops) * 1e-3)
        self.arch_params_reduce = nn.Parameter(torch.randn(num_edges, num_ops) * 1e-3)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        s0 = s1 = self.stem(x)
        
        for cell in self.cells:
            if cell.reduction:
                weights = F.softmax(self.arch_params_reduce, dim=-1)
            else:
                weights = F.softmax(self.arch_params_normal, dim=-1)
            
            s0, s1 = s1, cell(s0, s1, weights)
        
        out = F.adaptive_avg_pool2d(s1, 1)
        out = out.view(out.size(0), -1)
        return self.classifier(out)
    
    def arch_parameters(self):
        return [self.arch_params_normal, self.arch_params_reduce]
    
    def model_parameters(self):
        return [p for n, p in self.named_parameters() if 'arch_params' not in n]


class DARTSTrainer:
    """Bi-level optimization for DARTS."""
    
    def __init__(
        self,
        model: DARTSNetwork,
        train_loader,
        val_loader,
        lr_model: float = 0.025,
        lr_arch: float = 3e-4
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        
        # Two optimizers
        self.optimizer_model = torch.optim.SGD(
            model.model_parameters(),
            lr=lr_model,
            momentum=0.9,
            weight_decay=3e-4
        )
        
        self.optimizer_arch = torch.optim.Adam(
            model.arch_parameters(),
            lr=lr_arch,
            weight_decay=1e-3
        )
    
    def train_epoch(self):
        """One epoch of bi-level optimization."""
        
        self.model.train()
        train_iter = iter(self.train_loader)
        val_iter = iter(self.val_loader)
        
        for step in range(len(self.train_loader)):
            # Get training batch
            try:
                train_x, train_y = next(train_iter)
            except StopIteration:
                train_iter = iter(self.train_loader)
                train_x, train_y = next(train_iter)
            
            # Get validation batch
            try:
                val_x, val_y = next(val_iter)
            except StopIteration:
                val_iter = iter(self.val_loader)
                val_x, val_y = next(val_iter)
            
            train_x, train_y = train_x.cuda(), train_y.cuda()
            val_x, val_y = val_x.cuda(), val_y.cuda()
            
            # Update architecture parameters on validation data
            self.optimizer_arch.zero_grad()
            val_loss = F.cross_entropy(self.model(val_x), val_y)
            val_loss.backward()
            self.optimizer_arch.step()
            
            # Update model parameters on training data
            self.optimizer_model.zero_grad()
            train_loss = F.cross_entropy(self.model(train_x), train_y)
            train_loss.backward()
            self.optimizer_model.step()
    
    def derive_architecture(self) -> Dict:
        """Extract discrete architecture from continuous weights."""
        
        def parse_cell(weights):
            """Select top-2 operations for each node."""
            gene = []
            
            offset = 0
            for i in range(self.model.num_nodes):
                # Get weights for edges to this node
                node_weights = weights[offset:offset + i + 2]
                
                # Select top-2 edges
                edge_scores = node_weights.max(dim=-1)[0]
                top2 = edge_scores.topk(2)[1]
                
                for j in top2:
                    op_idx = node_weights[j].argmax().item()
                    op_name = list(OPERATIONS.keys())[op_idx]
                    gene.append((op_name, j.item()))
                
                offset += i + 2
            
            return gene
        
        with torch.no_grad():
            normal = parse_cell(F.softmax(self.model.arch_params_normal, dim=-1))
            reduce = parse_cell(F.softmax(self.model.arch_params_reduce, dim=-1))
        
        return {'normal': normal, 'reduce': reduce}

Use evolutionary algorithms to search for architectures.
@dataclass
class Individual:
    """An individual in the population (an architecture)."""
    architecture: Dict
    fitness: float = 0.0
    age: int = 0


class EvolutionaryNAS:
    """Regularized evolution for architecture search."""
    
    def __init__(
        self,
        search_space: Dict,
        population_size: int = 100,
        tournament_size: int = 10,
        mutation_prob: float = 0.1
    ):
        self.search_space = search_space
        self.population_size = population_size
        self.tournament_size = tournament_size
        self.mutation_prob = mutation_prob
        
        self.population: List[Individual] = []
        self.history: List[Individual] = []
    
    def initialize_population(self):
        """Create initial random population."""
        for _ in range(self.population_size):
            arch = self.sample_random_architecture()
            self.population.append(Individual(architecture=arch))
    
    def sample_random_architecture(self) -> Dict:
        """Sample a random architecture from search space."""
        arch = {}
        for key, choices in self.search_space.items():
            arch[key] = random.choice(choices)
        return arch
    
    def mutate(self, parent: Dict) -> Dict:
        """Mutate an architecture."""
        child = parent.copy()
        
        # Randomly select one gene to mutate
        key = random.choice(list(self.search_space.keys()))
        choices = self.search_space[key]
        child[key] = random.choice(choices)
        
        return child
    
    def tournament_select(self) -> Individual:
        """Select parent via tournament selection."""
        candidates = random.sample(self.population, self.tournament_size)
        return max(candidates, key=lambda x: x.fitness)
    
    def evolve(self, num_generations: int, evaluator: PerformanceEstimator, dataset):
        """Run evolutionary search."""
        
        # Initialize
        self.initialize_population()
        
        # Evaluate initial population
        for ind in self.population:
            ind.fitness = evaluator.estimate(ind.architecture, dataset)
        
        best_fitness = max(ind.fitness for ind in self.population)
        print(f"Generation 0: Best fitness = {best_fitness:.4f}")
        
        for gen in range(num_generations):
            # Select parent
            parent = self.tournament_select()
            
            # Mutate to create child
            child_arch = self.mutate(parent.architecture)
            
            # Evaluate child
            child_fitness = evaluator.estimate(child_arch, dataset)
            child = Individual(architecture=child_arch, fitness=child_fitness)
            
            # Add to population
            self.population.append(child)
            
            # Remove oldest individual (regularized evolution)
            self.population.sort(key=lambda x: x.age)
            oldest = self.population.pop(0)
            
            # Age all individuals
            for ind in self.population:
                ind.age += 1
            
            # Track history
            self.history.append(child)
            
            # Report
            if (gen + 1) % 50 == 0:
                best_fitness = max(ind.fitness for ind in self.population)
                print(f"Generation {gen+1}: Best fitness = {best_fitness:.4f}")
        
        # Return best
        return max(self.population, key=lambda x: x.fitness)


# Example search space
CELL_SEARCH_SPACE = {
    'op_0_0': list(OPERATIONS.keys()),
    'op_0_1': list(OPERATIONS.keys()),
    'op_1_0': list(OPERATIONS.keys()),
    'op_1_1': list(OPERATIONS.keys()),
    'op_1_2': list(OPERATIONS.keys()),
    'op_1_3': list(OPERATIONS.keys()),
    # ... more edges
}

Weight Sharing / One-Shot NAS

Train a supernet once, then search by selecting subnetworks.
class SuperNet(nn.Module):
    """
    Supernet with all possible operations.
    Different architectures share weights.
    """
    
    def __init__(self, C: int = 16, num_classes: int = 10, num_layers: int = 8):
        super().__init__()
        
        self.num_layers = num_layers
        
        # Stem
        self.stem = nn.Sequential(
            nn.Conv2d(3, C, 3, padding=1, bias=False),
            nn.BatchNorm2d(C)
        )
        
        # Choice blocks
        self.layers = nn.ModuleList()
        
        for i in range(num_layers):
            stride = 2 if i in [num_layers // 3, 2 * num_layers // 3] else 1
            layer = ChoiceBlock(C, C if stride == 1 else C * 2, stride)
            self.layers.append(layer)
            
            if stride == 2:
                C *= 2
        
        # Classifier
        self.classifier = nn.Linear(C, num_classes)
    
    def forward(self, x: torch.Tensor, architecture: List[int]) -> torch.Tensor:
        """
        Forward with specific architecture.
        architecture: List of operation indices, one per layer.
        """
        x = self.stem(x)
        
        for layer, op_idx in zip(self.layers, architecture):
            x = layer(x, op_idx)
        
        x = F.adaptive_avg_pool2d(x, 1)
        x = x.view(x.size(0), -1)
        return self.classifier(x)
    
    def sample_architecture(self) -> List[int]:
        """Sample random architecture."""
        return [random.randint(0, len(OPERATIONS) - 1) for _ in range(self.num_layers)]


class ChoiceBlock(nn.Module):
    """Block with multiple operation choices."""
    
    def __init__(self, C_in: int, C_out: int, stride: int):
        super().__init__()
        
        self.ops = nn.ModuleList([
            op_fn(C_in, stride) for name, op_fn in OPERATIONS.items()
        ])
        
        # Project if needed
        self.proj = None
        if C_in != C_out or stride != 1:
            self.proj = nn.Sequential(
                nn.Conv2d(C_in, C_out, 1, stride=stride, bias=False),
                nn.BatchNorm2d(C_out)
            )
    
    def forward(self, x: torch.Tensor, op_idx: int) -> torch.Tensor:
        out = self.ops[op_idx](x)
        
        if self.proj is not None:
            out = self.proj(out)
        
        return out


class OneShotNAS:
    """One-shot NAS with weight sharing."""
    
    def __init__(
        self,
        supernet: SuperNet,
        train_loader,
        val_loader,
        lr: float = 0.025
    ):
        self.supernet = supernet
        self.train_loader = train_loader
        self.val_loader = val_loader
        
        self.optimizer = torch.optim.SGD(
            supernet.parameters(),
            lr=lr,
            momentum=0.9,
            weight_decay=3e-4
        )
    
    def train_supernet(self, epochs: int = 50):
        """Train supernet with random architecture sampling."""
        
        for epoch in range(epochs):
            self.supernet.train()
            
            for x, y in self.train_loader:
                x, y = x.cuda(), y.cuda()
                
                # Sample random architecture
                arch = self.supernet.sample_architecture()
                
                # Forward
                self.optimizer.zero_grad()
                out = self.supernet(x, arch)
                loss = F.cross_entropy(out, y)
                loss.backward()
                self.optimizer.step()
            
            # Evaluate random architecture
            val_acc = self.evaluate_architecture(self.supernet.sample_architecture())
            print(f"Epoch {epoch+1}: Val accuracy = {val_acc:.2f}%")
    
    def evaluate_architecture(self, architecture: List[int]) -> float:
        """Evaluate a specific architecture on validation set."""
        
        self.supernet.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for x, y in self.val_loader:
                x, y = x.cuda(), y.cuda()
                out = self.supernet(x, architecture)
                pred = out.argmax(dim=1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        
        return 100 * correct / total
    
    def search(self, num_samples: int = 1000) -> Tuple[List[int], float]:
        """Random search over trained supernet."""
        
        best_arch = None
        best_acc = 0
        
        for _ in range(num_samples):
            arch = self.supernet.sample_architecture()
            acc = self.evaluate_architecture(arch)
            
            if acc > best_acc:
                best_arch = arch
                best_acc = acc
        
        return best_arch, best_acc

Hardware-Aware NAS

Optimize for real hardware constraints.
class HardwareAwareNAS:
    """NAS with hardware constraints."""
    
    def __init__(
        self,
        search_space: Dict,
        target_latency_ms: float = 10.0,
        target_flops: float = 300e6
    ):
        self.search_space = search_space
        self.target_latency = target_latency_ms
        self.target_flops = target_flops
        
        # Latency lookup table (measured on target hardware)
        self.latency_lut = self._build_latency_lut()
    
    def _build_latency_lut(self) -> Dict:
        """Build latency lookup table for operations."""
        lut = {}
        for op_name in OPERATIONS.keys():
            # Measure latency for each operation
            lut[op_name] = self._measure_latency(op_name)
        return lut
    
    def _measure_latency(self, op_name: str, input_size: Tuple = (1, 64, 56, 56)) -> float:
        """Measure operation latency on target device."""
        import time
        
        op = OPERATIONS[op_name](64, 1)
        x = torch.randn(*input_size)
        
        # Warmup
        for _ in range(10):
            _ = op(x)
        
        # Measure
        times = []
        for _ in range(100):
            start = time.time()
            _ = op(x)
            times.append((time.time() - start) * 1000)  # ms
        
        return np.median(times)
    
    def estimate_latency(self, architecture: Dict) -> float:
        """Estimate total latency of architecture."""
        total = 0
        for key, op_name in architecture.items():
            total += self.latency_lut[op_name]
        return total
    
    def estimate_flops(self, architecture: Dict, input_size: Tuple = (224, 224)) -> float:
        """Estimate FLOPs of architecture."""
        # Use analytical formulas or libraries like fvcore
        pass
    
    def multi_objective_fitness(
        self,
        architecture: Dict,
        accuracy: float,
        alpha: float = 0.1,
        beta: float = 0.1
    ) -> float:
        """
        Multi-objective fitness: accuracy with latency/FLOPs penalty.
        """
        latency = self.estimate_latency(architecture)
        flops = self.estimate_flops(architecture) or 0
        
        # Soft penalty
        latency_penalty = max(0, latency - self.target_latency) / self.target_latency
        flops_penalty = max(0, flops - self.target_flops) / self.target_flops
        
        fitness = accuracy - alpha * latency_penalty - beta * flops_penalty
        
        return fitness


class ProxylessNAS:
    """
    ProxylessNAS: directly search on target task and hardware.
    Uses path-level binarization to reduce memory.
    """
    
    def __init__(self, supernet: SuperNet, latency_lut: Dict, target_latency: float):
        self.supernet = supernet
        self.latency_lut = latency_lut
        self.target_latency = target_latency
        
        # Architecture parameters
        self.arch_params = nn.ParameterList([
            nn.Parameter(torch.zeros(len(OPERATIONS)))
            for _ in range(supernet.num_layers)
        ])
    
    def sample_path(self) -> Tuple[List[int], torch.Tensor]:
        """Sample a single path (binary architecture)."""
        path = []
        log_probs = []
        
        for params in self.arch_params:
            probs = F.softmax(params, dim=0)
            dist = torch.distributions.Categorical(probs)
            idx = dist.sample()
            path.append(idx.item())
            log_probs.append(dist.log_prob(idx))
        
        return path, torch.stack(log_probs).sum()
    
    def latency_loss(self, path: List[int]) -> torch.Tensor:
        """Compute latency loss for path."""
        latency = sum(
            self.latency_lut[list(OPERATIONS.keys())[idx]]
            for idx in path
        )
        return F.relu(torch.tensor(latency - self.target_latency))

Practical NAS: Once-for-All

class OnceForAllNetwork(nn.Module):
    """
    Once-for-All: Train once, deploy everywhere.
    Supports dynamic depth, width, and resolution.
    """
    
    def __init__(
        self,
        base_channels: int = 64,
        max_layers: int = 12,
        width_mult_list: List[float] = [0.5, 0.75, 1.0],
        depth_list: List[int] = [2, 3, 4]
    ):
        super().__init__()
        
        self.max_layers = max_layers
        self.width_mult_list = width_mult_list
        self.depth_list = depth_list
        
        # Build maximum network
        self.layers = nn.ModuleList()
        C = base_channels
        
        for i in range(max_layers):
            # Each layer supports multiple widths
            layer = ElasticConvBlock(
                C, C,
                width_mult_list=width_mult_list
            )
            self.layers.append(layer)
        
        # Classifier supports multiple widths
        max_width = int(C * max(width_mult_list))
        self.classifier = nn.Linear(max_width, 1000)
    
    def forward(
        self,
        x: torch.Tensor,
        depth: int,
        width_mult: float
    ) -> torch.Tensor:
        """
        Forward with specific subnet configuration.
        """
        # Use only first `depth` layers
        for i in range(depth):
            x = self.layers[i](x, width_mult)
        
        x = F.adaptive_avg_pool2d(x, 1)
        x = x.view(x.size(0), -1)
        
        # Slice classifier weights for width
        out_features = int(self.classifier.in_features * width_mult)
        weight = self.classifier.weight[:, :out_features]
        x = F.linear(x[:, :out_features], weight, self.classifier.bias)
        
        return x
    
    def sample_subnet(self) -> Dict:
        """Sample a random subnet configuration."""
        return {
            'depth': random.choice(self.depth_list),
            'width_mult': random.choice(self.width_mult_list)
        }


class ElasticConvBlock(nn.Module):
    """Convolution block that supports elastic width."""
    
    def __init__(self, C_in: int, C_out: int, width_mult_list: List[float]):
        super().__init__()
        
        max_C_out = int(C_out * max(width_mult_list))
        
        self.conv = nn.Conv2d(C_in, max_C_out, 3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(max_C_out)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x: torch.Tensor, width_mult: float) -> torch.Tensor:
        out = self.conv(x)
        
        # Slice to active width
        active_C = int(self.conv.out_channels * width_mult)
        out = out[:, :active_C]
        
        out = self.bn(out)
        out = self.relu(out)
        
        return out

Exercises

Implement random search baseline:
def random_search_nas(search_space, evaluator, n_trials=1000):
    best_arch = None
    best_score = 0
    
    for _ in range(n_trials):
        arch = sample_random(search_space)
        score = evaluator(arch)
        if score > best_score:
            best_arch, best_score = arch, score
    
    return best_arch
Train a performance predictor to speed up search:
class PerformancePredictor(nn.Module):
    def __init__(self, encoding_dim):
        # Encode architecture as vector
        # Predict accuracy from encoding
        pass
Implement Pareto-optimal architecture search:
def pareto_optimal_search(population, objectives):
    # objectives = ['accuracy', 'latency', 'params']
    # Find Pareto frontier
    # Return non-dominated architectures

What’s Next?