Triton and TPUs¶
CUDA C is powerful but verbose. Triton lets you write GPU kernels in Python. TPUs offer an alternative to GPUs with different tradeoffs. This file covers Triton kernel programming, Flash Attention as a case study, TPU architecture and JAX/Pallas, and how to choose the right tool. For Vulkan and cross-platform GPU compute, see file 07.
- The previous file taught GPU programming in CUDA C. This file climbs the abstraction ladder: Triton gives you 80% of CUDA's performance with 20% of the effort, in Python. TPUs and Vulkan provide alternative hardware targets for specific use cases.
Triton: GPU Kernels in Python¶
-
Triton (OpenAI) is a Python-based language for writing GPU kernels. Instead of reasoning about individual threads (CUDA), you reason about blocks of data. Triton's compiler handles thread mapping, memory coalescing, shared memory management, and many optimisations automatically.
-
Why Triton matters: CUDA C requires deep knowledge of warp scheduling, shared memory bank conflicts, register pressure, and coalescing patterns. Triton abstracts most of this away, making GPU kernel development accessible to ML researchers who know Python but not systems programming.
Your First Triton Kernel¶
import triton
import triton.language as tl
import torch
@triton.jit
def add_kernel(
x_ptr, y_ptr, output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr, # compile-time constant
):
# Each program instance processes one block of BLOCK_SIZE elements
pid = tl.program_id(axis=0) # which block am I?
block_start = pid * BLOCK_SIZE
# Offsets for this block
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Mask to handle the case where n_elements is not a multiple of BLOCK_SIZE
mask = offsets < n_elements
# Load data (masked: out-of-bounds reads return 0)
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# Compute
output = x + y
# Store result
tl.store(output_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.empty_like(x)
n_elements = output.numel()
# Launch: one program per block
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output
# Usage
x = torch.randn(1000000, device='cuda')
y = torch.randn(1000000, device='cuda')
z = add(x, y)
- Key differences from CUDA:
- No explicit thread management. You think in blocks (programs), not threads.
tl.arange(0, BLOCK_SIZE)creates a vector of offsets for the entire block. All operations on this vector are implicitly vectorised.maskhandles boundary conditions (like AVX-512 mask registers, file 03). No scalar cleanup loop needed.tl.loadandtl.storehandle coalesced access automatically.@triton.jitcompiles the function to PTX (GPU assembly) at first call, then caches the compiled kernel.
Triton Softmax Kernel¶
- Softmax is a great Triton example because it requires multiple passes over the data (max, subtract, exp, sum, divide) and benefits from keeping data in SRAM (shared memory) between passes:
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr,
):
# Each program handles one row
row_idx = tl.program_id(0)
row_start = input_ptr + row_idx * input_row_stride
# Load the row
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
row = tl.load(row_start + col_offsets, mask=mask, other=-float('inf'))
# Softmax: max for numerical stability, then exp, then normalise
row_max = tl.max(row, axis=0)
numerator = tl.exp(row - row_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Store result
output_start = output_ptr + row_idx * output_row_stride
tl.store(output_start + col_offsets, softmax_output, mask=mask)
- In PyTorch,
F.softmax(x, dim=-1)launches 3 separate kernels (max, exp-and-sum, divide), each reading from and writing to global memory. The Triton version does everything in one kernel, keeping data in registers/SRAM. This kernel fusion is why custom Triton kernels can be 2-4x faster than PyTorch's built-in operations.
Triton Auto-Tuning¶
- Triton supports auto-tuning: try multiple configurations and pick the fastest:
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}),
],
key=['M', 'N', 'K'], # re-tune when these change
)
@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, ...):
...
- Triton benchmarks each configuration on the actual hardware and selects the fastest. The optimal tile sizes depend on the GPU architecture, matrix dimensions, and memory layout — auto-tuning finds them without manual experimentation.
Triton vs CUDA: When to Use Each¶
| Triton | CUDA C | |
|---|---|---|
| Language | Python | C/C++ |
| Abstraction | Block-level | Thread-level |
| Development speed | Fast (10-50 lines per kernel) | Slow (100-500 lines) |
| Performance ceiling | ~80-95% of hand-tuned CUDA | 100% (full hardware control) |
| Shared memory | Automatic | Manual |
| Coalescing | Automatic | Manual |
| Warp-level primitives | Limited | Full (shuffle, vote, etc.) |
| Hardware support | NVIDIA only (AMD experimental) | NVIDIA only |
- Use Triton for: fused kernels, custom attention patterns, activation functions, most ML research kernel needs.
- Use CUDA C for: maximum performance (the last 5-20%), warp-level primitives, complex data-dependent parallelism, when Triton cannot express your pattern.
Case Study: Flash Attention¶
-
Flash Attention (Dao et al., 2022) is the most impactful custom kernel in recent ML. It computes attention with \(O(n)\) memory instead of \(O(n^2)\), enabling much longer sequences.
-
The problem: standard attention computes \(\text{softmax}(QK^T / \sqrt{d}) \cdot V\). The \(QK^T\) matrix is \(n \times n\) where \(n\) is the sequence length. For \(n = 128K\), this matrix is \(128K \times 128K \times 4\) bytes = 64 GB. It does not fit in GPU memory.
-
The insight: you do not need to materialise the full \(n \times n\) matrix. Compute attention in tiles: load a block of \(Q\), a block of \(K\), compute their partial attention scores, accumulate, and move to the next block. The \(n \times n\) matrix is never fully materialised — only one tile at a time exists in SRAM.
-
Online softmax: the tricky part is softmax, which requires knowing the maximum across the entire row (for numerical stability). Flash Attention uses the online softmax trick: maintain a running maximum and rescale previously computed values when a new maximum is found. This allows softmax to be computed incrementally, one tile at a time.
-
The algorithm:
For each block of Q rows:
For each block of K columns:
1. Load Q_block from HBM to SRAM
2. Load K_block from HBM to SRAM
3. Compute S_block = Q_block @ K_block.T (in SRAM)
4. Update running max, rescale previous results
5. Compute exp(S_block - running_max)
6. Update running sum and output accumulator
Load V_block and compute final output
Write output block back to HBM
-
Why it is fast: the inner loop operates entirely in SRAM (shared memory). Global memory (HBM) is only accessed to load blocks of Q, K, V and write the final output. The data reuse factor is proportional to the SRAM size, which is ~100x faster to access than HBM.
-
Flash Attention is implemented in both Triton and CUDA C. The CUDA version is faster (~10% more efficient), but the Triton version is far more readable and modifiable, which matters for research on new attention variants.
TPU Architecture¶
-
TPUs (Tensor Processing Units) are Google's custom ML accelerators. They take a radically different approach from GPUs:
-
Systolic arrays: the TPU's core compute unit is a Matrix Multiply Unit (MXU), a 128×128 or 256×256 systolic array that computes matrix multiplications by flowing data through a grid of multiply-accumulate units. Data enters from the edges and propagates through the array, with each unit performing one multiply-add and passing results to the next.
-
Unlike GPUs (which schedule thousands of independent threads), the systolic array is a single, deterministic data flow. There is no thread scheduling, no warp divergence, no branch prediction. This simplicity makes the MXU extremely energy-efficient for matrix multiplication.
-
HBM: TPUs use the same High Bandwidth Memory as GPUs. TPU v5e has 16 GB HBM2e per chip; TPU v5p has 95 GB HBM2e.
-
ICI (Inter-Chip Interconnect): TPU pods connect hundreds of TPUs with a custom high-speed network. Data parallelism and model parallelism (chapter 6) across TPU pods are supported natively by JAX.
-
BFloat16: TPUs were the first to use bfloat16 (chapter 13, file 02). BF16 has the same exponent range as float32 (preventing overflow during training) with less mantissa precision. This tradeoff is ideal for ML, where gradient values span a wide range but do not need 23 bits of precision.
Programming TPUs: JAX and Pallas¶
- TPUs are programmed through JAX and XLA. You write Python/JAX code,
jax.jitcompiles it to XLA HLO, and XLA compiles HLO to TPU-specific instructions. No CUDA, no C++.
import jax
import jax.numpy as jnp
@jax.jit
def matmul(a, b):
return jnp.dot(a, b)
# This runs on CPU, GPU, or TPU depending on the device
a = jnp.ones((1024, 1024))
b = jnp.ones((1024, 1024))
c = matmul(a, b)
- Pallas is JAX's kernel authoring API — the JAX equivalent of Triton. It lets you write low-level kernels that XLA compiles for GPU or TPU:
from jax.experimental import pallas as pl
import jax.numpy as jnp
def add_kernel(x_ref, y_ref, o_ref):
o_ref[...] = x_ref[...] + y_ref[...]
def add_pallas(x, y):
return pl.pallas_call(
add_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
grid=(x.shape[0] // 128,),
in_specs=[pl.BlockSpec((128,), lambda i: (i,)),
pl.BlockSpec((128,), lambda i: (i,))],
out_specs=pl.BlockSpec((128,), lambda i: (i,)),
)(x, y)
- Pallas is newer and less mature than Triton, but it is the only way to write custom kernels for TPUs (since TPUs do not support CUDA).
GPU vs TPU¶
| GPU (NVIDIA) | TPU (Google) | |
|---|---|---|
| Availability | Any cloud, on-premise | Google Cloud only |
| Programming | CUDA C, Triton, PyTorch | JAX/XLA, Pallas |
| Flexibility | General-purpose compute | Optimised for matrix-heavy ML |
| Peak matmul FLOPS | Very high (Tensor Cores) | Very high (MXU) |
| Non-matmul ops | Good | Slower (routed through vector unit, not MXU) |
| Multi-chip scaling | NVLink (8 GPUs), InfiniBand | ICI (thousands of TPUs, tighter integration) |
| Cost efficiency | Competitive | Often cheaper for large-scale training |
| Ecosystem | Largest (PyTorch, TensorFlow, JAX) | JAX-focused |
- Use GPUs for: most ML workloads, PyTorch-based research, inference serving, workloads with significant non-matmul computation.
- Use TPUs for: large-scale JAX training (thousands of chips), cost-sensitive training on Google Cloud, workloads that are dominated by matrix multiplies.
Choosing the Right Tool¶
| Workload | Best Tool | Why |
|---|---|---|
| ML training (PyTorch) | NVIDIA GPU + CUDA/Triton | Largest ecosystem, best tooling |
| ML training (JAX, large-scale) | TPU or NVIDIA GPU | TPU for cost at Google-scale, GPU for flexibility |
| Custom fused kernels | Triton (Python) or CUDA C | Triton for speed of development, CUDA for peak performance |
| JAX custom kernels | Pallas | Only option for TPU, works on GPU too |
| Cross-platform inference | Vulkan (file 07) or ONNX Runtime | Runs on any GPU vendor |
| Mobile/edge inference | Metal (Apple), Vulkan (Android), NNAPI | Platform-specific accelerators |
| Browser inference | WebGPU (file 07) | Only option in the browser |
| CPU-only inference | ONNX Runtime + AVX/NEON | No GPU needed, uses SIMD (files 02-03) |
| Novel hardware | Vendor-specific SDK | Each accelerator has its own toolchain |
Coding Tasks (use CoLab with GPU runtime)¶
-
Write and run a Triton kernel for vector addition. Compare its performance against PyTorch's built-in addition.
import triton import triton.language as tl import torch import time @triton.jit def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) mask = offs < n x = tl.load(x_ptr + offs, mask=mask) y = tl.load(y_ptr + offs, mask=mask) tl.store(out_ptr + offs, x + y, mask=mask) n = 10_000_000 x = torch.randn(n, device='cuda') y = torch.randn(n, device='cuda') # Triton out_triton = torch.empty_like(x) grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),) add_kernel[grid](x, y, out_triton, n, BLOCK=1024) # PyTorch out_torch = x + y # Verify correctness assert torch.allclose(out_triton, out_torch, atol=1e-5) # Benchmark torch.cuda.synchronize() start = time.time() for _ in range(1000): add_kernel[grid](x, y, out_triton, n, BLOCK=1024) torch.cuda.synchronize() triton_time = (time.time() - start) / 1000 start = time.time() for _ in range(1000): out_torch = x + y torch.cuda.synchronize() torch_time = (time.time() - start) / 1000 print(f"Triton: {triton_time*1000:.3f} ms") print(f"PyTorch: {torch_time*1000:.3f} ms") print(f"Ratio: {torch_time/triton_time:.2f}x") -
Write a Triton fused kernel that does multiply + add + ReLU in a single pass. Compare against three separate PyTorch operations.
import triton import triton.language as tl import torch import time @triton.jit def fused_mul_add_relu_kernel(x_ptr, w_ptr, b_ptr, out_ptr, n, BLOCK: tl.constexpr): pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) mask = offs < n x = tl.load(x_ptr + offs, mask=mask) w = tl.load(w_ptr + offs, mask=mask) b = tl.load(b_ptr + offs, mask=mask) result = tl.maximum(x * w + b, 0.0) # fused: mul + add + relu tl.store(out_ptr + offs, result, mask=mask) n = 10_000_000 x = torch.randn(n, device='cuda') w = torch.randn(n, device='cuda') b = torch.randn(n, device='cuda') # Fused (Triton) out_fused = torch.empty_like(x) grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),) fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024) # Unfused (PyTorch) out_unfused = torch.relu(x * w + b) assert torch.allclose(out_fused, out_unfused, atol=1e-5) # Benchmark torch.cuda.synchronize() start = time.time() for _ in range(1000): fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024) torch.cuda.synchronize() fused_time = (time.time() - start) / 1000 start = time.time() for _ in range(1000): out_unfused = torch.relu(x * w + b) torch.cuda.synchronize() unfused_time = (time.time() - start) / 1000 print(f"Fused (Triton): {fused_time*1000:.3f} ms") print(f"Unfused (PyTorch): {unfused_time*1000:.3f} ms") print(f"Speedup: {unfused_time/fused_time:.2f}x") -
Measure how JAX's XLA compiler automatically fuses operations. Compare a chain of operations with and without jit.
import jax import jax.numpy as jnp import time def chain_ops(x): x = x * 2.0 x = x + 1.0 x = jnp.maximum(x, 0.0) # ReLU x = x / jnp.sum(x) return x chain_jit = jax.jit(chain_ops) x = jax.random.normal(jax.random.PRNGKey(0), (10000, 1000)) # Warm up _ = chain_jit(x) jax.block_until_ready(_) # Eager (each op is a separate kernel launch) start = time.time() for _ in range(100): y = chain_ops(x) jax.block_until_ready(y) eager_time = (time.time() - start) / 100 # JIT (XLA fuses operations) start = time.time() for _ in range(100): y = chain_jit(x) jax.block_until_ready(y) jit_time = (time.time() - start) / 100 print(f"Eager: {eager_time*1000:.2f} ms") print(f"JIT: {jit_time*1000:.2f} ms") print(f"Speedup: {eager_time/jit_time:.1f}x (XLA fuses the 4 operations into 1 kernel)")