CUDA programming for ML engineers
"You don't need to write kernels to need the execution model"
This chapter fills a gap. Every performance conversation in this book — FlashAttention tiling in SRAM (Chapter 25), kernel frameworks and when to reach for Triton (Chapter 38), tensor parallelism with NCCL (Chapter 12), hardware reference numbers (Appendix D) — assumes you understand how a GPU actually runs code. This chapter provides that foundation.
You will not become a kernel author from this chapter. Kernel authorship is a multi-month skill. What you will get is the mental model that makes everything else make sense: why FlashAttention is fast, why continuous batching works, why tensor parallelism uses NCCL and not MPI, why torch.cuda.synchronize() exists. The model is the thing.
Outline:
- Why ML engineers need to understand CUDA.
- GPU architecture: SMs, warps, threads, blocks.
- The memory hierarchy: registers → shared memory → L2 → HBM.
- Writing a basic CUDA kernel.
- Memory coalescing.
- Shared memory and bank conflicts.
- Occupancy and launch configuration.
- CUDA streams and asynchronous execution.
- Kernel fusion: why fewer kernel launches = faster.
- Profiling with NVIDIA Nsight.
- How PyTorch talks to CUDA.
- NCCL: collective communication on GPUs.
- The mental model recap.
39.1 Why ML engineers need to understand CUDA
Most ML engineers never write a CUDA kernel. The kernels are already written — cuBLAS for matmul, FlashAttention for attention, NCCL for collective communication. The standard advice is “just use the libraries.” That advice is correct.
But the libraries make sense only if you understand what they’re abstracting. Consider:
FlashAttention. The claim is “2–4× faster than standard attention with no quality loss.” If you’ve read Chapter 25, you know the mechanism: SRAM tiling eliminates O(s²) HBM round-trips. But why does that work? Why is HBM the bottleneck and not compute? The answer requires understanding the GPU memory hierarchy.
Continuous batching. vLLM batches multiple in-flight requests into a single kernel launch per step. Why does this help? Because the GPU’s throughput machine only hits peak utilization when there’s enough work to fill its SMs. A single decode request leaves 99% of the GPU idle. The answer requires understanding SM occupancy.
Tensor parallelism. A model shard computes its partial results, then all-reduces across GPUs via NCCL. Why not just use MPI? Because the all-reduce needs to happen inside a CUDA stream, synchronized with the kernel that produced the partial result, on the GPU memory directly — no CPU round-trip. The answer requires understanding CUDA streams and device memory.
torch.compile. It generates Triton kernels and fuses operations. Why does fusion help? Because every un-fused operation writes intermediate results to HBM and reads them back. The answer requires understanding kernel launch overhead and the memory-bound regime.
The execution model is the prerequisite. Once you have it, you can read FlashAttention v3’s design document, evaluate whether a new kernel library is worth adopting, and diagnose GPU utilization problems in production. This is what “understanding CUDA” buys you.
39.2 GPU architecture: SMs, warps, threads, blocks
A GPU is a throughput machine — designed to run millions of operations in parallel rather than to run any single operation fast. The architecture reflects that design priority at every level.
The hierarchy, from top to bottom:
The chip. An H100 SXM5 has 132 Streaming Multiprocessors (SMs). Each SM is a semi-independent compute unit with its own registers, shared memory, and scheduler. Think of an SM as a wide, shallow CPU core: it can run many threads simultaneously but each individual thread is slow relative to a CPU core.
The SM. Each SM can run up to 2,048 threads concurrently. Those threads are grouped into warps of 32 threads each, so each SM can have up to 64 warps active. The SM has 4 warp schedulers; each cycle, each scheduler issues one instruction from one warp. An SM has 256 KB of registers and 228 KB of shared memory (on H100).
The warp. A warp is 32 threads that execute in lockstep — they all run the same instruction at the same time. This is SIMT execution: Single Instruction, Multiple Threads. If the threads in a warp take different code paths (a branch divergence), the GPU serializes the two paths, running one with the others masked off. Divergence kills parallelism.
The thread. The atomic unit of GPU execution. Each thread has its own registers and its own program counter. Threads are identified by threadIdx (within a block) and blockIdx (the block’s position in the grid). A thread knows its global index as blockIdx.x * blockDim.x + threadIdx.x.
When you launch a kernel, you specify a grid of blocks. The grid can be 1D, 2D, or 3D. Blocks are assigned to SMs by the runtime scheduler. An SM can run multiple blocks simultaneously (subject to shared memory and register constraints). Blocks are guaranteed to complete before the next kernel launch in the same stream; there is no guarantee about the order blocks run relative to each other.
The key insight from the hierarchy: the GPU hides latency through parallelism, not through caching. When a warp stalls waiting for memory, the warp scheduler switches to another warp. If you have enough warps active, the compute units stay busy during memory stalls. This is called latency hiding through occupancy, and it’s the central design principle of GPU programming.
Concrete H100 numbers worth memorizing:
- 132 SMs, 2,048 threads per SM = 268,416 concurrent threads
- 989 TFLOP/s (BF16 with sparsity), 3.35 TB/s HBM bandwidth
- Compute-bandwidth crossover: ~295 FLOPs/byte (ops the GPU can do per byte of HBM traffic)
- Below 295 FLOPs/byte: memory-bound. Above: compute-bound.
Most attention operations land at ~64–128 FLOPs/byte (arithmetic intensity ≈ d_head). Well below the crossover. Attention is memory-bound on every modern GPU.
39.3 The memory hierarchy: registers → shared memory → L2 → HBM
The single most important thing to internalize about GPU programming is the memory hierarchy. The numbers are what drive every optimization decision.
| Level | Size (H100) | Bandwidth | Latency | Scope | Managed by |
|---|---|---|---|---|---|
| Registers | 256 KB / SM | ~19 TB/s | 0 cycles | Per-thread | Compiler |
| Shared memory (SRAM) | 228 KB / SM | ~19 TB/s | ~20 cycles | Per-block | Programmer |
| L2 cache | 50 MB total | ~12 TB/s | ~200 cycles | Chip-wide | Hardware |
| HBM | 80 GB total | 3.35 TB/s | ~400 cycles | Chip-wide | Programmer |
Registers are per-thread private storage. They’re allocated at compile time. Zero latency — register reads are free. The compiler automatically uses registers for local variables. The risk: using too many registers per thread limits how many threads can run simultaneously, hurting occupancy (§39.7).
Shared memory (SRAM) is the programmable L1 cache. It’s shared among all threads in a block. You explicitly load data from HBM into SRAM, operate on it, and write results back. Because it’s on-chip and close to the compute units, the latency is ~20 cycles and bandwidth is ~19 TB/s — nearly 6× higher throughput than HBM. This is why FlashAttention tiles in SRAM: the attention intermediates live here, never touching HBM.
L2 cache is chip-wide, automatic, and managed by hardware. It sits between SRAM and HBM and catches repeated accesses to the same HBM addresses. Useful for read-heavy patterns (e.g., weights in a decode pass), but you can’t control it directly.
HBM is the main GPU memory. Large (80 GB on H100), persistent between kernel launches, but slow — ~400 cycles latency and 3.35 TB/s bandwidth. Every time you call torch.matmul, the inputs come from HBM and the outputs go to HBM. Kernel fusion (§39.9) is profitable precisely because it eliminates intermediate HBM writes and reads between operations.
Why this explains FlashAttention. Standard attention writes the (s × s) score matrix to HBM after computing it, reads it back to compute softmax, writes the softmax result to HBM, reads it back to multiply with V. For s = 8192, d = 128, that’s ~268 MB of HBM traffic per head per layer, just for intermediates. FlashAttention keeps those intermediates in SRAM. The HBM traffic drops to ~8 MB per head. Same FLOPs, 30× less HBM traffic — and because attention is memory-bound, wall-clock time drops by nearly the same factor.
39.4 Writing a basic CUDA kernel
A CUDA kernel is a C++ function decorated with __global__. When launched, it runs simultaneously on many threads. Each thread computes its own piece of the output.
Vector addition in CUDA C:
// kernel: each thread adds one pair of elements
__global__ void add_vectors(const float* a, const float* b, float* c, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
c[idx] = a[idx] + b[idx];
}
}
// host code: launch the kernel
void launch_add(const float* a, const float* b, float* c, int n) {
int threads_per_block = 256;
int blocks = (n + threads_per_block - 1) / threads_per_block;
add_vectors<<<blocks, threads_per_block>>>(a, b, c, n);
// a, b, c must be device (HBM) pointers — allocated with cudaMalloc
}
The <<<blocks, threads_per_block>>> syntax is CUDA’s kernel launch notation. It specifies the grid dimensions (how many blocks) and block dimensions (how many threads per block). The runtime distributes the blocks across SMs.
Thread indexing: each thread computes idx = blockIdx.x * blockDim.x + threadIdx.x. If you launch 16 blocks of 256 threads, you get thread indices 0–4095. The if (idx < n) guard handles the case where n isn’t a multiple of the block size — the last block may have threads with idx >= n that should do nothing.
The equivalent in Triton (Python):
import triton
import triton.language as tl
@triton.jit
def add_kernel(a_ptr, b_ptr, c_ptr, n, BLOCK: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < n
a = tl.load(a_ptr + offs, mask=mask)
b = tl.load(b_ptr + offs, mask=mask)
tl.store(c_ptr + offs, a + b, mask=mask)
# launch: one program per tile of BLOCK elements
grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)
add_kernel[grid](a_ptr, b_ptr, c_ptr, n, BLOCK=1024)
The Triton version operates on tiles (blocks of elements), not individual elements. tl.program_id(0) is the Triton equivalent of blockIdx.x. tl.arange(0, BLOCK) creates a vector of offsets — Triton’s model is explicitly tile-based, which maps well to how GPU hardware operates. The Triton compiler handles memory coalescing, shared memory allocation, and warp-level details automatically.
Choosing between CUDA C and Triton: for new kernel development, start with Triton. It produces kernels within 80–95% of hand-tuned CUDA performance at roughly 20% of the code volume. Drop to raw CUDA only when you need the last 5–10% of performance, require hardware features Triton doesn’t expose (e.g., TMA on Hopper), or are doing production matmul where CUTLASS templates outperform everything else.
39.5 Memory coalescing
One CUDA optimization dominates all others in practice: memory coalescing. Understanding it changes how you read kernel code and design data layouts.
A warp issues a load instruction with 32 memory addresses — one per thread. If those 32 addresses are contiguous (adjacent 4-byte words), the GPU hardware satisfies the load with a single 128-byte HBM transaction. If the addresses are scattered, it may require up to 32 separate transactions. The throughput difference is 32×.
The coalescing rule: thread N in a warp should access address base + N, not base + N * stride. The addresses across a warp should form a contiguous block.
Practical implications:
Row-major vs column-major for matmul. A matrix A[M][K] stored row-major means row i occupies addresses i*K to i*K+K-1. If each thread computes one row of the output, threads in a warp all access different rows — strided access into A. The standard fix is to tile: load a tile of A into SRAM (where strided access is cheap), then compute from SRAM.
Attention score layout. In the attention forward pass, Q[h, i, d] — head h, query position i, dimension d. If threads in a warp compute different positions i but the same dimension d, they access Q[h, i_warp_lane, d], which is strided by d_head. Kernels like FlashAttention carefully tile and transpose to ensure inner-loop access is coalesced.
Transposed kernels. When you need A^T @ B, naively accessing A in column order is uncoalesced. The standard fix is to load A into a transposed shared memory tile, then access from SRAM coalesced.
The practical takeaway: data layout is a first-class performance concern. When you see a kernel using __shared__ memory for intermediate tiles, coalescing is usually a primary motivation alongside SRAM bandwidth.
39.6 Shared memory and bank conflicts
SRAM is organized into 32 banks, each 4 bytes wide. In one clock cycle, the SRAM can satisfy 32 simultaneous accesses — one from each thread in a warp — as long as each access hits a different bank. If two threads in the same warp access the same bank (but different addresses within it), the accesses serialize: a bank conflict.
Bank mapping: address A maps to bank (A / 4) % 32. So addresses 0, 128, 256, … all map to bank 0; addresses 4, 132, 260, … all map to bank 1; and so on.
Tiled matrix multiplication with shared memory:
The canonical shared memory example is tiled matmul. To compute C = A @ B:
__global__ void tiled_matmul(const float* A, const float* B, float* C,
int M, int K, int N) {
const int TILE = 32;
__shared__ float sA[TILE][TILE];
__shared__ float sB[TILE][TILE];
int row = blockIdx.y * TILE + threadIdx.y;
int col = blockIdx.x * TILE + threadIdx.x;
float acc = 0.0f;
for (int t = 0; t < K / TILE; ++t) {
// Load tile into SRAM — coalesced: each row thread reads consecutive cols
sA[threadIdx.y][threadIdx.x] = A[row * K + t * TILE + threadIdx.x];
sB[threadIdx.y][threadIdx.x] = B[(t * TILE + threadIdx.y) * N + col];
__syncthreads(); // wait for all threads to finish loading
// Accumulate dot product from SRAM — no HBM access inside this loop
for (int k = 0; k < TILE; ++k)
acc += sA[threadIdx.y][k] * sB[k][threadIdx.x];
__syncthreads(); // wait before loading next tile
}
if (row < M && col < N) C[row * N + col] = acc;
}
Each thread block loads a TILE × TILE tile of A and B into SRAM, computes the partial dot products from SRAM, then loads the next tile. The inner loop runs entirely from SRAM (~19 TB/s) instead of HBM (3.35 TB/s). For a TILE = 32 kernel this gives a ~5× speedup from SRAM locality alone.
Bank conflicts in this kernel. When threads access sA[threadIdx.y][k] in the inner loop, all threads in the same row access the same k at the same time — meaning they all access the same address. That’s a broadcast, not a conflict: hardware handles same-address accesses with a broadcast for free. But accessing sB[k][threadIdx.x] has threads in a warp accessing column 0, 1, 2, …, 31 of the same row — those are consecutive addresses, hitting consecutive banks. Coalesced. No conflict.
When bank conflicts do occur. Consider transposing a matrix via SRAM:
__shared__ float tile[32][32];
// Load: thread (ty, tx) reads row-major from HBM — coalesced
tile[threadIdx.y][threadIdx.x] = input[row * N + col];
__syncthreads();
// Store: thread (ty, tx) writes column-major to output — coalesced
output[col * M + row] = tile[threadIdx.x][threadIdx.y]; // TRANSPOSED access
The load from SRAM uses tile[threadIdx.x][threadIdx.y]. All 32 threads in a warp have the same threadIdx.y but different threadIdx.x. They’re accessing tile[0][ty], tile[1][ty], …, tile[31][ty] — elements in the same column, which are 32 floats (128 bytes) apart. That maps to (address/4) % 32 = ty % 32 — all 32 threads hit bank ty. 32-way bank conflict.
The standard fix: pad the SRAM allocation by one element per row.
__shared__ float tile[32][33]; // 33 instead of 32 — one float padding per row
Now the column-access addresses are 33 * i + ty for thread i. The bank is (33i + ty) % 32. Since gcd(33, 32) = 1, all 32 threads hit different banks. Conflict eliminated. One extra float per row is a small price for full bandwidth.
39.7 Occupancy and launch configuration
Occupancy is the ratio of active warps to the maximum warps that could run on an SM. An H100 SM supports 64 warps. If your kernel runs with 32 active warps per SM, occupancy is 50%.
Why does occupancy matter? Because occupancy enables latency hiding. When one warp stalls on a memory load (takes ~400 cycles for HBM), the SM switches to another ready warp. If there are no other ready warps, the SM stalls, and you’re leaving compute cycles on the table.
The limiting resources for occupancy are:
Registers per thread. Each SM has 65,536 32-bit registers (256 KB). If your kernel uses 64 registers per thread, and a block has 256 threads, that’s 16,384 registers per block. You can fit 4 blocks per SM (4 × 16,384 = 65,536), each with 8 warps → 32 active warps → 50% occupancy. Using fewer registers per thread allows more threads to run simultaneously.
Shared memory per block. Each SM has 228 KB of shared memory on H100 (configurable between L1 and SRAM). If your kernel allocates 64 KB of shared memory per block, you can fit at most 3 blocks per SM (3 × 64 = 192 KB < 228 KB). Smaller shared memory allocations allow more blocks.
Threads per block. Blocks must have at most 1,024 threads. Too few threads per block (say 32 = one warp) makes poor use of the SM’s 4 warp schedulers.
Computing theoretical occupancy. Given a kernel with R registers/thread, S bytes of SRAM/block, B threads/block:
warps_per_block = B / 32
blocks_limited_by_registers = floor(65536 / (R * B))
blocks_limited_by_sram = floor(228*1024 / S)
blocks_limited_by_threads = floor(2048 / B)
active_blocks_per_SM = min(all three limits, hardware max_blocks_per_SM)
active_warps = active_blocks_per_SM * warps_per_block
occupancy = active_warps / 64
The occupancy trap. More occupancy is not always better. A kernel that uses many registers per thread runs fewer threads per SM but each thread does more useful work per cycle. A register-heavy kernel with 50% occupancy can outperform a register-sparse kernel with 100% occupancy if the former makes better use of registers to avoid redundant memory loads. Profile, don’t theorize.
The practical heuristic: start with 256 threads per block. It’s a multiple of 32 (full warps), fits in any reasonable register budget, and gives the compiler flexibility. If profiling shows high “warp stall: no ready warp,” increase occupancy. If it shows “warp stall: memory dependency,” you’re memory-bound and occupancy won’t help — fix coalescing or add SRAM tiling instead.
39.8 CUDA streams and asynchronous execution
By default, CUDA kernels are asynchronous from the CPU’s perspective: kernel<<<...>>>() returns before the kernel finishes. The GPU runs the kernel while the CPU continues executing. Operations dispatched in the same CUDA stream execute in order; operations in different streams may overlap.
Default stream semantics in PyTorch:
# All of this dispatches to the default CUDA stream
y = model(x) # kernel launched, CPU returns immediately
loss = criterion(y, target) # depends on y — PyTorch queues this after y
optimizer.step() # queued after loss
# CPU execution is now far ahead of GPU execution
# Both are running concurrently — the CPU is dispatching the next batch
# while the GPU finishes the current one.
# Synchronize: CPU blocks until all GPU work in the default stream is done
torch.cuda.synchronize()
# After this returns, y, loss, gradients are all computed.
torch.cuda.synchronize() exists because operations like timing, logging, or sending data to the CPU require knowing that the GPU has finished. Without it, you might read a PyTorch tensor that the GPU hasn’t finished writing yet.
Non-default streams for overlap:
stream_a = torch.cuda.Stream()
stream_b = torch.cuda.Stream()
with torch.cuda.stream(stream_a):
result_a = heavy_kernel_a(x) # dispatched to stream_a
with torch.cuda.stream(stream_b):
result_b = heavy_kernel_b(y) # dispatched to stream_b — may overlap with a
# Synchronize both streams before using results on CPU
stream_a.synchronize()
stream_b.synchronize()
If the GPU has enough resources (SMs), heavy_kernel_a and heavy_kernel_b run simultaneously. This is how vLLM overlaps communication (NCCL all-reduce) with compute: the all-reduce runs in a separate stream while the next layer’s matmul runs in another.
Practical relevance for serving. In a tensor-parallel serving setup, each GPU computes its shard, NCCL all-reduces, then the result feeds the next layer. The naive approach serializes: compute → communicate → compute → communicate. With streams, you can start the next layer’s matmul on the already-finished portion of the all-reduce output while the all-reduce for the last portion is still running. This overlap is called pipelining communication and compute, and it’s a standard technique in production serving systems.
39.9 Kernel fusion: why fewer kernel launches = faster
Every kernel launch has overhead: ~5–10 microseconds to dispatch from the CPU, plus the cost of reading all inputs from HBM and writing all outputs back to HBM. For operations that are individually cheap, this overhead dominates.
Consider three consecutive elementwise operations:
# Unfused: 3 kernel launches, 3 HBM read/write round-trips
x = a + b # kernel 1: reads a,b from HBM, writes x to HBM
y = relu(x) # kernel 2: reads x from HBM, writes y to HBM
z = y * scale # kernel 3: reads y from HBM, writes z to HBM
Fusing these three operations into one kernel: the kernel loads a and b from HBM into registers, computes a + b, relu, and * scale in registers, and writes only the final result to HBM. The intermediate x and y live in registers and never touch HBM. HBM traffic drops from 5 reads + 3 writes to 2 reads + 1 write.
For elementwise ops on large tensors, the speedup is roughly proportional to the fusion factor. Fusing 3 ops → ~3× less time in the memory-bound case.
FlashAttention as the canonical fusion example. Standard attention is three separate operations: Q @ K^T, softmax, attn @ V. Each writes to HBM and the next reads from HBM. FlashAttention fuses all three into one kernel that keeps the (s × s) intermediate entirely in SRAM. The fusion is the speedup — the math is identical.
torch.compile generating fused kernels. When you run model = torch.compile(model), PyTorch’s TorchInductor backend analyzes the operation graph and fuses compatible elementwise operations into single Triton kernels. A typical transformer forward pass might go from ~300 kernel launches to ~50. The speedup varies: 1.5–2× is common, 3× is possible for models with many small elementwise ops.
When fusion matters most. Fusion is most valuable when:
- Operations are memory-bound (AI < compute-bandwidth crossover of ~295 for H100).
- Intermediate tensors are large (more HBM traffic saved).
- Operation granularity is small (launch overhead is a non-trivial fraction of runtime).
Matmul (typically compute-bound for large shapes) benefits less from fusion with adjacent elementwise ops because the matmul dominates.
39.10 Profiling with NVIDIA Nsight
Theory without measurement is guessing. NVIDIA provides two profilers:
Nsight Systems (nsys) for timeline analysis: what kernels ran, in what order, what the SM utilization was, whether NCCL and compute overlapped, whether there were idle gaps between kernels.
Nsight Compute (ncu) for kernel-level analysis: inside a single kernel, what were the bottlenecks? Memory throughput vs compute throughput, warp stall reasons, cache hit rates.
A 5-minute profiling workflow:
Step 1: Profile with Nsight Systems to find the hot kernels.
nsys profile --trace=cuda,nvtx python my_script.py
nsys-ui report.nsys-rep
Look at the CUDA API timeline. Find the top 3 kernels by total GPU time. Those are your targets.
Step 2: Profile the target kernel with Nsight Compute.
ncu --set full -k target_kernel_name python my_script.py
Step 3: Read the roofline. Nsight Compute shows a roofline model: a plot with compute throughput on one axis and memory throughput on the other. Your kernel appears as a point. If it’s below the memory bandwidth roof, you’re memory-bound. If it’s below the compute roof, you’re compute-bound.
Step 4: Check warp stall reasons. The “Warp State Statistics” section shows why warps weren’t issuing instructions:
Stall MIO throttle/Stall Long Scoreboard: memory latency — load from HBM that hasn’t returned yet. Fix: add more warps (increase occupancy) or load earlier (prefetch).Stall Barrier: waiting at__syncthreads(). Fix: restructure the synchronization points.Stall No Instruction: no warp was ready. Fix: increase occupancy or investigate register pressure.
Step 5: Check memory throughput. If achieved bandwidth is 2.8 TB/s (on an H100 with 3.35 TB/s peak), you’re getting 84% of peak — good. If it’s 1.0 TB/s, you have a coalescing problem.
Key metrics at a glance:
| Metric | Good | Bad | Fix |
|---|---|---|---|
| Achieved occupancy | > 60% | < 30% | Reduce registers or SRAM per block |
| Memory throughput | > 80% of peak | < 50% | Fix coalescing, increase locality |
| Compute throughput | > 80% of peak | < 50% | Increase arithmetic intensity |
| Warp stall: long scoreboard | < 10% | > 50% | Increase occupancy, prefetch |
The most common pattern for ML kernels: memory-bound, moderate occupancy, high long-scoreboard stalls. The fix is usually SRAM tiling or fusion, not occupancy tuning.
39.11 How PyTorch talks to CUDA
When you write torch.matmul(A, B) in Python, the call chain is:
- Python dispatch. The Python
torch.matmulcall enters the ATen dispatcher. - ATen dispatcher. Selects the appropriate kernel based on dtype, device, and shape. For CUDA float/bfloat16 tensors, it dispatches to cuBLAS (or Cutlass, depending on the shape).
- cuBLAS GEMM. cuBLAS selects the best kernel heuristically from its library of pre-compiled kernels, based on matrix shapes, alignment, and GPU architecture. This kernel runs on the GPU.
- Return. Control returns to Python. The output tensor is a device pointer; the CUDA kernel is running asynchronously.
torch.matmul(A, B)
→ ATen dispatcher
→ at::cuda::blas::gemm(...) (C++ ATen op)
→ cublasGemmEx(...) (cuBLAS GEMM API)
→ cublas_hgemm<<<...>>>() (the actual GPU kernel)
For convolutions: torch.nn.functional.conv2d → ATen dispatcher → cuDNN cudnnConvolutionForward.
For custom operations not in cuBLAS/cuDNN, PyTorch provides two paths:
Custom CUDA extensions. Write a __global__ kernel, compile it with torch.utils.cpp_extension.load, and call it from Python. This is how vLLM’s paged attention kernel is integrated.
Triton via torch.compile. Write a Triton kernel decorated with @triton.jit, call it from Python. Or let torch.compile (TorchInductor backend) auto-generate Triton kernels from your PyTorch code.
The memory model. PyTorch tensors on CUDA are backed by cudaMalloc-allocated device memory. The tensor’s .data_ptr() is the raw HBM address. Operations on CUDA tensors are dispatched to kernels that access these HBM addresses. No data moves between CPU and GPU unless you explicitly call .cpu(), .cuda(), or use pinned memory with async copies.
The gradient tape. PyTorch’s autograd tracks operations in a computation graph. For CUDA operations, the backward pass calls corresponding CUDA kernels. loss.backward() dispatches a sequence of CUDA kernels that compute gradients via the chain rule. These run in the same CUDA stream as the forward pass by default.
39.12 NCCL: collective communication on GPUs
When a model is sharded across multiple GPUs (tensor parallelism, data parallelism), those GPUs need to exchange tensors. NCCL (NVIDIA Collective Communications Library) is the standard library for this.
The four collectives you need to know:
All-reduce. Every GPU has a local tensor; after all-reduce, every GPU has the sum (or mean) across all GPUs. Used in data parallelism for gradient averaging: each GPU computes its own batch gradients, all-reduce sums them, every GPU updates with the same aggregated gradient.
All-gather. Each GPU has a shard; after all-gather, every GPU has all the shards concatenated. Used in tensor parallelism when you need the full tensor after computing on a shard (e.g., gathering embedding shards).
Reduce-scatter. Each GPU has a full tensor; after reduce-scatter, GPU i has the reduced (summed) shard i. Used in tensor parallelism during the forward pass: compute on shards, reduce-scatter the partial results so each GPU ends up with a different portion of the output.
Broadcast. One GPU sends its tensor to all others. Used during model loading and weight initialization.
Ring topology. NCCL implements all-reduce using a ring algorithm: the N GPUs are arranged in a ring. Data circulates around the ring in two phases — scatter-reduce and all-gather. Each GPU sends and receives one shard per step. After N-1 steps, every GPU has the sum. The communication volume is 2(N-1)/N × tensor_size per GPU — nearly optimal.
Why NVLink matters. NVLink is NVIDIA’s high-bandwidth GPU-to-GPU interconnect. An H100 SXM has NVLink bandwidth of 900 GB/s bidirectional per GPU in an 8-GPU NVLink switch configuration. PCIe 5.0 offers ~128 GB/s. For a 4×-larger weight shard (say 1 GB) in a TP=8 setup, the all-reduce takes:
- NVLink: ~2 ms (1 GB × 2 / 900 GB/s)
- PCIe: ~16 ms (1 GB × 2 / 128 GB/s)
The NVLink all-reduce completes in the time it takes the matrix multiply to run on the partial shard. The PCIe version does not. This is why TP > 4 effectively requires NVLink — with PCIe, communication dominates compute and efficiency collapses.
NCCL in PyTorch. PyTorch’s torch.distributed wraps NCCL. The idioms:
import torch.distributed as dist
# Initialize the process group (one call at startup)
dist.init_process_group(backend='nccl')
# Gradient all-reduce (called automatically by DDP during backward)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
# Manual all-gather
output_list = [torch.zeros_like(tensor) for _ in range(world_size)]
dist.all_gather(output_list, tensor)
In practice, you rarely call NCCL directly. PyTorch DDP calls all_reduce during backward(). vLLM calls NCCL in its tensor-parallel layers via a hand-written all_reduce call at the end of each parallel linear. Both use a dedicated NCCL stream to allow communication to overlap with computation in the next layer.
39.13 The mental model recap
Five things every ML engineer should remember from this chapter:
(1) GPU is a throughput machine, not a latency machine. A single operation is slow on a GPU — it takes ~400 cycles just to load from HBM. The GPU wins by running 268,416 threads simultaneously, hiding the latency of any one thread behind the work of thousands of others. Design for throughput: feed the GPU enough work to keep all its SMs busy.
(2) HBM bandwidth is the bottleneck, not compute. The H100 can do 989 TFLOP/s of BF16 math, but only at 3.35 TB/s of HBM bandwidth. Most ML operations — attention, elementwise ops, layer norm, sampling — have arithmetic intensity below the ~295 FLOPs/byte crossover. They are memory-bound. Adding compute hardware does not help. Only reducing HBM traffic helps. This is why FlashAttention matters, why quantization matters (fewer bytes = less HBM traffic), and why batching matters (amortizes weight reads across more tokens).
(3) Shared memory (SRAM) is your secret weapon. It’s 6× faster than HBM, programmer-controlled, and the basis for every major ML kernel optimization. Tiling, caching weight tiles, keeping attention intermediates in SRAM — all of these exploit SRAM bandwidth. When a kernel is unusually fast, SRAM is usually why.
(4) Fuse kernels to avoid HBM round-trips. Every unfused kernel reads its inputs from HBM and writes its outputs to HBM. For memory-bound operations, this overhead dominates. FlashAttention fuses three operations, torch.compile fuses tens of them. The principle is the same: keep data in registers and SRAM, minimize HBM writes until you have a final result.
(5) Profile before optimizing. The above principles are true on average but not universal. A kernel might be compute-bound, not memory-bound. An operation might have high occupancy that’s already wasting no time. Optimization based on theory without profiling routinely produces surprises. ncu and nsys take 5 minutes to run and give you ground truth. Use them.
Read it yourself
- NVIDIA CUDA C Programming Guide. Canonical reference. Read Chapters 4 (programming model), 5 (memory hierarchy), and 6 (performance guidelines) first.
- Triton documentation and tutorials at triton-lang.org. The “vector addition” and “matrix multiplication” tutorials cover everything in §39.4–§39.6.
- NVIDIA Nsight Compute documentation. Especially the “Roofline Analysis” and “Memory Workload Analysis” sections.
- Hwu, Kirk, and Hajj, Programming Massively Parallel Processors (4th edition, 2022). The definitive textbook. Chapters 4–7 on memory hierarchy and Chapter 16 on profiling are essential.
- Simon Boehm’s blog post “How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance.” A worked example taking a naive matmul to 90% of cuBLAS performance step by step — each step corresponds to a concept in this chapter.
- The NCCL GitHub repository README. Clear explanation of the ring algorithm and topology selection.
Practice
-
An H100 kernel processes a tensor of shape
(4096, 4096)in BF16 with one elementwise operation per element. Compute the HBM traffic (bytes read + written). At 3.35 TB/s, how long should it take? How many FLOPs does it do? Is it compute-bound or memory-bound? -
A kernel uses 64 registers per thread and 48 KB of shared memory per block, with 256 threads per block. Compute the theoretical occupancy on an H100 SM (65,536 registers, 228 KB SRAM, max 2,048 threads, max 32 blocks per SM).
-
Explain why
sA[threadIdx.y][32]causes a 32-way bank conflict in a transpose kernel butsA[threadIdx.y][33]does not. What is the relationship between the row stride and the number of banks? -
A Triton kernel processes a row of 8192 elements in tiles of 128. How many
tl.program_id(0)values does the kernel run with? How does this map to warp count on the GPU? -
Trace the call path from
torch.nn.Linear.forward(x)to an actual cuBLAS GEMM kernel. Identify each layer: Python, ATen, cuBLAS API, kernel. Usetorch.profilerornsysto confirm. -
In a TP=8 configuration on an 8-GPU node connected via NVLink (900 GB/s), each GPU shards a weight matrix of total size 2 GB. After multiplying by a local input shard, each GPU has a partial output of size 512 MB. How long does the all-reduce take? Compare to a PCIe-connected node at 128 GB/s. At what model layer size does the communication exceed the compute time (assume 1 ms per matmul step)?
-
Stretch: Implement the tiled matrix multiplication kernel from §39.6 in Triton. Compare performance to
torch.matmulon shapes(1024, 1024) @ (1024, 1024)and(4096, 4096) @ (4096, 4096). Profile withncuand identify whether your kernel is memory-bound or compute-bound. Add SRAM padding to fix any bank conflicts and measure the improvement.