Part III · Inference Internals & Production Serving
Chapter 25 Core ~18 min read

FlashAttention and the GPU memory hierarchy

"The bottleneck of attention is not multiplication. It is reads and writes"

In Chapters 22–24 we built the data structure (KV cache) and the storage layout (PagedAttention) for the keys and values that attention needs. In this chapter we look at how to compute the attention itself efficiently. The technique is FlashAttention (Dao et al., 2022), and like most great ideas, it’s a memory-access optimization disguised as a math optimization.

The standard attention formulation we built in Chapter 6 — softmax(Q K^T / √d_k) V — has a hidden cost: it materializes the (s × s) attention matrix as an intermediate. For long sequences this matrix is enormous and its read/write traffic dominates the total runtime. FlashAttention rewrites the computation to never materialize the full matrix by computing it in tiles that fit in GPU SRAM. The result is a 2–4× speedup on attention with no quality loss, and it’s the kernel that every modern serving stack uses.

This chapter goes deep on the GPU memory hierarchy that makes FlashAttention work, the tiling math, and why the technique was such a breakthrough.

Outline:

  1. The GPU memory hierarchy.
  2. The arithmetic intensity of attention.
  3. The standard attention’s memory cost.
  4. The tiling idea.
  5. The online softmax trick.
  6. FlashAttention v1, v2, v3.
  7. Why FlashAttention isn’t faster in FLOPs but is faster in wall-clock.
  8. FlashAttention with PagedAttention.
  9. The state of attention kernels in 2025.

25.1 The GPU memory hierarchy

You can’t understand FlashAttention without understanding how GPU memory is organized. There are several levels:

HBM (High Bandwidth Memory). The main GPU DRAM. On an H100, ~80 GB. Bandwidth is fast (3 TB/s) but slow relative to compute (the compute-bandwidth crossover is ~330 FLOPs/byte). This is where your tensors live.

L2 cache. On-chip cache. ~50 MB on H100. Faster than HBM (~5 TB/s effective) but smaller. Shared across all SMs. Mostly automatic.

SRAM (a.k.a. shared memory, scratchpad). On-chip, per-SM. ~228 KB per SM on H100, ~33 MB total across all SMs. Very fast (~19 TB/s effective). Programmer-controlled — you have to explicitly load data into it. This is where the speed lives.

Registers. Per-thread storage. Tiny (a few KB per thread) but instantaneous.

GPU memory hierarchy: HBM is large but slow; L2 cache is medium; SRAM is tiny but 6x faster than HBM; registers are the fastest but smallest. HBM — 80 GB — 3 TB/s L2 cache — 50 MB — ~5 TB/s SRAM — 33 MB — 19 TB/s Registers ~fastest ← large, slow, programmer-invisible ← medium, automatic caching ← FlashAttention keeps intermediates here ← per-thread, instantaneous speed / size
SRAM is 6× faster than HBM but 2000× smaller — FlashAttention's entire speedup comes from keeping the O(s²) attention intermediates in SRAM so they never touch the slow HBM path.

The key insight: the bandwidth gap between HBM and SRAM is roughly . If you can keep data in SRAM for as much computation as possible — only loading from HBM occasionally — you can achieve much higher effective throughput than if you stream data through HBM constantly.

This is the memory hierarchy game: structure your computation so that data lives in the fastest memory tier for as long as possible. Modern matmul kernels (CUTLASS, cuBLAS) are masters of this: they tile the matmul into chunks small enough to fit in SRAM, do all the compute on the tile, write the result back to HBM, repeat.

FlashAttention applies the same idea to attention.

25.2 The arithmetic intensity of attention

Recall the attention computation:

S = Q K^T / √d_k     # (s, s)
P = softmax(S)       # (s, s)
O = P V              # (s, d)

The intermediate S and P are (s, s) tensors. For long sequences they’re enormous: at s = 8192 and bf16, each one is 8192 × 8192 × 2 = 134 MB. The model stores S to compute P, then stores P to compute O.

The total HBM traffic for standard attention:

  • Read Q: s × d × 2 bytes
  • Read K: s × d × 2 bytes
  • Compute S = Q K^T: write s × s × 2 bytes to HBM
  • Read S to compute P: s × s × 2 bytes
  • Compute P = softmax(S): write s × s × 2 bytes to HBM
  • Read P and V to compute O: (s × s + s × d) × 2 bytes
  • Write O: s × d × 2 bytes

The dominant cost for large s is the O(s²) traffic for S and P. The compute is O(s² × d), which for typical d = 64 head dim, gives an arithmetic intensity of:

AI ≈ (s² × d) / (s² + s × d) ≈ d   (for large s)

So the arithmetic intensity of attention is roughly d_head ≈ 64 to 128. Below the compute-bandwidth crossover (~330). Attention is memory-bound, just like decode in Chapter 21. The bottleneck is HBM bandwidth.

This is the hidden inefficiency. Even though the FLOPs of attention scale as s² × d, the wall-clock is dominated by the HBM traffic for the (s × s) intermediates, not by the compute.

25.3 The standard attention’s memory cost

To make this concrete: for s = 8192, d = 128, H = 32 heads, the per-head intermediate attention matrices are:

  • S: 8192 × 8192 × 2 bytes = 134 MB
  • P: 8192 × 8192 × 2 bytes = 134 MB
  • Total per head: 268 MB
  • Total per layer (32 heads): ~8.5 GB of HBM traffic just for S and P.

For 80 layers in Llama 70B, that’s 680 GB of HBM traffic per forward pass just for the attention intermediates. At 3 TB/s of HBM bandwidth, that’s:

680 / 3000 ≈ 0.23 seconds

…of pure HBM time spent on S and P per forward pass for an 8k-token prompt. The matmul compute itself is much faster than this. Attention is memory-bound, and the (s × s) intermediates are the cost.

For shorter sequences, the cost is smaller, but it grows quadratically. At s = 32k, the intermediates are 16× larger per head, and attention can take seconds per forward pass.

This is the problem FlashAttention solves.

25.4 The tiling idea

The FlashAttention insight: you don’t have to materialize the (s × s) attention matrix in HBM. You can compute attention by processing the sequence in tiles that fit in SRAM, with the intermediate S and P living entirely in SRAM and never being written to HBM.

The structure:

  1. Split Q into row tiles (chunks of query positions). Each tile is small enough to fit in SRAM.
  2. Split K and V into column tiles (chunks of key/value positions). Each tile fits in SRAM.
  3. For each Q tile:
    • Load it into SRAM.
    • For each K, V tile:
      • Load K and V tiles into SRAM.
      • Compute the partial attention scores Q_tile @ K_tile^T in SRAM.
      • Apply the (incremental) softmax — see §25.5.
      • Compute the partial output softmax(scores) @ V_tile in SRAM.
      • Accumulate into the running output for this Q tile.
    • Write the final output for this Q tile to HBM.

The total HBM traffic is:

  • Read Q once: s × d × 2 bytes
  • Read K once: s × d × 2 bytes
  • Read V once: s × d × 2 bytes
  • Write O once: s × d × 2 bytes
FlashAttention tile loop: Q row-tiles and K/V column-tiles are loaded into SRAM and intermediate scores never touch HBM. HBM — Q, K, V, O tensors (read/write once) Q (s × d) Q tile K V O output SRAM — tile compute (intermediates S, P stay here) Q_tile @ K_tile^T → scores S_tile online_softmax(S_tile) → P_tile → P_tile @ V_tile → O_partial S and P never written to HBM Standard attention: S → HBM → P → HBM → O (O(s²) traffic) eliminated by FlashAttention
FlashAttention tiles Q into row-chunks and K/V into column-chunks, keeping all intermediate attention scores in SRAM — the entire O(s²) HBM write/read round-trip from standard attention is eliminated.

That’s it. No O(s²) HBM traffic. The intermediates live entirely in SRAM and are computed and discarded as the kernel processes tiles.

For s = 8192, the HBM traffic is now 4 × 8192 × 128 × 2 ≈ 8.4 MB per head, instead of 268 MB. A 30× reduction in HBM traffic for the attention computation.

The compute is the same (you’re still doing the same number of FLOPs), but because attention was memory-bound, the wall-clock improvement is dramatic.

25.5 The online softmax trick

The hard part is the softmax. Standard softmax requires the maximum of all logits and the sum of exponentiated logits, both of which are global properties of the row. If you process the row in tiles, you don’t have access to the global max or sum until you’ve seen all the tiles.

The fix is the online softmax algorithm. As you process tiles, you maintain a running estimate of the row max and the row exponential sum, and you adjust the partial output accordingly when you see a new tile that changes the max.

Concretely, for one query row and K/V split into tiles 1, 2, ..., T:

Initialize:
    m = -inf       # running max
    l = 0          # running sum of exp(logits - m)
    o = 0          # running output (vector of size d)

For each K, V tile t:
    s_t = Q_row @ K_t^T    # logits for this tile
    m_new = max(m, max(s_t))    # updated row max
    
    # Rescale the previous accumulator to reflect the new max
    l_new = exp(m - m_new) * l + sum(exp(s_t - m_new))
    o = exp(m - m_new) * o + exp(s_t - m_new) @ V_t
    
    m = m_new
    l = l_new

Final output for this row:
    o = o / l    # divide by the final exponential sum

The key is the exp(m - m_new) rescaling factor. Whenever you encounter a new max, you rescale the previously accumulated values to be consistent with the new max. By the end, you have the same result as if you’d computed the softmax in one pass over the full row.

This is mathematically exact (modulo floating point) — no approximation. The online softmax produces bit-identical (well, fp-equivalent) results to the standard softmax. The savings are entirely in memory access, not in math.

The online softmax was actually known before FlashAttention; the Milakov & Gimelshein paper Online normalizer calculation for softmax (2018) introduced it. FlashAttention’s contribution was applying it to attention in a fused kernel that also handled the matmul tiling.

25.6 FlashAttention v1, v2, v3

The FlashAttention line of work has had three major versions, each with significant improvements.

FlashAttention v1 (Dao et al., 2022)

The original. Introduced the tiling + online softmax approach. Achieved ~2× speedup over PyTorch’s standard attention on long sequences. Was immediately adopted by HuggingFace, vLLM, and basically every other library that needed fast attention.

The v1 implementation was a single CUDA kernel that handled the forward pass. The backward pass was also fused but had some inefficiency in how it tiled over the query dimension.

FlashAttention v2 (Dao, 2023)

Significant refinement. The main improvements:

  • Better parallelism. v1 parallelized over the batch and head dimensions, but kept each sequence’s processing serialized. v2 also parallelizes over the sequence length by splitting the query rows across thread blocks. Critical for short batch sizes (where there isn’t enough batch * head parallelism to fill the GPU).
  • Better work partitioning. v1 had each warp compute a full block of attention; v2 splits the work more cleverly so that warps within a block share data via SRAM, reducing redundant loads.
  • 2× faster than v1 on most workloads.

FlashAttention v2 is the most widely deployed version. It’s the default in vLLM, PyTorch’s scaled_dot_product_attention (when the inputs are large enough), and most other frameworks.

FlashAttention v3 (Shah et al., 2024)

Hopper-specific (H100/H200). Uses Hopper’s new hardware features:

  • Asynchronous matmul (WGMMA). Hopper introduced warpgroup matrix-multiply-accumulate, a new instruction that can do matmul asynchronously with the rest of the kernel. v3 uses this to overlap compute with data loading.
  • Asynchronous data movement (TMA). Hopper’s Tensor Memory Accelerator allows hardware-accelerated tile loading from HBM to SRAM. v3 uses TMA to free up the threads that would otherwise be loading data.
  • fp8 support. v3 adds first-class fp8 attention, taking advantage of Hopper’s fp8 Tensor Cores.

v3 is another 1.5–2× faster than v2 on H100. It’s the state of the art for Hopper-based serving as of late 2025. Not yet supported on older GPUs.

25.7 Why FlashAttention isn’t faster in FLOPs but is faster in wall-clock

A subtle point worth being explicit about: FlashAttention does not change the FLOP count of attention. It computes exactly the same operations — Q K^T / √d, softmax, multiply by V. The total number of multiplies and adds is unchanged.

HBM traffic comparison: standard attention reads and writes O(s squared) for intermediate matrices; FlashAttention reads and writes only O(s times d) for inputs and outputs. HBM traffic: standard attention vs FlashAttention (s=8192, d=128) Standard Q+K write S (134 MB) + read S write P (134 MB) + read P V + O ≈ 268 MB/head Flash Q+K+V+O only ≈ 8.4 MB/head — 30× less HBM traffic 30× reduction in HBM traffic — same FLOPs, faster wall-clock
FlashAttention does not save any FLOPs but cuts HBM traffic by ~30× for long sequences — because attention was memory-bound, the wall-clock time falls by nearly the same factor.

What changes is the memory access pattern. Standard attention reads and writes the (s × s) intermediates to HBM; FlashAttention keeps them in SRAM. The HBM traffic difference is the entire wall-clock difference.

This is why FlashAttention is sometimes confusing: people see the speedup and think “it must be doing fewer operations,” but it’s not. It’s doing the same operations with better locality. The lesson is the lesson of Chapter 21: HBM bandwidth is the bottleneck, not compute.

This also tells you why FlashAttention can’t speed up tasks that are already compute-bound. If you’re running attention on a very small sequence (s = 64 say), the matrix is tiny and HBM traffic is already negligible. FlashAttention gives no speedup for short sequences. The benefit grows with s.

For decode attention (which has s_query = 1 and s_key = S_so_far), FlashAttention is essential because you’re scanning the entire KV cache for each query. The decode-attention variant is sometimes called FlashAttention-Decoding or FlashDecoding and is even more memory-aware than the prefill version.

25.8 FlashAttention with PagedAttention

The natural question: how does FlashAttention compose with PagedAttention (Chapter 24)? PagedAttention stores K and V in scattered physical blocks; FlashAttention assumes contiguous K and V tensors. How do they coexist?

The answer is that the FlashAttention kernel is rewritten to read from blocked KV cache. The kernel takes two extra inputs:

  • The block table (per-sequence list of physical block IDs).
  • The block size (number of tokens per block).

Inside the kernel, when it needs to read the K and V for a particular range of token positions, it uses the block table to find the right physical blocks, loads them into SRAM, and proceeds with the same tiled-softmax computation as the standard FlashAttention.

The performance hit from the block-table indirection is small (a few percent) because the block table itself is tiny and stays in cache.

In vLLM, this is the kernel called flash_attn_with_paged_kv_cache. It’s the workhorse of vLLM’s decode attention. SGLang has a similar implementation. TensorRT-LLM has its own kernel that does the same thing.

The point: PagedAttention and FlashAttention are complementary, and modern serving stacks combine them. The FlashAttention paper’s tiling math + the PagedAttention paper’s block tables = the kernel that runs in production.

25.9 The state of attention kernels in 2025

The attention kernel landscape has matured significantly. As of late 2025, the main implementations:

FlashAttention 3 (flash-attn library). The reference implementation by Tri Dao. H100/H200 only. Used by vLLM, PyTorch SDPA, and many others. Open source, BSD license.

xFormers memory_efficient_attention. Meta’s implementation. Has some features FlashAttention doesn’t (e.g., custom attention biases). Often used as a fallback when FlashAttention doesn’t support a specific configuration.

PyTorch scaled_dot_product_attention. PyTorch’s official wrapper. Selects between FlashAttention, xFormers, and a fallback implementation based on the inputs. The recommended default for new code.

CUTLASS-based custom kernels. Some labs (especially TensorRT-LLM, SGLang) write their own attention kernels using NVIDIA’s CUTLASS template library. They achieve performance comparable to FlashAttention with more flexibility (e.g., custom data types, custom masking).

llm.c / native CUDA implementations. Karpathy’s llm.c and similar projects implement attention from scratch in raw CUDA for educational purposes. Slower than FlashAttention but readable.

Triton-based kernels. Triton (the Python DSL for GPU kernels) has reference implementations of FlashAttention that are easier to read and modify than the C++/CUDA original. The flash-attn library actually has a Triton implementation alongside the CUDA one.

For production: just use the FlashAttention 3 library or PyTorch’s SDPA. They’re highly optimized and free. Don’t write your own attention kernel unless you have a specific reason (custom data types, custom masking, exotic hardware).

25.10 The mental model

Eight points to take into Chapter 26:

  1. Attention is memory-bound. Its arithmetic intensity is ~d_head, far below the compute-bandwidth crossover.
  2. The (s × s) intermediate is the cost. It dominates HBM traffic in standard attention.
  3. FlashAttention tiles the computation so the intermediate lives in SRAM and never hits HBM.
  4. The online softmax allows tiled processing of the rows by maintaining running max and sum, with rescaling as new tiles arrive.
  5. FlashAttention does not reduce FLOPs, only HBM traffic. The wall-clock speedup is purely from memory access.
  6. v1 → v2 → v3 each added another 1.5–2×. v3 is Hopper-only and uses async matmul and TMA.
  7. PagedAttention + FlashAttention = the production attention kernel. The kernel reads from blocked KV via the block table.
  8. Use flash-attn v3 or PyTorch SDPA. Don’t write your own.

In Chapter 26 we look at the other major memory optimization for serving: quantization.


Read it yourself

  • Dao et al., FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (2022). The original paper.
  • Dao, FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (2023).
  • Shah et al., FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (2024).
  • Milakov & Gimelshein, Online normalizer calculation for softmax (2018). The online softmax paper.
  • The flash-attn GitHub repository. Read the Python wrapper and the CUDA kernel header for the v3 implementation.
  • Horace He’s blog post on FlashAttention internals.

Practice

  1. Compute the HBM traffic for standard attention vs FlashAttention on a sequence of length 16384 with d_head = 128. What’s the ratio?
  2. Why does the online softmax need a rescaling factor exp(m - m_new)? Derive it from the standard softmax formula.
  3. The FlashAttention kernel keeps tiles in SRAM. Look up the SRAM size of an H100 SM (~228 KB). How big can a tile be in fp16 before it doesn’t fit?
  4. Why doesn’t FlashAttention help when s is small (say, 64)? Compute the HBM traffic for standard attention at s=64 and check if it dominates anything.
  5. Implement a tiny “online softmax” in pure Python that processes a vector in chunks. Verify it gives the same result as torch.softmax on the full vector.
  6. Why does FlashAttention v3 need Hopper hardware? What specific instructions does it use that aren’t available on Ampere?
  7. Stretch: Read the FlashAttention CUDA kernel in flash-attn/csrc/flash_attn/. Identify the tile loop, the online softmax update, and the output accumulation.

Concept check

4 questions. Click a choice to check. Your score is saved locally.

Score
0 / 4
  1. 1. Standard attention materializes the full (S x S) attention score matrix as an intermediate tensor. Why is this the dominant cost for long sequences?
  2. 2. FlashAttention computes the same mathematical result as standard attention but without materializing the full attention matrix. What is the key algorithmic insight that makes this possible?
  3. 3. FlashAttention achieves wall-clock speedup without reducing FLOPs. Why does reducing HBM reads and writes translate to faster wall-clock time on GPUs?
  4. 4. FlashAttention-2 introduced work partitioning improvements. Why does distributing work across warps within a thread block improve performance specifically for the attention computation?
Related chapters