Skip to main content

Documentation Index

Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt

Use this file to discover all available pages before exploring further.

Graph Neural Networks

Graph Neural Networks

Why Graphs?

Most neural networks assume data lives on a grid: images are 2D grids of pixels, text is a 1D sequence of tokens, audio is a 1D sequence of samples. But much of the real world does not fit neatly on a grid. A social network is not a grid. A molecule is not a sequence. A road network is not a rectangle. This is where graphs come in. A graph is simply a collection of nodes (things) connected by edges (relationships). This deceptively simple structure can represent an enormous range of real-world data:
  • Social networks: Users (nodes) and friendships (edges)
  • Molecules: Atoms (nodes) and chemical bonds (edges)
  • Knowledge graphs: Entities (nodes) and relations (edges)
  • Citation networks: Papers (nodes) and references (edges)
  • Traffic: Intersections (nodes) and roads (edges)
Graphs capture relational structure that traditional neural networks miss entirely. A standard MLP fed a molecule’s atoms as a flat vector has no idea which atoms are bonded to which. A GNN knows.
If you are new to GNNs, start with PyTorch Geometric (PyG). It provides optimized implementations of all the architectures discussed below, handles batching of variable-size graphs, and includes standard benchmark datasets. Building GNNs from scratch (as we do here for pedagogical purposes) is valuable for understanding, but use PyG for production work.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional, List, Dict

torch.manual_seed(42)

Graph Fundamentals

Graph Representation

class GraphData:
    """
    Graph data structure.
    
    A graph G = (V, E) with:
    - V: Set of nodes (vertices)
    - E: Set of edges
    
    Representations:
    - Adjacency matrix A: A[i,j] = 1 if edge (i,j) exists
    - Edge index: [2, E] tensor of (source, target) pairs
    """
    
    def __init__(
        self,
        x: torch.Tensor,           # Node features [N, F]
        edge_index: torch.Tensor,   # Edge indices [2, E]
        edge_attr: Optional[torch.Tensor] = None,  # Edge features [E, D]
        y: Optional[torch.Tensor] = None           # Labels
    ):
        self.x = x
        self.edge_index = edge_index
        self.edge_attr = edge_attr
        self.y = y
        
        self.num_nodes = x.size(0)
        self.num_edges = edge_index.size(1)
        self.num_features = x.size(1)
    
    def to_adjacency(self) -> torch.Tensor:
        """Convert edge index to adjacency matrix."""
        adj = torch.zeros(self.num_nodes, self.num_nodes)
        adj[self.edge_index[0], self.edge_index[1]] = 1
        return adj
    
    @staticmethod
    def from_adjacency(adj: torch.Tensor, x: torch.Tensor) -> 'GraphData':
        """Create from adjacency matrix."""
        edge_index = adj.nonzero().t().contiguous()
        return GraphData(x, edge_index)
    
    def add_self_loops(self) -> 'GraphData':
        """Add self-loop edges."""
        self_loops = torch.arange(self.num_nodes).unsqueeze(0).repeat(2, 1)
        edge_index = torch.cat([self.edge_index, self_loops], dim=1)
        return GraphData(self.x, edge_index, self.edge_attr, self.y)


# Example: Create a simple graph
def create_example_graph():
    # 5 nodes with 3 features each
    x = torch.randn(5, 3)
    
    # Edges: 0->1, 0->2, 1->2, 2->3, 3->4
    edge_index = torch.tensor([
        [0, 0, 1, 2, 3],
        [1, 2, 2, 3, 4]
    ])
    
    # Make undirected (add reverse edges)
    edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
    
    return GraphData(x, edge_index)

graph = create_example_graph()
print(f"Nodes: {graph.num_nodes}, Edges: {graph.num_edges}")

Message Passing Framework

The foundation of all GNNs. The core idea is beautifully simple: each node updates its representation by collecting information from its neighbors. Think of it like a game of telephone, but structured — at each round, every person asks their friends “what do you know?”, gathers the answers, and updates their own understanding. After a few rounds, each person has absorbed information from an increasingly wide neighborhood of the network: hv(k+1)=UPDATE(hv(k),AGGREGATE({hu(k):uN(v)}))h_v^{(k+1)} = \text{UPDATE}\left(h_v^{(k)}, \text{AGGREGATE}\left(\{h_u^{(k)} : u \in \mathcal{N}(v)\}\right)\right)
class MessagePassingLayer(nn.Module):
    """
    General message passing framework.
    
    Steps:
    1. MESSAGE: Compute messages from neighbors
       (Each neighbor packages up its current knowledge into a "message")
    2. AGGREGATE: Combine messages (sum, mean, max)
       (The node collects all incoming messages and reduces them to a fixed
        size -- this must be permutation-invariant since neighbors have no
        inherent ordering)
    3. UPDATE: Update node representations
       (Combine the aggregated neighbor info with the node's own features)
    """
    
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
    
    def message(
        self,
        x_i: torch.Tensor,  # Target node features
        x_j: torch.Tensor,  # Source node features
        edge_attr: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Compute messages from source to target."""
        return x_j
    
    def aggregate(
        self,
        messages: torch.Tensor,
        index: torch.Tensor,
        num_nodes: int
    ) -> torch.Tensor:
        """Aggregate messages at each node."""
        # Scatter-add messages to target nodes
        out = torch.zeros(num_nodes, messages.size(-1), device=messages.device)
        out.scatter_add_(0, index.unsqueeze(-1).expand_as(messages), messages)
        return out
    
    def update(
        self,
        aggregated: torch.Tensor,
        x: torch.Tensor
    ) -> torch.Tensor:
        """Update node embeddings."""
        return aggregated
    
    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            x: Node features [N, F]
            edge_index: Edge indices [2, E]
        
        Returns:
            Updated node features [N, F']
        """
        source, target = edge_index
        
        # Get source and target features
        x_i = x[target]  # Target nodes
        x_j = x[source]  # Source nodes
        
        # Compute messages
        messages = self.message(x_i, x_j, edge_attr)
        
        # Aggregate messages
        aggregated = self.aggregate(messages, target, x.size(0))
        
        # Update
        out = self.update(aggregated, x)
        
        return out

Graph Convolutional Network (GCN)

GCN is the “ResNet of graph learning” — a simple, strong baseline that most practitioners reach for first. The idea: each node’s new representation is a weighted average of its neighbors’ features (including itself), passed through a linear transform and activation. The weighting is based on node degree, which prevents high-degree nodes from dominating.
class GCNConv(nn.Module):
    """
    Graph Convolutional Network layer (Kipf & Welling, 2017).
    
    Equation:
    H^{(l+1)} = sigma(D_tilde^{-1/2} A_tilde D_tilde^{-1/2} H^{(l)} W^{(l)})
    
    Where:
    - A_tilde = A + I (adjacency with self-loops, so each node includes itself)
    - D_tilde = diagonal degree matrix of A_tilde
    - The D^{-1/2} A D^{-1/2} term is symmetric normalization: it prevents
      nodes with many connections from having disproportionately large features
    
    Practical tip: GCN works best with 2-3 layers. Going deeper causes
    "over-smoothing" where all node representations converge to the same value.
    """
    
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True
    ):
        super().__init__()
        
        self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    
    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            x: Node features [N, F_in]
            edge_index: Edge indices [2, E]
        
        Returns:
            Updated features [N, F_out]
        """
        N = x.size(0)
        source, target = edge_index
        
        # Add self-loops
        self_loops = torch.arange(N, device=x.device)
        source = torch.cat([source, self_loops])
        target = torch.cat([target, self_loops])
        
        # Compute degree
        deg = torch.zeros(N, device=x.device)
        deg.scatter_add_(0, source, torch.ones_like(source, dtype=torch.float))
        
        # Symmetric normalization: D^{-1/2} A D^{-1/2}
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        
        norm = deg_inv_sqrt[source] * deg_inv_sqrt[target]
        
        # Transform features
        x = torch.matmul(x, self.weight)
        
        # Message passing
        out = torch.zeros_like(x)
        messages = norm.unsqueeze(-1) * x[source]
        out.scatter_add_(0, target.unsqueeze(-1).expand_as(messages), messages)
        
        if self.bias is not None:
            out = out + self.bias
        
        return out


class GCN(nn.Module):
    """Multi-layer Graph Convolutional Network."""
    
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        out_features: int,
        num_layers: int = 2,
        dropout: float = 0.5
    ):
        super().__init__()
        
        self.convs = nn.ModuleList()
        
        # First layer
        self.convs.append(GCNConv(in_features, hidden_features))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_features, hidden_features))
        
        # Output layer
        self.convs.append(GCNConv(hidden_features, out_features))
        
        self.dropout = dropout
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        x = self.convs[-1](x, edge_index)
        return x

Graph Attention Network (GAT)

GCN treats all neighbors equally (modulo degree normalization). But in practice, some neighbors are more relevant than others — your best friend’s opinion matters more than an acquaintance’s. GAT addresses this by learning attention weights for each edge, so the model can decide how much to listen to each neighbor.
class GATConv(nn.Module):
    """
    Graph Attention Network layer (Velickovic et al., 2018).
    
    Uses attention to weight neighbor contributions:
    alpha_ij = softmax_j(LeakyReLU(a^T [Wh_i || Wh_j]))
    
    Multi-head attention for stability -- same idea as in Transformers.
    Each head learns a different notion of "relevance" and their outputs
    are concatenated (intermediate layers) or averaged (final layer).
    
    Practical tip: GAT is more expensive than GCN but shines when
    neighbor importance genuinely varies (e.g., heterogeneous graphs).
    For homogeneous citation networks, GCN and GAT perform similarly.
    """
    
    def __init__(
        self,
        in_features: int,
        out_features: int,
        heads: int = 8,
        concat: bool = True,
        dropout: float = 0.6,
        negative_slope: float = 0.2
    ):
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.heads = heads
        self.concat = concat
        self.dropout = dropout
        self.negative_slope = negative_slope
        
        # Linear transformation
        self.W = nn.Linear(in_features, heads * out_features, bias=False)
        
        # Attention parameters
        self.att_src = nn.Parameter(torch.Tensor(1, heads, out_features))
        self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_features))
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W.weight)
        nn.init.xavier_uniform_(self.att_src)
        nn.init.xavier_uniform_(self.att_dst)
    
    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            x: Node features [N, F_in]
            edge_index: Edge indices [2, E]
        
        Returns:
            Updated features [N, H*F_out] if concat else [N, F_out]
        """
        N = x.size(0)
        source, target = edge_index
        
        # Linear transformation and reshape for multi-head
        x = self.W(x).view(-1, self.heads, self.out_features)  # [N, H, F]
        
        # Compute attention scores
        # Source attention
        alpha_src = (x * self.att_src).sum(dim=-1)  # [N, H]
        # Target attention
        alpha_dst = (x * self.att_dst).sum(dim=-1)  # [N, H]
        
        # Edge attention
        alpha = alpha_src[source] + alpha_dst[target]  # [E, H]
        alpha = F.leaky_relu(alpha, negative_slope=self.negative_slope)
        
        # Softmax over neighbors
        alpha = self._softmax(alpha, target, N)  # [E, H]
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        
        # Message passing
        out = torch.zeros(N, self.heads, self.out_features, device=x.device)
        messages = alpha.unsqueeze(-1) * x[source]  # [E, H, F]
        
        # Scatter-add
        out.scatter_add_(
            0,
            target.view(-1, 1, 1).expand_as(messages),
            messages
        )
        
        if self.concat:
            out = out.view(N, -1)  # [N, H*F]
        else:
            out = out.mean(dim=1)  # [N, F]
        
        return out
    
    def _softmax(
        self,
        alpha: torch.Tensor,
        target: torch.Tensor,
        num_nodes: int
    ) -> torch.Tensor:
        """Softmax over neighbors."""
        # Max for numerical stability
        alpha_max = torch.zeros(num_nodes, alpha.size(1), device=alpha.device)
        alpha_max.scatter_reduce_(
            0,
            target.unsqueeze(-1).expand_as(alpha),
            alpha,
            reduce='amax'
        )
        alpha = alpha - alpha_max[target]
        
        # Exp and sum
        alpha = alpha.exp()
        alpha_sum = torch.zeros(num_nodes, alpha.size(1), device=alpha.device)
        alpha_sum.scatter_add_(
            0,
            target.unsqueeze(-1).expand_as(alpha),
            alpha
        )
        
        return alpha / (alpha_sum[target] + 1e-8)


class GAT(nn.Module):
    """Multi-layer Graph Attention Network."""
    
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        out_features: int,
        heads: int = 8,
        dropout: float = 0.6
    ):
        super().__init__()
        
        self.conv1 = GATConv(in_features, hidden_features, heads=heads, dropout=dropout)
        self.conv2 = GATConv(
            hidden_features * heads, out_features, 
            heads=1, concat=False, dropout=dropout
        )
        
        self.dropout = dropout
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x

GraphSAGE

GCN and GAT require the full graph in memory during training, which breaks down for graphs with millions or billions of nodes (think: the entire Facebook social graph). GraphSAGE solves this with a simple but powerful idea: instead of using all neighbors, sample a fixed number of neighbors at each layer. This makes mini-batch training possible on arbitrarily large graphs.
class SAGEConv(nn.Module):
    """
    GraphSAGE (Hamilton et al., 2017).
    
    Key innovations:
    1. Sampling fixed-size neighborhoods (makes training scalable to
       billion-node graphs -- you no longer need the full graph in memory)
    2. Multiple aggregation functions (mean, max, LSTM -- each captures
       different aspects of the neighborhood)
    3. Inductive learning (works on unseen nodes and entirely new graphs,
       unlike GCN which is transductive by default)
    
    Practical tip: GraphSAGE is often the best starting point for
    production systems because of its scalability and inductive capability.
    """
    
    def __init__(
        self,
        in_features: int,
        out_features: int,
        aggregator: str = 'mean',  # 'mean', 'max', 'lstm'
        normalize: bool = True
    ):
        super().__init__()
        
        self.aggregator = aggregator
        self.normalize = normalize
        
        # Transform for concatenating self + aggregated neighbors
        self.lin = nn.Linear(2 * in_features, out_features)
        
        if aggregator == 'lstm':
            self.lstm = nn.LSTM(in_features, in_features, batch_first=True)
    
    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor
    ) -> torch.Tensor:
        N = x.size(0)
        source, target = edge_index
        
        # Aggregate neighbor features
        if self.aggregator == 'mean':
            agg = self._mean_aggregator(x, source, target, N)
        elif self.aggregator == 'max':
            agg = self._max_aggregator(x, source, target, N)
        else:
            agg = self._mean_aggregator(x, source, target, N)
        
        # Concatenate self features with aggregated neighbors
        out = torch.cat([x, agg], dim=-1)
        out = self.lin(out)
        
        if self.normalize:
            out = F.normalize(out, p=2, dim=-1)
        
        return out
    
    def _mean_aggregator(self, x, source, target, N):
        """Mean aggregation."""
        out = torch.zeros(N, x.size(-1), device=x.device)
        count = torch.zeros(N, 1, device=x.device)
        
        out.scatter_add_(0, target.unsqueeze(-1).expand(-1, x.size(-1)), x[source])
        count.scatter_add_(0, target.unsqueeze(-1), torch.ones_like(target.unsqueeze(-1), dtype=torch.float))
        
        return out / (count + 1e-8)
    
    def _max_aggregator(self, x, source, target, N):
        """Max aggregation."""
        out = torch.full((N, x.size(-1)), float('-inf'), device=x.device)
        out.scatter_reduce_(0, target.unsqueeze(-1).expand(-1, x.size(-1)), x[source], reduce='amax')
        out[out == float('-inf')] = 0
        return out


class NeighborSampler:
    """
    Sample fixed-size neighborhoods for mini-batch training.
    Enables training on large graphs.
    """
    
    def __init__(self, edge_index: torch.Tensor, sizes: List[int]):
        """
        Args:
            edge_index: Full graph edge index
            sizes: Number of neighbors to sample at each layer
                   e.g., [25, 10] means 25 neighbors for layer 1, 10 for layer 2
        """
        self.edge_index = edge_index
        self.sizes = sizes
        
        # Build adjacency list
        self.adj_list = self._build_adj_list()
    
    def _build_adj_list(self) -> Dict[int, List[int]]:
        """Build adjacency list from edge index."""
        adj = {}
        source, target = self.edge_index.tolist()
        
        for s, t in zip(source, target):
            if t not in adj:
                adj[t] = []
            adj[t].append(s)
        
        return adj
    
    def sample(self, batch_nodes: torch.Tensor) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        """
        Sample subgraph for batch nodes.
        
        Returns list of (sampled_edge_index, batch_nodes) for each layer.
        """
        sampled_layers = []
        nodes = batch_nodes.tolist()
        
        for size in self.sizes:
            # Sample neighbors for each node
            new_nodes = set()
            edges_src, edges_dst = [], []
            
            for node in nodes:
                neighbors = self.adj_list.get(node, [])
                
                if len(neighbors) > size:
                    sampled = np.random.choice(neighbors, size, replace=False)
                else:
                    sampled = neighbors
                
                for neighbor in sampled:
                    new_nodes.add(neighbor)
                    edges_src.append(neighbor)
                    edges_dst.append(node)
            
            edge_index = torch.tensor([edges_src, edges_dst])
            sampled_layers.append((edge_index, torch.tensor(list(new_nodes | set(nodes)))))
            
            nodes = list(new_nodes | set(nodes))
        
        return sampled_layers

Graph Isomorphism Network (GIN)

GIN asks the theoretical question: how powerful can message-passing GNNs actually be? The answer turns out to be: at most as powerful as the Weisfeiler-Lehman (WL) graph isomorphism test, a classical algorithm for checking if two graphs are structurally identical. GIN achieves this theoretical maximum expressiveness. If you need your GNN to distinguish between subtly different graph structures (common in molecular property prediction), GIN is your tool.
class GINConv(nn.Module):
    """
    Graph Isomorphism Network (Xu et al., 2019).
    
    Most expressive among 1-WL equivalent GNNs:
    h_v^{(k+1)} = MLP((1 + epsilon) * h_v^{(k)} + SUM_{u in N(v)} h_u^{(k)})
    
    Why SUM and not MEAN? Mean aggregation loses information about
    neighborhood SIZE. Two nodes with neighborhoods {1, 1} and {1, 1, 1, 1}
    produce the same mean but different sums. GIN uses sum to preserve
    this structural information.
    
    The learnable epsilon controls how much weight the node gives to its
    own features versus its neighbors' -- a subtle but important knob.
    """
    
    def __init__(
        self,
        in_features: int,
        out_features: int,
        eps: float = 0.0,
        train_eps: bool = True
    ):
        super().__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(in_features, out_features),
            nn.BatchNorm1d(out_features),
            nn.ReLU(),
            nn.Linear(out_features, out_features),
            nn.BatchNorm1d(out_features),
            nn.ReLU()
        )
        
        if train_eps:
            self.eps = nn.Parameter(torch.Tensor([eps]))
        else:
            self.register_buffer('eps', torch.Tensor([eps]))
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        N = x.size(0)
        source, target = edge_index
        
        # Aggregate neighbors
        out = torch.zeros(N, x.size(-1), device=x.device)
        out.scatter_add_(0, target.unsqueeze(-1).expand(-1, x.size(-1)), x[source])
        
        # Add self with learnable epsilon
        out = (1 + self.eps) * x + out
        
        # MLP
        out = self.mlp(out)
        
        return out


class GIN(nn.Module):
    """Multi-layer GIN for graph classification."""
    
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        out_features: int,
        num_layers: int = 5
    ):
        super().__init__()
        
        self.convs = nn.ModuleList()
        
        self.convs.append(GINConv(in_features, hidden_features))
        for _ in range(num_layers - 1):
            self.convs.append(GINConv(hidden_features, hidden_features))
        
        # Readout for graph classification
        self.readout = nn.Sequential(
            nn.Linear(hidden_features * num_layers, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, out_features)
        )
    
    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        batch: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            batch: Tensor indicating which graph each node belongs to
        """
        # Collect representations from all layers
        h_list = []
        
        for conv in self.convs:
            x = conv(x, edge_index)
            h_list.append(self._readout(x, batch))
        
        # Concatenate all layer representations
        h = torch.cat(h_list, dim=-1)
        
        return self.readout(h)
    
    def _readout(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
        """Global sum pooling per graph."""
        num_graphs = batch.max().item() + 1
        out = torch.zeros(num_graphs, x.size(-1), device=x.device)
        out.scatter_add_(0, batch.unsqueeze(-1).expand_as(x), x)
        return out

Graph Pooling

Just as CNNs use pooling to reduce spatial resolution and build hierarchical features, GNNs need pooling to reduce the number of nodes and create graph-level representations. The challenge is that graphs have irregular structure — you cannot just “stride 2” across a graph. Graph pooling methods learn which nodes to keep and which to merge, creating a coarser version of the original graph.
Over-smoothing is the silent killer of deep GNNs. After about 5-6 message-passing layers, all node representations converge to nearly the same vector — the GNN equivalent of blurring an image until everything is grey. If your validation accuracy drops when you add more layers, over-smoothing is the likely culprit. Mitigations include skip connections (like JKNet), DropEdge during training, or graph pooling to reduce the effective depth.
class TopKPooling(nn.Module):
    """
    Top-K pooling: Select top-k scoring nodes.
    Used for hierarchical graph representation.
    
    The idea: learn a scoring function, keep the top-scoring nodes,
    and remove the rest. Edges between removed nodes are dropped.
    Think of it as learning which parts of the graph are most informative
    and focusing on those.
    """
    
    def __init__(self, in_features: int, ratio: float = 0.5):
        super().__init__()
        
        self.ratio = ratio
        self.score = nn.Linear(in_features, 1)
    
    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        batch: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Returns:
            x: Pooled node features
            edge_index: Updated edge index
            batch: Updated batch assignment
        """
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        # Compute scores
        scores = self.score(x).squeeze(-1)
        scores = torch.tanh(scores)
        
        # Select top-k nodes per graph
        num_graphs = batch.max().item() + 1
        keep_mask = torch.zeros(x.size(0), dtype=torch.bool, device=x.device)
        
        for g in range(num_graphs):
            graph_mask = batch == g
            graph_scores = scores[graph_mask]
            
            k = max(1, int(graph_mask.sum() * self.ratio))
            _, top_indices = graph_scores.topk(k)
            
            graph_nodes = graph_mask.nonzero().squeeze()
            keep_nodes = graph_nodes[top_indices]
            keep_mask[keep_nodes] = True
        
        # Filter nodes
        x = x[keep_mask] * scores[keep_mask].unsqueeze(-1)
        
        # Re-index edges
        node_map = torch.full((x.size(0),), -1, dtype=torch.long, device=x.device)
        node_map[keep_mask] = torch.arange(keep_mask.sum(), device=x.device)
        
        source, target = edge_index
        edge_mask = keep_mask[source] & keep_mask[target]
        edge_index = torch.stack([
            node_map[source[edge_mask]],
            node_map[target[edge_mask]]
        ])
        
        batch = batch[keep_mask]
        
        return x, edge_index, batch


class SAGPooling(nn.Module):
    """
    Self-Attention Graph Pooling.
    Uses GNN to compute attention scores.
    """
    
    def __init__(self, in_features: int, ratio: float = 0.5):
        super().__init__()
        
        self.ratio = ratio
        self.gnn = GCNConv(in_features, 1)
    
    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        batch: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        # Compute attention scores using GNN
        scores = self.gnn(x, edge_index).squeeze(-1)
        scores = torch.tanh(scores)
        
        # Same top-k selection as above...
        # (implementation similar to TopKPooling)
        
        return x, edge_index, batch


class GlobalAttentionPooling(nn.Module):
    """Global attention pooling for graph-level representation."""
    
    def __init__(self, in_features: int, hidden_features: int):
        super().__init__()
        
        self.gate = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.Tanh(),
            nn.Linear(hidden_features, 1),
            nn.Sigmoid()
        )
        
        self.transform = nn.Linear(in_features, hidden_features)
    
    def forward(
        self,
        x: torch.Tensor,
        batch: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            x: Node features [N, F]
            batch: Batch assignment [N]
        
        Returns:
            Graph representations [B, F']
        """
        # Compute attention
        gate = self.gate(x)  # [N, 1]
        x = self.transform(x) * gate  # [N, F']
        
        # Sum per graph
        num_graphs = batch.max().item() + 1
        out = torch.zeros(num_graphs, x.size(-1), device=x.device)
        out.scatter_add_(0, batch.unsqueeze(-1).expand_as(x), x)
        
        return out

Practical Application: Node Classification

The classic benchmark for GNNs is semi-supervised node classification on the Cora citation network: given a graph of academic papers (nodes) with citation links (edges), classify each paper into one of 7 research topics using only the labels of 20 papers per class (140 total out of 2,708). This is a remarkably data-efficient setting — the graph structure carries so much information that you can classify most unlabeled nodes by labeling just a handful.
def train_node_classification():
    """Example: Semi-supervised node classification on Cora."""
    
    # Create example data (simulated Cora-like)
    # Real Cora: 2708 papers, 1433 unique words (bag-of-words), 7 classes
    num_nodes = 2708
    num_features = 1433
    num_classes = 7
    
    x = torch.randn(num_nodes, num_features)
    
    # Random edges (in practice, use actual citation graph)
    num_edges = 10556
    edge_index = torch.randint(0, num_nodes, (2, num_edges))
    
    # Labels and masks
    y = torch.randint(0, num_classes, (num_nodes,))
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    train_mask[:140] = True  # 20 per class
    
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask[140:640] = True
    
    test_mask = ~(train_mask | val_mask)
    
    # Model
    model = GCN(num_features, 64, num_classes, num_layers=2)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    
    # Training
    for epoch in range(200):
        model.train()
        optimizer.zero_grad()
        
        out = model(x, edge_index)
        loss = F.cross_entropy(out[train_mask], y[train_mask])
        
        loss.backward()
        optimizer.step()
        
        # Validation
        model.eval()
        with torch.no_grad():
            pred = model(x, edge_index).argmax(dim=1)
            val_acc = (pred[val_mask] == y[val_mask]).float().mean()
        
        if epoch % 50 == 0:
            print(f"Epoch {epoch}: Loss = {loss:.4f}, Val Acc = {val_acc:.4f}")
    
    # Test
    model.eval()
    with torch.no_grad():
        pred = model(x, edge_index).argmax(dim=1)
        test_acc = (pred[test_mask] == y[test_mask]).float().mean()
    
    print(f"\nTest Accuracy: {test_acc:.4f}")

# train_node_classification()

GNN Best Practices

def gnn_best_practices():
    """Guidelines for GNN development."""
    
    guidelines = """
    ╔════════════════════════════════════════════════════════════════╗
    ║                    GNN BEST PRACTICES                          ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                ║
    ║  1. CHOOSE THE RIGHT MODEL                                     ║
    ║     • GCN: Simple baseline, fast                               ║
    ║     • GAT: When neighbor importance varies                     ║
    ║     • GraphSAGE: Inductive, large graphs                       ║
    ║     • GIN: Maximum expressiveness needed                       ║
    ║                                                                ║
    ║  2. DEPTH CONSIDERATIONS                                       ║
    ║     • 2-3 layers often optimal (each layer = 1-hop neighbors)  ║
    ║     • Deeper = over-smoothing (all nodes converge to the same  ║
    ║       representation -- like running a heat diffusion too long  ║
    ║       until everything reaches the same temperature)            ║
    ║     • Use skip connections or JK-Net for deeper networks       ║
    ║                                                                ║
    ║  3. NORMALIZATION                                              ║
    ║     • Symmetric normalization (GCN) is common                  ║
    ║     • Mean normalization for GraphSAGE                         ║
    ║     • Layer/batch norm helps training                          ║
    ║                                                                ║
    ║  4. REGULARIZATION                                             ║
    ║     • Dropout on features and attention                        ║
    ║     • DropEdge: randomly remove edges during training          ║
    ║     • L2 regularization                                        ║
    ║                                                                ║
    ║  5. SCALABILITY                                                ║
    ║     • Mini-batch with neighbor sampling (GraphSAGE)            ║
    ║     • Cluster-GCN for very large graphs                        ║
    ║     • Sparse operations for memory efficiency                  ║
    ║                                                                ║
    ║  6. EVALUATION                                                 ║
    ║     • Fixed train/val/test splits for comparison               ║
    ║     • Multiple random seeds                                    ║
    ║     • Report mean ± std                                        ║
    ║                                                                ║
    ╚════════════════════════════════════════════════════════════════╝
    """
    print(guidelines)

gnn_best_practices()

Exercises

Extend GCN to use edge features in the convolution:
class EdgeGCN(nn.Module):
    # Include edge_attr in message computation
    pass
Build a model for graph-level classification:
  • Use GIN layers
  • Implement graph-level readout
  • Test on molecular property prediction

What’s Next?

3D Deep Learning

Point clouds, voxels, and meshes

Object Detection

YOLO, Faster R-CNN, DETR