Why C++ and How ML Frameworks Work¶
Every jnp.matmul, every torch.nn.Linear, every np.dot call in this book has been executing C++ and CUDA code underneath. This file pulls back the curtain: why ML frameworks are built this way, quick C++ fundamentals for Python engineers, when to write custom C++ kernels, and how to bind them into Python, the bridge between the code you write and the hardware it runs on.
-
You have spent 15 chapters writing Python. You imported JAX, called
jax.grad, ran training loops, and built models. It all felt like Python. But here is the truth: almost none of the actual computation happened in Python. -
When you write
output = model(input)in PyTorch oroutput = jnp.matmul(W, x)in JAX, Python does almost nothing. It constructs a description of the computation (a graph of operations), then hands it off to a C++/CUDA backend that does the real work. Python is the steering wheel; C++ is the engine.
Why Python Frontend, C++ Backend¶
- This two-language architecture exists because Python and C++ are good at opposite things:
| Python | C++ | |
|---|---|---|
| Development speed | Fast (dynamic typing, REPL, no compilation) | Slow (static typing, headers, compile times) |
| Execution speed | ~100x slower than C (interpreted, GIL) | Near-hardware speed (compiled, no overhead) |
| Memory control | Automatic (GC), no control over layout | Manual, precise control over every byte |
| Hardware access | None (no SIMD, no GPU, no custom memory) | Full (intrinsics, CUDA, inline assembly) |
| Ecosystem | Rich for ML (notebooks, visualisation, data) | Rich for systems (OS, drivers, engines) |
-
The insight: use each language for what it is good at. Python handles the parts where human productivity matters (experiment design, hyperparameter tuning, data exploration). C++ handles the parts where machine performance matters (matrix multiplication, convolution, attention kernels).
-
A single matrix multiplication
jnp.matmul(A, B)where \(A\) is \(4096 \times 4096\) performs ~137 billion floating-point operations. In pure Python (nested loops), this takes ~30 minutes. In optimised C++ with AVX-512 SIMD and multithreading, it takes ~10 milliseconds. That is a 180,000x difference. No amount of Python cleverness closes this gap.
How ML Frameworks Are Structured¶
- Every major ML framework follows the same architecture:
User code (Python)
↓
Python API layer (torch.nn, jax.numpy, numpy)
↓
Dispatch / JIT compiler (torch.compile, XLA, NumPy dispatch)
↓
C++ kernel library (ATen/PyTorch, XLA, BLAS/LAPACK)
↓
Hardware-specific backends (CUDA, cuDNN, MKL, oneDNN, Metal)
↓
Hardware (CPU SIMD units, GPU cores, TPU MXUs)
NumPy¶
-
NumPy's core is written in C. When you call
np.dot(A, B), Python calls a C function that calls BLAS (Basic Linear Algebra Subprograms), typically Intel MKL or OpenBLAS. BLAS is hand-optimised C and Fortran code that uses SIMD instructions, cache-aware memory access patterns, and multithreading. Decades of optimisation went into making matrix multiplication fast. -
NumPy is CPU-only. It does not use GPUs. But on CPU, it is extremely fast because it delegates to the best available BLAS implementation.
PyTorch¶
-
PyTorch's computation engine is ATen (A Tensor Library), written in C++. ATen implements ~2000 tensor operations (add, matmul, conv2d, softmax, ...), each with CPU and CUDA backends.
-
When you call
torch.matmul(A, B):- Python dispatches to the ATen C++ function.
- ATen checks the device (CPU or CUDA) and dtype.
- On CPU: calls MKL/OpenBLAS. On GPU: calls cuBLAS (NVIDIA's GPU-optimised BLAS).
- The result is wrapped in a Python tensor object and returned.
-
torch.compile (PyTorch 2.0+) takes this further: it traces your Python code, builds a computation graph, and compiles it using Triton (for GPU) or C++/OpenMP (for CPU). The compiled code fuses operations, eliminates Python overhead, and can be 2-5x faster than eager mode.
JAX¶
-
JAX compiles Python functions to XLA (Accelerated Linear Algebra), Google's compiler for ML workloads. When you
jax.jita function:- JAX traces the function, capturing the operations as an XLA computation graph (HLO — High Level Operations).
- XLA optimises the graph: fuses operations, eliminates redundant computation, optimises memory layout.
- XLA compiles to the target backend: CPU (via LLVM), GPU (via CUDA/PTX), or TPU (via TPU-specific instructions).
- The compiled code runs directly on hardware with zero Python involvement.
-
This is why
jax.jitis so important: without it, every operation is a separate Python→C++ round trip. With it, the entire function is a single compiled kernel.
Quick C++ Fundamentals for Python Engineers¶
- You do not need to become a C++ expert. You need to understand enough to read kernel code, write simple extensions, and understand performance discussions. Here are the essentials.
Types and Variables¶
// C++ requires explicit types (unlike Python)
int count = 0; // 32-bit integer
float loss = 0.5f; // 32-bit float
double lr = 3e-4; // 64-bit float
bool training = true; // boolean
// Arrays (fixed size, stack-allocated)
float weights[1024]; // 1024 floats, contiguous in memory
// Pointers: a variable that holds a memory address
float* ptr = weights; // ptr points to the first element of weights
float val = ptr[42]; // access element 42 via pointer arithmetic
// ptr[42] is equivalent to *(ptr + 42)
- Pointers are the biggest conceptual difference from Python. In Python, everything is a reference and you never think about memory addresses. In C++, pointers give you direct access to memory — powerful but dangerous (dangling pointers, buffer overflows).
Functions¶
// Function declaration: return_type name(param_type param_name)
float relu(float x) {
return x > 0.0f ? x : 0.0f;
}
// Passing by reference (avoids copying large objects)
void scale_vector(std::vector<float>& vec, float factor) {
for (size_t i = 0; i < vec.size(); i++) {
vec[i] *= factor;
}
}
// const reference: read-only, no copy
float sum(const std::vector<float>& vec) {
float total = 0.0f;
for (float x : vec) { // range-based for loop (like Python's for x in vec)
total += x;
}
return total;
}
Memory: Stack vs Heap¶
// Stack allocation: fast, automatic lifetime (freed when function returns)
float buffer[256]; // 256 floats on the stack
// Heap allocation: manual, survives beyond the function
float* data = new float[n]; // allocate n floats on the heap
// ... use data ...
delete[] data; // YOU must free it (no garbage collector)
// Modern C++: smart pointers (automatic cleanup, like Python references)
#include <memory>
auto data = std::make_unique<float[]>(n); // freed automatically when out of scope
- The key rule: stack is fast but limited (typically 1-8 MB). Large arrays (tensors, feature maps) must go on the heap. In Python, everything is on the heap and the GC handles cleanup. In C++, you manage it yourself (or use smart pointers).
Templates (Generics)¶
// A function that works with any numeric type
template <typename T>
T add(T a, T b) {
return a + b;
}
add<float>(1.5f, 2.5f); // returns 4.0f
add<int>(3, 4); // returns 7
- Templates are how C++ libraries (like ATen) write code that works with float16, float32, float64, etc. without duplicating the implementation.
The Standard Library Essentials¶
#include <vector> // dynamic array (like Python list)
#include <string> // string type
#include <unordered_map> // hash map (like Python dict)
#include <algorithm> // sort, find, transform, etc.
#include <cmath> // math functions
std::vector<float> vec = {1.0f, 2.0f, 3.0f};
vec.push_back(4.0f); // append
float first = vec[0]; // index
size_t len = vec.size(); // length
std::unordered_map<std::string, int> counts;
counts["hello"] = 5; // insert
if (counts.count("hello")) { } // check existence
When to Write Custom C++ Kernels¶
-
Most ML engineers never need to write C++. The framework's built-in operations cover 99% of use cases. You should consider custom C++ only when:
-
Your operation does not exist in the framework: a novel activation function, a custom attention pattern, a specialised loss function that cannot be expressed as a composition of existing ops.
-
Fusing operations for performance: your model does
relu(layernorm(matmul(x, W) + b)). Each operation launches a separate kernel, reads and writes memory, and synchronises. A fused kernel does it all in one pass, avoiding memory round-trips. This can be 2-5x faster. -
Reducing memory usage: a custom kernel can compute gradients without storing all intermediate activations (gradient checkpointing at the kernel level).
-
Targeting novel hardware: a new accelerator (e.g., Cerebras, Groq) may not have framework support. You write kernels directly.
-
For cases 1-2, Triton (chapter 16, file 05) is often sufficient and much easier than writing CUDA C directly. Only drop to CUDA C when Triton cannot express what you need.
How to Bind C++ to Python¶
- Writing C++ is half the job. You also need to call it from Python.
pybind11 (General Purpose)¶
- pybind11 creates Python bindings for C++ functions with minimal boilerplate:
// my_ops.cpp
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
namespace py = pybind11;
// A simple custom operation
py::array_t<float> custom_relu(py::array_t<float> input) {
auto buf = input.request();
float* ptr = static_cast<float*>(buf.ptr);
size_t n = buf.size;
auto result = py::array_t<float>(n);
float* out = static_cast<float*>(result.request().ptr);
for (size_t i = 0; i < n; i++) {
out[i] = ptr[i] > 0 ? ptr[i] : 0;
}
return result;
}
PYBIND11_MODULE(my_ops, m) {
m.def("custom_relu", &custom_relu, "Custom ReLU operation");
}
# Compile
pip install pybind11
c++ -O3 -shared -std=c++17 -fPIC $(python3 -m pybind11 --includes) my_ops.cpp -o my_ops$(python3-config --extension-suffix)
# Use from Python
import my_ops
import numpy as np
x = np.array([-1.0, 2.0, -3.0, 4.0], dtype=np.float32)
y = my_ops.custom_relu(x)
print(y) # [0. 2. 0. 4.]
PyTorch C++ Extensions¶
- PyTorch provides a streamlined way to add custom ops:
// custom_op.cpp
#include <torch/extension.h>
torch::Tensor custom_gelu(torch::Tensor x) {
return x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("custom_gelu", &custom_gelu, "Custom GELU activation");
}
# Load and compile on-the-fly
from torch.utils.cpp_extension import load
custom_ops = load(
name="custom_ops",
sources=["custom_op.cpp"],
extra_cflags=["-O3"],
)
x = torch.randn(1000)
y = custom_ops.custom_gelu(x)
torch.utils.cpp_extension.loadcompiles the C++ code, creates a shared library, and loads it as a Python module, all in one call. This is the easiest way to experiment with custom C++ ops in PyTorch.
JAX Custom Calls¶
-
JAX uses XLA custom calls. The process is more involved (you register a C function with XLA), but the concept is the same: write C/C++, bind it, call it from Python.
-
For most JAX users, Pallas (covered in file 05) is the better choice: it lets you write GPU kernels in a Python-like syntax that XLA compiles, without leaving the JAX ecosystem.
The Big Picture¶
-
This file explained the layer between Python and hardware. The remaining files in this chapter go deeper:
- File 01: the hardware itself (CPU architecture, GPU architecture, memory systems)
- Files 02-03: SIMD programming on CPU (ARM NEON, x86 AVX) — where you write C++ that uses the CPU's vector units
- File 04: GPU programming with CUDA — where you write C++ that runs on thousands of GPU cores
- File 05: Triton, Pallas, and higher-level GPU programming — where you write Python that compiles to GPU kernels
-
The progression mirrors the abstraction ladder: C++ intrinsics (lowest, most control) → CUDA (GPU-specific) → Triton/Pallas (Pythonic, compiled) → JAX/PyTorch (highest, automatic). Each level trades control for convenience. Understanding the lower levels makes you a better user of the higher ones.
Coding Tasks (compile with g++ or clang++)¶
-
Write your first C++ program. Allocate an array, fill it, compute the sum, and measure the time. This introduces compilation, arrays, pointers, and timing.
// task1_basics.cpp // Compile: g++ -O3 -o task1 task1_basics.cpp // Run: ./task1 #include <iostream> #include <chrono> #include <vector> int main() { const int N = 10'000'000; // C++ allows ' as digit separator std::vector<float> data(N); // Fill the array for (int i = 0; i < N; i++) { data[i] = static_cast<float>(i) * 0.001f; } // Compute sum auto start = std::chrono::high_resolution_clock::now(); float sum = 0.0f; for (int i = 0; i < N; i++) { sum += data[i]; } auto end = std::chrono::high_resolution_clock::now(); double elapsed = std::chrono::duration<double, std::milli>(end - start).count(); std::cout << "Sum: " << sum << std::endl; std::cout << "Time: " << elapsed << " ms" << std::endl; std::cout << "Elements: " << N << std::endl; std::cout << "Throughput: " << (N * sizeof(float)) / elapsed / 1e6 << " GB/s" << std::endl; return 0; } -
Write a C++ function that computes ReLU on an array, then build a Python binding using pybind11. Call it from Python and compare speed against NumPy.
// task2_relu.cpp // Compile: c++ -O3 -shared -std=c++17 -fPIC $(python3 -m pybind11 --includes) \ // task2_relu.cpp -o my_relu$(python3-config --extension-suffix) #include <pybind11/pybind11.h> #include <pybind11/numpy.h> namespace py = pybind11; py::array_t<float> cpp_relu(py::array_t<float> input) { auto buf = input.request(); float* ptr = static_cast<float*>(buf.ptr); int n = buf.size; auto result = py::array_t<float>(n); float* out = static_cast<float*>(result.request().ptr); for (int i = 0; i < n; i++) { out[i] = ptr[i] > 0.0f ? ptr[i] : 0.0f; } return result; } PYBIND11_MODULE(my_relu, m) { m.def("relu", &cpp_relu, "C++ ReLU"); }# test_relu.py — run after compiling the C++ module above import numpy as np import time import my_relu # the compiled C++ module x = np.random.randn(10_000_000).astype(np.float32) # C++ ReLU start = time.time() for _ in range(100): y_cpp = my_relu.relu(x) cpp_time = (time.time() - start) / 100 # NumPy ReLU start = time.time() for _ in range(100): y_np = np.maximum(x, 0) np_time = (time.time() - start) / 100 print(f"C++ ReLU: {cpp_time*1000:.2f} ms") print(f"NumPy ReLU: {np_time*1000:.2f} ms") print(f"Match: {np.allclose(y_cpp, y_np)}") -
Write a C++ program that demonstrates why memory layout matters. Compare row-major vs column-major access patterns and measure the performance difference.
// task3_layout.cpp // Compile: g++ -O3 -o task3 task3_layout.cpp #include <iostream> #include <chrono> #include <vector> int main() { const int N = 4096; std::vector<float> matrix(N * N, 1.0f); // Row-major access: sequential memory addresses (cache-friendly) auto start = std::chrono::high_resolution_clock::now(); float sum_row = 0.0f; for (int i = 0; i < N; i++) { for (int j = 0; j < N; j++) { sum_row += matrix[i * N + j]; // stride-1 access } } auto end = std::chrono::high_resolution_clock::now(); double row_ms = std::chrono::duration<double, std::milli>(end - start).count(); // Column-major access: stride-N access (cache-unfriendly) start = std::chrono::high_resolution_clock::now(); float sum_col = 0.0f; for (int j = 0; j < N; j++) { for (int i = 0; i < N; i++) { sum_col += matrix[i * N + j]; // stride-N access (cache misses!) } } end = std::chrono::high_resolution_clock::now(); double col_ms = std::chrono::duration<double, std::milli>(end - start).count(); std::cout << "Row-major (cache-friendly): " << row_ms << " ms" << std::endl; std::cout << "Col-major (cache-hostile): " << col_ms << " ms" << std::endl; std::cout << "Slowdown: " << col_ms / row_ms << "x" << std::endl; std::cout << "(Both sums: " << sum_row << ", " << sum_col << ")" << std::endl; return 0; }