Graph Attention Networks¶
Graph attention networks replace uniform neighbour aggregation with learned, data-dependent weighting. This file covers GAT, multi-head graph attention, GATv2, Graph Transformers, positional and structural encodings, and scalability
-
In the GCN (file 3), each node aggregates its neighbours' features using fixed weights determined by the graph structure (the normalised adjacency). A node with three neighbours gives each neighbour roughly equal weight (\(\approx 1/3\)). But not all neighbours are equally important: a message from a close collaborator should matter more than one from a distant acquaintance.
-
Graph Attention Networks solve this by learning which neighbours to attend to, using the same attention mechanism that powers transformers (chapter 7). Instead of fixed, structure-based weights, each node computes dynamic, content-based attention scores over its neighbours.
GAT: Graph Attention Network¶
- GAT (Veličković et al., 2018) computes attention coefficients between each node and its neighbours. For node \(i\) and neighbour \(j\):
-
where \(W \in \mathbb{R}^{d' \times d}\) is a shared linear transformation, \(\|\) denotes concatenation, and \(\mathbf{a} \in \mathbb{R}^{2d'}\) is a learnable attention vector. The score \(e_{ij}\) measures how important node \(j\)'s features are to node \(i\).
-
The raw scores are normalised across all neighbours using softmax:
- This ensures attention weights sum to 1 across each node's neighbourhood, just like transformer attention (chapter 7). The node's updated features are:
-
The crucial difference from GCN: the weights \(\alpha_{ij}\) are learned from the data, not fixed by the graph structure. A node can learn to focus on the most informative neighbours while ignoring noisy or irrelevant ones.
-
Note that attention is computed only over edges (node \(i\) attends only to its neighbours \(\mathcal{N}(i)\)), not over all node pairs. This keeps computation proportional to the number of edges, not the square of the number of nodes.
Multi-Head Graph Attention¶
- Just as in transformers (chapter 7), multi-head attention runs \(K\) independent attention mechanisms in parallel, each with its own parameters \(W^k\) and \(\mathbf{a}^k\). The results are concatenated (in intermediate layers) or averaged (in the final layer):
-
Each head can attend to different aspects of the neighbourhood: one head might focus on structural features, another on semantic similarity. This is the same motivation as multi-head attention in transformers: different heads capture different types of relationships.
-
With \(K\) heads and output dimension \(d'\) per head, the concatenated output has dimension \(K \times d'\). The final layer typically averages instead of concatenating to produce a fixed-size output.
GATv2: Fixing Static Attention¶
-
The original GAT has a subtle limitation: its attention function is static (also called ranking-based). The attention score depends on the concatenation \([W\mathbf{h}_i \| W\mathbf{h}_j]\), but because the attention vector \(\mathbf{a}\) is applied after concatenation, it can be decomposed into two independent components: \(\mathbf{a}^T [W\mathbf{h}_i \| W\mathbf{h}_j] = \mathbf{a}_1^T W\mathbf{h}_i + \mathbf{a}_2^T W\mathbf{h}_j\).
-
This means the ranking of neighbours for a given node \(i\) is determined entirely by the neighbours' features \(\mathbf{h}_j\) (the term \(\mathbf{a}_1^T W\mathbf{h}_i\) is constant across all neighbours of \(i\)). The attention ranking does not truly depend on the query node's features. Node \(i\) and node \(k\) will rank the same set of neighbours identically, which limits expressiveness.
-
GATv2 (Brody et al., 2022) fixes this by applying the nonlinearity before the attention vector:
- Moving LeakyReLU inside the computation means the attention score is a nonlinear function of the joint features, not decomposable into independent terms. This makes attention dynamic: the ranking of neighbours now depends on the specific query node. GATv2 is strictly more expressive than GAT with no additional computational cost.
Graph Transformers¶
-
Standard message-passing GNNs are limited by the graph topology: a node can only attend to its direct neighbours. After \(k\) layers, information from \(k\)-hop neighbours has been mixed through multiple aggregation steps, losing fidelity. This local bottleneck (combined with over-smoothing, file 3) limits the ability to capture long-range dependencies.
-
Graph Transformers break this bottleneck by applying global self-attention to all node pairs, regardless of whether they share an edge. Every node can attend to every other node in a single layer, just like in a standard transformer (chapter 7).
-
The basic idea: treat all nodes as tokens and apply transformer self-attention:
-
where \(Q = XW_Q\), \(K = XW_K\), \(V = XW_V\) are the query, key, and value projections of the node features \(X\) (exactly as in chapter 7). This is a GNN on a fully connected graph (complete graph \(K_n\), file 2).
-
The problem: a fully connected graph ignores the actual graph structure. The edge information (who is actually connected to whom) is lost. Two approaches restore this:
-
Graphormer (Ying et al., 2021) injects graph structure into the transformer via bias terms in the attention scores:
-
The spatial bias \(b_{\text{spatial}}\) encodes the shortest-path distance between nodes \(i\) and \(j\). The edge bias \(b_{\text{edge}}\) encodes edge features along the shortest path. Additionally, Graphormer uses a centrality encoding that adds the node's degree to its input embedding, giving the model information about each node's structural role.
-
GPS (General, Powerful, Scalable Graph Transformer, Rampášek et al., 2022) combines local message passing with global attention in each layer:
- Each layer applies both a standard GNN (for local structure) and a transformer (for global context), then combines the results. This gets the best of both worlds: local structure from message passing and long-range dependencies from attention.
Positional and Structural Encodings¶
-
Transformers on sequences use positional encodings (chapter 7) to inject order information. Graphs have no canonical ordering, so graph-specific encodings are needed.
-
Laplacian eigenvector encodings use the eigenvectors of the graph Laplacian (file 2) as positional features. The \(k\) smallest non-trivial eigenvectors provide a spectral embedding of the graph: nodes that are "nearby" in the graph have similar eigenvector values. These are concatenated to the node features.
-
A subtlety: Laplacian eigenvectors have a sign ambiguity (if \(\mathbf{u}\) is an eigenvector, so is \(-\mathbf{u}\)). The model must be invariant to these sign flips. Solutions include using random sign flips as data augmentation during training, or learning sign-invariant transformations.
-
Random walk encodings compute the probability of a random walk starting at node \(i\) returning to node \(i\) after \(k\) steps, for \(k = 1, 2, \ldots, K\). These probabilities encode local structural information: nodes in dense clusters have high return probabilities, while nodes in sparse regions have low ones. The landing probability \(p_{ii}^{(k)} = (A_{\text{rw}}^k)_{ii}\) where \(A_{\text{rw}} = D^{-1}A\) is the random walk transition matrix.
-
Degree encodings simply add the node degree as a feature. This is surprisingly effective because degree is a strong structural signal: leaf nodes (degree 1), bridge nodes, and hub nodes behave differently.
-
These encodings provide the structural information that vanilla transformers lack, enabling Graph Transformers to outperform standard message-passing GNNs on tasks requiring long-range reasoning.
Scalability¶
-
The fundamental scalability challenge for GNNs is that graphs can have millions of nodes and billions of edges. Training a GNN on the full graph requires storing all node features and the entire adjacency matrix in memory, which is often infeasible.
-
Mini-batch training for GNNs is more complex than for images or sequences because nodes are interconnected. Naively sampling a batch of nodes requires their neighbours (layer 1), their neighbours' neighbours (layer 2), and so on. This neighbourhood explosion means a batch of 1000 target nodes might require millions of nodes in the computation graph.
-
Neighbourhood sampling (GraphSAGE-style, file 3) limits the explosion by sampling a fixed number of neighbours per node per layer. With 2 layers and 15 samples per layer, each target node's subgraph has at most \(15^2 = 225\) nodes, regardless of the full graph size.
-
Cluster-GCN (Chiang et al., 2019) partitions the graph into clusters using a graph clustering algorithm (e.g., METIS), then trains on one cluster at a time. Within-cluster edges are dense (most neighbours are in the same cluster), so the subgraph captures the relevant structure. Cross-cluster edges are handled by occasionally including edges between clusters.
-
Graph Transformer scalability is harder because global attention is \(O(n^2)\). For graphs with millions of nodes, full attention is infeasible. Solutions include:
- Sparse attention patterns (attend only to \(k\)-nearest nodes in the graph)
- Linear attention approximations
- Combining local message passing (cheap, \(O(|E|)\)) with global attention on a coarsened graph (fewer nodes)
Temporal and Dynamic Graphs¶
-
The graphs we have studied so far are static: the nodes, edges, and features are fixed. But many real-world graphs evolve over time: new users join social networks, financial transactions create edges, traffic patterns shift throughout the day, and molecular interactions fluctuate.
-
A temporal graph augments each edge with a timestamp: \((i, j, t)\) means node \(i\) interacted with node \(j\) at time \(t\). The challenge is to learn representations that capture both the graph structure and the temporal dynamics.
-
There are two paradigms:
-
Discrete-time dynamic graphs (DTDG): the graph is represented as a sequence of snapshots \(G_1, G_2, \ldots, G_T\), one per timestep. A GNN processes each snapshot, and an RNN or temporal attention mechanism captures the evolution across snapshots. This is simple but loses fine-grained timing information (events between snapshots are lost) and requires choosing a snapshot frequency.
-
Continuous-time dynamic graphs (CTDG): events are modelled as a stream of timestamped interactions. Each event \((i, j, t)\) updates the representations of nodes \(i\) and \(j\) at the exact time it occurs. This preserves all temporal information.
-
Temporal Graph Network (TGN) (Rossi et al., 2020) is the leading CTDG architecture. Each node maintains a memory state \(\mathbf{s}_i(t)\) that is updated whenever the node participates in an interaction:
-
where \(\mathbf{m}_i(t)\) is a message computed from the interaction (combining the features of both nodes, the edge features, and the time encoding). The GRU (chapter 6) selectively retains and forgets past information, allowing the memory to capture long-term patterns while adapting to recent events.
-
Time encoding represents the elapsed time since the last interaction as a feature vector, analogous to positional encoding in transformers (chapter 7). A common approach uses learnable Fourier features:
-
This gives the model a rich representation of temporal gaps: "this user was last active 5 minutes ago" vs "3 months ago" are embedded differently.
-
Temporal Graph Attention (TGAT) applies self-attention over a node's temporal neighbourhood: the set of recent interactions, each weighted by both feature relevance (like GAT) and temporal recency. Interactions from the distant past are naturally down-weighted.
-
Applications include fraud detection (anomalous transaction patterns in financial graphs), traffic forecasting (predicting congestion from historical flow patterns), social network dynamics (predicting viral content spread), and drug interaction prediction over time.
Coding Tasks (use CoLab or notebook)¶
-
Implement a single GAT attention head from scratch. Compute attention weights between a node and its neighbours and verify they sum to 1.
import jax import jax.numpy as jnp rng = jax.random.PRNGKey(0) k1, k2, k3 = jax.random.split(rng, 3) n_nodes, d_in, d_out = 5, 4, 3 # Random node features H = jax.random.normal(k1, (n_nodes, d_in)) # Learnable parameters W = jax.random.normal(k2, (d_in, d_out)) * 0.5 a = jax.random.normal(k3, (2 * d_out,)) * 0.5 # Adjacency (node 0 connects to 1, 2, 3) neighbours_of_0 = [1, 2, 3] # Transform features Wh = H @ W # (n_nodes, d_out) # Compute attention scores for node 0 h_i = Wh[0] scores = [] for j in neighbours_of_0: h_j = Wh[j] e_ij = jnp.dot(a, jnp.concatenate([h_i, h_j])) e_ij = jax.nn.leaky_relu(e_ij, negative_slope=0.2) scores.append(float(e_ij)) scores = jnp.array(scores) alpha = jax.nn.softmax(scores) print(f"Raw scores: {scores}") print(f"Attention weights: {alpha}") print(f"Sum of weights: {alpha.sum():.4f}") # Weighted aggregation h_new = sum(alpha[k] * Wh[neighbours_of_0[k]] for k in range(len(neighbours_of_0))) print(f"Updated node 0 features: {h_new}") -
Compare GCN (fixed weights) vs GAT (learned weights) aggregation. Show that GAT can assign different weights to neighbours while GCN treats them uniformly.
import jax import jax.numpy as jnp # 4 nodes: node 0 connects to 1, 2, 3 A = jnp.array([[0,1,1,1], [1,0,0,0], [1,0,0,0], [1,0,0,0]], dtype=float) # Features: node 1 is very relevant, node 2 is noise, node 3 is moderate H = jnp.array([[0.0, 0.0], # node 0 [1.0, 0.0], # node 1 (signal) [0.0, 0.0], # node 2 (noise) [0.5, 0.0]]) # node 3 (moderate) # GCN: normalised adjacency weights A_hat = A + jnp.eye(4) D_inv = jnp.diag(1.0 / A_hat.sum(axis=1)) gcn_weights = (D_inv @ A_hat)[0] # weights for node 0 print(f"GCN weights for node 0: {gcn_weights}") print(" → All neighbours get roughly equal weight") # GAT: learned attention (simulated) # Suppose the attention mechanism learns to focus on node 1 gat_weights = jnp.array([0.1, 0.7, 0.05, 0.15]) # learned print(f"\nGAT weights for node 0: {gat_weights}") print(" → Node 1 (informative) gets most attention") gcn_output = gcn_weights @ H gat_output = gat_weights @ H print(f"\nGCN output: {gcn_output} (diluted by noise)") print(f"GAT output: {gat_output} (focused on signal)") -
Demonstrate the benefit of positional encodings. Compute Laplacian eigenvector encodings for a graph and show that structurally similar nodes get similar encodings.
import jax.numpy as jnp import matplotlib.pyplot as plt # Barbell graph: two cliques connected by a bridge n = 10 A = jnp.zeros((n, n)) # Clique 1: nodes 0-4 for i in range(5): for j in range(i+1, 5): A = A.at[i,j].set(1).at[j,i].set(1) # Clique 2: nodes 5-9 for i in range(5, 10): for j in range(i+1, 10): A = A.at[i,j].set(1).at[j,i].set(1) # Bridge A = A.at[4,5].set(1).at[5,4].set(1) D = jnp.diag(A.sum(axis=1)) L = D - A eigenvalues, eigenvectors = jnp.linalg.eigh(L) # Use first 3 non-trivial eigenvectors as positional encoding pe = eigenvectors[:, 1:4] print("Laplacian Positional Encodings:") for i in range(n): group = "Clique 1" if i < 5 else "Clique 2" bridge = " (bridge)" if i in [4, 5] else "" print(f" Node {i} ({group}{bridge}): {pe[i]}") plt.scatter(pe[:5, 0], pe[:5, 1], c="#3498db", s=80, label="Clique 1") plt.scatter(pe[5:, 0], pe[5:, 1], c="#e74c3c", s=80, label="Clique 2") plt.scatter(pe[[4,5], 0], pe[[4,5], 1], c="black", s=120, marker="*", label="Bridge nodes", zorder=5) plt.legend(); plt.grid(True) plt.title("Laplacian Eigenvector Positional Encodings") plt.xlabel("Eigenvector 1"); plt.ylabel("Eigenvector 2") plt.show()