Skip to main content
Graph Neural Networks

Graph Neural Networks

Why Graphs?

Many real-world data is naturally graph-structured:
  • Social networks: Users and friendships
  • Molecules: Atoms and bonds
  • Knowledge graphs: Entities and relations
  • Citation networks: Papers and references
  • Traffic: Roads and intersections
Graphs capture relational structure that traditional neural networks miss.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional, List, Dict

torch.manual_seed(42)

Graph Fundamentals

Graph Representation

class GraphData:
    """
    Graph data structure.
    
    A graph G = (V, E) with:
    - V: Set of nodes (vertices)
    - E: Set of edges
    
    Representations:
    - Adjacency matrix A: A[i,j] = 1 if edge (i,j) exists
    - Edge index: [2, E] tensor of (source, target) pairs
    """
    
    def __init__(
        self,
        x: torch.Tensor,           # Node features [N, F]
        edge_index: torch.Tensor,   # Edge indices [2, E]
        edge_attr: Optional[torch.Tensor] = None,  # Edge features [E, D]
        y: Optional[torch.Tensor] = None           # Labels
    ):
        self.x = x
        self.edge_index = edge_index
        self.edge_attr = edge_attr
        self.y = y
        
        self.num_nodes = x.size(0)
        self.num_edges = edge_index.size(1)
        self.num_features = x.size(1)
    
    def to_adjacency(self) -> torch.Tensor:
        """Convert edge index to adjacency matrix."""
        adj = torch.zeros(self.num_nodes, self.num_nodes)
        adj[self.edge_index[0], self.edge_index[1]] = 1
        return adj
    
    @staticmethod
    def from_adjacency(adj: torch.Tensor, x: torch.Tensor) -> 'GraphData':
        """Create from adjacency matrix."""
        edge_index = adj.nonzero().t().contiguous()
        return GraphData(x, edge_index)
    
    def add_self_loops(self) -> 'GraphData':
        """Add self-loop edges."""
        self_loops = torch.arange(self.num_nodes).unsqueeze(0).repeat(2, 1)
        edge_index = torch.cat([self.edge_index, self_loops], dim=1)
        return GraphData(self.x, edge_index, self.edge_attr, self.y)


# Example: Create a simple graph
def create_example_graph():
    # 5 nodes with 3 features each
    x = torch.randn(5, 3)
    
    # Edges: 0->1, 0->2, 1->2, 2->3, 3->4
    edge_index = torch.tensor([
        [0, 0, 1, 2, 3],
        [1, 2, 2, 3, 4]
    ])
    
    # Make undirected (add reverse edges)
    edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
    
    return GraphData(x, edge_index)

graph = create_example_graph()
print(f"Nodes: {graph.num_nodes}, Edges: {graph.num_edges}")

Message Passing Framework

The foundation of all GNNs: hv(k+1)=UPDATE(hv(k),AGGREGATE({hu(k):uN(v)}))h_v^{(k+1)} = \text{UPDATE}\left(h_v^{(k)}, \text{AGGREGATE}\left(\{h_u^{(k)} : u \in \mathcal{N}(v)\}\right)\right)
class MessagePassingLayer(nn.Module):
    """
    General message passing framework.
    
    Steps:
    1. MESSAGE: Compute messages from neighbors
    2. AGGREGATE: Combine messages (sum, mean, max)
    3. UPDATE: Update node representations
    """
    
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
    
    def message(
        self,
        x_i: torch.Tensor,  # Target node features
        x_j: torch.Tensor,  # Source node features
        edge_attr: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Compute messages from source to target."""
        return x_j
    
    def aggregate(
        self,
        messages: torch.Tensor,
        index: torch.Tensor,
        num_nodes: int
    ) -> torch.Tensor:
        """Aggregate messages at each node."""
        # Scatter-add messages to target nodes
        out = torch.zeros(num_nodes, messages.size(-1), device=messages.device)
        out.scatter_add_(0, index.unsqueeze(-1).expand_as(messages), messages)
        return out
    
    def update(
        self,
        aggregated: torch.Tensor,
        x: torch.Tensor
    ) -> torch.Tensor:
        """Update node embeddings."""
        return aggregated
    
    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            x: Node features [N, F]
            edge_index: Edge indices [2, E]
        
        Returns:
            Updated node features [N, F']
        """
        source, target = edge_index
        
        # Get source and target features
        x_i = x[target]  # Target nodes
        x_j = x[source]  # Source nodes
        
        # Compute messages
        messages = self.message(x_i, x_j, edge_attr)
        
        # Aggregate messages
        aggregated = self.aggregate(messages, target, x.size(0))
        
        # Update
        out = self.update(aggregated, x)
        
        return out

Graph Convolutional Network (GCN)

class GCNConv(nn.Module):
    """
    Graph Convolutional Network layer (Kipf & Welling, 2017).
    
    Equation:
    H^{(l+1)} = σ(D̃^{-1/2} Ã D̃^{-1/2} H^{(l)} W^{(l)})
    
    Where:
    - Ã = A + I (adjacency with self-loops)
    - D̃ = diagonal degree matrix of Ã
    """
    
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True
    ):
        super().__init__()
        
        self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    
    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            x: Node features [N, F_in]
            edge_index: Edge indices [2, E]
        
        Returns:
            Updated features [N, F_out]
        """
        N = x.size(0)
        source, target = edge_index
        
        # Add self-loops
        self_loops = torch.arange(N, device=x.device)
        source = torch.cat([source, self_loops])
        target = torch.cat([target, self_loops])
        
        # Compute degree
        deg = torch.zeros(N, device=x.device)
        deg.scatter_add_(0, source, torch.ones_like(source, dtype=torch.float))
        
        # Symmetric normalization: D^{-1/2} A D^{-1/2}
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        
        norm = deg_inv_sqrt[source] * deg_inv_sqrt[target]
        
        # Transform features
        x = torch.matmul(x, self.weight)
        
        # Message passing
        out = torch.zeros_like(x)
        messages = norm.unsqueeze(-1) * x[source]
        out.scatter_add_(0, target.unsqueeze(-1).expand_as(messages), messages)
        
        if self.bias is not None:
            out = out + self.bias
        
        return out


class GCN(nn.Module):
    """Multi-layer Graph Convolutional Network."""
    
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        out_features: int,
        num_layers: int = 2,
        dropout: float = 0.5
    ):
        super().__init__()
        
        self.convs = nn.ModuleList()
        
        # First layer
        self.convs.append(GCNConv(in_features, hidden_features))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_features, hidden_features))
        
        # Output layer
        self.convs.append(GCNConv(hidden_features, out_features))
        
        self.dropout = dropout
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        x = self.convs[-1](x, edge_index)
        return x

Graph Attention Network (GAT)

class GATConv(nn.Module):
    """
    Graph Attention Network layer (Veličković et al., 2018).
    
    Uses attention to weight neighbor contributions:
    α_ij = softmax_j(LeakyReLU(a^T [Wh_i || Wh_j]))
    
    Multi-head attention for stability.
    """
    
    def __init__(
        self,
        in_features: int,
        out_features: int,
        heads: int = 8,
        concat: bool = True,
        dropout: float = 0.6,
        negative_slope: float = 0.2
    ):
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.heads = heads
        self.concat = concat
        self.dropout = dropout
        self.negative_slope = negative_slope
        
        # Linear transformation
        self.W = nn.Linear(in_features, heads * out_features, bias=False)
        
        # Attention parameters
        self.att_src = nn.Parameter(torch.Tensor(1, heads, out_features))
        self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_features))
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W.weight)
        nn.init.xavier_uniform_(self.att_src)
        nn.init.xavier_uniform_(self.att_dst)
    
    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            x: Node features [N, F_in]
            edge_index: Edge indices [2, E]
        
        Returns:
            Updated features [N, H*F_out] if concat else [N, F_out]
        """
        N = x.size(0)
        source, target = edge_index
        
        # Linear transformation and reshape for multi-head
        x = self.W(x).view(-1, self.heads, self.out_features)  # [N, H, F]
        
        # Compute attention scores
        # Source attention
        alpha_src = (x * self.att_src).sum(dim=-1)  # [N, H]
        # Target attention
        alpha_dst = (x * self.att_dst).sum(dim=-1)  # [N, H]
        
        # Edge attention
        alpha = alpha_src[source] + alpha_dst[target]  # [E, H]
        alpha = F.leaky_relu(alpha, negative_slope=self.negative_slope)
        
        # Softmax over neighbors
        alpha = self._softmax(alpha, target, N)  # [E, H]
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        
        # Message passing
        out = torch.zeros(N, self.heads, self.out_features, device=x.device)
        messages = alpha.unsqueeze(-1) * x[source]  # [E, H, F]
        
        # Scatter-add
        out.scatter_add_(
            0,
            target.view(-1, 1, 1).expand_as(messages),
            messages
        )
        
        if self.concat:
            out = out.view(N, -1)  # [N, H*F]
        else:
            out = out.mean(dim=1)  # [N, F]
        
        return out
    
    def _softmax(
        self,
        alpha: torch.Tensor,
        target: torch.Tensor,
        num_nodes: int
    ) -> torch.Tensor:
        """Softmax over neighbors."""
        # Max for numerical stability
        alpha_max = torch.zeros(num_nodes, alpha.size(1), device=alpha.device)
        alpha_max.scatter_reduce_(
            0,
            target.unsqueeze(-1).expand_as(alpha),
            alpha,
            reduce='amax'
        )
        alpha = alpha - alpha_max[target]
        
        # Exp and sum
        alpha = alpha.exp()
        alpha_sum = torch.zeros(num_nodes, alpha.size(1), device=alpha.device)
        alpha_sum.scatter_add_(
            0,
            target.unsqueeze(-1).expand_as(alpha),
            alpha
        )
        
        return alpha / (alpha_sum[target] + 1e-8)


class GAT(nn.Module):
    """Multi-layer Graph Attention Network."""
    
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        out_features: int,
        heads: int = 8,
        dropout: float = 0.6
    ):
        super().__init__()
        
        self.conv1 = GATConv(in_features, hidden_features, heads=heads, dropout=dropout)
        self.conv2 = GATConv(
            hidden_features * heads, out_features, 
            heads=1, concat=False, dropout=dropout
        )
        
        self.dropout = dropout
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x

GraphSAGE

class SAGEConv(nn.Module):
    """
    GraphSAGE (Hamilton et al., 2017).
    
    Key innovations:
    1. Sampling fixed-size neighborhoods (scalable)
    2. Multiple aggregation functions
    3. Inductive learning (works on new graphs)
    """
    
    def __init__(
        self,
        in_features: int,
        out_features: int,
        aggregator: str = 'mean',  # 'mean', 'max', 'lstm'
        normalize: bool = True
    ):
        super().__init__()
        
        self.aggregator = aggregator
        self.normalize = normalize
        
        # Transform for concatenating self + aggregated neighbors
        self.lin = nn.Linear(2 * in_features, out_features)
        
        if aggregator == 'lstm':
            self.lstm = nn.LSTM(in_features, in_features, batch_first=True)
    
    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor
    ) -> torch.Tensor:
        N = x.size(0)
        source, target = edge_index
        
        # Aggregate neighbor features
        if self.aggregator == 'mean':
            agg = self._mean_aggregator(x, source, target, N)
        elif self.aggregator == 'max':
            agg = self._max_aggregator(x, source, target, N)
        else:
            agg = self._mean_aggregator(x, source, target, N)
        
        # Concatenate self features with aggregated neighbors
        out = torch.cat([x, agg], dim=-1)
        out = self.lin(out)
        
        if self.normalize:
            out = F.normalize(out, p=2, dim=-1)
        
        return out
    
    def _mean_aggregator(self, x, source, target, N):
        """Mean aggregation."""
        out = torch.zeros(N, x.size(-1), device=x.device)
        count = torch.zeros(N, 1, device=x.device)
        
        out.scatter_add_(0, target.unsqueeze(-1).expand(-1, x.size(-1)), x[source])
        count.scatter_add_(0, target.unsqueeze(-1), torch.ones_like(target.unsqueeze(-1), dtype=torch.float))
        
        return out / (count + 1e-8)
    
    def _max_aggregator(self, x, source, target, N):
        """Max aggregation."""
        out = torch.full((N, x.size(-1)), float('-inf'), device=x.device)
        out.scatter_reduce_(0, target.unsqueeze(-1).expand(-1, x.size(-1)), x[source], reduce='amax')
        out[out == float('-inf')] = 0
        return out


class NeighborSampler:
    """
    Sample fixed-size neighborhoods for mini-batch training.
    Enables training on large graphs.
    """
    
    def __init__(self, edge_index: torch.Tensor, sizes: List[int]):
        """
        Args:
            edge_index: Full graph edge index
            sizes: Number of neighbors to sample at each layer
                   e.g., [25, 10] means 25 neighbors for layer 1, 10 for layer 2
        """
        self.edge_index = edge_index
        self.sizes = sizes
        
        # Build adjacency list
        self.adj_list = self._build_adj_list()
    
    def _build_adj_list(self) -> Dict[int, List[int]]:
        """Build adjacency list from edge index."""
        adj = {}
        source, target = self.edge_index.tolist()
        
        for s, t in zip(source, target):
            if t not in adj:
                adj[t] = []
            adj[t].append(s)
        
        return adj
    
    def sample(self, batch_nodes: torch.Tensor) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        """
        Sample subgraph for batch nodes.
        
        Returns list of (sampled_edge_index, batch_nodes) for each layer.
        """
        sampled_layers = []
        nodes = batch_nodes.tolist()
        
        for size in self.sizes:
            # Sample neighbors for each node
            new_nodes = set()
            edges_src, edges_dst = [], []
            
            for node in nodes:
                neighbors = self.adj_list.get(node, [])
                
                if len(neighbors) > size:
                    sampled = np.random.choice(neighbors, size, replace=False)
                else:
                    sampled = neighbors
                
                for neighbor in sampled:
                    new_nodes.add(neighbor)
                    edges_src.append(neighbor)
                    edges_dst.append(node)
            
            edge_index = torch.tensor([edges_src, edges_dst])
            sampled_layers.append((edge_index, torch.tensor(list(new_nodes | set(nodes)))))
            
            nodes = list(new_nodes | set(nodes))
        
        return sampled_layers

Graph Isomorphism Network (GIN)

class GINConv(nn.Module):
    """
    Graph Isomorphism Network (Xu et al., 2019).
    
    Most expressive among 1-WL equivalent GNNs:
    h_v^{(k+1)} = MLP((1 + ε) · h_v^{(k)} + Σ_{u∈N(v)} h_u^{(k)})
    """
    
    def __init__(
        self,
        in_features: int,
        out_features: int,
        eps: float = 0.0,
        train_eps: bool = True
    ):
        super().__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(in_features, out_features),
            nn.BatchNorm1d(out_features),
            nn.ReLU(),
            nn.Linear(out_features, out_features),
            nn.BatchNorm1d(out_features),
            nn.ReLU()
        )
        
        if train_eps:
            self.eps = nn.Parameter(torch.Tensor([eps]))
        else:
            self.register_buffer('eps', torch.Tensor([eps]))
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        N = x.size(0)
        source, target = edge_index
        
        # Aggregate neighbors
        out = torch.zeros(N, x.size(-1), device=x.device)
        out.scatter_add_(0, target.unsqueeze(-1).expand(-1, x.size(-1)), x[source])
        
        # Add self with learnable epsilon
        out = (1 + self.eps) * x + out
        
        # MLP
        out = self.mlp(out)
        
        return out


class GIN(nn.Module):
    """Multi-layer GIN for graph classification."""
    
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        out_features: int,
        num_layers: int = 5
    ):
        super().__init__()
        
        self.convs = nn.ModuleList()
        
        self.convs.append(GINConv(in_features, hidden_features))
        for _ in range(num_layers - 1):
            self.convs.append(GINConv(hidden_features, hidden_features))
        
        # Readout for graph classification
        self.readout = nn.Sequential(
            nn.Linear(hidden_features * num_layers, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, out_features)
        )
    
    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        batch: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            batch: Tensor indicating which graph each node belongs to
        """
        # Collect representations from all layers
        h_list = []
        
        for conv in self.convs:
            x = conv(x, edge_index)
            h_list.append(self._readout(x, batch))
        
        # Concatenate all layer representations
        h = torch.cat(h_list, dim=-1)
        
        return self.readout(h)
    
    def _readout(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
        """Global sum pooling per graph."""
        num_graphs = batch.max().item() + 1
        out = torch.zeros(num_graphs, x.size(-1), device=x.device)
        out.scatter_add_(0, batch.unsqueeze(-1).expand_as(x), x)
        return out

Graph Pooling

class TopKPooling(nn.Module):
    """
    Top-K pooling: Select top-k scoring nodes.
    Used for hierarchical graph representation.
    """
    
    def __init__(self, in_features: int, ratio: float = 0.5):
        super().__init__()
        
        self.ratio = ratio
        self.score = nn.Linear(in_features, 1)
    
    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        batch: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Returns:
            x: Pooled node features
            edge_index: Updated edge index
            batch: Updated batch assignment
        """
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        # Compute scores
        scores = self.score(x).squeeze(-1)
        scores = torch.tanh(scores)
        
        # Select top-k nodes per graph
        num_graphs = batch.max().item() + 1
        keep_mask = torch.zeros(x.size(0), dtype=torch.bool, device=x.device)
        
        for g in range(num_graphs):
            graph_mask = batch == g
            graph_scores = scores[graph_mask]
            
            k = max(1, int(graph_mask.sum() * self.ratio))
            _, top_indices = graph_scores.topk(k)
            
            graph_nodes = graph_mask.nonzero().squeeze()
            keep_nodes = graph_nodes[top_indices]
            keep_mask[keep_nodes] = True
        
        # Filter nodes
        x = x[keep_mask] * scores[keep_mask].unsqueeze(-1)
        
        # Re-index edges
        node_map = torch.full((x.size(0),), -1, dtype=torch.long, device=x.device)
        node_map[keep_mask] = torch.arange(keep_mask.sum(), device=x.device)
        
        source, target = edge_index
        edge_mask = keep_mask[source] & keep_mask[target]
        edge_index = torch.stack([
            node_map[source[edge_mask]],
            node_map[target[edge_mask]]
        ])
        
        batch = batch[keep_mask]
        
        return x, edge_index, batch


class SAGPooling(nn.Module):
    """
    Self-Attention Graph Pooling.
    Uses GNN to compute attention scores.
    """
    
    def __init__(self, in_features: int, ratio: float = 0.5):
        super().__init__()
        
        self.ratio = ratio
        self.gnn = GCNConv(in_features, 1)
    
    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        batch: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        # Compute attention scores using GNN
        scores = self.gnn(x, edge_index).squeeze(-1)
        scores = torch.tanh(scores)
        
        # Same top-k selection as above...
        # (implementation similar to TopKPooling)
        
        return x, edge_index, batch


class GlobalAttentionPooling(nn.Module):
    """Global attention pooling for graph-level representation."""
    
    def __init__(self, in_features: int, hidden_features: int):
        super().__init__()
        
        self.gate = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.Tanh(),
            nn.Linear(hidden_features, 1),
            nn.Sigmoid()
        )
        
        self.transform = nn.Linear(in_features, hidden_features)
    
    def forward(
        self,
        x: torch.Tensor,
        batch: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            x: Node features [N, F]
            batch: Batch assignment [N]
        
        Returns:
            Graph representations [B, F']
        """
        # Compute attention
        gate = self.gate(x)  # [N, 1]
        x = self.transform(x) * gate  # [N, F']
        
        # Sum per graph
        num_graphs = batch.max().item() + 1
        out = torch.zeros(num_graphs, x.size(-1), device=x.device)
        out.scatter_add_(0, batch.unsqueeze(-1).expand_as(x), x)
        
        return out

Practical Application: Node Classification

def train_node_classification():
    """Example: Semi-supervised node classification on Cora."""
    
    # Create example data (simulated Cora-like)
    num_nodes = 2708
    num_features = 1433
    num_classes = 7
    
    x = torch.randn(num_nodes, num_features)
    
    # Random edges (in practice, use actual graph)
    num_edges = 10556
    edge_index = torch.randint(0, num_nodes, (2, num_edges))
    
    # Labels and masks
    y = torch.randint(0, num_classes, (num_nodes,))
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    train_mask[:140] = True  # 20 per class
    
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask[140:640] = True
    
    test_mask = ~(train_mask | val_mask)
    
    # Model
    model = GCN(num_features, 64, num_classes, num_layers=2)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    
    # Training
    for epoch in range(200):
        model.train()
        optimizer.zero_grad()
        
        out = model(x, edge_index)
        loss = F.cross_entropy(out[train_mask], y[train_mask])
        
        loss.backward()
        optimizer.step()
        
        # Validation
        model.eval()
        with torch.no_grad():
            pred = model(x, edge_index).argmax(dim=1)
            val_acc = (pred[val_mask] == y[val_mask]).float().mean()
        
        if epoch % 50 == 0:
            print(f"Epoch {epoch}: Loss = {loss:.4f}, Val Acc = {val_acc:.4f}")
    
    # Test
    model.eval()
    with torch.no_grad():
        pred = model(x, edge_index).argmax(dim=1)
        test_acc = (pred[test_mask] == y[test_mask]).float().mean()
    
    print(f"\nTest Accuracy: {test_acc:.4f}")

# train_node_classification()

GNN Best Practices

def gnn_best_practices():
    """Guidelines for GNN development."""
    
    guidelines = """
    ╔════════════════════════════════════════════════════════════════╗
    ║                    GNN BEST PRACTICES                          ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                ║
    ║  1. CHOOSE THE RIGHT MODEL                                     ║
    ║     • GCN: Simple baseline, fast                               ║
    ║     • GAT: When neighbor importance varies                     ║
    ║     • GraphSAGE: Inductive, large graphs                       ║
    ║     • GIN: Maximum expressiveness needed                       ║
    ║                                                                ║
    ║  2. DEPTH CONSIDERATIONS                                       ║
    ║     • 2-3 layers often optimal                                 ║
    ║     • Deeper = over-smoothing (all nodes become similar)       ║
    ║     • Use skip connections for deeper networks                 ║
    ║                                                                ║
    ║  3. NORMALIZATION                                              ║
    ║     • Symmetric normalization (GCN) is common                  ║
    ║     • Mean normalization for GraphSAGE                         ║
    ║     • Layer/batch norm helps training                          ║
    ║                                                                ║
    ║  4. REGULARIZATION                                             ║
    ║     • Dropout on features and attention                        ║
    ║     • DropEdge: randomly remove edges during training          ║
    ║     • L2 regularization                                        ║
    ║                                                                ║
    ║  5. SCALABILITY                                                ║
    ║     • Mini-batch with neighbor sampling (GraphSAGE)            ║
    ║     • Cluster-GCN for very large graphs                        ║
    ║     • Sparse operations for memory efficiency                  ║
    ║                                                                ║
    ║  6. EVALUATION                                                 ║
    ║     • Fixed train/val/test splits for comparison               ║
    ║     • Multiple random seeds                                    ║
    ║     • Report mean ± std                                        ║
    ║                                                                ║
    ╚════════════════════════════════════════════════════════════════╝
    """
    print(guidelines)

gnn_best_practices()

Exercises

Extend GCN to use edge features in the convolution:
class EdgeGCN(nn.Module):
    # Include edge_attr in message computation
    pass
Build a model for graph-level classification:
  • Use GIN layers
  • Implement graph-level readout
  • Test on molecular property prediction

What’s Next?