Part III · Inference Internals & Production Serving
Chapter 22 Core ~22 min read

The KV cache: the single most important optimization in LLM inference

"The KV cache is the data structure. PagedAttention is how you store it. FlashAttention is how you compute it. Prefix caching is how you share it. Disaggregation is how you transfer it. Everything in modern LLM serving is about the KV cache"

In Chapter 21 we established that decode is memory-bandwidth-bound and that the only way to make it efficient is to avoid recomputation. The technique that does this is the KV cache, and it is the most important optimization in LLM inference. Period. Every later chapter in Part III is either an optimization of the KV cache, a way to share the KV cache, a way to compute attention against the KV cache faster, or a way to reduce its memory footprint.

This chapter goes deep on the KV cache: what it stores, how big it gets, why it dominates GPU memory at scale, and what its existence implies for everything else.

Outline:

  1. The recomputation problem.
  2. What the KV cache stores.
  3. The KV cache size formula — memorize this.
  4. Why the cache dominates memory at scale.
  5. The decode step with a KV cache.
  6. The prefill step and how it populates the cache.
  7. The trade-off: memory vs latency.
  8. Per-request KV cache vs shared KV cache.
  9. KV cache management strategies.
  10. Forward pointers.

22.1 The recomputation problem

Recall the naive decoding loop:

tokens = prompt_ids
for step in range(max_new_tokens):
    logits = model(tokens)            # forward over the WHOLE sequence
    next_token = sample(logits[-1])
    tokens = torch.cat([tokens, next_token])

At step t, the model runs a forward pass over the entire sequence up to tS_prompt + t tokens — even though S_prompt + t - 1 of those tokens were already processed in the previous step. The model is doing the same computation repeatedly, just with one extra token at the end each time.

The wasted work is large. For a generation of T output tokens after a prompt of length S, the naive approach does:

Total forward passes ≈ S + (S+1) + (S+2) + ... + (S+T-1) tokens
                    ≈ S × T + T(T-1)/2 tokens of computation

For S = 1000, T = 200, that’s about 220,000 token-passes — to generate 200 tokens. You’re doing 1100× more work than necessary.

The waste is even worse than it looks because of the O(s²) attention scaling. A 1200-token forward pass is about 4× the work of a 600-token one (due to attention quadratic), so re-running the long sequences over and over compounds the cost.

The fix is the KV cache: store the K and V vectors from past tokens and never recompute them.

Without KV cache, every decode step re-processes the full sequence; with KV cache, only the new token is processed. WITHOUT KV cache (step t=3) t₀ t₁ t₂ t₃ redundant recomputation ←→ all 4 tokens forwarded every step O(S²) total work WITH KV cache (step t=3) t₀ t₁ t₂ t₃ cached K,V (read-only) only new 1 token forwarded O(1) per step
Without a KV cache, each new token forces a full-sequence forward pass; with it, only the single new token runs through the model, collapsing quadratic recomputation to constant per-step work.

22.2 What the KV cache stores

For each transformer block in the model, attention computes Q, K, V projections of the input. Recall the shapes (Chapter 6):

  • Q: (N, H, S, d_h)
  • K: (N, H, S, d_h)
  • V: (N, H, S, d_h)

The output of attention depends on:

  • The current Q (which is computed fresh from the current token’s input).
  • The current and past K and V (which encode “what the past tokens are saying”).

Crucial observation: in a causal model, the K and V vectors for past tokens never change. Token 5’s K and V are determined by token 5’s input, which is fixed. Adding token 6 to the sequence doesn’t modify token 5 in any way.

So we can:

  1. Compute K and V once, when the token is first seen.
  2. Save them in a cache, indexed by (layer, head, position).
  3. At every subsequent step, only compute K and V for the new token, append to the cache, and use the full cached K and V to compute attention.

The cache is the K and V tensors from every token that’s been seen so far. The Q vectors are not cached because they’re only needed at the position they’re computed for — once you’ve used a Q to compute its attention output, you don’t need it again.

KV cache structure: for each transformer layer, K and V tensors grow as new tokens are decoded, while Q is computed fresh each step and discarded. KV Cache — one transformer layer K cache (n_kv_heads × S × d_h) k₀ k₁ k₂ kₜ ← new append V cache (n_kv_heads × S × d_h) v₀ v₁ v₂ vₜ Q (current) computed fresh, NOT cached qₜ ← used once, then discarded scores = qₜ @ K_full^T → softmax → output = attn @ V_full
K and V are stored permanently and grow by one row per decode step; Q is ephemeral — it is computed, used to attend over the full K and V cache, then discarded.

The KV cache replaces an O(S²) recomputation with O(1) per-step amortized work. This is the optimization that makes long-context inference even theoretically tractable.

22.3 The KV cache size formula

This is the formula you must memorize for any LLM systems interview:

KV cache size per token = 2 × n_layers × n_kv_heads × d_h × bytes_per_element

The factor of 2 is because we store both K and V. The other factors are:

  • n_layers: the number of transformer blocks.
  • n_kv_heads: the number of distinct K/V heads (with GQA, this can be smaller than the number of attention heads — see Chapter 33).
  • d_h: the per-head dimension (= d_model / n_heads).
  • bytes_per_element: 2 for bf16/fp16, 1 for fp8/int8, 0.5 for int4.

This is per token. Multiply by sequence length to get total cache size for a sequence:

KV cache size per sequence = 2 × n_layers × n_kv_heads × d_h × S × bytes_per_element

Now plug in real numbers for Llama 3 70B, which has n_layers = 80, n_kv_heads = 8 (GQA with 8 KV heads, 64 query heads), d_h = 128, in bf16:

Per token: 2 × 80 × 8 × 128 × 2 = 327,680 bytes ≈ 320 KB

So one token of context for Llama 3 70B occupies 320 KB of GPU memory in the KV cache. For a sequence of S = 4096:

Per sequence: 320 KB × 4096 ≈ 1.3 GB

For S = 32,768 (the model’s full context):

Per sequence: 320 KB × 32,768 ≈ 10.5 GB

For 100 concurrent sequences each at full context: 1.05 TB. This obviously doesn’t fit on any single GPU. The KV cache is the most expensive thing in LLM serving memory.

A few worked numbers for comparison:

KV cache size formula decomposed: 2 times n_layers times n_kv_heads times d_h times bytes_per_element. KV cache size per token 2 K and V × n_layers 80 (Llama 70B) × n_kv_heads 8 (GQA) × d_h 128 × bytes/elem 2 (bf16) = 327,680 B ≈ 320 KB/token
GQA shrinks n_kv_heads from 64 (full MHA) to 8 for Llama 3 70B, making the per-token cache 8× smaller — the primary reason Llama 3 adopted GQA.
Modeln_layersn_kv_headsd_hdtypePer-token sizeAt 2k context
Llama 3 8B (GQA 8)328128bf16128 KB256 MB
Llama 3 70B (GQA 8)808128bf16320 KB640 MB
Llama 3 405B (GQA 8)1268128bf16504 KB1.0 GB
Mistral 7B (GQA 8)328128bf16128 KB256 MB
GPT-3 175B (MHA, no GQA)9696128bf164.7 MB9.4 GB

The GPT-3 row is striking. With full multi-head attention (no GQA), the per-token KV cache is 4.7 MB — 14× larger than Llama 3 70B with GQA, despite both being similar parameter counts. This is one of the reasons GQA matters so much: a model with full MHA has an unaffordable KV cache at scale, and the move to GQA in Llama 2 and beyond was driven primarily by KV cache size, not by training quality.

We’ll cover GQA in detail in Chapter 33.

22.4 Why the cache dominates memory at scale

To put the numbers in serving terms. Suppose you have an H100 (80 GB) and you want to serve Llama 3 70B (140 GB in bf16). The model itself doesn’t fit on one GPU; you need at least two H100s with tensor parallelism (TP=2). Each GPU holds half the model: 70 GB.

That leaves 10 GB per GPU for everything else: KV cache, activations, framework overhead. Out of those 10 GB:

  • ~2 GB for activations and framework overhead
  • ~8 GB for the KV cache

Each token of context takes 320 KB of KV cache, divided across the 2 GPUs (160 KB per GPU). So per GPU, you can hold:

8 GB / 160 KB ≈ 50,000 tokens of total context

Across all concurrent users. So if you have 10 users each with a 5000-token context, you’ve filled the cache. Adding an 11th user requires evicting an existing one or rejecting the request.

This is why the KV cache is the dominant constraint on serving capacity for any non-trivial LLM. It’s not the model weights (those are static and shared across all requests). It’s not the activations (those are per-token and small). It’s the KV cache, which scales as users × context_length.

The implications:

  • Concurrency is bounded by KV cache memory. You can’t have more concurrent users than fit in the KV budget.
  • Long contexts are expensive. A 32k-token request takes 16× the KV cache of a 2k-token request.
  • KV cache management is its own subfield. The next chapter (PagedAttention) is fundamentally about how to manage KV cache memory efficiently.

22.5 The decode step with a KV cache

Concretely, what does a decode step look like with the KV cache? Pseudocode:

def decode_step(model, kv_cache, last_token):
    # last_token is just the most recently sampled token, shape (N, 1)
    
    # Embed the new token
    x = model.embed(last_token)        # (N, 1, D)
    
    for layer_idx, block in enumerate(model.blocks):
        # RMSNorm
        x_norm = rmsnorm(x, block.norm1)
        
        # Compute Q, K, V for the NEW token only
        q_new = block.q_proj(x_norm)    # (N, 1, D)
        k_new = block.k_proj(x_norm)    # (N, 1, D)
        v_new = block.v_proj(x_norm)    # (N, 1, D)
        
        # Reshape for multi-head
        q_new = q_new.view(N, 1, H, D_h).transpose(1, 2)   # (N, H, 1, D_h)
        k_new = k_new.view(N, 1, H, D_h).transpose(1, 2)   # (N, H, 1, D_h)
        v_new = v_new.view(N, 1, H, D_h).transpose(1, 2)   # (N, H, 1, D_h)
        
        # APPEND new K, V to the cache
        k_full = torch.cat([kv_cache[layer_idx]['k'], k_new], dim=2)  # (N, H, S+1, D_h)
        v_full = torch.cat([kv_cache[layer_idx]['v'], v_new], dim=2)  # (N, H, S+1, D_h)
        kv_cache[layer_idx]['k'] = k_full
        kv_cache[layer_idx]['v'] = v_full
        
        # Compute attention with new Q against ALL past K, V
        scores = q_new @ k_full.transpose(-2, -1) / math.sqrt(D_h)   # (N, H, 1, S+1)
        attn = scores.softmax(dim=-1)
        attn_out = attn @ v_full                                      # (N, H, 1, D_h)
        
        # Project, add residual
        attn_out = attn_out.transpose(1, 2).reshape(N, 1, D)
        x = x + block.o_proj(attn_out)
        
        # FFN
        x = x + block.ffn(rmsnorm(x, block.norm2))
    
    # LM head produces next-token logits
    logits = model.lm_head(rmsnorm(x, model.norm_final))   # (N, 1, vocab)
    return logits[:, -1, :], kv_cache

The key things to notice:

  1. The input is just the last sampled token (shape (N, 1)), not the whole sequence.
  2. The new K and V are computed and appended to the cache.
  3. The new Q attends to the full K and V from the cache, not just the new K and V.
  4. All matmul shapes have M = 1 in the sequence dimension, except the attention scores which have K = S+1.

This is the form vLLM, SGLang, TensorRT-LLM, and every production LLM serving stack execute. The KV cache is the central data structure; the model code reads from it and writes to it on every step.

22.6 The prefill step

Prefill is where the cache gets populated. The model runs a forward pass on the entire prompt at once and stores all the K and V tensors as it goes:

def prefill(model, kv_cache, prompt_tokens):
    x = model.embed(prompt_tokens)        # (N, S_prompt, D)
    
    for layer_idx, block in enumerate(model.blocks):
        x_norm = rmsnorm(x, block.norm1)
        q = block.q_proj(x_norm)
        k = block.k_proj(x_norm)
        v = block.v_proj(x_norm)
        
        # Reshape for heads
        q = q.view(N, S_prompt, H, D_h).transpose(1, 2)
        k = k.view(N, S_prompt, H, D_h).transpose(1, 2)
        v = v.view(N, S_prompt, H, D_h).transpose(1, 2)
        
        # SAVE the K and V for ALL prompt positions
        kv_cache[layer_idx] = {'k': k, 'v': v}
        
        # Standard attention
        scores = q @ k.transpose(-2, -1) / math.sqrt(D_h)
        causal_mask = make_causal_mask(S_prompt)
        scores = scores + causal_mask
        attn = scores.softmax(dim=-1)
        attn_out = attn @ v
        
        attn_out = attn_out.transpose(1, 2).reshape(N, S_prompt, D)
        x = x + block.o_proj(attn_out)
        x = x + block.ffn(rmsnorm(x, block.norm2))
    
    # Final logits, used to sample the first output token
    logits = model.lm_head(rmsnorm(x, model.norm_final))
    return logits[:, -1, :], kv_cache

Prefill is a single forward pass. After it completes, the cache is populated for all prompt positions, and decode can begin.

The split between prefill and decode is what allows the prefill/decode asymmetry from Chapter 21 to be exploited operationally. Prefill is one big forward pass; decode is many small ones. Different optimization techniques apply to each.

22.7 The trade-off: memory vs latency

The KV cache trades memory for latency. You spend a lot of GPU memory storing past K and V vectors, in exchange for not having to recompute them on every step.

The trade is enormously favorable. Memory is cheap (well, expensive, but cheaper than time). Recomputation would take literally thousands of times more compute over a long generation. Without the KV cache, no production LLM serving system could exist.

But the memory cost is real, and it shapes everything:

  • You can’t serve more concurrent users than fit in the KV budget.
  • You can’t serve longer contexts than fit in the KV budget per user.
  • You can’t serve a model bigger than your hardware can fit alongside the KV budget.

The next several chapters are about how to make the cache budget go further:

  • Chapter 24 (PagedAttention): store the cache in fixed-size blocks like virtual memory, enabling sharing and reducing fragmentation.
  • Chapter 26 (quantization): store the cache in INT8 or FP8 instead of bf16, halving or quartering its size.
  • Chapter 29 (prefix caching): when many users share a prompt prefix (system prompt, RAG context), share the KV cache for that prefix instead of duplicating it.
  • Chapter 33 (GQA, MLA): reduce n_kv_heads so the per-token cache size shrinks.
  • Chapter 37 (KV cache offload): store rarely-accessed cache in slower memory (CPU RAM, NVMe, remote nodes).

These are the techniques. The KV cache is the data they all operate on.

22.8 Per-request vs shared KV cache

Two different conceptual models for the cache:

GPU memory split for Llama 3 70B on 2x H100: model weights consume 140 GB, leaving only ~8 GB per GPU for KV cache and activations. GPU memory budget: Llama 3 70B on TP=2 (2× H100 80 GB = 160 GB total) Model weights: 140 GB (87.5% of GPU HBM) KV 8 GB ~2G ← 140 GB (shared across 2 GPUs, 70 GB each, frozen weights) → KV cache ← bottleneck act. 8 GB KV budget ÷ 160 KB/token/GPU = ~50,000 tokens of total concurrent context at 5,000 tokens/user → capacity = 10 concurrent users. KV cache is the binding constraint.
On a two-GPU H100 deployment the model weights consume nearly all HBM, leaving only ~8 GB per GPU for the KV cache — so 10 simultaneous users with 5k-token contexts completely fills the serving capacity.

Per-request KV cache

Each in-flight request has its own KV cache. The cache lives for the duration of the request and is freed when the request completes (the model emits EOS or is canceled). This is the simplest model and what naive serving stacks do.

The downside: every request that shares a prompt prefix duplicates the KV cache for that prefix. If 100 users all start with the same 1000-token system prompt, that’s 100× the prefix KV cache stored independently. Massive waste.

Shared KV cache

The KV cache is treated as shared content-addressed storage. Each block of the cache is keyed by the token sequence that produced it. When a new request arrives with a prefix that matches an existing block, the cache is reused instead of recomputed.

This is prefix caching (Chapter 29). It’s the killer feature of modern LLM serving for any workload with shared prefixes — and almost every production workload has shared prefixes (system prompts, few-shot examples, RAG context).

The implementation requires:

  • A way to identify when prefixes match (hash the token sequence).
  • A way to store the cache in granular blocks (PagedAttention, Chapter 24).
  • A way to evict stale blocks when memory is full (LRU, LFU, etc.).

This is all coming. The point for this chapter is just: the KV cache can be shared, not just stored per-request, and modern serving stacks exploit this aggressively.

22.9 KV cache management strategies

A serving stack has to decide what to do when the KV cache is full and a new request comes in. The options:

(1) Reject the new request. Send back a 503 or 429. Simple but bad for user experience.

(2) Queue the new request. Wait for an existing one to complete and free its cache. Good when the queue is short; bad when it grows.

(3) Evict an existing request mid-flight. Kill an in-progress generation and free its cache. Aggressive; the user gets an error.

(4) Recompute (drop and re-prefill). Free the KV cache for an in-progress request, but remember the request’s token sequence. When it gets re-scheduled, re-prefill from scratch. This is the standard vLLM behavior under heavy load.

(5) Offload to slower memory. Move the cache for some requests to CPU RAM or NVMe. The request can resume later by reading the cache back. This is what KV cache offloading (Chapter 37) does.

Real serving stacks combine these. vLLM, for example, uses:

  • Continuous batching to keep many requests in flight.
  • Drop-and-recompute for low-priority requests under memory pressure.
  • Optional CPU offloading for KV cache that doesn’t fit in HBM.

The right strategy depends on your latency vs throughput vs success-rate trade-offs. There is no universal answer.

22.10 Forward pointers

Every later chapter in Part III builds on the KV cache:

  • Chapter 23 (Continuous batching): how to schedule decode steps from many requests so they share the same forward pass and amortize the weight read.
  • Chapter 24 (PagedAttention): how to store the KV cache in a way that enables sharing, eviction, and batching across requests with different sequence lengths.
  • Chapter 25 (FlashAttention): how to compute attention against the KV cache without materializing the full attention matrix.
  • Chapter 26 (Quantization): how to quantize the KV cache (and the weights) to reduce memory pressure.
  • Chapter 29 (Prefix caching): how to share the KV cache across requests with the same prefix.
  • Chapter 33 (GQA, MLA): how to architect the model so the cache is smaller in the first place.
  • Chapter 36 (Disaggregation): how to physically move the KV cache between prefill GPUs and decode GPUs.
  • Chapter 37 (LMCache): how to share the KV cache across replicas via a fast external cache.

The KV cache is the foundation. Hold the formula in your head:

KV cache per token = 2 × n_layers × n_kv_heads × d_h × bytes_per_element

You will use it constantly. Every interview, every architecture decision, every capacity planning exercise traces back to this formula.

22.11 The mental model

Eight points to take into Chapter 23:

  1. Without a KV cache, autoregressive generation is O(S²) per step and hopelessly slow.
  2. With a KV cache, generation is O(S) per step — and the K, V for past tokens are reused, not recomputed.
  3. The KV cache stores K and V vectors for every token at every layer at every head. Q is not cached.
  4. The size formula is 2 × n_layers × n_kv_heads × d_h × bytes per token. Memorize it.
  5. The KV cache dominates serving memory for any non-trivial model. Concurrency is bounded by it.
  6. GQA / MLA matters because it shrinks n_kv_heads, making the cache smaller.
  7. The cache can be shared across requests when they have a common prefix — this is the foundation of prefix caching.
  8. Cache management strategies (queue, evict, drop-and-recompute, offload) are the levers serving stacks use under memory pressure.

In Chapter 23 we look at how multiple concurrent requests share a single forward pass: continuous batching.


Read it yourself

  • Pope et al., Efficiently Scaling Transformer Inference (2022). The foundational paper on inference optimization at scale, including the KV cache analysis.
  • The vLLM blog post on PagedAttention — gives concrete numbers on how the KV cache memory budget is the dominant constraint.
  • The Llama 3 paper, section on attention architecture — explains the GQA choice in terms of KV cache size.
  • Anthropic’s blog post on prompt caching — gives the production-side view of prefix sharing.
  • Horace He’s blog series on inference performance.

Practice

  1. Compute the per-token KV cache size for: Llama 3 8B, Llama 3 70B, Mistral 7B, Qwen 2.5 72B. All in bf16. Verify against the table in §22.3.
  2. A production serving fleet runs Llama 3 70B on H100 nodes with TP=2. Each node has 160 GB of GPU memory. The model weights take 140 GB. How many tokens of total KV cache fit in the remaining ~16 GB? At 2k tokens of context per user, how many concurrent users can the node serve?
  3. Why does GPT-3 (no GQA) have a 14× larger per-token KV cache than Llama 3 70B (GQA 8) at similar parameter counts? Walk through the formula.
  4. A request has a 32k-token context. Compute the KV cache size for it on Llama 3 70B in bf16, then in int8. What’s the savings?
  5. Write the decode step pseudocode from §22.5 from memory. Verify the matmul shapes line up.
  6. Why isn’t the Q vector cached? Trace through one decode step and identify exactly when each Q is needed and when it can be discarded.
  7. Stretch: Implement a tiny KV-cached generation loop in PyTorch from scratch on a small open model (e.g., GPT-2 small). Compare wall-clock per token to a non-cached version. Verify the cached version is dramatically faster.

Concept check

4 questions. Click a choice to check. Your score is saved locally.

Score
0 / 4
  1. 1. Without a KV cache, how does the compute cost scale with sequence length during autoregressive generation of T output tokens after a prompt of length S?
  2. 2. What does the KV cache store, and why does it only store K and V and not Q?
  3. 3. How do you calculate the KV cache memory per token for a transformer layer with H heads, D head dimension, and precision P bytes, summed over L layers?
  4. 4. At a batch of 128 concurrent requests each generating up to 2048 tokens, a 70B model's KV cache exceeds GPU HBM capacity. Which architectural decision made this problem worse compared to an older 13B model?
Related chapters