Skip to content

Why C++ and How ML Frameworks Work

Every jnp.matmul, every torch.nn.Linear, every np.dot call in this book has been executing C++ and CUDA code underneath. This file pulls back the curtain: why ML frameworks are built this way, quick C++ fundamentals for Python engineers, when to write custom C++ kernels, and how to bind them into Python, the bridge between the code you write and the hardware it runs on.

  • You have spent 15 chapters writing Python. You imported JAX, called jax.grad, ran training loops, and built models. It all felt like Python. But here is the truth: almost none of the actual computation happened in Python.

  • When you write output = model(input) in PyTorch or output = jnp.matmul(W, x) in JAX, Python does almost nothing. It constructs a description of the computation (a graph of operations), then hands it off to a C++/CUDA backend that does the real work. Python is the steering wheel; C++ is the engine.

Why Python Frontend, C++ Backend

  • This two-language architecture exists because Python and C++ are good at opposite things:
Python C++
Development speed Fast (dynamic typing, REPL, no compilation) Slow (static typing, headers, compile times)
Execution speed ~100x slower than C (interpreted, GIL) Near-hardware speed (compiled, no overhead)
Memory control Automatic (GC), no control over layout Manual, precise control over every byte
Hardware access None (no SIMD, no GPU, no custom memory) Full (intrinsics, CUDA, inline assembly)
Ecosystem Rich for ML (notebooks, visualisation, data) Rich for systems (OS, drivers, engines)
  • The insight: use each language for what it is good at. Python handles the parts where human productivity matters (experiment design, hyperparameter tuning, data exploration). C++ handles the parts where machine performance matters (matrix multiplication, convolution, attention kernels).

  • A single matrix multiplication jnp.matmul(A, B) where \(A\) is \(4096 \times 4096\) performs ~137 billion floating-point operations. In pure Python (nested loops), this takes ~30 minutes. In optimised C++ with AVX-512 SIMD and multithreading, it takes ~10 milliseconds. That is a 180,000x difference. No amount of Python cleverness closes this gap.

How ML Frameworks Are Structured

  • Every major ML framework follows the same architecture:
User code (Python)
Python API layer (torch.nn, jax.numpy, numpy)
Dispatch / JIT compiler (torch.compile, XLA, NumPy dispatch)
C++ kernel library (ATen/PyTorch, XLA, BLAS/LAPACK)
Hardware-specific backends (CUDA, cuDNN, MKL, oneDNN, Metal)
Hardware (CPU SIMD units, GPU cores, TPU MXUs)

NumPy

  • NumPy's core is written in C. When you call np.dot(A, B), Python calls a C function that calls BLAS (Basic Linear Algebra Subprograms), typically Intel MKL or OpenBLAS. BLAS is hand-optimised C and Fortran code that uses SIMD instructions, cache-aware memory access patterns, and multithreading. Decades of optimisation went into making matrix multiplication fast.

  • NumPy is CPU-only. It does not use GPUs. But on CPU, it is extremely fast because it delegates to the best available BLAS implementation.

PyTorch

  • PyTorch's computation engine is ATen (A Tensor Library), written in C++. ATen implements ~2000 tensor operations (add, matmul, conv2d, softmax, ...), each with CPU and CUDA backends.

  • When you call torch.matmul(A, B):

    1. Python dispatches to the ATen C++ function.
    2. ATen checks the device (CPU or CUDA) and dtype.
    3. On CPU: calls MKL/OpenBLAS. On GPU: calls cuBLAS (NVIDIA's GPU-optimised BLAS).
    4. The result is wrapped in a Python tensor object and returned.
  • torch.compile (PyTorch 2.0+) takes this further: it traces your Python code, builds a computation graph, and compiles it using Triton (for GPU) or C++/OpenMP (for CPU). The compiled code fuses operations, eliminates Python overhead, and can be 2-5x faster than eager mode.

JAX

  • JAX compiles Python functions to XLA (Accelerated Linear Algebra), Google's compiler for ML workloads. When you jax.jit a function:

    1. JAX traces the function, capturing the operations as an XLA computation graph (HLO — High Level Operations).
    2. XLA optimises the graph: fuses operations, eliminates redundant computation, optimises memory layout.
    3. XLA compiles to the target backend: CPU (via LLVM), GPU (via CUDA/PTX), or TPU (via TPU-specific instructions).
    4. The compiled code runs directly on hardware with zero Python involvement.
  • This is why jax.jit is so important: without it, every operation is a separate Python→C++ round trip. With it, the entire function is a single compiled kernel.

Quick C++ Fundamentals for Python Engineers

  • You do not need to become a C++ expert. You need to understand enough to read kernel code, write simple extensions, and understand performance discussions. Here are the essentials.

Types and Variables

// C++ requires explicit types (unlike Python)
int count = 0;           // 32-bit integer
float loss = 0.5f;       // 32-bit float
double lr = 3e-4;        // 64-bit float
bool training = true;    // boolean

// Arrays (fixed size, stack-allocated)
float weights[1024];     // 1024 floats, contiguous in memory

// Pointers: a variable that holds a memory address
float* ptr = weights;    // ptr points to the first element of weights
float val = ptr[42];     // access element 42 via pointer arithmetic
// ptr[42] is equivalent to *(ptr + 42)
  • Pointers are the biggest conceptual difference from Python. In Python, everything is a reference and you never think about memory addresses. In C++, pointers give you direct access to memory — powerful but dangerous (dangling pointers, buffer overflows).

Functions

// Function declaration: return_type name(param_type param_name)
float relu(float x) {
    return x > 0.0f ? x : 0.0f;
}

// Passing by reference (avoids copying large objects)
void scale_vector(std::vector<float>& vec, float factor) {
    for (size_t i = 0; i < vec.size(); i++) {
        vec[i] *= factor;
    }
}

// const reference: read-only, no copy
float sum(const std::vector<float>& vec) {
    float total = 0.0f;
    for (float x : vec) {  // range-based for loop (like Python's for x in vec)
        total += x;
    }
    return total;
}

Memory: Stack vs Heap

// Stack allocation: fast, automatic lifetime (freed when function returns)
float buffer[256];   // 256 floats on the stack

// Heap allocation: manual, survives beyond the function
float* data = new float[n];   // allocate n floats on the heap
// ... use data ...
delete[] data;                 // YOU must free it (no garbage collector)

// Modern C++: smart pointers (automatic cleanup, like Python references)
#include <memory>
auto data = std::make_unique<float[]>(n);  // freed automatically when out of scope
  • The key rule: stack is fast but limited (typically 1-8 MB). Large arrays (tensors, feature maps) must go on the heap. In Python, everything is on the heap and the GC handles cleanup. In C++, you manage it yourself (or use smart pointers).

Templates (Generics)

// A function that works with any numeric type
template <typename T>
T add(T a, T b) {
    return a + b;
}

add<float>(1.5f, 2.5f);   // returns 4.0f
add<int>(3, 4);             // returns 7
  • Templates are how C++ libraries (like ATen) write code that works with float16, float32, float64, etc. without duplicating the implementation.

The Standard Library Essentials

#include <vector>      // dynamic array (like Python list)
#include <string>      // string type
#include <unordered_map>  // hash map (like Python dict)
#include <algorithm>   // sort, find, transform, etc.
#include <cmath>       // math functions

std::vector<float> vec = {1.0f, 2.0f, 3.0f};
vec.push_back(4.0f);            // append
float first = vec[0];           // index
size_t len = vec.size();        // length

std::unordered_map<std::string, int> counts;
counts["hello"] = 5;            // insert
if (counts.count("hello")) { }  // check existence

When to Write Custom C++ Kernels

  • Most ML engineers never need to write C++. The framework's built-in operations cover 99% of use cases. You should consider custom C++ only when:

  • Your operation does not exist in the framework: a novel activation function, a custom attention pattern, a specialised loss function that cannot be expressed as a composition of existing ops.

  • Fusing operations for performance: your model does relu(layernorm(matmul(x, W) + b)). Each operation launches a separate kernel, reads and writes memory, and synchronises. A fused kernel does it all in one pass, avoiding memory round-trips. This can be 2-5x faster.

  • Reducing memory usage: a custom kernel can compute gradients without storing all intermediate activations (gradient checkpointing at the kernel level).

  • Targeting novel hardware: a new accelerator (e.g., Cerebras, Groq) may not have framework support. You write kernels directly.

  • For cases 1-2, Triton (chapter 16, file 05) is often sufficient and much easier than writing CUDA C directly. Only drop to CUDA C when Triton cannot express what you need.

How to Bind C++ to Python

  • Writing C++ is half the job. You also need to call it from Python.

pybind11 (General Purpose)

  • pybind11 creates Python bindings for C++ functions with minimal boilerplate:
// my_ops.cpp
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
namespace py = pybind11;

// A simple custom operation
py::array_t<float> custom_relu(py::array_t<float> input) {
    auto buf = input.request();
    float* ptr = static_cast<float*>(buf.ptr);
    size_t n = buf.size;

    auto result = py::array_t<float>(n);
    float* out = static_cast<float*>(result.request().ptr);

    for (size_t i = 0; i < n; i++) {
        out[i] = ptr[i] > 0 ? ptr[i] : 0;
    }
    return result;
}

PYBIND11_MODULE(my_ops, m) {
    m.def("custom_relu", &custom_relu, "Custom ReLU operation");
}
# Compile
pip install pybind11
c++ -O3 -shared -std=c++17 -fPIC $(python3 -m pybind11 --includes) my_ops.cpp -o my_ops$(python3-config --extension-suffix)
# Use from Python
import my_ops
import numpy as np

x = np.array([-1.0, 2.0, -3.0, 4.0], dtype=np.float32)
y = my_ops.custom_relu(x)
print(y)  # [0. 2. 0. 4.]

PyTorch C++ Extensions

  • PyTorch provides a streamlined way to add custom ops:
// custom_op.cpp
#include <torch/extension.h>

torch::Tensor custom_gelu(torch::Tensor x) {
    return x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("custom_gelu", &custom_gelu, "Custom GELU activation");
}
# Load and compile on-the-fly
from torch.utils.cpp_extension import load

custom_ops = load(
    name="custom_ops",
    sources=["custom_op.cpp"],
    extra_cflags=["-O3"],
)

x = torch.randn(1000)
y = custom_ops.custom_gelu(x)
  • torch.utils.cpp_extension.load compiles the C++ code, creates a shared library, and loads it as a Python module, all in one call. This is the easiest way to experiment with custom C++ ops in PyTorch.

JAX Custom Calls

  • JAX uses XLA custom calls. The process is more involved (you register a C function with XLA), but the concept is the same: write C/C++, bind it, call it from Python.

  • For most JAX users, Pallas (covered in file 05) is the better choice: it lets you write GPU kernels in a Python-like syntax that XLA compiles, without leaving the JAX ecosystem.

The Big Picture

  • This file explained the layer between Python and hardware. The remaining files in this chapter go deeper:

    • File 01: the hardware itself (CPU architecture, GPU architecture, memory systems)
    • Files 02-03: SIMD programming on CPU (ARM NEON, x86 AVX) — where you write C++ that uses the CPU's vector units
    • File 04: GPU programming with CUDA — where you write C++ that runs on thousands of GPU cores
    • File 05: Triton, Pallas, and higher-level GPU programming — where you write Python that compiles to GPU kernels
  • The progression mirrors the abstraction ladder: C++ intrinsics (lowest, most control) → CUDA (GPU-specific) → Triton/Pallas (Pythonic, compiled) → JAX/PyTorch (highest, automatic). Each level trades control for convenience. Understanding the lower levels makes you a better user of the higher ones.

Coding Tasks (compile with g++ or clang++)

  1. Write your first C++ program. Allocate an array, fill it, compute the sum, and measure the time. This introduces compilation, arrays, pointers, and timing.

    // task1_basics.cpp
    // Compile: g++ -O3 -o task1 task1_basics.cpp
    // Run: ./task1
    
    #include <iostream>
    #include <chrono>
    #include <vector>
    
    int main() {
        const int N = 10'000'000;  // C++ allows ' as digit separator
        std::vector<float> data(N);
    
        // Fill the array
        for (int i = 0; i < N; i++) {
            data[i] = static_cast<float>(i) * 0.001f;
        }
    
        // Compute sum
        auto start = std::chrono::high_resolution_clock::now();
        float sum = 0.0f;
        for (int i = 0; i < N; i++) {
            sum += data[i];
        }
        auto end = std::chrono::high_resolution_clock::now();
        double elapsed = std::chrono::duration<double, std::milli>(end - start).count();
    
        std::cout << "Sum: " << sum << std::endl;
        std::cout << "Time: " << elapsed << " ms" << std::endl;
        std::cout << "Elements: " << N << std::endl;
        std::cout << "Throughput: " << (N * sizeof(float)) / elapsed / 1e6 << " GB/s" << std::endl;
    
        return 0;
    }
    

  2. Write a C++ function that computes ReLU on an array, then build a Python binding using pybind11. Call it from Python and compare speed against NumPy.

    // task2_relu.cpp
    // Compile: c++ -O3 -shared -std=c++17 -fPIC $(python3 -m pybind11 --includes) \
    //          task2_relu.cpp -o my_relu$(python3-config --extension-suffix)
    
    #include <pybind11/pybind11.h>
    #include <pybind11/numpy.h>
    namespace py = pybind11;
    
    py::array_t<float> cpp_relu(py::array_t<float> input) {
        auto buf = input.request();
        float* ptr = static_cast<float*>(buf.ptr);
        int n = buf.size;
    
        auto result = py::array_t<float>(n);
        float* out = static_cast<float*>(result.request().ptr);
    
        for (int i = 0; i < n; i++) {
            out[i] = ptr[i] > 0.0f ? ptr[i] : 0.0f;
        }
        return result;
    }
    
    PYBIND11_MODULE(my_relu, m) {
        m.def("relu", &cpp_relu, "C++ ReLU");
    }
    
    # test_relu.py — run after compiling the C++ module above
    import numpy as np
    import time
    import my_relu  # the compiled C++ module
    
    x = np.random.randn(10_000_000).astype(np.float32)
    
    # C++ ReLU
    start = time.time()
    for _ in range(100):
        y_cpp = my_relu.relu(x)
    cpp_time = (time.time() - start) / 100
    
    # NumPy ReLU
    start = time.time()
    for _ in range(100):
        y_np = np.maximum(x, 0)
    np_time = (time.time() - start) / 100
    
    print(f"C++ ReLU:   {cpp_time*1000:.2f} ms")
    print(f"NumPy ReLU: {np_time*1000:.2f} ms")
    print(f"Match: {np.allclose(y_cpp, y_np)}")
    

  3. Write a C++ program that demonstrates why memory layout matters. Compare row-major vs column-major access patterns and measure the performance difference.

    // task3_layout.cpp
    // Compile: g++ -O3 -o task3 task3_layout.cpp
    
    #include <iostream>
    #include <chrono>
    #include <vector>
    
    int main() {
        const int N = 4096;
        std::vector<float> matrix(N * N, 1.0f);
    
        // Row-major access: sequential memory addresses (cache-friendly)
        auto start = std::chrono::high_resolution_clock::now();
        float sum_row = 0.0f;
        for (int i = 0; i < N; i++) {
            for (int j = 0; j < N; j++) {
                sum_row += matrix[i * N + j];  // stride-1 access
            }
        }
        auto end = std::chrono::high_resolution_clock::now();
        double row_ms = std::chrono::duration<double, std::milli>(end - start).count();
    
        // Column-major access: stride-N access (cache-unfriendly)
        start = std::chrono::high_resolution_clock::now();
        float sum_col = 0.0f;
        for (int j = 0; j < N; j++) {
            for (int i = 0; i < N; i++) {
                sum_col += matrix[i * N + j];  // stride-N access (cache misses!)
            }
        }
        end = std::chrono::high_resolution_clock::now();
        double col_ms = std::chrono::duration<double, std::milli>(end - start).count();
    
        std::cout << "Row-major (cache-friendly): " << row_ms << " ms" << std::endl;
        std::cout << "Col-major (cache-hostile):  " << col_ms << " ms" << std::endl;
        std::cout << "Slowdown: " << col_ms / row_ms << "x" << std::endl;
        std::cout << "(Both sums: " << sum_row << ", " << sum_col << ")" << std::endl;
    
        return 0;
    }