Graph Neural Networks
Why Graphs?
Many real-world data is naturally graph-structured:- Social networks: Users and friendships
- Molecules: Atoms and bonds
- Knowledge graphs: Entities and relations
- Citation networks: Papers and references
- Traffic: Roads and intersections
Copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional, List, Dict
torch.manual_seed(42)
Graph Fundamentals
Graph Representation
Copy
class GraphData:
"""
Graph data structure.
A graph G = (V, E) with:
- V: Set of nodes (vertices)
- E: Set of edges
Representations:
- Adjacency matrix A: A[i,j] = 1 if edge (i,j) exists
- Edge index: [2, E] tensor of (source, target) pairs
"""
def __init__(
self,
x: torch.Tensor, # Node features [N, F]
edge_index: torch.Tensor, # Edge indices [2, E]
edge_attr: Optional[torch.Tensor] = None, # Edge features [E, D]
y: Optional[torch.Tensor] = None # Labels
):
self.x = x
self.edge_index = edge_index
self.edge_attr = edge_attr
self.y = y
self.num_nodes = x.size(0)
self.num_edges = edge_index.size(1)
self.num_features = x.size(1)
def to_adjacency(self) -> torch.Tensor:
"""Convert edge index to adjacency matrix."""
adj = torch.zeros(self.num_nodes, self.num_nodes)
adj[self.edge_index[0], self.edge_index[1]] = 1
return adj
@staticmethod
def from_adjacency(adj: torch.Tensor, x: torch.Tensor) -> 'GraphData':
"""Create from adjacency matrix."""
edge_index = adj.nonzero().t().contiguous()
return GraphData(x, edge_index)
def add_self_loops(self) -> 'GraphData':
"""Add self-loop edges."""
self_loops = torch.arange(self.num_nodes).unsqueeze(0).repeat(2, 1)
edge_index = torch.cat([self.edge_index, self_loops], dim=1)
return GraphData(self.x, edge_index, self.edge_attr, self.y)
# Example: Create a simple graph
def create_example_graph():
# 5 nodes with 3 features each
x = torch.randn(5, 3)
# Edges: 0->1, 0->2, 1->2, 2->3, 3->4
edge_index = torch.tensor([
[0, 0, 1, 2, 3],
[1, 2, 2, 3, 4]
])
# Make undirected (add reverse edges)
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
return GraphData(x, edge_index)
graph = create_example_graph()
print(f"Nodes: {graph.num_nodes}, Edges: {graph.num_edges}")
Message Passing Framework
The foundation of all GNNs: hv(k+1)=UPDATE(hv(k),AGGREGATE({hu(k):u∈N(v)}))Copy
class MessagePassingLayer(nn.Module):
"""
General message passing framework.
Steps:
1. MESSAGE: Compute messages from neighbors
2. AGGREGATE: Combine messages (sum, mean, max)
3. UPDATE: Update node representations
"""
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.in_features = in_features
self.out_features = out_features
def message(
self,
x_i: torch.Tensor, # Target node features
x_j: torch.Tensor, # Source node features
edge_attr: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Compute messages from source to target."""
return x_j
def aggregate(
self,
messages: torch.Tensor,
index: torch.Tensor,
num_nodes: int
) -> torch.Tensor:
"""Aggregate messages at each node."""
# Scatter-add messages to target nodes
out = torch.zeros(num_nodes, messages.size(-1), device=messages.device)
out.scatter_add_(0, index.unsqueeze(-1).expand_as(messages), messages)
return out
def update(
self,
aggregated: torch.Tensor,
x: torch.Tensor
) -> torch.Tensor:
"""Update node embeddings."""
return aggregated
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Forward pass.
Args:
x: Node features [N, F]
edge_index: Edge indices [2, E]
Returns:
Updated node features [N, F']
"""
source, target = edge_index
# Get source and target features
x_i = x[target] # Target nodes
x_j = x[source] # Source nodes
# Compute messages
messages = self.message(x_i, x_j, edge_attr)
# Aggregate messages
aggregated = self.aggregate(messages, target, x.size(0))
# Update
out = self.update(aggregated, x)
return out
Graph Convolutional Network (GCN)
Copy
class GCNConv(nn.Module):
"""
Graph Convolutional Network layer (Kipf & Welling, 2017).
Equation:
H^{(l+1)} = σ(D̃^{-1/2} Ã D̃^{-1/2} H^{(l)} W^{(l)})
Where:
- Ã = A + I (adjacency with self-loops)
- D̃ = diagonal degree matrix of Ã
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True
):
super().__init__()
self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor
) -> torch.Tensor:
"""
Args:
x: Node features [N, F_in]
edge_index: Edge indices [2, E]
Returns:
Updated features [N, F_out]
"""
N = x.size(0)
source, target = edge_index
# Add self-loops
self_loops = torch.arange(N, device=x.device)
source = torch.cat([source, self_loops])
target = torch.cat([target, self_loops])
# Compute degree
deg = torch.zeros(N, device=x.device)
deg.scatter_add_(0, source, torch.ones_like(source, dtype=torch.float))
# Symmetric normalization: D^{-1/2} A D^{-1/2}
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[source] * deg_inv_sqrt[target]
# Transform features
x = torch.matmul(x, self.weight)
# Message passing
out = torch.zeros_like(x)
messages = norm.unsqueeze(-1) * x[source]
out.scatter_add_(0, target.unsqueeze(-1).expand_as(messages), messages)
if self.bias is not None:
out = out + self.bias
return out
class GCN(nn.Module):
"""Multi-layer Graph Convolutional Network."""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int,
num_layers: int = 2,
dropout: float = 0.5
):
super().__init__()
self.convs = nn.ModuleList()
# First layer
self.convs.append(GCNConv(in_features, hidden_features))
# Hidden layers
for _ in range(num_layers - 2):
self.convs.append(GCNConv(hidden_features, hidden_features))
# Output layer
self.convs.append(GCNConv(hidden_features, out_features))
self.dropout = dropout
def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index)
return x
Graph Attention Network (GAT)
Copy
class GATConv(nn.Module):
"""
Graph Attention Network layer (Veličković et al., 2018).
Uses attention to weight neighbor contributions:
α_ij = softmax_j(LeakyReLU(a^T [Wh_i || Wh_j]))
Multi-head attention for stability.
"""
def __init__(
self,
in_features: int,
out_features: int,
heads: int = 8,
concat: bool = True,
dropout: float = 0.6,
negative_slope: float = 0.2
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.heads = heads
self.concat = concat
self.dropout = dropout
self.negative_slope = negative_slope
# Linear transformation
self.W = nn.Linear(in_features, heads * out_features, bias=False)
# Attention parameters
self.att_src = nn.Parameter(torch.Tensor(1, heads, out_features))
self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.W.weight)
nn.init.xavier_uniform_(self.att_src)
nn.init.xavier_uniform_(self.att_dst)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor
) -> torch.Tensor:
"""
Args:
x: Node features [N, F_in]
edge_index: Edge indices [2, E]
Returns:
Updated features [N, H*F_out] if concat else [N, F_out]
"""
N = x.size(0)
source, target = edge_index
# Linear transformation and reshape for multi-head
x = self.W(x).view(-1, self.heads, self.out_features) # [N, H, F]
# Compute attention scores
# Source attention
alpha_src = (x * self.att_src).sum(dim=-1) # [N, H]
# Target attention
alpha_dst = (x * self.att_dst).sum(dim=-1) # [N, H]
# Edge attention
alpha = alpha_src[source] + alpha_dst[target] # [E, H]
alpha = F.leaky_relu(alpha, negative_slope=self.negative_slope)
# Softmax over neighbors
alpha = self._softmax(alpha, target, N) # [E, H]
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
# Message passing
out = torch.zeros(N, self.heads, self.out_features, device=x.device)
messages = alpha.unsqueeze(-1) * x[source] # [E, H, F]
# Scatter-add
out.scatter_add_(
0,
target.view(-1, 1, 1).expand_as(messages),
messages
)
if self.concat:
out = out.view(N, -1) # [N, H*F]
else:
out = out.mean(dim=1) # [N, F]
return out
def _softmax(
self,
alpha: torch.Tensor,
target: torch.Tensor,
num_nodes: int
) -> torch.Tensor:
"""Softmax over neighbors."""
# Max for numerical stability
alpha_max = torch.zeros(num_nodes, alpha.size(1), device=alpha.device)
alpha_max.scatter_reduce_(
0,
target.unsqueeze(-1).expand_as(alpha),
alpha,
reduce='amax'
)
alpha = alpha - alpha_max[target]
# Exp and sum
alpha = alpha.exp()
alpha_sum = torch.zeros(num_nodes, alpha.size(1), device=alpha.device)
alpha_sum.scatter_add_(
0,
target.unsqueeze(-1).expand_as(alpha),
alpha
)
return alpha / (alpha_sum[target] + 1e-8)
class GAT(nn.Module):
"""Multi-layer Graph Attention Network."""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int,
heads: int = 8,
dropout: float = 0.6
):
super().__init__()
self.conv1 = GATConv(in_features, hidden_features, heads=heads, dropout=dropout)
self.conv2 = GATConv(
hidden_features * heads, out_features,
heads=1, concat=False, dropout=dropout
)
self.dropout = dropout
def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index)
return x
GraphSAGE
Copy
class SAGEConv(nn.Module):
"""
GraphSAGE (Hamilton et al., 2017).
Key innovations:
1. Sampling fixed-size neighborhoods (scalable)
2. Multiple aggregation functions
3. Inductive learning (works on new graphs)
"""
def __init__(
self,
in_features: int,
out_features: int,
aggregator: str = 'mean', # 'mean', 'max', 'lstm'
normalize: bool = True
):
super().__init__()
self.aggregator = aggregator
self.normalize = normalize
# Transform for concatenating self + aggregated neighbors
self.lin = nn.Linear(2 * in_features, out_features)
if aggregator == 'lstm':
self.lstm = nn.LSTM(in_features, in_features, batch_first=True)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor
) -> torch.Tensor:
N = x.size(0)
source, target = edge_index
# Aggregate neighbor features
if self.aggregator == 'mean':
agg = self._mean_aggregator(x, source, target, N)
elif self.aggregator == 'max':
agg = self._max_aggregator(x, source, target, N)
else:
agg = self._mean_aggregator(x, source, target, N)
# Concatenate self features with aggregated neighbors
out = torch.cat([x, agg], dim=-1)
out = self.lin(out)
if self.normalize:
out = F.normalize(out, p=2, dim=-1)
return out
def _mean_aggregator(self, x, source, target, N):
"""Mean aggregation."""
out = torch.zeros(N, x.size(-1), device=x.device)
count = torch.zeros(N, 1, device=x.device)
out.scatter_add_(0, target.unsqueeze(-1).expand(-1, x.size(-1)), x[source])
count.scatter_add_(0, target.unsqueeze(-1), torch.ones_like(target.unsqueeze(-1), dtype=torch.float))
return out / (count + 1e-8)
def _max_aggregator(self, x, source, target, N):
"""Max aggregation."""
out = torch.full((N, x.size(-1)), float('-inf'), device=x.device)
out.scatter_reduce_(0, target.unsqueeze(-1).expand(-1, x.size(-1)), x[source], reduce='amax')
out[out == float('-inf')] = 0
return out
class NeighborSampler:
"""
Sample fixed-size neighborhoods for mini-batch training.
Enables training on large graphs.
"""
def __init__(self, edge_index: torch.Tensor, sizes: List[int]):
"""
Args:
edge_index: Full graph edge index
sizes: Number of neighbors to sample at each layer
e.g., [25, 10] means 25 neighbors for layer 1, 10 for layer 2
"""
self.edge_index = edge_index
self.sizes = sizes
# Build adjacency list
self.adj_list = self._build_adj_list()
def _build_adj_list(self) -> Dict[int, List[int]]:
"""Build adjacency list from edge index."""
adj = {}
source, target = self.edge_index.tolist()
for s, t in zip(source, target):
if t not in adj:
adj[t] = []
adj[t].append(s)
return adj
def sample(self, batch_nodes: torch.Tensor) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""
Sample subgraph for batch nodes.
Returns list of (sampled_edge_index, batch_nodes) for each layer.
"""
sampled_layers = []
nodes = batch_nodes.tolist()
for size in self.sizes:
# Sample neighbors for each node
new_nodes = set()
edges_src, edges_dst = [], []
for node in nodes:
neighbors = self.adj_list.get(node, [])
if len(neighbors) > size:
sampled = np.random.choice(neighbors, size, replace=False)
else:
sampled = neighbors
for neighbor in sampled:
new_nodes.add(neighbor)
edges_src.append(neighbor)
edges_dst.append(node)
edge_index = torch.tensor([edges_src, edges_dst])
sampled_layers.append((edge_index, torch.tensor(list(new_nodes | set(nodes)))))
nodes = list(new_nodes | set(nodes))
return sampled_layers
Graph Isomorphism Network (GIN)
Copy
class GINConv(nn.Module):
"""
Graph Isomorphism Network (Xu et al., 2019).
Most expressive among 1-WL equivalent GNNs:
h_v^{(k+1)} = MLP((1 + ε) · h_v^{(k)} + Σ_{u∈N(v)} h_u^{(k)})
"""
def __init__(
self,
in_features: int,
out_features: int,
eps: float = 0.0,
train_eps: bool = True
):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_features, out_features),
nn.BatchNorm1d(out_features),
nn.ReLU(),
nn.Linear(out_features, out_features),
nn.BatchNorm1d(out_features),
nn.ReLU()
)
if train_eps:
self.eps = nn.Parameter(torch.Tensor([eps]))
else:
self.register_buffer('eps', torch.Tensor([eps]))
def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
N = x.size(0)
source, target = edge_index
# Aggregate neighbors
out = torch.zeros(N, x.size(-1), device=x.device)
out.scatter_add_(0, target.unsqueeze(-1).expand(-1, x.size(-1)), x[source])
# Add self with learnable epsilon
out = (1 + self.eps) * x + out
# MLP
out = self.mlp(out)
return out
class GIN(nn.Module):
"""Multi-layer GIN for graph classification."""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int,
num_layers: int = 5
):
super().__init__()
self.convs = nn.ModuleList()
self.convs.append(GINConv(in_features, hidden_features))
for _ in range(num_layers - 1):
self.convs.append(GINConv(hidden_features, hidden_features))
# Readout for graph classification
self.readout = nn.Sequential(
nn.Linear(hidden_features * num_layers, hidden_features),
nn.ReLU(),
nn.Linear(hidden_features, out_features)
)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
batch: torch.Tensor
) -> torch.Tensor:
"""
Args:
batch: Tensor indicating which graph each node belongs to
"""
# Collect representations from all layers
h_list = []
for conv in self.convs:
x = conv(x, edge_index)
h_list.append(self._readout(x, batch))
# Concatenate all layer representations
h = torch.cat(h_list, dim=-1)
return self.readout(h)
def _readout(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
"""Global sum pooling per graph."""
num_graphs = batch.max().item() + 1
out = torch.zeros(num_graphs, x.size(-1), device=x.device)
out.scatter_add_(0, batch.unsqueeze(-1).expand_as(x), x)
return out
Graph Pooling
Copy
class TopKPooling(nn.Module):
"""
Top-K pooling: Select top-k scoring nodes.
Used for hierarchical graph representation.
"""
def __init__(self, in_features: int, ratio: float = 0.5):
super().__init__()
self.ratio = ratio
self.score = nn.Linear(in_features, 1)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
batch: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
x: Pooled node features
edge_index: Updated edge index
batch: Updated batch assignment
"""
if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
# Compute scores
scores = self.score(x).squeeze(-1)
scores = torch.tanh(scores)
# Select top-k nodes per graph
num_graphs = batch.max().item() + 1
keep_mask = torch.zeros(x.size(0), dtype=torch.bool, device=x.device)
for g in range(num_graphs):
graph_mask = batch == g
graph_scores = scores[graph_mask]
k = max(1, int(graph_mask.sum() * self.ratio))
_, top_indices = graph_scores.topk(k)
graph_nodes = graph_mask.nonzero().squeeze()
keep_nodes = graph_nodes[top_indices]
keep_mask[keep_nodes] = True
# Filter nodes
x = x[keep_mask] * scores[keep_mask].unsqueeze(-1)
# Re-index edges
node_map = torch.full((x.size(0),), -1, dtype=torch.long, device=x.device)
node_map[keep_mask] = torch.arange(keep_mask.sum(), device=x.device)
source, target = edge_index
edge_mask = keep_mask[source] & keep_mask[target]
edge_index = torch.stack([
node_map[source[edge_mask]],
node_map[target[edge_mask]]
])
batch = batch[keep_mask]
return x, edge_index, batch
class SAGPooling(nn.Module):
"""
Self-Attention Graph Pooling.
Uses GNN to compute attention scores.
"""
def __init__(self, in_features: int, ratio: float = 0.5):
super().__init__()
self.ratio = ratio
self.gnn = GCNConv(in_features, 1)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
batch: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
# Compute attention scores using GNN
scores = self.gnn(x, edge_index).squeeze(-1)
scores = torch.tanh(scores)
# Same top-k selection as above...
# (implementation similar to TopKPooling)
return x, edge_index, batch
class GlobalAttentionPooling(nn.Module):
"""Global attention pooling for graph-level representation."""
def __init__(self, in_features: int, hidden_features: int):
super().__init__()
self.gate = nn.Sequential(
nn.Linear(in_features, hidden_features),
nn.Tanh(),
nn.Linear(hidden_features, 1),
nn.Sigmoid()
)
self.transform = nn.Linear(in_features, hidden_features)
def forward(
self,
x: torch.Tensor,
batch: torch.Tensor
) -> torch.Tensor:
"""
Args:
x: Node features [N, F]
batch: Batch assignment [N]
Returns:
Graph representations [B, F']
"""
# Compute attention
gate = self.gate(x) # [N, 1]
x = self.transform(x) * gate # [N, F']
# Sum per graph
num_graphs = batch.max().item() + 1
out = torch.zeros(num_graphs, x.size(-1), device=x.device)
out.scatter_add_(0, batch.unsqueeze(-1).expand_as(x), x)
return out
Practical Application: Node Classification
Copy
def train_node_classification():
"""Example: Semi-supervised node classification on Cora."""
# Create example data (simulated Cora-like)
num_nodes = 2708
num_features = 1433
num_classes = 7
x = torch.randn(num_nodes, num_features)
# Random edges (in practice, use actual graph)
num_edges = 10556
edge_index = torch.randint(0, num_nodes, (2, num_edges))
# Labels and masks
y = torch.randint(0, num_classes, (num_nodes,))
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[:140] = True # 20 per class
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask[140:640] = True
test_mask = ~(train_mask | val_mask)
# Model
model = GCN(num_features, 64, num_classes, num_layers=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# Training
for epoch in range(200):
model.train()
optimizer.zero_grad()
out = model(x, edge_index)
loss = F.cross_entropy(out[train_mask], y[train_mask])
loss.backward()
optimizer.step()
# Validation
model.eval()
with torch.no_grad():
pred = model(x, edge_index).argmax(dim=1)
val_acc = (pred[val_mask] == y[val_mask]).float().mean()
if epoch % 50 == 0:
print(f"Epoch {epoch}: Loss = {loss:.4f}, Val Acc = {val_acc:.4f}")
# Test
model.eval()
with torch.no_grad():
pred = model(x, edge_index).argmax(dim=1)
test_acc = (pred[test_mask] == y[test_mask]).float().mean()
print(f"\nTest Accuracy: {test_acc:.4f}")
# train_node_classification()
GNN Best Practices
Copy
def gnn_best_practices():
"""Guidelines for GNN development."""
guidelines = """
╔════════════════════════════════════════════════════════════════╗
║ GNN BEST PRACTICES ║
╠════════════════════════════════════════════════════════════════╣
║ ║
║ 1. CHOOSE THE RIGHT MODEL ║
║ • GCN: Simple baseline, fast ║
║ • GAT: When neighbor importance varies ║
║ • GraphSAGE: Inductive, large graphs ║
║ • GIN: Maximum expressiveness needed ║
║ ║
║ 2. DEPTH CONSIDERATIONS ║
║ • 2-3 layers often optimal ║
║ • Deeper = over-smoothing (all nodes become similar) ║
║ • Use skip connections for deeper networks ║
║ ║
║ 3. NORMALIZATION ║
║ • Symmetric normalization (GCN) is common ║
║ • Mean normalization for GraphSAGE ║
║ • Layer/batch norm helps training ║
║ ║
║ 4. REGULARIZATION ║
║ • Dropout on features and attention ║
║ • DropEdge: randomly remove edges during training ║
║ • L2 regularization ║
║ ║
║ 5. SCALABILITY ║
║ • Mini-batch with neighbor sampling (GraphSAGE) ║
║ • Cluster-GCN for very large graphs ║
║ • Sparse operations for memory efficiency ║
║ ║
║ 6. EVALUATION ║
║ • Fixed train/val/test splits for comparison ║
║ • Multiple random seeds ║
║ • Report mean ± std ║
║ ║
╚════════════════════════════════════════════════════════════════╝
"""
print(guidelines)
gnn_best_practices()
Exercises
Exercise 1: Implement Edge Features
Exercise 1: Implement Edge Features
Extend GCN to use edge features in the convolution:
Copy
class EdgeGCN(nn.Module):
# Include edge_attr in message computation
pass
Exercise 2: Graph Classification
Exercise 2: Graph Classification
Build a model for graph-level classification:
- Use GIN layers
- Implement graph-level readout
- Test on molecular property prediction
Exercise 3: Link Prediction
Exercise 3: Link Prediction
Implement link prediction:
Copy
# Given: node embeddings
# Predict: probability of edge existing
# Loss: binary cross-entropy on positive/negative edges