Skip to content

Distributed Deep Learning

Distributed training spreads computation across multiple GPUs and machines to train models that are too large or too slow for a single device. This file covers mixed precision, data parallelism, model parallelism, pipeline parallelism, ZeRO, FSDP, tensor parallelism, and communication primitives like all-reduce -- essential for training LLMs at scale.

  • Training a large neural network on a single GPU eventually hits a wall. The model might not fit in memory, or training might take months. Distributed training spreads the work across multiple devices (GPUs, TPUs, or entire machines) to train faster and train bigger models. This file covers the techniques that make that possible.

  • To understand why distribution matters, start with the computational cost of training. A single forward pass through a dense layer with \(d_{\text{in}}\) inputs and \(d_{\text{out}}\) outputs on a batch of \(B\) examples requires roughly \(2 \cdot B \cdot d_{\text{in}} \cdot d_{\text{out}}\) FLOPs (floating-point operations): one multiply and one add for each element of the output matrix. The backward pass costs roughly twice the forward pass (computing gradients with respect to both the inputs and the weights), so one training step on a dense layer is about \(6 \cdot B \cdot d_{\text{in}} \cdot d_{\text{out}}\) FLOPs.

  • For a transformer layer with hidden dimension \(d\), the self-attention block involves four projections (Q, K, V, and output) each costing \(O(B \cdot n \cdot d^2)\) FLOPs (where \(n\) is the sequence length), plus the attention matrix computation at \(O(B \cdot n^2 \cdot d)\). The feed-forward block has two dense layers, typically expanding to \(4d\) and back: \(O(B \cdot n \cdot 8d^2)\). Total per layer: roughly \(O(B \cdot n \cdot 12d^2 + B \cdot n^2 \cdot d)\). Multiply by the number of layers and you see why training GPT-scale models requires thousands of GPU-hours.

  • The memory wall is often the tighter constraint. During training, GPU memory must hold four things simultaneously:

Stacked bar showing training memory breakdown: parameters, gradients, optimizer states, activations

  • Parameters: the model weights. A 7-billion parameter model in FP32 (4 bytes per parameter) needs 28 GB just for weights.
  • Gradients: same size as the parameters. Another 28 GB.
  • Optimizer states: Adam maintains two additional buffers (first and second moment estimates), each the size of the parameters. These are kept in FP32 for numerical stability, even when the model uses lower precision. For our 7B model, that is \(2 \times 28 = 56\) GB.
  • Activations: intermediate values saved during the forward pass for use in the backward pass. The size depends on batch size, sequence length, and model width. This is often the largest component and grows linearly with batch size.

  • For our 7B model with FP32 Adam: 28 (params) + 28 (grads) + 56 (optimizer) = 112 GB, before we even count activations. A single 80 GB A100 GPU cannot hold this. This is why distributed strategies are essential.

  • Mixed precision training is the first line of defence. Instead of storing everything in FP32 (32-bit floating point), you train using FP16 or BF16 (16-bit) for the forward and backward passes, while keeping a master copy of weights in FP32 for the optimizer update.

  • FP16 has high precision (10-bit mantissa) but a limited range, which can cause overflow/underflow. Loss scaling (multiplying the loss by a large factor before the backward pass, then dividing gradients by the same factor) mitigates this.

  • BF16 (brain float) has the same exponent range as FP32 (8-bit exponent) but less precision (7-bit mantissa). It almost never overflows and rarely needs loss scaling, making it simpler to use. BF16 is the default for modern transformer training.

  • Mixed precision roughly halves the memory for activations and gradients (the dominant costs during forward/backward passes), while keeping optimizer states in FP32 for numerical stability.

  • Data parallelism is the simplest distributed strategy. You replicate the entire model on \(N\) GPUs, split each mini-batch into \(N\) equal chunks, and send one chunk to each GPU. Each GPU runs the forward and backward pass on its chunk independently. Then the gradients are averaged across all GPUs (using an all-reduce operation), and each GPU updates its local copy of the model.

  • From the model's perspective, this is equivalent to training with a mini-batch that is \(N\) times larger. If each GPU processes a batch of size \(B\), the effective batch size is \(N \cdot B\).

Side-by-side comparison: data parallelism replicates model and splits data, model parallelism splits model and shares data

  • The gradient averaging can be done synchronously or asynchronously. Synchronous SGD waits for all GPUs to finish before averaging, ensuring mathematical equivalence to single-GPU training with a larger batch. The downside is that the slowest GPU (the "straggler") holds everyone up.

  • Asynchronous SGD lets each GPU update a shared parameter server independently, without waiting. This eliminates the straggler problem but introduces "stale gradients": a GPU might compute gradients based on slightly outdated parameters. Stale gradients add noise and can slow convergence. In practice, synchronous SGD with efficient communication is preferred.

  • Gradient accumulation is a software trick for simulating larger batch sizes on limited hardware. Instead of doing one update per mini-batch, you run several forward/backward passes and accumulate the gradients, then do one update. This gives the same result as a larger batch without needing more GPU memory for activations (only one mini-batch of activations is in memory at a time).

  • When the model itself is too large to fit on a single GPU, you need model parallelism. There are two main flavours.

  • Tensor parallelism splits individual layers across GPUs. A large matrix multiply \(Y = XW\) can be split column-wise: partition \(W\) into \([W_1, W_2]\) across two GPUs, compute \(Y_1 = XW_1\) and \(Y_2 = XW_2\) in parallel, then concatenate. This works for attention projections and feed-forward layers. It requires fast communication between GPUs (typically NVLink within a node) because partial results must be combined at every layer.

  • Pipeline parallelism assigns different layers to different GPUs. GPU 0 runs layers 1-4, GPU 1 runs layers 5-8, and so on. Data flows through the pipeline like an assembly line. The naive approach has a "pipeline bubble": while GPU 0 processes the forward pass for micro-batch 1, GPUs 1-3 sit idle. Micro-batching mitigates this by splitting the mini-batch into smaller micro-batches that flow through the pipeline in sequence, keeping all GPUs busy most of the time.

  • Hybrid parallelism combines data, tensor, and pipeline parallelism. A typical large-model setup might use tensor parallelism within a node (8 GPUs connected by fast NVLink), pipeline parallelism across nodes, and data parallelism across groups of nodes. This is how models like GPT-4 and Llama are trained.

  • The efficiency of distributed training depends heavily on communication. The key operation is all-reduce: given a value on each of \(N\) GPUs, compute the sum (or average) and distribute the result to all GPUs.

  • A naive all-reduce sends all data to one GPU, sums it, and broadcasts back. This is \(O(N)\) in communication and creates a bottleneck at the root.

  • Ring all-reduce is much more efficient. Arrange the \(N\) GPUs in a ring. Each GPU splits its data into \(N\) chunks. In \(N - 1\) steps, each GPU sends one chunk to its neighbour and receives a chunk from its other neighbour, accumulating partial sums. After another \(N - 1\) steps, the full sum is propagated to all GPUs. Total data transferred per GPU: \(2(N-1)/N\) times the data size, which approaches \(2\times\) as \(N\) grows. Crucially, this does not increase with \(N\), making it bandwidth-optimal.

Four GPUs arranged in a ring, each passing gradient chunks to its neighbour until all have the complete sum

  • Parameter servers are an alternative architecture where dedicated server nodes hold the model parameters. Workers compute gradients and send them to the server, which updates parameters and sends them back. This is simpler but can create communication bottlenecks at the server.

  • NCCL (NVIDIA Collective Communications Library) is the standard library for GPU-to-GPU communication. It provides optimised implementations of all-reduce, all-gather, broadcast, and other collective operations, automatically choosing the best algorithm for the network topology.

  • Scaling laws describe how model performance improves with compute, data, and model size. The original Kaplan et al. (2020) scaling laws found that loss decreases as a power law with each:

\[L(N) \propto N^{-\alpha_N}, \quad L(D) \propto D^{-\alpha_D}, \quad L(C) \propto C^{-\alpha_C}\]
  • where \(N\) is the number of parameters, \(D\) is the dataset size, and \(C\) is the compute budget.

  • The Chinchilla scaling laws (Hoffmann et al., 2022) showed that most models were undertrained: for a given compute budget, you should train a smaller model on more data than previously thought. The optimal ratio is roughly 20 tokens per parameter. A 7B model should see about 140B tokens, not the 300B tokens that Llama 1 used with a 65B model. This finding shifted the field toward "compute-optimal" training.

  • Mixture of Experts (MoE) is an architecture that scales model capacity without proportionally scaling compute. Instead of one feed-forward network per transformer layer, you have \(N\) "expert" networks (each a standard FFN). A gating network (router) examines each token and sends it to the top-\(K\) experts (typically \(K = 1\) or \(K = 2\)).

Tokens routed through a gating network to selected experts, with top-K sparse routing and weighted combination of outputs

  • The total parameter count is much larger (because you have \(N\) experts), but the FLOPs per token stay roughly constant (because only \(K\) experts activate per token). For example, Mixtral 8x7B has 47B total parameters but only uses about 13B per forward pass, giving the performance of a much larger model at the cost of a smaller one.

  • MoE introduces challenges. Load balancing: if the router sends most tokens to the same expert, the others are wasted. An auxiliary loss encourages uniform routing. Communication: different experts may live on different GPUs, so routing tokens requires all-to-all communication, which is expensive.

  • Fault tolerance is critical when training runs last weeks or months on thousands of GPUs. If a single GPU fails, you do not want to lose all progress. Checkpointing periodically saves model weights, optimizer states, and the training state (learning rate, step count, data position) to disk. If a failure occurs, you restart from the last checkpoint.

  • Gradient checkpointing (also called activation recomputation) is a memory optimisation, not a fault-tolerance mechanism. During the forward pass, instead of saving all activations for the backward pass, you only save activations at certain checkpoints. During the backward pass, you recompute the missing activations from the checkpoints. This trades compute for memory: it increases the forward-pass cost by roughly 33% but can reduce activation memory by a factor of \(\sqrt{L}\) (where \(L\) is the number of layers).

  • Putting it all together, training a frontier model combines all of these techniques: BF16 mixed precision, data parallelism across thousands of GPUs with ring all-reduce, tensor parallelism within nodes, pipeline parallelism across nodes, gradient checkpointing to reduce memory, MoE for parameter efficiency, and regular checkpointing for fault tolerance. The systems engineering is as challenging as the algorithm design.

  • To summarise the distributed training toolkit:

Technique What It Does Tradeoff
Mixed precision (BF16) Halves memory for activations/grads Slight numerical differences
Data parallelism Scales batch size across GPUs Communication overhead for gradient sync
Tensor parallelism Splits layers across GPUs Requires fast interconnect
Pipeline parallelism Splits model stages across GPUs Pipeline bubble (wasted compute)
Gradient accumulation Simulates large batches Slower (multiple forward/backward passes)
Gradient checkpointing Reduces activation memory ~33% more compute
Ring all-reduce Efficient gradient averaging Bandwidth-limited for large models
MoE More capacity, same FLOPs Load balancing, routing complexity
Scaling laws Guides compute allocation Empirical, may not hold at all scales

Coding Tasks (use CoLab or notebook)

  1. Compute the FLOPs and memory requirements for a transformer layer. Given hidden dimension \(d\), sequence length \(n\), batch size \(B\), and number of layers, estimate the total training cost.

    import jax.numpy as jnp
    
    def transformer_layer_flops(d, n, B):
        """Approximate FLOPs for one transformer layer forward pass."""
        # QKV projections: 3 * (B * n * d * d) * 2 (multiply-add)
        qkv_flops = 3 * 2 * B * n * d * d
        # Attention: (B * n * n * d) * 2 for QK^T, (B * n * n * d) * 2 for attn*V
        attn_flops = 2 * 2 * B * n * n * d
        # Output projection: (B * n * d * d) * 2
        out_flops = 2 * B * n * d * d
        # FFN: two layers, d->4d and 4d->d: 2 * (B * n * d * 4d) * 2
        ffn_flops = 2 * 2 * B * n * d * 4 * d
        return qkv_flops + attn_flops + out_flops + ffn_flops
    
    def transformer_layer_memory(d, n, B, dtype_bytes=2):
        """Approximate activation memory (bytes) for one layer."""
        # QKV: 3 * B * n * d
        qkv_mem = 3 * B * n * d * dtype_bytes
        # Attention weights: B * heads * n * n (approx B * n * n * sizeof)
        attn_mem = B * n * n * dtype_bytes
        # FFN intermediate: B * n * 4d
        ffn_mem = B * n * 4 * d * dtype_bytes
        return qkv_mem + attn_mem + ffn_mem
    
    # Example: GPT-2 scale
    d, n, B, L = 1024, 1024, 8, 24
    fwd_flops = transformer_layer_flops(d, n, B)
    total_flops = 3 * L * fwd_flops  # 3x for forward + backward
    act_mem = L * transformer_layer_memory(d, n, B)
    param_count = L * (12 * d * d + 13 * d)  # approximate
    
    print(f"Model: d={d}, n={n}, B={B}, L={L}")
    print(f"Parameters: {param_count / 1e6:.0f}M")
    print(f"FLOPs per step: {total_flops / 1e12:.2f} TFLOPs")
    print(f"Activation memory: {act_mem / 1e9:.2f} GB (BF16)")
    print(f"Parameter memory (FP32): {param_count * 4 / 1e9:.2f} GB")
    print(f"Adam optimizer memory: {param_count * 8 / 1e9:.2f} GB")
    print(f"Total training memory: {(param_count * 16 + act_mem) / 1e9:.2f} GB")
    

  2. Simulate data-parallel training. Split a dataset across multiple "virtual GPUs," compute gradients independently, average them, and verify the result matches single-GPU training.

    import jax
    import jax.numpy as jnp
    
    # Simple linear model: y = wx + b
    key = jax.random.PRNGKey(0)
    X = jax.random.normal(key, (64, 4))
    w_true = jnp.array([1.0, -2.0, 3.0, 0.5])
    y = X @ w_true + 0.1 * jax.random.normal(key, (64,))
    
    def loss_fn(w, X, y):
        return jnp.mean((X @ w - y) ** 2)
    
    grad_fn = jax.grad(loss_fn)
    
    # Single GPU: full batch gradient
    w = jnp.zeros(4)
    grad_single = grad_fn(w, X, y)
    
    # Data parallel: split across 4 "GPUs"
    n_gpus = 4
    chunk_size = len(X) // n_gpus
    grads = []
    for i in range(n_gpus):
        X_chunk = X[i*chunk_size:(i+1)*chunk_size]
        y_chunk = y[i*chunk_size:(i+1)*chunk_size]
        grads.append(grad_fn(w, X_chunk, y_chunk))
    
    # All-reduce: average gradients
    grad_parallel = jnp.mean(jnp.stack(grads), axis=0)
    
    print("Single-GPU gradient:", grad_single)
    print("Data-parallel gradient (avg):", grad_parallel)
    print(f"Match: {jnp.allclose(grad_single, grad_parallel, atol=1e-5)}")
    
    # Train both and compare
    w_single, w_parallel = jnp.zeros(4), jnp.zeros(4)
    lr = 0.1
    for step in range(100):
        w_single = w_single - lr * grad_fn(w_single, X, y)
    
        grads = [grad_fn(w_parallel, X[i*chunk_size:(i+1)*chunk_size],
                         y[i*chunk_size:(i+1)*chunk_size]) for i in range(n_gpus)]
        avg_grad = jnp.mean(jnp.stack(grads), axis=0)
        w_parallel = w_parallel - lr * avg_grad
    
    print(f"\nAfter 100 steps:")
    print(f"Single-GPU weights: {w_single}")
    print(f"Data-parallel weights: {w_parallel}")
    print(f"Max difference: {jnp.max(jnp.abs(w_single - w_parallel)):.2e}")
    

  3. Implement a simple Mixture of Experts layer. Create a gating network that routes tokens to top-K experts and combine their outputs.

    import jax
    import jax.numpy as jnp
    
    def expert_fn(x, W1, b1, W2, b2):
        """Simple 2-layer FFN expert."""
        h = jnp.maximum(0, x @ W1 + b1)  # ReLU
        return h @ W2 + b2
    
    def moe_layer(x, gate_W, experts_params, top_k=2):
        """
        MoE forward pass.
        x: (batch, d_model)
        gate_W: (d_model, n_experts)
        experts_params: list of (W1, b1, W2, b2) per expert
        """
        n_experts = len(experts_params)
    
        # Gating: compute routing scores
        gate_logits = x @ gate_W  # (batch, n_experts)
        gate_probs = jax.nn.softmax(gate_logits, axis=-1)
    
        # Top-K selection
        top_k_indices = jnp.argsort(-gate_probs, axis=-1)[:, :top_k]
        top_k_probs = jnp.take_along_axis(gate_probs, top_k_indices, axis=-1)
        # Renormalise
        top_k_probs = top_k_probs / jnp.sum(top_k_probs, axis=-1, keepdims=True)
    
        # Compute expert outputs (simplified: run all experts, mask later)
        expert_outputs = jnp.stack([
            expert_fn(x, *experts_params[i]) for i in range(n_experts)
        ], axis=1)  # (batch, n_experts, d_model)
    
        # Gather top-K expert outputs and weight them
        batch_idx = jnp.arange(x.shape[0])[:, None]
        selected_outputs = expert_outputs[batch_idx, top_k_indices]  # (batch, top_k, d_model)
        output = jnp.sum(selected_outputs * top_k_probs[:, :, None], axis=1)
    
        return output, gate_probs
    
    # Setup
    key = jax.random.PRNGKey(42)
    batch, d_model, d_ff, n_experts = 8, 16, 32, 4
    
    # Initialise experts
    experts_params = []
    for i in range(n_experts):
        k1, k2, key = jax.random.split(key, 3)[0], jax.random.split(key, 3)[1], jax.random.split(key, 3)[2]
        experts_params.append((
            jax.random.normal(k1, (d_model, d_ff)) * 0.1,
            jnp.zeros(d_ff),
            jax.random.normal(k2, (d_ff, d_model)) * 0.1,
            jnp.zeros(d_model),
        ))
    
    key, subkey = jax.random.split(key)
    gate_W = jax.random.normal(subkey, (d_model, n_experts)) * 0.1
    x = jax.random.normal(key, (batch, d_model))
    
    output, gate_probs = moe_layer(x, gate_W, experts_params, top_k=2)
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Gate probabilities (first sample): {gate_probs[0]}")
    print(f"Expert usage (avg across batch):")
    for i in range(n_experts):
        usage = jnp.mean(gate_probs[:, i])
        print(f"  Expert {i}: {usage:.3f}")