Documentation Index
Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt
Use this file to discover all available pages before exploring further.
Graph Neural Networks
Why Graphs?
Most neural networks assume data lives on a grid: images are 2D grids of pixels, text is a 1D sequence of tokens, audio is a 1D sequence of samples. But much of the real world does not fit neatly on a grid. A social network is not a grid. A molecule is not a sequence. A road network is not a rectangle. This is where graphs come in. A graph is simply a collection of nodes (things) connected by edges (relationships). This deceptively simple structure can represent an enormous range of real-world data:- Social networks: Users (nodes) and friendships (edges)
- Molecules: Atoms (nodes) and chemical bonds (edges)
- Knowledge graphs: Entities (nodes) and relations (edges)
- Citation networks: Papers (nodes) and references (edges)
- Traffic: Intersections (nodes) and roads (edges)
If you are new to GNNs, start with PyTorch Geometric (PyG). It provides optimized implementations of all the architectures discussed below, handles batching of variable-size graphs, and includes standard benchmark datasets. Building GNNs from scratch (as we do here for pedagogical purposes) is valuable for understanding, but use PyG for production work.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional, List, Dict
torch.manual_seed(42)
Graph Fundamentals
Graph Representation
class GraphData:
"""
Graph data structure.
A graph G = (V, E) with:
- V: Set of nodes (vertices)
- E: Set of edges
Representations:
- Adjacency matrix A: A[i,j] = 1 if edge (i,j) exists
- Edge index: [2, E] tensor of (source, target) pairs
"""
def __init__(
self,
x: torch.Tensor, # Node features [N, F]
edge_index: torch.Tensor, # Edge indices [2, E]
edge_attr: Optional[torch.Tensor] = None, # Edge features [E, D]
y: Optional[torch.Tensor] = None # Labels
):
self.x = x
self.edge_index = edge_index
self.edge_attr = edge_attr
self.y = y
self.num_nodes = x.size(0)
self.num_edges = edge_index.size(1)
self.num_features = x.size(1)
def to_adjacency(self) -> torch.Tensor:
"""Convert edge index to adjacency matrix."""
adj = torch.zeros(self.num_nodes, self.num_nodes)
adj[self.edge_index[0], self.edge_index[1]] = 1
return adj
@staticmethod
def from_adjacency(adj: torch.Tensor, x: torch.Tensor) -> 'GraphData':
"""Create from adjacency matrix."""
edge_index = adj.nonzero().t().contiguous()
return GraphData(x, edge_index)
def add_self_loops(self) -> 'GraphData':
"""Add self-loop edges."""
self_loops = torch.arange(self.num_nodes).unsqueeze(0).repeat(2, 1)
edge_index = torch.cat([self.edge_index, self_loops], dim=1)
return GraphData(self.x, edge_index, self.edge_attr, self.y)
# Example: Create a simple graph
def create_example_graph():
# 5 nodes with 3 features each
x = torch.randn(5, 3)
# Edges: 0->1, 0->2, 1->2, 2->3, 3->4
edge_index = torch.tensor([
[0, 0, 1, 2, 3],
[1, 2, 2, 3, 4]
])
# Make undirected (add reverse edges)
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
return GraphData(x, edge_index)
graph = create_example_graph()
print(f"Nodes: {graph.num_nodes}, Edges: {graph.num_edges}")
Message Passing Framework
The foundation of all GNNs. The core idea is beautifully simple: each node updates its representation by collecting information from its neighbors. Think of it like a game of telephone, but structured — at each round, every person asks their friends “what do you know?”, gathers the answers, and updates their own understanding. After a few rounds, each person has absorbed information from an increasingly wide neighborhood of the network: hv(k+1)=UPDATE(hv(k),AGGREGATE({hu(k):u∈N(v)}))class MessagePassingLayer(nn.Module):
"""
General message passing framework.
Steps:
1. MESSAGE: Compute messages from neighbors
(Each neighbor packages up its current knowledge into a "message")
2. AGGREGATE: Combine messages (sum, mean, max)
(The node collects all incoming messages and reduces them to a fixed
size -- this must be permutation-invariant since neighbors have no
inherent ordering)
3. UPDATE: Update node representations
(Combine the aggregated neighbor info with the node's own features)
"""
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.in_features = in_features
self.out_features = out_features
def message(
self,
x_i: torch.Tensor, # Target node features
x_j: torch.Tensor, # Source node features
edge_attr: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Compute messages from source to target."""
return x_j
def aggregate(
self,
messages: torch.Tensor,
index: torch.Tensor,
num_nodes: int
) -> torch.Tensor:
"""Aggregate messages at each node."""
# Scatter-add messages to target nodes
out = torch.zeros(num_nodes, messages.size(-1), device=messages.device)
out.scatter_add_(0, index.unsqueeze(-1).expand_as(messages), messages)
return out
def update(
self,
aggregated: torch.Tensor,
x: torch.Tensor
) -> torch.Tensor:
"""Update node embeddings."""
return aggregated
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Forward pass.
Args:
x: Node features [N, F]
edge_index: Edge indices [2, E]
Returns:
Updated node features [N, F']
"""
source, target = edge_index
# Get source and target features
x_i = x[target] # Target nodes
x_j = x[source] # Source nodes
# Compute messages
messages = self.message(x_i, x_j, edge_attr)
# Aggregate messages
aggregated = self.aggregate(messages, target, x.size(0))
# Update
out = self.update(aggregated, x)
return out
Graph Convolutional Network (GCN)
GCN is the “ResNet of graph learning” — a simple, strong baseline that most practitioners reach for first. The idea: each node’s new representation is a weighted average of its neighbors’ features (including itself), passed through a linear transform and activation. The weighting is based on node degree, which prevents high-degree nodes from dominating.class GCNConv(nn.Module):
"""
Graph Convolutional Network layer (Kipf & Welling, 2017).
Equation:
H^{(l+1)} = sigma(D_tilde^{-1/2} A_tilde D_tilde^{-1/2} H^{(l)} W^{(l)})
Where:
- A_tilde = A + I (adjacency with self-loops, so each node includes itself)
- D_tilde = diagonal degree matrix of A_tilde
- The D^{-1/2} A D^{-1/2} term is symmetric normalization: it prevents
nodes with many connections from having disproportionately large features
Practical tip: GCN works best with 2-3 layers. Going deeper causes
"over-smoothing" where all node representations converge to the same value.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True
):
super().__init__()
self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor
) -> torch.Tensor:
"""
Args:
x: Node features [N, F_in]
edge_index: Edge indices [2, E]
Returns:
Updated features [N, F_out]
"""
N = x.size(0)
source, target = edge_index
# Add self-loops
self_loops = torch.arange(N, device=x.device)
source = torch.cat([source, self_loops])
target = torch.cat([target, self_loops])
# Compute degree
deg = torch.zeros(N, device=x.device)
deg.scatter_add_(0, source, torch.ones_like(source, dtype=torch.float))
# Symmetric normalization: D^{-1/2} A D^{-1/2}
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[source] * deg_inv_sqrt[target]
# Transform features
x = torch.matmul(x, self.weight)
# Message passing
out = torch.zeros_like(x)
messages = norm.unsqueeze(-1) * x[source]
out.scatter_add_(0, target.unsqueeze(-1).expand_as(messages), messages)
if self.bias is not None:
out = out + self.bias
return out
class GCN(nn.Module):
"""Multi-layer Graph Convolutional Network."""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int,
num_layers: int = 2,
dropout: float = 0.5
):
super().__init__()
self.convs = nn.ModuleList()
# First layer
self.convs.append(GCNConv(in_features, hidden_features))
# Hidden layers
for _ in range(num_layers - 2):
self.convs.append(GCNConv(hidden_features, hidden_features))
# Output layer
self.convs.append(GCNConv(hidden_features, out_features))
self.dropout = dropout
def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index)
return x
Graph Attention Network (GAT)
GCN treats all neighbors equally (modulo degree normalization). But in practice, some neighbors are more relevant than others — your best friend’s opinion matters more than an acquaintance’s. GAT addresses this by learning attention weights for each edge, so the model can decide how much to listen to each neighbor.class GATConv(nn.Module):
"""
Graph Attention Network layer (Velickovic et al., 2018).
Uses attention to weight neighbor contributions:
alpha_ij = softmax_j(LeakyReLU(a^T [Wh_i || Wh_j]))
Multi-head attention for stability -- same idea as in Transformers.
Each head learns a different notion of "relevance" and their outputs
are concatenated (intermediate layers) or averaged (final layer).
Practical tip: GAT is more expensive than GCN but shines when
neighbor importance genuinely varies (e.g., heterogeneous graphs).
For homogeneous citation networks, GCN and GAT perform similarly.
"""
def __init__(
self,
in_features: int,
out_features: int,
heads: int = 8,
concat: bool = True,
dropout: float = 0.6,
negative_slope: float = 0.2
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.heads = heads
self.concat = concat
self.dropout = dropout
self.negative_slope = negative_slope
# Linear transformation
self.W = nn.Linear(in_features, heads * out_features, bias=False)
# Attention parameters
self.att_src = nn.Parameter(torch.Tensor(1, heads, out_features))
self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.W.weight)
nn.init.xavier_uniform_(self.att_src)
nn.init.xavier_uniform_(self.att_dst)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor
) -> torch.Tensor:
"""
Args:
x: Node features [N, F_in]
edge_index: Edge indices [2, E]
Returns:
Updated features [N, H*F_out] if concat else [N, F_out]
"""
N = x.size(0)
source, target = edge_index
# Linear transformation and reshape for multi-head
x = self.W(x).view(-1, self.heads, self.out_features) # [N, H, F]
# Compute attention scores
# Source attention
alpha_src = (x * self.att_src).sum(dim=-1) # [N, H]
# Target attention
alpha_dst = (x * self.att_dst).sum(dim=-1) # [N, H]
# Edge attention
alpha = alpha_src[source] + alpha_dst[target] # [E, H]
alpha = F.leaky_relu(alpha, negative_slope=self.negative_slope)
# Softmax over neighbors
alpha = self._softmax(alpha, target, N) # [E, H]
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
# Message passing
out = torch.zeros(N, self.heads, self.out_features, device=x.device)
messages = alpha.unsqueeze(-1) * x[source] # [E, H, F]
# Scatter-add
out.scatter_add_(
0,
target.view(-1, 1, 1).expand_as(messages),
messages
)
if self.concat:
out = out.view(N, -1) # [N, H*F]
else:
out = out.mean(dim=1) # [N, F]
return out
def _softmax(
self,
alpha: torch.Tensor,
target: torch.Tensor,
num_nodes: int
) -> torch.Tensor:
"""Softmax over neighbors."""
# Max for numerical stability
alpha_max = torch.zeros(num_nodes, alpha.size(1), device=alpha.device)
alpha_max.scatter_reduce_(
0,
target.unsqueeze(-1).expand_as(alpha),
alpha,
reduce='amax'
)
alpha = alpha - alpha_max[target]
# Exp and sum
alpha = alpha.exp()
alpha_sum = torch.zeros(num_nodes, alpha.size(1), device=alpha.device)
alpha_sum.scatter_add_(
0,
target.unsqueeze(-1).expand_as(alpha),
alpha
)
return alpha / (alpha_sum[target] + 1e-8)
class GAT(nn.Module):
"""Multi-layer Graph Attention Network."""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int,
heads: int = 8,
dropout: float = 0.6
):
super().__init__()
self.conv1 = GATConv(in_features, hidden_features, heads=heads, dropout=dropout)
self.conv2 = GATConv(
hidden_features * heads, out_features,
heads=1, concat=False, dropout=dropout
)
self.dropout = dropout
def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index)
return x
GraphSAGE
GCN and GAT require the full graph in memory during training, which breaks down for graphs with millions or billions of nodes (think: the entire Facebook social graph). GraphSAGE solves this with a simple but powerful idea: instead of using all neighbors, sample a fixed number of neighbors at each layer. This makes mini-batch training possible on arbitrarily large graphs.class SAGEConv(nn.Module):
"""
GraphSAGE (Hamilton et al., 2017).
Key innovations:
1. Sampling fixed-size neighborhoods (makes training scalable to
billion-node graphs -- you no longer need the full graph in memory)
2. Multiple aggregation functions (mean, max, LSTM -- each captures
different aspects of the neighborhood)
3. Inductive learning (works on unseen nodes and entirely new graphs,
unlike GCN which is transductive by default)
Practical tip: GraphSAGE is often the best starting point for
production systems because of its scalability and inductive capability.
"""
def __init__(
self,
in_features: int,
out_features: int,
aggregator: str = 'mean', # 'mean', 'max', 'lstm'
normalize: bool = True
):
super().__init__()
self.aggregator = aggregator
self.normalize = normalize
# Transform for concatenating self + aggregated neighbors
self.lin = nn.Linear(2 * in_features, out_features)
if aggregator == 'lstm':
self.lstm = nn.LSTM(in_features, in_features, batch_first=True)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor
) -> torch.Tensor:
N = x.size(0)
source, target = edge_index
# Aggregate neighbor features
if self.aggregator == 'mean':
agg = self._mean_aggregator(x, source, target, N)
elif self.aggregator == 'max':
agg = self._max_aggregator(x, source, target, N)
else:
agg = self._mean_aggregator(x, source, target, N)
# Concatenate self features with aggregated neighbors
out = torch.cat([x, agg], dim=-1)
out = self.lin(out)
if self.normalize:
out = F.normalize(out, p=2, dim=-1)
return out
def _mean_aggregator(self, x, source, target, N):
"""Mean aggregation."""
out = torch.zeros(N, x.size(-1), device=x.device)
count = torch.zeros(N, 1, device=x.device)
out.scatter_add_(0, target.unsqueeze(-1).expand(-1, x.size(-1)), x[source])
count.scatter_add_(0, target.unsqueeze(-1), torch.ones_like(target.unsqueeze(-1), dtype=torch.float))
return out / (count + 1e-8)
def _max_aggregator(self, x, source, target, N):
"""Max aggregation."""
out = torch.full((N, x.size(-1)), float('-inf'), device=x.device)
out.scatter_reduce_(0, target.unsqueeze(-1).expand(-1, x.size(-1)), x[source], reduce='amax')
out[out == float('-inf')] = 0
return out
class NeighborSampler:
"""
Sample fixed-size neighborhoods for mini-batch training.
Enables training on large graphs.
"""
def __init__(self, edge_index: torch.Tensor, sizes: List[int]):
"""
Args:
edge_index: Full graph edge index
sizes: Number of neighbors to sample at each layer
e.g., [25, 10] means 25 neighbors for layer 1, 10 for layer 2
"""
self.edge_index = edge_index
self.sizes = sizes
# Build adjacency list
self.adj_list = self._build_adj_list()
def _build_adj_list(self) -> Dict[int, List[int]]:
"""Build adjacency list from edge index."""
adj = {}
source, target = self.edge_index.tolist()
for s, t in zip(source, target):
if t not in adj:
adj[t] = []
adj[t].append(s)
return adj
def sample(self, batch_nodes: torch.Tensor) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""
Sample subgraph for batch nodes.
Returns list of (sampled_edge_index, batch_nodes) for each layer.
"""
sampled_layers = []
nodes = batch_nodes.tolist()
for size in self.sizes:
# Sample neighbors for each node
new_nodes = set()
edges_src, edges_dst = [], []
for node in nodes:
neighbors = self.adj_list.get(node, [])
if len(neighbors) > size:
sampled = np.random.choice(neighbors, size, replace=False)
else:
sampled = neighbors
for neighbor in sampled:
new_nodes.add(neighbor)
edges_src.append(neighbor)
edges_dst.append(node)
edge_index = torch.tensor([edges_src, edges_dst])
sampled_layers.append((edge_index, torch.tensor(list(new_nodes | set(nodes)))))
nodes = list(new_nodes | set(nodes))
return sampled_layers
Graph Isomorphism Network (GIN)
GIN asks the theoretical question: how powerful can message-passing GNNs actually be? The answer turns out to be: at most as powerful as the Weisfeiler-Lehman (WL) graph isomorphism test, a classical algorithm for checking if two graphs are structurally identical. GIN achieves this theoretical maximum expressiveness. If you need your GNN to distinguish between subtly different graph structures (common in molecular property prediction), GIN is your tool.class GINConv(nn.Module):
"""
Graph Isomorphism Network (Xu et al., 2019).
Most expressive among 1-WL equivalent GNNs:
h_v^{(k+1)} = MLP((1 + epsilon) * h_v^{(k)} + SUM_{u in N(v)} h_u^{(k)})
Why SUM and not MEAN? Mean aggregation loses information about
neighborhood SIZE. Two nodes with neighborhoods {1, 1} and {1, 1, 1, 1}
produce the same mean but different sums. GIN uses sum to preserve
this structural information.
The learnable epsilon controls how much weight the node gives to its
own features versus its neighbors' -- a subtle but important knob.
"""
def __init__(
self,
in_features: int,
out_features: int,
eps: float = 0.0,
train_eps: bool = True
):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_features, out_features),
nn.BatchNorm1d(out_features),
nn.ReLU(),
nn.Linear(out_features, out_features),
nn.BatchNorm1d(out_features),
nn.ReLU()
)
if train_eps:
self.eps = nn.Parameter(torch.Tensor([eps]))
else:
self.register_buffer('eps', torch.Tensor([eps]))
def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
N = x.size(0)
source, target = edge_index
# Aggregate neighbors
out = torch.zeros(N, x.size(-1), device=x.device)
out.scatter_add_(0, target.unsqueeze(-1).expand(-1, x.size(-1)), x[source])
# Add self with learnable epsilon
out = (1 + self.eps) * x + out
# MLP
out = self.mlp(out)
return out
class GIN(nn.Module):
"""Multi-layer GIN for graph classification."""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int,
num_layers: int = 5
):
super().__init__()
self.convs = nn.ModuleList()
self.convs.append(GINConv(in_features, hidden_features))
for _ in range(num_layers - 1):
self.convs.append(GINConv(hidden_features, hidden_features))
# Readout for graph classification
self.readout = nn.Sequential(
nn.Linear(hidden_features * num_layers, hidden_features),
nn.ReLU(),
nn.Linear(hidden_features, out_features)
)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
batch: torch.Tensor
) -> torch.Tensor:
"""
Args:
batch: Tensor indicating which graph each node belongs to
"""
# Collect representations from all layers
h_list = []
for conv in self.convs:
x = conv(x, edge_index)
h_list.append(self._readout(x, batch))
# Concatenate all layer representations
h = torch.cat(h_list, dim=-1)
return self.readout(h)
def _readout(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
"""Global sum pooling per graph."""
num_graphs = batch.max().item() + 1
out = torch.zeros(num_graphs, x.size(-1), device=x.device)
out.scatter_add_(0, batch.unsqueeze(-1).expand_as(x), x)
return out
Graph Pooling
Just as CNNs use pooling to reduce spatial resolution and build hierarchical features, GNNs need pooling to reduce the number of nodes and create graph-level representations. The challenge is that graphs have irregular structure — you cannot just “stride 2” across a graph. Graph pooling methods learn which nodes to keep and which to merge, creating a coarser version of the original graph.Over-smoothing is the silent killer of deep GNNs. After about 5-6 message-passing layers, all node representations converge to nearly the same vector — the GNN equivalent of blurring an image until everything is grey. If your validation accuracy drops when you add more layers, over-smoothing is the likely culprit. Mitigations include skip connections (like JKNet), DropEdge during training, or graph pooling to reduce the effective depth.
class TopKPooling(nn.Module):
"""
Top-K pooling: Select top-k scoring nodes.
Used for hierarchical graph representation.
The idea: learn a scoring function, keep the top-scoring nodes,
and remove the rest. Edges between removed nodes are dropped.
Think of it as learning which parts of the graph are most informative
and focusing on those.
"""
def __init__(self, in_features: int, ratio: float = 0.5):
super().__init__()
self.ratio = ratio
self.score = nn.Linear(in_features, 1)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
batch: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
x: Pooled node features
edge_index: Updated edge index
batch: Updated batch assignment
"""
if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
# Compute scores
scores = self.score(x).squeeze(-1)
scores = torch.tanh(scores)
# Select top-k nodes per graph
num_graphs = batch.max().item() + 1
keep_mask = torch.zeros(x.size(0), dtype=torch.bool, device=x.device)
for g in range(num_graphs):
graph_mask = batch == g
graph_scores = scores[graph_mask]
k = max(1, int(graph_mask.sum() * self.ratio))
_, top_indices = graph_scores.topk(k)
graph_nodes = graph_mask.nonzero().squeeze()
keep_nodes = graph_nodes[top_indices]
keep_mask[keep_nodes] = True
# Filter nodes
x = x[keep_mask] * scores[keep_mask].unsqueeze(-1)
# Re-index edges
node_map = torch.full((x.size(0),), -1, dtype=torch.long, device=x.device)
node_map[keep_mask] = torch.arange(keep_mask.sum(), device=x.device)
source, target = edge_index
edge_mask = keep_mask[source] & keep_mask[target]
edge_index = torch.stack([
node_map[source[edge_mask]],
node_map[target[edge_mask]]
])
batch = batch[keep_mask]
return x, edge_index, batch
class SAGPooling(nn.Module):
"""
Self-Attention Graph Pooling.
Uses GNN to compute attention scores.
"""
def __init__(self, in_features: int, ratio: float = 0.5):
super().__init__()
self.ratio = ratio
self.gnn = GCNConv(in_features, 1)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
batch: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
# Compute attention scores using GNN
scores = self.gnn(x, edge_index).squeeze(-1)
scores = torch.tanh(scores)
# Same top-k selection as above...
# (implementation similar to TopKPooling)
return x, edge_index, batch
class GlobalAttentionPooling(nn.Module):
"""Global attention pooling for graph-level representation."""
def __init__(self, in_features: int, hidden_features: int):
super().__init__()
self.gate = nn.Sequential(
nn.Linear(in_features, hidden_features),
nn.Tanh(),
nn.Linear(hidden_features, 1),
nn.Sigmoid()
)
self.transform = nn.Linear(in_features, hidden_features)
def forward(
self,
x: torch.Tensor,
batch: torch.Tensor
) -> torch.Tensor:
"""
Args:
x: Node features [N, F]
batch: Batch assignment [N]
Returns:
Graph representations [B, F']
"""
# Compute attention
gate = self.gate(x) # [N, 1]
x = self.transform(x) * gate # [N, F']
# Sum per graph
num_graphs = batch.max().item() + 1
out = torch.zeros(num_graphs, x.size(-1), device=x.device)
out.scatter_add_(0, batch.unsqueeze(-1).expand_as(x), x)
return out
Practical Application: Node Classification
The classic benchmark for GNNs is semi-supervised node classification on the Cora citation network: given a graph of academic papers (nodes) with citation links (edges), classify each paper into one of 7 research topics using only the labels of 20 papers per class (140 total out of 2,708). This is a remarkably data-efficient setting — the graph structure carries so much information that you can classify most unlabeled nodes by labeling just a handful.def train_node_classification():
"""Example: Semi-supervised node classification on Cora."""
# Create example data (simulated Cora-like)
# Real Cora: 2708 papers, 1433 unique words (bag-of-words), 7 classes
num_nodes = 2708
num_features = 1433
num_classes = 7
x = torch.randn(num_nodes, num_features)
# Random edges (in practice, use actual citation graph)
num_edges = 10556
edge_index = torch.randint(0, num_nodes, (2, num_edges))
# Labels and masks
y = torch.randint(0, num_classes, (num_nodes,))
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[:140] = True # 20 per class
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask[140:640] = True
test_mask = ~(train_mask | val_mask)
# Model
model = GCN(num_features, 64, num_classes, num_layers=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# Training
for epoch in range(200):
model.train()
optimizer.zero_grad()
out = model(x, edge_index)
loss = F.cross_entropy(out[train_mask], y[train_mask])
loss.backward()
optimizer.step()
# Validation
model.eval()
with torch.no_grad():
pred = model(x, edge_index).argmax(dim=1)
val_acc = (pred[val_mask] == y[val_mask]).float().mean()
if epoch % 50 == 0:
print(f"Epoch {epoch}: Loss = {loss:.4f}, Val Acc = {val_acc:.4f}")
# Test
model.eval()
with torch.no_grad():
pred = model(x, edge_index).argmax(dim=1)
test_acc = (pred[test_mask] == y[test_mask]).float().mean()
print(f"\nTest Accuracy: {test_acc:.4f}")
# train_node_classification()
GNN Best Practices
def gnn_best_practices():
"""Guidelines for GNN development."""
guidelines = """
╔════════════════════════════════════════════════════════════════╗
║ GNN BEST PRACTICES ║
╠════════════════════════════════════════════════════════════════╣
║ ║
║ 1. CHOOSE THE RIGHT MODEL ║
║ • GCN: Simple baseline, fast ║
║ • GAT: When neighbor importance varies ║
║ • GraphSAGE: Inductive, large graphs ║
║ • GIN: Maximum expressiveness needed ║
║ ║
║ 2. DEPTH CONSIDERATIONS ║
║ • 2-3 layers often optimal (each layer = 1-hop neighbors) ║
║ • Deeper = over-smoothing (all nodes converge to the same ║
║ representation -- like running a heat diffusion too long ║
║ until everything reaches the same temperature) ║
║ • Use skip connections or JK-Net for deeper networks ║
║ ║
║ 3. NORMALIZATION ║
║ • Symmetric normalization (GCN) is common ║
║ • Mean normalization for GraphSAGE ║
║ • Layer/batch norm helps training ║
║ ║
║ 4. REGULARIZATION ║
║ • Dropout on features and attention ║
║ • DropEdge: randomly remove edges during training ║
║ • L2 regularization ║
║ ║
║ 5. SCALABILITY ║
║ • Mini-batch with neighbor sampling (GraphSAGE) ║
║ • Cluster-GCN for very large graphs ║
║ • Sparse operations for memory efficiency ║
║ ║
║ 6. EVALUATION ║
║ • Fixed train/val/test splits for comparison ║
║ • Multiple random seeds ║
║ • Report mean ± std ║
║ ║
╚════════════════════════════════════════════════════════════════╝
"""
print(guidelines)
gnn_best_practices()
Exercises
Exercise 1: Implement Edge Features
Exercise 1: Implement Edge Features
Extend GCN to use edge features in the convolution:
class EdgeGCN(nn.Module):
# Include edge_attr in message computation
pass
Exercise 2: Graph Classification
Exercise 2: Graph Classification
Build a model for graph-level classification:
- Use GIN layers
- Implement graph-level readout
- Test on molecular property prediction
Exercise 3: Link Prediction
Exercise 3: Link Prediction
Implement link prediction:
# Given: node embeddings
# Predict: probability of edge existing
# Loss: binary cross-entropy on positive/negative edges
What’s Next?
3D Deep Learning
Point clouds, voxels, and meshes
Object Detection
YOLO, Faster R-CNN, DETR