GPU Architecture and CUDA¶
GPUs transformed AI by providing thousands of cores for massive parallelism. This file covers GPU vs CPU design philosophy, the GPU memory hierarchy, CUDA programming in C++, the SIMT execution model, memory access patterns, synchronisation, streams, profiling, and NVIDIA GPU generations, the knowledge needed to write and understand GPU kernels.
-
For hands-on CUDA tutorials with full working examples, see the companion repository: github.com/HenryNdubuaku/cuda-tutorials.
-
A modern NVIDIA GPU has over 10,000 CUDA cores. A CPU has 4-128 cores. This 100-1000x core advantage is why GPUs dominate ML: training a transformer requires trillions of multiply-add operations, and GPUs process them in parallel at a scale CPUs cannot match.
-
Even if you never write CUDA kernels yourself, understanding GPU architecture explains: why batch size matters (need enough work to saturate the GPU), why memory is usually the bottleneck (not compute), and why certain operations (scatter, conditional branching) are slow on GPUs.
GPU vs CPU: Fundamentally Different Designs¶
-
A CPU is designed for latency: minimise the time to complete one task. It devotes most of its transistor budget to caches, branch predictors, and out-of-order execution — all tricks to make one thread fast.
-
A GPU is designed for throughput: maximise the number of tasks completed per second. It devotes most transistors to execution units (ALUs). Individual threads are slow, but there are thousands of them.
| CPU | GPU | |
|---|---|---|
| Cores | 4-128 (complex, fast) | 1,000-20,000 (simple, slow) |
| Clock speed | 3-5 GHz | 1-2.5 GHz |
| Cache | Large (32 MB+ L3) | Small (per-SM shared memory) |
| Branch prediction | Sophisticated | None (all threads follow same path) |
| Best for | Low-latency, complex control flow | High-throughput, data-parallel work |
| Typical FLOPS (FP32) | 1-5 TFLOPS | 30-80 TFLOPS |
| Memory bandwidth | 50-100 GB/s | 1-3 TB/s |
- The GPU's memory bandwidth advantage (10-30x) is often more important than its compute advantage. Many ML operations are memory-bound (element-wise ops, normalization, attention), and the GPU's bandwidth lets it feed data to its cores fast enough.
GPU Memory Hierarchy¶
- Understanding GPU memory is critical because memory access is the primary bottleneck, not computation.
| Memory | Size | Latency | Bandwidth | Scope |
|---|---|---|---|---|
| Registers | ~256 KB per SM | 0 cycles | Highest | Per thread |
| Shared memory | 48-228 KB per SM | ~5 cycles | ~20 TB/s | Per thread block |
| L1 cache | 128-256 KB per SM | ~30 cycles | Per SM | |
| L2 cache | 4-96 MB | ~200 cycles | ~6 TB/s | Global |
| Global memory (HBM) | 24-192 GB | ~400 cycles | 1-3.3 TB/s | Global |
-
Registers are the fastest but most limited. Each thread has a private set of registers (typically 255 max). Using too many registers per thread reduces occupancy (fewer threads can run simultaneously).
-
Shared memory is programmer-managed cache shared by all threads in a block. It is the key to writing fast CUDA kernels: load a tile of data from slow global memory to fast shared memory, then compute on it. This is the tiling pattern that dominates GPU programming.
-
Global memory (HBM): the main GPU memory (VRAM). Large but slow (400 cycle latency). All data starts and ends here. The goal of kernel optimisation is to minimise global memory accesses.
CUDA Programming Model¶
- CUDA (Compute Unified Device Architecture) is NVIDIA's programming model for GPUs. You write kernels: functions that run on the GPU, executed by thousands of threads simultaneously.
The Hierarchy: Grids, Blocks, Threads¶
Grid (the entire launch)
├── Block (0,0)
│ ├── Thread (0,0)
│ ├── Thread (1,0)
│ ├── Thread (2,0)
│ └── ... (up to 1024 threads per block)
├── Block (1,0)
│ ├── Thread (0,0)
│ └── ...
└── ... (millions of blocks possible)
- Thread: the smallest unit. Each thread has a unique ID (
threadIdx.x) within its block. - Block: a group of threads that can share memory and synchronise. Block ID:
blockIdx.x. Block size:blockDim.x(up to 1024 threads). -
Grid: all blocks launched by a single kernel. Can be 1D, 2D, or 3D.
-
Each thread computes its global index:
int idx = blockIdx.x * blockDim.x + threadIdx.x;
Your First CUDA Kernel¶
// vector_add.cu — CUDA source file (.cu extension)
#include <stdio.h>
// __global__ marks this as a GPU kernel (called from CPU, runs on GPU)
__global__ void vector_add(const float* a, const float* b, float* c, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) { // bounds check (grid may be larger than data)
c[idx] = a[idx] + b[idx];
}
}
int main() {
int n = 1 << 20; // ~1 million elements
size_t bytes = n * sizeof(float);
// Allocate host (CPU) memory
float *h_a = new float[n];
float *h_b = new float[n];
float *h_c = new float[n];
// Initialise
for (int i = 0; i < n; i++) {
h_a[i] = 1.0f;
h_b[i] = 2.0f;
}
// Allocate device (GPU) memory
float *d_a, *d_b, *d_c;
cudaMalloc(&d_a, bytes);
cudaMalloc(&d_b, bytes);
cudaMalloc(&d_c, bytes);
// Copy data from CPU to GPU
cudaMemcpy(d_a, h_a, bytes, cudaMemcpyHostToDevice);
cudaMemcpy(d_b, h_b, bytes, cudaMemcpyHostToDevice);
// Launch kernel: 256 threads per block, enough blocks to cover n elements
int block_size = 256;
int grid_size = (n + block_size - 1) / block_size; // ceiling division
vector_add<<<grid_size, block_size>>>(d_a, d_b, d_c, n);
// Copy result from GPU to CPU
cudaMemcpy(h_c, d_a, bytes, cudaMemcpyDeviceToHost);
// Verify
printf("c[0] = %f (expected 3.0)\n", h_c[0]);
// Free memory
cudaFree(d_a); cudaFree(d_b); cudaFree(d_c);
delete[] h_a; delete[] h_b; delete[] h_c;
return 0;
}
- Key C++ concepts in CUDA:
__global__: a CUDA keyword marking a kernel function. Called from CPU (host), runs on GPU (device).<<<grid_size, block_size>>>: kernel launch syntax. Specifies how many blocks and threads to use.cudaMalloc/cudaFree: allocate/free GPU memory (likenew/deletebut for the GPU).cudaMemcpy: copy data between CPU and GPU. This is often the biggest bottleneck (PCIe bandwidth is ~32 GB/s, while GPU memory bandwidth is ~3 TB/s).
Warps and SIMT¶
-
The GPU executes threads in groups of 32 called warps. All 32 threads in a warp execute the same instruction at the same time (Single Instruction, Multiple Threads — SIMT). This is the GPU's equivalent of SIMD, but at the thread level.
-
Warp divergence occurs when threads in the same warp take different branches of an
ifstatement. The GPU cannot execute two different instructions simultaneously in one warp, so it executes both branches sequentially, masking out the threads that should not participate. This halves performance (or worse).
// BAD: warp divergence (threads in same warp take different paths)
if (threadIdx.x % 2 == 0) {
c[idx] = a[idx] + b[idx]; // even threads do this
} else {
c[idx] = a[idx] - b[idx]; // odd threads do this (same warp, serialised)
}
// BETTER: branchless (no divergence)
float sign = (threadIdx.x % 2 == 0) ? 1.0f : -1.0f;
c[idx] = a[idx] + sign * b[idx]; // all threads execute the same instruction
Memory Coalescing¶
- Coalesced access: when consecutive threads access consecutive memory addresses, the GPU combines them into a single memory transaction. This is critical for performance.
// GOOD: coalesced — thread 0 reads a[0], thread 1 reads a[1], ...
c[idx] = a[idx] + b[idx];
// BAD: strided — thread 0 reads a[0], thread 1 reads a[stride], ...
c[idx] = a[idx * stride] + b[idx * stride]; // stride > 1 wastes bandwidth
- For a warp of 32 threads, coalesced access loads 128 bytes (32 × 4 bytes for float32) in one transaction. Strided access requires multiple transactions, each loading 128 bytes but using only a fraction. A stride of 32 is the worst case: each transaction loads 128 bytes but only one thread uses 4 bytes (3% utilisation).
Shared Memory and Tiling¶
- The tiling pattern is the most important GPU optimisation technique. The idea: load a block of data from slow global memory into fast shared memory, compute on it, then write results back.
// Matrix multiply with shared memory tiling (simplified)
__global__ void matmul_tiled(const float* A, const float* B, float* C,
int M, int N, int K) {
// Shared memory for one tile of A and one tile of B
__shared__ float tile_A[TILE_SIZE][TILE_SIZE];
__shared__ float tile_B[TILE_SIZE][TILE_SIZE];
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
float sum = 0.0f;
// Loop over tiles
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
// Load one tile of A and B into shared memory
if (row < M && t * TILE_SIZE + threadIdx.x < K)
tile_A[threadIdx.y][threadIdx.x] = A[row * K + t * TILE_SIZE + threadIdx.x];
else
tile_A[threadIdx.y][threadIdx.x] = 0.0f;
if (col < N && t * TILE_SIZE + threadIdx.y < K)
tile_B[threadIdx.y][threadIdx.x] = B[(t * TILE_SIZE + threadIdx.y) * N + col];
else
tile_B[threadIdx.y][threadIdx.x] = 0.0f;
__syncthreads(); // wait for all threads to finish loading
// Compute partial dot product from this tile
for (int k = 0; k < TILE_SIZE; k++) {
sum += tile_A[threadIdx.y][k] * tile_B[k][threadIdx.x];
}
__syncthreads(); // wait before loading the next tile
}
if (row < M && col < N)
C[row * N + col] = sum;
}
__shared__: declares memory shared by all threads in the block (fast, on-chip).__syncthreads(): a barrier that waits until all threads in the block have reached this point. Required between writing to shared memory and reading from it (otherwise some threads read stale data).- Why tiling works: without it, each thread loads from global memory for every multiply. With tiling, a TILE_SIZE × TILE_SIZE block of data is loaded once into shared memory and reused by all threads in the block. The reuse factor is TILE_SIZE, reducing global memory traffic by that factor.
Streams and Concurrency¶
- By default, CUDA operations are sequential: the CPU launches a kernel, waits for it to finish, then launches the next one. Streams enable overlapping:
cudaStream_t stream1, stream2;
cudaStreamCreate(&stream1);
cudaStreamCreate(&stream2);
// These operations can overlap: different streams execute concurrently
cudaMemcpyAsync(d_a, h_a, bytes, cudaMemcpyHostToDevice, stream1);
cudaMemcpyAsync(d_b, h_b, bytes, cudaMemcpyHostToDevice, stream2);
kernel1<<<grid, block, 0, stream1>>>(d_a, d_c);
kernel2<<<grid, block, 0, stream2>>>(d_b, d_d);
- Streams overlap data transfer with computation: while one stream's kernel runs, another stream copies data. This hides the PCIe transfer latency and keeps the GPU busy.
Profiling CUDA Code¶
# NVIDIA Nsight Compute: kernel-level profiling
ncu --set full ./my_program
# NVIDIA Nsight Systems: system-level timeline
nsys profile ./my_program
# Quick metrics
ncu --metrics sm__throughput,dram__throughput ./my_program
- What to look for:
- Occupancy: fraction of the SM's capacity that is used. Low occupancy (< 50%) means too few threads to hide memory latency. Causes: too many registers per thread, too much shared memory per block.
- Memory throughput: compare to peak bandwidth. If you achieve < 50% of peak, memory access patterns are inefficient (non-coalesced, bank conflicts).
- Compute throughput: compare to peak FLOPS. If both memory and compute throughput are low, the kernel is latency-bound (not enough parallelism).
Advanced Optimisation Techniques¶
- Beyond the basics of coalescing and shared memory tiling, high-performance GPU (and CPU) code uses several advanced techniques:
Data Layout: AoS vs SoA¶
- Array of Structures (AoS): each element stores all its fields together.
[{x,y,z}, {x,y,z}, {x,y,z}]. - Structure of Arrays (SoA): each field is stored in its own contiguous array.
{[x,x,x], [y,y,y], [z,z,z]}.
// AoS: BAD for SIMD/GPU (accessing all x values touches non-contiguous memory)
struct Particle { float x, y, z, mass; };
Particle particles[N];
// particles[0].x, particles[1].x are 16 bytes apart
// SoA: GOOD for SIMD/GPU (all x values are contiguous)
struct Particles {
float x[N], y[N], z[N], mass[N];
};
// x[0], x[1] are 4 bytes apart — perfect for coalesced access and SIMD
- SoA is almost always faster for data-parallel workloads (SIMD, GPU). AoS is better when you always access all fields of one element together (rare in numerical code). PyTorch tensors are SoA by nature: each feature is a contiguous dimension.
Software Prefetching¶
- The CPU can be told to start loading data before it is needed, hiding memory latency:
#include <xmmintrin.h> // for _mm_prefetch
for (int i = 0; i < n; i += 4) {
_mm_prefetch((char*)(a + i + 64), _MM_HINT_T0); // prefetch 64 elements ahead
// process a[i:i+4] with SIMD
__m128 va = _mm_load_ps(a + i);
// ...
}
- The prefetch instruction is a hint: if the data is already in cache, it is a no-op. If not, the CPU starts fetching it in the background while executing other instructions. The prefetch distance (64 elements ahead in this example) should be tuned to match the memory latency and loop iteration time.
Kernel Fusion¶
- Kernel fusion combines multiple operations into a single kernel to avoid writing intermediate results to memory. This is the single most impactful GPU optimisation for ML:
// UNFUSED: 3 kernel launches, 3 global memory round-trips
y = matmul(x, W) // write y to global memory
z = y + bias // read y, write z
out = relu(z) // read z, write out
// FUSED: 1 kernel launch, 1 global memory write
out = fused_matmul_bias_relu(x, W, bias) // y and z never leave SRAM
- For memory-bound operations (bias add, ReLU, layer norm), the memory traffic dominates execution time. Fusing eliminates the traffic entirely. PyTorch's
torch.compileand Triton enable fusion automatically or with minimal effort.
Mixed-Precision Kernels¶
- Using lower precision (FP16, BF16, INT8) for computation and higher precision (FP32) for accumulation gives the best of both worlds:
// Tensor Core: multiply FP16 matrices, accumulate in FP32
// Each Tensor Core instruction: D (FP32) = A (FP16) × B (FP16) + C (FP32)
nvcuda::wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
- FP16 is 2x smaller than FP32, so it doubles memory bandwidth (the usual bottleneck) and fits 2x more data in cache. Tensor Cores process FP16 at 8-16x the rate of FP32 CUDA cores. This is why mixed-precision training (chapter 6) provides 2-3x speedup with minimal accuracy loss.
Memory Pool Allocators¶
-
cudaMallocis slow (~1 ms per call) because it synchronises with the GPU. In a training loop that allocates temporary buffers every iteration, this adds up. -
Memory pools (PyTorch's caching allocator, CUDA memory pools) pre-allocate a large block of GPU memory and sub-allocate from it without system calls:
# PyTorch does this automatically — but understanding why matters
# Each torch.empty() reuses memory from the pool, no cudaMalloc
temp = torch.empty(1024, 1024, device='cuda') # microseconds, not milliseconds
- This is why PyTorch's
torch.cuda.memory_allocated()andtorch.cuda.max_memory_allocated()differ: allocated is what is currently in use, max is the peak (the pool may hold more than is currently used).
Profile-Guided Optimisation¶
-
Do not optimise blindly. Profile first, identify the bottleneck, optimise that, and re-profile. The roofline model (file 01) tells you whether the bottleneck is memory or compute:
- Memory-bound (low arithmetic intensity): optimise data layout (SoA), fuse kernels, use lower precision, prefetch.
- Compute-bound (high arithmetic intensity): use Tensor Cores, increase parallelism, use faster instructions (FMA).
- Latency-bound (insufficient parallelism): increase occupancy, reduce register usage, launch more threads.
-
Most ML workloads are memory-bound. The surprising implication: a faster GPU (more FLOPS) often does not help. Faster memory (HBM3 vs HBM2e) helps more. This is why the A100→H100 upgrade is not just about FLOPS — the H100 also has 2x the memory bandwidth.
NVIDIA GPU Generations¶
| Generation | Year | Key Innovation | AI Relevance |
|---|---|---|---|
| Pascal (P100) | 2016 | HBM2, NVLink | First serious deep learning GPU |
| Volta (V100) | 2017 | Tensor Cores (mixed-precision matmul) | Enabled FP16 training, 125 TFLOPS TF32 |
| Ampere (A100) | 2020 | TF32, Sparsity, 3rd gen Tensor Cores | 312 TFLOPS TF32, structural sparsity 2:4 |
| Hopper (H100) | 2022 | Transformer Engine (FP8), HBM3 | 989 TFLOPS FP8, dynamic precision switching |
| Blackwell (B200) | 2024 | 2nd gen Transformer Engine, NVLink 5 | 2.5 PFLOPS FP4, multi-die design |
-
Tensor Cores are specialised matrix multiply units. A single Tensor Core instruction computes a 4×4 matrix multiply (D = A×B + C) in one cycle. Regular CUDA cores would need 64 FMA operations. Tensor Cores are why mixed-precision training (float16 compute, float32 accumulation) is fast.
-
The Transformer Engine (Hopper+) dynamically switches between FP8 and FP16 precision within a single layer, choosing higher precision only where needed. This maximises throughput without sacrificing model quality. It is specifically designed for transformer architectures (attention + MLP), which dominate modern AI.
Coding Tasks (compile with nvcc)¶
-
Write a CUDA kernel that applies ReLU to an array. Measure the time including memory transfers. This teaches kernel writing, cudaMalloc/cudaMemcpy, and the host↔device transfer bottleneck.
// task1_relu.cu // Compile: nvcc -O3 -o task1_relu task1_relu.cu #include <stdio.h> #include <stdlib.h> #include <cuda_runtime.h> __global__ void relu_kernel(const float* input, float* output, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { output[idx] = input[idx] > 0.0f ? input[idx] : 0.0f; } } int main() { const int N = 1 << 24; // ~16M elements size_t bytes = N * sizeof(float); // Allocate host memory float* h_input = (float*)malloc(bytes); float* h_output = (float*)malloc(bytes); for (int i = 0; i < N; i++) { h_input[i] = (float)(i % 100) - 50.0f; // mix of positive and negative } // Allocate device memory float *d_input, *d_output; cudaMalloc(&d_input, bytes); cudaMalloc(&d_output, bytes); // Time the full pipeline: copy to GPU, compute, copy back cudaEvent_t start, stop; cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); cudaMemcpy(d_input, h_input, bytes, cudaMemcpyHostToDevice); int block_size = 256; int grid_size = (N + block_size - 1) / block_size; relu_kernel<<<grid_size, block_size>>>(d_input, d_output, N); cudaMemcpy(h_output, d_output, bytes, cudaMemcpyDeviceToHost); cudaEventRecord(stop); cudaEventSynchronize(stop); float ms = 0; cudaEventElapsedTime(&ms, start, stop); // Verify int errors = 0; for (int i = 0; i < N; i++) { float expected = h_input[i] > 0.0f ? h_input[i] : 0.0f; if (h_output[i] != expected) errors++; } printf("Time (including transfers): %.2f ms\n", ms); printf("Bandwidth: %.1f GB/s\n", 2.0 * bytes / ms / 1e6); // read + write printf("Errors: %d / %d\n", errors, N); cudaFree(d_input); cudaFree(d_output); free(h_input); free(h_output); return 0; } -
Write a tiled matrix multiply in CUDA using shared memory. Compare the performance against a naive (non-tiled) version. This teaches shared memory,
__syncthreads, and why tiling matters.// task2_matmul.cu // Compile: nvcc -O3 -o task2_matmul task2_matmul.cu #include <stdio.h> #include <cuda_runtime.h> #define TILE 16 // Naive matmul: each thread computes one element of C __global__ void matmul_naive(const float* A, const float* B, float* C, int N) { int row = blockIdx.y * blockDim.y + threadIdx.y; int col = blockIdx.x * blockDim.x + threadIdx.x; if (row < N && col < N) { float sum = 0.0f; for (int k = 0; k < N; k++) { sum += A[row * N + k] * B[k * N + col]; } C[row * N + col] = sum; } } // Tiled matmul: use shared memory to reduce global memory accesses __global__ void matmul_tiled(const float* A, const float* B, float* C, int N) { __shared__ float sA[TILE][TILE]; __shared__ float sB[TILE][TILE]; int row = blockIdx.y * TILE + threadIdx.y; int col = blockIdx.x * TILE + threadIdx.x; float sum = 0.0f; for (int t = 0; t < (N + TILE - 1) / TILE; t++) { sA[threadIdx.y][threadIdx.x] = (row < N && t*TILE+threadIdx.x < N) ? A[row * N + t*TILE + threadIdx.x] : 0.0f; sB[threadIdx.y][threadIdx.x] = (t*TILE+threadIdx.y < N && col < N) ? B[(t*TILE + threadIdx.y) * N + col] : 0.0f; __syncthreads(); for (int k = 0; k < TILE; k++) sum += sA[threadIdx.y][k] * sB[k][threadIdx.x]; __syncthreads(); } if (row < N && col < N) C[row * N + col] = sum; } int main() { const int N = 1024; size_t bytes = N * N * sizeof(float); float *d_A, *d_B, *d_C; cudaMalloc(&d_A, bytes); cudaMalloc(&d_B, bytes); cudaMalloc(&d_C, bytes); // Initialise with ones (easy to verify: C should be all N) float* h_A = new float[N*N]; for (int i = 0; i < N*N; i++) h_A[i] = 1.0f; cudaMemcpy(d_A, h_A, bytes, cudaMemcpyHostToDevice); cudaMemcpy(d_B, h_A, bytes, cudaMemcpyHostToDevice); dim3 block(TILE, TILE); dim3 grid((N+TILE-1)/TILE, (N+TILE-1)/TILE); // Benchmark naive cudaEvent_t start, stop; cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); for (int i = 0; i < 10; i++) matmul_naive<<<grid, block>>>(d_A, d_B, d_C, N); cudaEventRecord(stop); cudaEventSynchronize(stop); float naive_ms; cudaEventElapsedTime(&naive_ms, start, stop); // Benchmark tiled cudaEventRecord(start); for (int i = 0; i < 10; i++) matmul_tiled<<<grid, block>>>(d_A, d_B, d_C, N); cudaEventRecord(stop); cudaEventSynchronize(stop); float tiled_ms; cudaEventElapsedTime(&tiled_ms, start, stop); double gflops_naive = 2.0 * N * N * N * 10 / naive_ms / 1e6; double gflops_tiled = 2.0 * N * N * N * 10 / tiled_ms / 1e6; printf("Naive: %.2f ms, %.1f GFLOPS\n", naive_ms/10, gflops_naive); printf("Tiled: %.2f ms, %.1f GFLOPS\n", tiled_ms/10, gflops_tiled); printf("Speedup: %.1fx\n", naive_ms / tiled_ms); cudaFree(d_A); cudaFree(d_B); cudaFree(d_C); delete[] h_A; return 0; } -
Demonstrate warp divergence. Write a kernel where threads in the same warp take different branches, and compare against a branchless version.
// task3_divergence.cu // Compile: nvcc -O3 -o task3_diverge task3_divergence.cu #include <stdio.h> #include <cuda_runtime.h> // BAD: warp divergence — even/odd threads take different paths __global__ void divergent_kernel(float* data, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { if (idx % 2 == 0) { data[idx] = data[idx] * 2.0f + 1.0f; } else { data[idx] = data[idx] * 0.5f - 1.0f; } } } // GOOD: branchless — all threads execute the same instruction __global__ void branchless_kernel(float* data, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { float scale = (idx % 2 == 0) ? 2.0f : 0.5f; float offset = (idx % 2 == 0) ? 1.0f : -1.0f; data[idx] = data[idx] * scale + offset; } } int main() { const int N = 1 << 24; float* d_data; cudaMalloc(&d_data, N * sizeof(float)); cudaMemset(d_data, 0, N * sizeof(float)); int block = 256, grid = (N + block - 1) / block; cudaEvent_t start, stop; cudaEventCreate(&start); cudaEventCreate(&stop); // Divergent cudaEventRecord(start); for (int i = 0; i < 100; i++) divergent_kernel<<<grid, block>>>(d_data, N); cudaEventRecord(stop); cudaEventSynchronize(stop); float div_ms; cudaEventElapsedTime(&div_ms, start, stop); // Branchless cudaEventRecord(start); for (int i = 0; i < 100; i++) branchless_kernel<<<grid, block>>>(d_data, N); cudaEventRecord(stop); cudaEventSynchronize(stop); float nodiv_ms; cudaEventElapsedTime(&nodiv_ms, start, stop); printf("Divergent: %.2f ms\n", div_ms / 100); printf("Branchless: %.2f ms\n", nodiv_ms / 100); printf("Speedup: %.2fx\n", div_ms / nodiv_ms); cudaFree(d_data); return 0; }