Graph Neural Networks¶
Graph neural networks learn from graph-structured data by passing messages between connected nodes. This file covers the message-passing framework, GCN, GraphSAGE, GIN, over-smoothing, graph pooling, and node/edge/graph-level tasks; the core architectures that power molecular property prediction, social network analysis, and recommendation systems.
-
In the previous files, we established the mathematical foundations: geometric deep learning (file 1) tells us to exploit symmetries, and graph theory (file 2) gives us the language of nodes, edges, and adjacency. Now we build neural networks that operate directly on graphs.
-
The core challenge: graph data is irregular. Unlike images (fixed grid) or sequences (fixed ordering), graphs have variable numbers of nodes, variable connectivity, and no canonical node ordering. A neural network for graphs must handle all of this while being permutation-equivariant (relabelling nodes should not change the output).
The Message-Passing Framework¶
-
Nearly all GNNs follow the same recipe, called message passing (also called neighbourhood aggregation). The idea is simple and elegant: each node updates its representation by collecting information from its neighbours.
-
At each layer \(l\), every node \(i\) does three things:
- Message: each neighbour \(j\) of node \(i\) computes a message \(\mathbf{m}_{j \to i}\) based on its current features.
- Aggregate: node \(i\) collects all incoming messages and combines them with a permutation-invariant function (sum, mean, or max).
- Update: node \(i\) combines the aggregated message with its own features to produce a new representation.
-
Formally:
- where \(\mathcal{N}(i)\) is the set of neighbours of node \(i\), \(\bigoplus\) is a permutation-invariant aggregation (sum, mean, max), \(\phi\) is the message function, \(\psi\) is the update function, and \(\mathbf{e}_{ij}\) is the optional edge feature.
-
The aggregation \(\bigoplus\) must be permutation-invariant (it does not matter what order the neighbours are processed in) to ensure the overall function is permutation-equivariant. This directly implements the symmetry principle from file 1.
-
After \(k\) layers of message passing, each node's representation encodes information from its \(k\)-hop neighbourhood: all nodes reachable within \(k\) edges. Layer 1 sees immediate neighbours, layer 2 sees neighbours of neighbours, and so on. This is how local information propagates to build global understanding.
-
The receptive field of a GNN grows with depth, just like the receptive field of a CNN grows with layers (chapter 8). But unlike CNNs on regular grids, the receptive field shape varies per node depending on the graph topology.
Graph Convolutional Network (GCN)¶
-
The GCN (Kipf & Welling, 2017) is the foundational GNN architecture. It simplifies spectral graph convolution (from file 2) into an elegant, efficient formula.
-
Starting from the spectral convolution \(g_\theta \star \mathbf{x} = U \, \text{diag}(\hat{g}_\theta) \, U^T \mathbf{x}\), Kipf and Welling approximate the spectral filter with a first-order Chebyshev polynomial, which avoids computing the eigendecomposition entirely. After simplification, the layer-wise update becomes:
-
where:
- \(H^{(l)} \in \mathbb{R}^{n \times d}\) is the matrix of node features at layer \(l\)
- \(W^{(l)} \in \mathbb{R}^{d \times d'}\) is a learnable weight matrix
- \(\hat{A} = \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2}\) is the symmetrically normalised adjacency matrix with self-loops
- \(\tilde{A} = A + I\) adds self-loops (so each node also receives its own message)
- \(\tilde{D}\) is the degree matrix of \(\tilde{A}\)
- \(\sigma\) is a nonlinear activation (ReLU, as in chapter 6)
-
The matrix multiplication \(\hat{A} H^{(l)}\) is the aggregation step: for each node, it computes a weighted average of its neighbours' features (plus its own, via the self-loop). The weight matrix \(W^{(l)}\) is the learnable transformation, shared across all nodes. The activation adds nonlinearity.
-
This is remarkably simple: it is just matrix multiplication followed by a learned linear map and activation. The entire GCN layer can be written in one line of code. The normalisation by \(\tilde{D}^{-1/2}\) prevents nodes with many neighbours from dominating: high-degree nodes have their messages scaled down.
-
In the message-passing framework, GCN uses:
- Message: \(\phi(\mathbf{h}_j) = \mathbf{h}_j\) (just send your features)
- Aggregation: normalised sum (weighted by degree)
- Update: linear transformation + activation
GraphSAGE¶
-
GCN is transductive: it requires the full graph during training and cannot handle new, unseen nodes. If a new user joins a social network, the GCN must be retrained on the entire graph. GraphSAGE (Hamilton et al., 2017) fixes this with an inductive approach.
-
The key idea is neighbourhood sampling: instead of using all neighbours, sample a fixed-size subset. This makes computation independent of the full graph structure and allows generalisation to unseen nodes and graphs.
-
The GraphSAGE update for node \(i\):
-
where \(\mathcal{S}(i)\) is a sampled subset of neighbours (e.g., randomly sample 10 out of 500 neighbours). The CONCAT operation explicitly separates the node's own features from the aggregated neighbour features, letting the network learn different transformations for "self" and "neighbourhood."
-
GraphSAGE supports multiple aggregation functions:
- Mean: \(\text{AGG} = \frac{1}{|\mathcal{S}|} \sum_{j \in \mathcal{S}} \mathbf{h}_j\) (simple, effective)
- LSTM: feed the sampled neighbours through an LSTM (but this introduces an ordering dependency, somewhat violating permutation invariance)
- Pool: \(\text{AGG} = \max(\{\sigma(W_{\text{pool}} \mathbf{h}_j + \mathbf{b})\})\) (nonlinear transform then max)
-
The sampling strategy makes GraphSAGE scalable to very large graphs. Training uses mini-batches of nodes: for each target node, sample \(k_1\) neighbours at layer 1, then \(k_2\) neighbours for each of those at layer 2. With \(k_1 = k_2 = 10\) and 2 layers, each node's computation tree has at most \(10 \times 10 = 100\) nodes, regardless of the graph size.
Graph Isomorphism Network (GIN)¶
-
Different GNN architectures have different expressive power: their ability to distinguish structurally different graphs. GCN and GraphSAGE, despite being effective in practice, are provably limited in what graph structures they can distinguish.
-
The theoretical tool for measuring GNN expressiveness is the Weisfeiler-Lehman (WL) test, a classical algorithm for testing graph isomorphism (whether two graphs are structurally identical). The WL test iteratively refines node labels by hashing each node's label together with the multiset of its neighbours' labels.
-
GIN (Xu et al., 2019) is designed to be as expressive as the WL test, making it the most powerful message-passing GNN (within the theoretical limits of message passing). The key insight: the aggregation function must be injective on multisets (different multisets of neighbour features must produce different aggregated values).
-
Sum aggregation is injective on multisets (summing \(\{1, 1, 2\}\) gives 4, while \(\{1, 3\}\) gives 4 too, but over feature vectors with enough dimensions, sums of different multisets are generically distinct). Mean and max are not injective: mean cannot distinguish \(\{1, 1\}\) from \(\{2, 2\}\), and max cannot distinguish \(\{1, 2, 3\}\) from \(\{1, 1, 3\}\).
-
The GIN update is:
- where \(\epsilon\) is a learnable scalar (or fixed to 0) and the MLP provides the nonlinear, injective mapping. The sum aggregation preserves the multiset structure, and the MLP can learn to distinguish any two different aggregated values.
Over-Smoothing¶
- A major challenge in GNNs is over-smoothing: as the number of layers increases, all node representations converge to the same value, losing the ability to distinguish different nodes.
-
The mechanism is intuitive. Each message-passing layer averages a node's features with its neighbours'. After many rounds of averaging, every node has seen (and blended with) every other node in its connected component. The features become a uniform average, the graph equivalent of blurring an image too many times until it becomes a solid colour.
-
Formally, repeated application of the normalised adjacency \(\hat{A}\) converges to a rank-1 matrix (every row becomes proportional to the stationary distribution of a random walk on the graph). This is the same convergence as power iteration towards the dominant eigenvector (chapter 2).
-
Over-smoothing limits GNNs to shallow depths (typically 2-4 layers), unlike CNNs and transformers that benefit from dozens or hundreds of layers. This means each node can only see a limited neighbourhood, which is problematic for tasks requiring long-range information.
-
Mitigations include:
- Residual connections (from ResNets, chapter 8): \(\mathbf{h}_i^{(l+1)} = \mathbf{h}_i^{(l+1)} + \mathbf{h}_i^{(l)}\), preserving information from earlier layers.
- Jumping knowledge: concatenate or attention-pool representations from all layers, not just the last.
- DropEdge: randomly remove edges during training, slowing the information spread.
- Graph Transformers (file 4): bypass the local message-passing bottleneck with global attention.
Graph Pooling¶
-
For graph-level tasks (predicting a property of the entire graph, like a molecule's toxicity), we need to collapse all node representations into a single graph-level vector. This is graph pooling, the graph analogue of global average pooling in CNNs (chapter 8).
-
The simplest approach is readout: apply a permutation-invariant function to the set of all node features:
-
This is the DeepSets aggregation from file 1, applied after the final GNN layer. Sum preserves size information (a graph with 100 nodes will have a larger sum than one with 10), while mean normalises for size.
-
Hierarchical pooling progressively coarsens the graph, mirroring how CNNs progressively downsample images. At each level, groups of nodes are merged into "supernodes":
-
DiffPool (Differentiable Pooling) learns a soft assignment matrix \(S^{(l)} \in \mathbb{R}^{n_l \times n_{l+1}}\) that assigns each node to a cluster:
-
The assignment matrix is predicted by a separate GNN, making the clustering end-to-end differentiable. This creates a hierarchy: the original graph → a coarsened graph with fewer nodes → an even coarser graph → a single node (the graph representation).
-
TopKPool takes a simpler approach: learn a scalar score for each node, keep the top-\(k\) scoring nodes, and drop the rest. This is a hard selection (not soft assignment) and is computationally cheaper than DiffPool.
Heterogeneous Graphs¶
-
All GNNs so far assume a homogeneous graph: one type of node, one type of edge. But most real-world graphs are heterogeneous: multiple node types and multiple edge types. A knowledge graph has person nodes, organisation nodes, and location nodes, connected by "works at," "born in," and "located in" edges. A recommender system has user nodes and item nodes connected by "purchased," "viewed," and "rated" edges.
-
A heterogeneous graph has a schema (also called a metagraph) that defines the allowed node types and edge types. Each edge type connects a specific source type to a specific target type. For example, "works at" connects Person → Organisation.
-
Relational GCN (R-GCN) (Schlichtkrull et al., 2018) handles heterogeneous edges by using a separate weight matrix for each edge type:
-
where \(\mathcal{R}\) is the set of edge types, \(\mathcal{N}_r(i)\) is the set of neighbours connected to node \(i\) via relation \(r\), and \(W_r\) is the weight matrix specific to relation \(r\). The self-connection \(W_0\) handles the node's own features separately.
-
The problem: with many relation types, the number of parameters explodes (one \(d \times d\) matrix per relation). R-GCN mitigates this with basis decomposition: \(W_r = \sum_{b=1}^{B} a_{rb} V_b\), where \(V_b\) are shared basis matrices and \(a_{rb}\) are scalar coefficients per relation. This is analogous to low-rank factorisation (chapter 2): the relation-specific matrices live in a low-dimensional subspace.
-
Heterogeneous Graph Transformer (HGT) (Hu et al., 2020) applies the attention mechanism to heterogeneous graphs. The key insight is that attention should depend on both the node types and the edge type connecting them. HGT uses type-specific projection matrices for queries, keys, and values:
-
where \(\tau(i)\) is the type of node \(i\) and \(\phi(i,j)\) is the edge type between them. This ensures that the model attends differently to different relationship types: a paper attending to its authors should use different attention weights than when attending to its references.
-
Metapath-based methods define meaningful paths through the schema (e.g., Author → Paper → Author for co-authorship) and aggregate information along these paths. HAN (Heterogeneous Attention Network) applies attention at two levels: within each metapath (which neighbours along this path matter?) and across metapaths (which relationship patterns matter?).
Link Prediction and Knowledge Graph Completion¶
-
Link prediction asks: given the existing edges, which missing edges are likely to exist? This is the core task for knowledge graph completion (predict missing facts), recommendation (predict which items a user will like), and social network analysis (predict future friendships).
-
Embedding-based methods learn a vector for each entity and a transformation for each relation, then score potential edges by how well the entities and relation fit together:
-
TransE models relations as translations in embedding space: if \((h, r, t)\) is a valid triple (head entity, relation, tail entity), then \(\mathbf{h} + \mathbf{r} \approx \mathbf{t}\). The scoring function is \(f(h, r, t) = -\|\mathbf{h} + \mathbf{r} - \mathbf{t}\|\). Intuitively, the relation vector "moves" the head entity to the tail entity in embedding space.
-
RotatE models relations as rotations in complex space: \(\mathbf{t} = \mathbf{h} \circ \mathbf{r}\), where \(\circ\) is element-wise complex multiplication and \(|\mathbf{r}_i| = 1\) (unit complex numbers are rotations). This can model symmetry, antisymmetry, inversion, and composition patterns that TransE cannot.
-
ComplEx uses complex-valued embeddings with a Hermitian dot product, enabling it to model asymmetric relations (if A is the boss of B, B is not the boss of A).
-
GNN-based link prediction computes node embeddings with message passing, then scores edges using the endpoint embeddings. This combines the structural reasoning of GNNs with the relational modelling of embedding methods. The GNN encoder captures multi-hop neighbourhood structure that single-embedding methods miss.
Task Types¶
-
GNNs solve three categories of tasks:
-
Node-level tasks: predict a property for each node. Examples: classifying users in a social network (bot or human), predicting the function of each protein in an interaction network, semi-supervised node classification (label a few nodes, predict the rest). The output is the node embedding \(\mathbf{h}_i^{(L)}\) passed through a classifier.
-
Edge-level tasks: predict a property for each edge or predict whether an edge exists. Examples: link prediction (will these two users become friends?), knowledge graph completion (does this relationship hold between these entities?), drug-drug interaction prediction. The output typically uses the embeddings of both endpoint nodes: \(\hat{y}_{ij} = f(\mathbf{h}_i, \mathbf{h}_j)\), where \(f\) is a dot product, concatenation + MLP, or other combination.
-
Graph-level tasks: predict a property for the entire graph. Examples: molecular property prediction (is this molecule toxic?), graph classification (is this social network a bot network?), graph generation (design a molecule with desired properties). The output uses graph pooling to produce \(\mathbf{h}_G\), which is then classified or regressed.
Coding Tasks (use CoLab or notebook)¶
-
Implement a single GCN layer from scratch using the normalised adjacency matrix. Apply it to a small graph and observe how node features are smoothed.
import jax import jax.numpy as jnp # Graph: 5 nodes, simple chain with a branch A = jnp.array([[0, 1, 0, 0, 0], [1, 0, 1, 0, 0], [0, 1, 0, 1, 1], [0, 0, 1, 0, 0], [0, 0, 1, 0, 0]], dtype=float) # Add self-loops A_hat = A + jnp.eye(5) D_hat = jnp.diag(A_hat.sum(axis=1)) D_inv_sqrt = jnp.diag(1.0 / jnp.sqrt(A_hat.sum(axis=1))) A_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt # Node features: one-hot identity H = jnp.eye(5) # Weight matrix (random initialisation) rng = jax.random.PRNGKey(0) W = jax.random.normal(rng, (5, 3)) * 0.5 # GCN layer: H' = ReLU(A_norm @ H @ W) H_new = jax.nn.relu(A_norm @ H @ W) print("Original features (one-hot):") print(H) print("\nAfter GCN layer:") print(jnp.round(H_new, 3)) print("\nNotice: connected nodes now have similar representations") -
Implement message passing with sum aggregation (GIN-style) and compare with mean aggregation (GCN-style). Show that sum can distinguish multisets that mean cannot.
import jax.numpy as jnp # Two different neighbourhood multisets that have the same mean # Node A: neighbours have features [1, 1, 1, 1] (four neighbours, all 1) # Node B: neighbours have features [2, 2] (two neighbours, all 2) neighbours_A = jnp.array([[1.0], [1.0], [1.0], [1.0]]) neighbours_B = jnp.array([[2.0], [2.0]]) # Mean aggregation mean_A = neighbours_A.mean(axis=0) mean_B = neighbours_B.mean(axis=0) print(f"Mean A: {mean_A}, Mean B: {mean_B}, Same: {jnp.allclose(mean_A, mean_B)}") # Sum aggregation sum_A = neighbours_A.sum(axis=0) sum_B = neighbours_B.sum(axis=0) print(f"Sum A: {sum_A}, Sum B: {sum_B}, Same: {jnp.allclose(sum_A, sum_B)}") print("\nSum distinguishes these multisets; mean does not!") -
Demonstrate over-smoothing. Apply the normalised adjacency repeatedly and watch node features converge.
import jax.numpy as jnp import matplotlib.pyplot as plt # Random graph A = jnp.array([[0,1,1,0,0,0], [1,0,1,0,0,0], [1,1,0,1,0,0], [0,0,1,0,1,1], [0,0,0,1,0,1], [0,0,0,1,1,0]], dtype=float) A_hat = A + jnp.eye(6) D_inv_sqrt = jnp.diag(1.0 / jnp.sqrt(A_hat.sum(axis=1))) A_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt # Initial features: distinct per node H = jnp.array([[1,0], [0,1], [1,1], [-1,0], [0,-1], [-1,-1]], dtype=float) distances = [] for k in range(20): H = A_norm @ H # Measure how distinct the features are (std across nodes) spread = jnp.std(H, axis=0).mean() distances.append(float(spread)) plt.plot(distances, "o-") plt.xlabel("Number of message-passing rounds") plt.ylabel("Feature spread (std across nodes)") plt.title("Over-Smoothing: Features Converge with Depth") plt.show()