Neural Architecture Search: Automating Design
Why Automate Architecture Design?
Designing neural networks is:- Time-consuming: Weeks of human experimentation
- Expertise-dependent: Requires deep intuition
- Suboptimal: Humans explore limited search space
- Explore vast architecture spaces systematically
- Find novel, non-intuitive designs
- Optimize for specific hardware constraints
- Discover architectures that outperform human designs
Copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from typing import List, Dict, Tuple, Optional
from collections import namedtuple
torch.manual_seed(42)
NAS Components
1. Search Space
Define what architectures are possible:Copy
# Operation choices for each edge
OPERATIONS = {
'none': lambda C, stride: Zero(stride),
'skip_connect': lambda C, stride: Identity() if stride == 1 else FactorizedReduce(C, C),
'sep_conv_3x3': lambda C, stride: SepConv(C, C, 3, stride, 1),
'sep_conv_5x5': lambda C, stride: SepConv(C, C, 5, stride, 2),
'dil_conv_3x3': lambda C, stride: DilConv(C, C, 3, stride, 2, 2),
'dil_conv_5x5': lambda C, stride: DilConv(C, C, 5, stride, 4, 2),
'avg_pool_3x3': lambda C, stride: nn.AvgPool2d(3, stride=stride, padding=1),
'max_pool_3x3': lambda C, stride: nn.MaxPool2d(3, stride=stride, padding=1),
}
class SepConv(nn.Module):
"""Separable convolution."""
def __init__(self, C_in, C_out, kernel_size, stride, padding):
super().__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, 1, bias=False),
nn.BatchNorm2d(C_out),
nn.ReLU(inplace=False),
nn.Conv2d(C_out, C_out, kernel_size, 1, padding, groups=C_out, bias=False),
nn.Conv2d(C_out, C_out, 1, bias=False),
nn.BatchNorm2d(C_out)
)
def forward(self, x):
return self.op(x)
class DilConv(nn.Module):
"""Dilated convolution."""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
super().__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, 1, bias=False),
nn.BatchNorm2d(C_out)
)
def forward(self, x):
return self.op(x)
class Identity(nn.Module):
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, stride):
super().__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x * 0
return x[:, :, ::self.stride, ::self.stride] * 0
class FactorizedReduce(nn.Module):
"""Reduce spatial size while maintaining channels."""
def __init__(self, C_in, C_out):
super().__init__()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, bias=False)
self.bn = nn.BatchNorm2d(C_out)
def forward(self, x):
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
return self.bn(out)
2. Search Strategy
How to explore the search space:Copy
class SearchStrategy:
"""Base class for NAS search strategies."""
def __init__(self, search_space: Dict):
self.search_space = search_space
def sample_architecture(self) -> Dict:
"""Sample a random architecture from the search space."""
raise NotImplementedError
def update(self, architecture: Dict, performance: float):
"""Update search strategy based on evaluation."""
raise NotImplementedError
3. Performance Estimation
How to evaluate candidate architectures efficiently:Copy
class PerformanceEstimator:
"""Base class for performance estimation strategies."""
def estimate(self, architecture: Dict, dataset) -> float:
"""Estimate performance of an architecture."""
raise NotImplementedError
class FullTrainingEstimator(PerformanceEstimator):
"""Train model fully - slow but accurate."""
def estimate(self, architecture, dataset, epochs=100):
model = build_model(architecture)
train(model, dataset, epochs)
return evaluate(model, dataset)
class EarlyStoppingEstimator(PerformanceEstimator):
"""Train for few epochs - faster but less accurate."""
def estimate(self, architecture, dataset, epochs=10):
model = build_model(architecture)
train(model, dataset, epochs)
return evaluate(model, dataset)
DARTS: Differentiable Architecture Search
Make architecture search differentiable by using continuous relaxation.The DARTS Cell
Copy
class MixedOp(nn.Module):
"""Mixed operation with learnable architecture weights."""
def __init__(self, C: int, stride: int):
super().__init__()
self._ops = nn.ModuleList()
for name, op_fn in OPERATIONS.items():
op = op_fn(C, stride)
self._ops.append(op)
def forward(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""
Weighted sum of all operations.
weights: softmax of architecture parameters
"""
return sum(w * op(x) for w, op in zip(weights, self._ops))
class DARTSCell(nn.Module):
"""
DARTS cell with learnable architecture.
A cell is a DAG of nodes where:
- Node 0, 1: input nodes (from previous cells)
- Node 2, 3, ...: intermediate nodes
- Output: concatenation of all intermediate nodes
"""
def __init__(self, num_nodes: int, C: int, reduction: bool):
super().__init__()
self.num_nodes = num_nodes
self.reduction = reduction
self._ops = nn.ModuleList()
# Create mixed ops for each edge
for i in range(num_nodes):
for j in range(i + 2): # Connect to all previous nodes
stride = 2 if reduction and j < 2 else 1
self._ops.append(MixedOp(C, stride))
# Preprocessing for input nodes
self.preprocess0 = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C, C, 1, bias=False),
nn.BatchNorm2d(C)
)
self.preprocess1 = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C, C, 1, bias=False),
nn.BatchNorm2d(C)
)
def forward(
self,
s0: torch.Tensor,
s1: torch.Tensor,
weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
s0: Output from cell k-2
s1: Output from cell k-1
weights: Architecture weights [num_edges, num_ops]
"""
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
offset = 0
for i in range(self.num_nodes):
# Sum weighted outputs from all previous nodes
s = sum(
self._ops[offset + j](states[j], weights[offset + j])
for j in range(i + 2)
)
offset += i + 2
states.append(s)
# Concatenate all intermediate nodes
return torch.cat(states[2:], dim=1)
class DARTSNetwork(nn.Module):
"""Complete DARTS searchable network."""
def __init__(
self,
C: int = 16,
num_classes: int = 10,
num_layers: int = 8,
num_nodes: int = 4,
num_ops: int = 8
):
super().__init__()
self.num_layers = num_layers
self.num_nodes = num_nodes
# Initial stem
self.stem = nn.Sequential(
nn.Conv2d(3, C, 3, padding=1, bias=False),
nn.BatchNorm2d(C)
)
# Build cells
self.cells = nn.ModuleList()
C_curr = C
reduction_layers = [num_layers // 3, 2 * num_layers // 3]
for i in range(num_layers):
reduction = i in reduction_layers
cell = DARTSCell(num_nodes, C_curr, reduction)
self.cells.append(cell)
if reduction:
C_curr *= 2
# Classifier
self.classifier = nn.Linear(C_curr * num_nodes, num_classes)
# Architecture parameters
num_edges = sum(i + 2 for i in range(num_nodes))
self.arch_params_normal = nn.Parameter(torch.randn(num_edges, num_ops) * 1e-3)
self.arch_params_reduce = nn.Parameter(torch.randn(num_edges, num_ops) * 1e-3)
def forward(self, x: torch.Tensor) -> torch.Tensor:
s0 = s1 = self.stem(x)
for cell in self.cells:
if cell.reduction:
weights = F.softmax(self.arch_params_reduce, dim=-1)
else:
weights = F.softmax(self.arch_params_normal, dim=-1)
s0, s1 = s1, cell(s0, s1, weights)
out = F.adaptive_avg_pool2d(s1, 1)
out = out.view(out.size(0), -1)
return self.classifier(out)
def arch_parameters(self):
return [self.arch_params_normal, self.arch_params_reduce]
def model_parameters(self):
return [p for n, p in self.named_parameters() if 'arch_params' not in n]
class DARTSTrainer:
"""Bi-level optimization for DARTS."""
def __init__(
self,
model: DARTSNetwork,
train_loader,
val_loader,
lr_model: float = 0.025,
lr_arch: float = 3e-4
):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
# Two optimizers
self.optimizer_model = torch.optim.SGD(
model.model_parameters(),
lr=lr_model,
momentum=0.9,
weight_decay=3e-4
)
self.optimizer_arch = torch.optim.Adam(
model.arch_parameters(),
lr=lr_arch,
weight_decay=1e-3
)
def train_epoch(self):
"""One epoch of bi-level optimization."""
self.model.train()
train_iter = iter(self.train_loader)
val_iter = iter(self.val_loader)
for step in range(len(self.train_loader)):
# Get training batch
try:
train_x, train_y = next(train_iter)
except StopIteration:
train_iter = iter(self.train_loader)
train_x, train_y = next(train_iter)
# Get validation batch
try:
val_x, val_y = next(val_iter)
except StopIteration:
val_iter = iter(self.val_loader)
val_x, val_y = next(val_iter)
train_x, train_y = train_x.cuda(), train_y.cuda()
val_x, val_y = val_x.cuda(), val_y.cuda()
# Update architecture parameters on validation data
self.optimizer_arch.zero_grad()
val_loss = F.cross_entropy(self.model(val_x), val_y)
val_loss.backward()
self.optimizer_arch.step()
# Update model parameters on training data
self.optimizer_model.zero_grad()
train_loss = F.cross_entropy(self.model(train_x), train_y)
train_loss.backward()
self.optimizer_model.step()
def derive_architecture(self) -> Dict:
"""Extract discrete architecture from continuous weights."""
def parse_cell(weights):
"""Select top-2 operations for each node."""
gene = []
offset = 0
for i in range(self.model.num_nodes):
# Get weights for edges to this node
node_weights = weights[offset:offset + i + 2]
# Select top-2 edges
edge_scores = node_weights.max(dim=-1)[0]
top2 = edge_scores.topk(2)[1]
for j in top2:
op_idx = node_weights[j].argmax().item()
op_name = list(OPERATIONS.keys())[op_idx]
gene.append((op_name, j.item()))
offset += i + 2
return gene
with torch.no_grad():
normal = parse_cell(F.softmax(self.model.arch_params_normal, dim=-1))
reduce = parse_cell(F.softmax(self.model.arch_params_reduce, dim=-1))
return {'normal': normal, 'reduce': reduce}
Evolutionary Search
Use evolutionary algorithms to search for architectures.Copy
@dataclass
class Individual:
"""An individual in the population (an architecture)."""
architecture: Dict
fitness: float = 0.0
age: int = 0
class EvolutionaryNAS:
"""Regularized evolution for architecture search."""
def __init__(
self,
search_space: Dict,
population_size: int = 100,
tournament_size: int = 10,
mutation_prob: float = 0.1
):
self.search_space = search_space
self.population_size = population_size
self.tournament_size = tournament_size
self.mutation_prob = mutation_prob
self.population: List[Individual] = []
self.history: List[Individual] = []
def initialize_population(self):
"""Create initial random population."""
for _ in range(self.population_size):
arch = self.sample_random_architecture()
self.population.append(Individual(architecture=arch))
def sample_random_architecture(self) -> Dict:
"""Sample a random architecture from search space."""
arch = {}
for key, choices in self.search_space.items():
arch[key] = random.choice(choices)
return arch
def mutate(self, parent: Dict) -> Dict:
"""Mutate an architecture."""
child = parent.copy()
# Randomly select one gene to mutate
key = random.choice(list(self.search_space.keys()))
choices = self.search_space[key]
child[key] = random.choice(choices)
return child
def tournament_select(self) -> Individual:
"""Select parent via tournament selection."""
candidates = random.sample(self.population, self.tournament_size)
return max(candidates, key=lambda x: x.fitness)
def evolve(self, num_generations: int, evaluator: PerformanceEstimator, dataset):
"""Run evolutionary search."""
# Initialize
self.initialize_population()
# Evaluate initial population
for ind in self.population:
ind.fitness = evaluator.estimate(ind.architecture, dataset)
best_fitness = max(ind.fitness for ind in self.population)
print(f"Generation 0: Best fitness = {best_fitness:.4f}")
for gen in range(num_generations):
# Select parent
parent = self.tournament_select()
# Mutate to create child
child_arch = self.mutate(parent.architecture)
# Evaluate child
child_fitness = evaluator.estimate(child_arch, dataset)
child = Individual(architecture=child_arch, fitness=child_fitness)
# Add to population
self.population.append(child)
# Remove oldest individual (regularized evolution)
self.population.sort(key=lambda x: x.age)
oldest = self.population.pop(0)
# Age all individuals
for ind in self.population:
ind.age += 1
# Track history
self.history.append(child)
# Report
if (gen + 1) % 50 == 0:
best_fitness = max(ind.fitness for ind in self.population)
print(f"Generation {gen+1}: Best fitness = {best_fitness:.4f}")
# Return best
return max(self.population, key=lambda x: x.fitness)
# Example search space
CELL_SEARCH_SPACE = {
'op_0_0': list(OPERATIONS.keys()),
'op_0_1': list(OPERATIONS.keys()),
'op_1_0': list(OPERATIONS.keys()),
'op_1_1': list(OPERATIONS.keys()),
'op_1_2': list(OPERATIONS.keys()),
'op_1_3': list(OPERATIONS.keys()),
# ... more edges
}
Weight Sharing / One-Shot NAS
Train a supernet once, then search by selecting subnetworks.Copy
class SuperNet(nn.Module):
"""
Supernet with all possible operations.
Different architectures share weights.
"""
def __init__(self, C: int = 16, num_classes: int = 10, num_layers: int = 8):
super().__init__()
self.num_layers = num_layers
# Stem
self.stem = nn.Sequential(
nn.Conv2d(3, C, 3, padding=1, bias=False),
nn.BatchNorm2d(C)
)
# Choice blocks
self.layers = nn.ModuleList()
for i in range(num_layers):
stride = 2 if i in [num_layers // 3, 2 * num_layers // 3] else 1
layer = ChoiceBlock(C, C if stride == 1 else C * 2, stride)
self.layers.append(layer)
if stride == 2:
C *= 2
# Classifier
self.classifier = nn.Linear(C, num_classes)
def forward(self, x: torch.Tensor, architecture: List[int]) -> torch.Tensor:
"""
Forward with specific architecture.
architecture: List of operation indices, one per layer.
"""
x = self.stem(x)
for layer, op_idx in zip(self.layers, architecture):
x = layer(x, op_idx)
x = F.adaptive_avg_pool2d(x, 1)
x = x.view(x.size(0), -1)
return self.classifier(x)
def sample_architecture(self) -> List[int]:
"""Sample random architecture."""
return [random.randint(0, len(OPERATIONS) - 1) for _ in range(self.num_layers)]
class ChoiceBlock(nn.Module):
"""Block with multiple operation choices."""
def __init__(self, C_in: int, C_out: int, stride: int):
super().__init__()
self.ops = nn.ModuleList([
op_fn(C_in, stride) for name, op_fn in OPERATIONS.items()
])
# Project if needed
self.proj = None
if C_in != C_out or stride != 1:
self.proj = nn.Sequential(
nn.Conv2d(C_in, C_out, 1, stride=stride, bias=False),
nn.BatchNorm2d(C_out)
)
def forward(self, x: torch.Tensor, op_idx: int) -> torch.Tensor:
out = self.ops[op_idx](x)
if self.proj is not None:
out = self.proj(out)
return out
class OneShotNAS:
"""One-shot NAS with weight sharing."""
def __init__(
self,
supernet: SuperNet,
train_loader,
val_loader,
lr: float = 0.025
):
self.supernet = supernet
self.train_loader = train_loader
self.val_loader = val_loader
self.optimizer = torch.optim.SGD(
supernet.parameters(),
lr=lr,
momentum=0.9,
weight_decay=3e-4
)
def train_supernet(self, epochs: int = 50):
"""Train supernet with random architecture sampling."""
for epoch in range(epochs):
self.supernet.train()
for x, y in self.train_loader:
x, y = x.cuda(), y.cuda()
# Sample random architecture
arch = self.supernet.sample_architecture()
# Forward
self.optimizer.zero_grad()
out = self.supernet(x, arch)
loss = F.cross_entropy(out, y)
loss.backward()
self.optimizer.step()
# Evaluate random architecture
val_acc = self.evaluate_architecture(self.supernet.sample_architecture())
print(f"Epoch {epoch+1}: Val accuracy = {val_acc:.2f}%")
def evaluate_architecture(self, architecture: List[int]) -> float:
"""Evaluate a specific architecture on validation set."""
self.supernet.eval()
correct = 0
total = 0
with torch.no_grad():
for x, y in self.val_loader:
x, y = x.cuda(), y.cuda()
out = self.supernet(x, architecture)
pred = out.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.size(0)
return 100 * correct / total
def search(self, num_samples: int = 1000) -> Tuple[List[int], float]:
"""Random search over trained supernet."""
best_arch = None
best_acc = 0
for _ in range(num_samples):
arch = self.supernet.sample_architecture()
acc = self.evaluate_architecture(arch)
if acc > best_acc:
best_arch = arch
best_acc = acc
return best_arch, best_acc
Hardware-Aware NAS
Optimize for real hardware constraints.Copy
class HardwareAwareNAS:
"""NAS with hardware constraints."""
def __init__(
self,
search_space: Dict,
target_latency_ms: float = 10.0,
target_flops: float = 300e6
):
self.search_space = search_space
self.target_latency = target_latency_ms
self.target_flops = target_flops
# Latency lookup table (measured on target hardware)
self.latency_lut = self._build_latency_lut()
def _build_latency_lut(self) -> Dict:
"""Build latency lookup table for operations."""
lut = {}
for op_name in OPERATIONS.keys():
# Measure latency for each operation
lut[op_name] = self._measure_latency(op_name)
return lut
def _measure_latency(self, op_name: str, input_size: Tuple = (1, 64, 56, 56)) -> float:
"""Measure operation latency on target device."""
import time
op = OPERATIONS[op_name](64, 1)
x = torch.randn(*input_size)
# Warmup
for _ in range(10):
_ = op(x)
# Measure
times = []
for _ in range(100):
start = time.time()
_ = op(x)
times.append((time.time() - start) * 1000) # ms
return np.median(times)
def estimate_latency(self, architecture: Dict) -> float:
"""Estimate total latency of architecture."""
total = 0
for key, op_name in architecture.items():
total += self.latency_lut[op_name]
return total
def estimate_flops(self, architecture: Dict, input_size: Tuple = (224, 224)) -> float:
"""Estimate FLOPs of architecture."""
# Use analytical formulas or libraries like fvcore
pass
def multi_objective_fitness(
self,
architecture: Dict,
accuracy: float,
alpha: float = 0.1,
beta: float = 0.1
) -> float:
"""
Multi-objective fitness: accuracy with latency/FLOPs penalty.
"""
latency = self.estimate_latency(architecture)
flops = self.estimate_flops(architecture) or 0
# Soft penalty
latency_penalty = max(0, latency - self.target_latency) / self.target_latency
flops_penalty = max(0, flops - self.target_flops) / self.target_flops
fitness = accuracy - alpha * latency_penalty - beta * flops_penalty
return fitness
class ProxylessNAS:
"""
ProxylessNAS: directly search on target task and hardware.
Uses path-level binarization to reduce memory.
"""
def __init__(self, supernet: SuperNet, latency_lut: Dict, target_latency: float):
self.supernet = supernet
self.latency_lut = latency_lut
self.target_latency = target_latency
# Architecture parameters
self.arch_params = nn.ParameterList([
nn.Parameter(torch.zeros(len(OPERATIONS)))
for _ in range(supernet.num_layers)
])
def sample_path(self) -> Tuple[List[int], torch.Tensor]:
"""Sample a single path (binary architecture)."""
path = []
log_probs = []
for params in self.arch_params:
probs = F.softmax(params, dim=0)
dist = torch.distributions.Categorical(probs)
idx = dist.sample()
path.append(idx.item())
log_probs.append(dist.log_prob(idx))
return path, torch.stack(log_probs).sum()
def latency_loss(self, path: List[int]) -> torch.Tensor:
"""Compute latency loss for path."""
latency = sum(
self.latency_lut[list(OPERATIONS.keys())[idx]]
for idx in path
)
return F.relu(torch.tensor(latency - self.target_latency))
Practical NAS: Once-for-All
Copy
class OnceForAllNetwork(nn.Module):
"""
Once-for-All: Train once, deploy everywhere.
Supports dynamic depth, width, and resolution.
"""
def __init__(
self,
base_channels: int = 64,
max_layers: int = 12,
width_mult_list: List[float] = [0.5, 0.75, 1.0],
depth_list: List[int] = [2, 3, 4]
):
super().__init__()
self.max_layers = max_layers
self.width_mult_list = width_mult_list
self.depth_list = depth_list
# Build maximum network
self.layers = nn.ModuleList()
C = base_channels
for i in range(max_layers):
# Each layer supports multiple widths
layer = ElasticConvBlock(
C, C,
width_mult_list=width_mult_list
)
self.layers.append(layer)
# Classifier supports multiple widths
max_width = int(C * max(width_mult_list))
self.classifier = nn.Linear(max_width, 1000)
def forward(
self,
x: torch.Tensor,
depth: int,
width_mult: float
) -> torch.Tensor:
"""
Forward with specific subnet configuration.
"""
# Use only first `depth` layers
for i in range(depth):
x = self.layers[i](x, width_mult)
x = F.adaptive_avg_pool2d(x, 1)
x = x.view(x.size(0), -1)
# Slice classifier weights for width
out_features = int(self.classifier.in_features * width_mult)
weight = self.classifier.weight[:, :out_features]
x = F.linear(x[:, :out_features], weight, self.classifier.bias)
return x
def sample_subnet(self) -> Dict:
"""Sample a random subnet configuration."""
return {
'depth': random.choice(self.depth_list),
'width_mult': random.choice(self.width_mult_list)
}
class ElasticConvBlock(nn.Module):
"""Convolution block that supports elastic width."""
def __init__(self, C_in: int, C_out: int, width_mult_list: List[float]):
super().__init__()
max_C_out = int(C_out * max(width_mult_list))
self.conv = nn.Conv2d(C_in, max_C_out, 3, padding=1, bias=False)
self.bn = nn.BatchNorm2d(max_C_out)
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor, width_mult: float) -> torch.Tensor:
out = self.conv(x)
# Slice to active width
active_C = int(self.conv.out_channels * width_mult)
out = out[:, :active_C]
out = self.bn(out)
out = self.relu(out)
return out
Exercises
Exercise 1: Implement Random Search NAS
Exercise 1: Implement Random Search NAS
Implement random search baseline:
Copy
def random_search_nas(search_space, evaluator, n_trials=1000):
best_arch = None
best_score = 0
for _ in range(n_trials):
arch = sample_random(search_space)
score = evaluator(arch)
if score > best_score:
best_arch, best_score = arch, score
return best_arch
Exercise 2: Add Predictor-Based NAS
Exercise 2: Add Predictor-Based NAS
Train a performance predictor to speed up search:
Copy
class PerformancePredictor(nn.Module):
def __init__(self, encoding_dim):
# Encode architecture as vector
# Predict accuracy from encoding
pass
Exercise 3: Multi-Objective NAS
Exercise 3: Multi-Objective NAS
Implement Pareto-optimal architecture search:
Copy
def pareto_optimal_search(population, objectives):
# objectives = ['accuracy', 'latency', 'params']
# Find Pareto frontier
# Return non-dominated architectures