Attention from first principles
"Attention is a soft, differentiable, learnable lookup table"
This is the chapter the rest of the book has been building toward. Attention is the operation that made transformers work, that made LLMs work, that made the modern wave of generative AI work. Every later chapter — KV cache, PagedAttention, FlashAttention, GQA, MLA, prefix caching, multimodal, MoE — is either an optimization of attention, a generalization of attention, or a workaround for a limitation of attention.
We will derive attention from scratch. By the end of this chapter you will be able to:
- Write the attention formula on a whiteboard from first principles, with no notes.
- Explain why every term is the way it is — including the
√d_kdivisor that everyone forgets. - Implement scaled dot-product attention in 20 lines of PyTorch.
- Explain why attention is
O(s²)and what consequences that has. - Place every later optimization (KV cache, PagedAttention, FlashAttention, GQA, MLA) in its proper place as a refinement of the operation built here.
Outline:
- The motivating problem: contextual representations.
- The intuition: a learnable, soft lookup table.
- Queries, keys, and values, derived.
- The dot-product similarity.
- Why
√d_k? (variance preservation, derived.) - Softmax over the sequence axis.
- Multi-head attention — why splitting helps.
- Causal masking for autoregressive models.
- Padding masks.
- The complexity story: why attention is
O(s²)and where it bites. - Naive attention in 20 lines.
- Forward pointers to every later optimization.
6.1 The motivating problem: contextual representations
A neural language model wants to take a sequence of tokens and produce, for every position, a representation that encodes everything the model needs to know about that token in context. The representation of bank in "the river bank" should differ from the representation of bank in "the savings bank" because the surrounding words give the word a different meaning.
How do you get a contextual representation? You let each token “look at” the other tokens in the sequence and aggregate information from them. The aggregation has to be:
- Differentiable, so gradients can flow back through it.
- Permutation-aware in the sense that order matters (we will inject order via position encodings, see Chapter 7).
- Learnable in the sense that the model can decide which other positions to attend to, based on the content of those positions.
- Parallelizable so the whole sequence can be processed at once during training.
The pre-transformer answers were:
- Recurrent networks (LSTM, GRU): each position only sees its left neighbors via a hidden state passed forward in time. Differentiable, position-aware, learnable. Not parallelizable — you have to wait for position
t-1before computing positiont. This made RNNs slow to train and limited the maximum sequence length. - Convolutional networks: each position sees a fixed window of neighbors via a learned filter. Parallelizable, but the receptive field grows linearly with depth, so capturing long-range dependencies needs many layers.
Both approaches are dead in modern language modeling. The thing that killed them is attention.
6.2 The intuition: a learnable, soft lookup table
The cleanest mental model for attention is a soft, differentiable lookup table.
Suppose you have a Python dict and a query string. You look up the query, get a value back. That’s a hard lookup: exactly one key matches, you get exactly one value. It’s not differentiable (the gradient of “did this key exactly match” is zero almost everywhere).
Now suppose you have a “soft” version: every key has a similarity score to your query, and the returned value is a weighted average of all the values, weighted by similarity. This is differentiable: small changes in the query produce small changes in the similarity scores, which produce small changes in the weighted average.
Attention is exactly this. The “dictionary” is the sequence of tokens (each token contributes a key and a value). The “query” is the token we’re computing the contextual representation for. The output is a weighted sum of the values, weighted by how similar each key is to the query.
That’s it. The whole operation is “compute similarities, softmax-weight them, weighted-sum the values.” Everything else in this chapter is filling in the details.
6.3 Queries, keys, and values, derived
Concretely. We have a sequence of token representations, each of dimension d_model:
X ∈ R^(s × d_model)
where s is the sequence length. From X, we want to produce three things:
- A query for each position — what this position is “looking for.”
- A key for each position — what this position “advertises about itself” to others.
- A value for each position — what this position “contributes” if attended to.
We produce them with three independent learned linear projections:
Q = X W_Q # shape (s, d_k)
K = X W_K # shape (s, d_k)
V = X W_V # shape (s, d_v)
W_Q, W_K, W_V are learned parameter matrices. d_k is the dimension of queries and keys (they have to match because we’re going to take their dot product). d_v is the dimension of values; in practice it’s almost always equal to d_k. The whole point of these projections is to give the model the ability to project the same input into three different spaces — one for asking, one for being asked, one for being read.
Why three different projections? Because the same token might play three different roles. As a query, it might “ask” about its semantic context. As a key, it might “advertise” its syntactic role. As a value, it might “contribute” its lexical content. Letting the model learn three different views of the same input is strictly more expressive than forcing them all to be the same.
6.4 The dot-product similarity
Now we have queries and keys. We need a similarity function between a query vector and a key vector. The simplest reasonable choice is the dot product:
similarity(q, k) = q · k = Σ_i q_i k_i
The dot product is large and positive when q and k point in the same direction, large and negative when they point in opposite directions, and zero when they’re orthogonal. It’s a linear function of both arguments, which is differentiable, and it’s cheap (one matmul).
For the whole sequence at once, we compute the matrix of all query-key dot products:
QK^T ∈ R^(s × s)
Entry (i, j) is the dot product of the i-th query and the j-th key. This is the attention score matrix before normalization. Reading position i’s row tells you “how much position i should attend to each other position.”
The cost of this matmul is O(s² · d_k), and the result is a matrix of size s × s. This is the source of the famous O(s²) cost of attention — both compute and memory are quadratic in sequence length. We’ll come back to this in §6.10.
6.5 Why √d_k? — the variance argument, derived
Here is the part that everyone gets wrong on the whiteboard. The actual attention formula divides the dot products by √d_k:
scores = Q K^T / √d_k
Why? It’s a numerical-stability argument. Suppose q and k are vectors of dimension d_k whose entries are independent random variables with mean 0 and variance 1. (This is roughly what learned projections produce, especially early in training.) The dot product is:
q · k = Σ_i q_i k_i
This is a sum of d_k independent products of mean-0 unit-variance random variables. Each product has mean 0 (since q_i and k_i are independent) and variance 1 (since Var(XY) = E[X²]E[Y²] = 1 for independent unit-variance variables). The sum of d_k such products has mean 0 and variance d_k, which means standard deviation √d_k.
So as d_k grows, the magnitude of the dot products grows as √d_k. For typical d_k = 64 or 128, the dot products are around ±8 to ±11. This is bad for softmax. Softmax is very sensitive at this scale: when the inputs are large, almost all the probability mass concentrates on the single largest input. The gradient through softmax in this regime is very small (tiny softmax probabilities have tiny gradients). Training stalls.
The fix is to divide by √d_k, which exactly cancels the variance growth:
Var(q · k / √d_k) = d_k / d_k = 1
Now the scores have variance 1 regardless of d_k, softmax stays in a healthy regime, and gradients flow.
This is one of the cleanest “small detail that makes the whole thing work” stories in deep learning. It’s also a favorite interview question: “Why does scaled dot-product attention divide by √d_k?” — and “to keep the variance of the dot products at 1 so softmax doesn’t saturate” is the right answer.
6.6 Softmax over the sequence axis
We have raw scores Q K^T / √d_k of shape (s, s). We want to turn each row into a probability distribution over positions, so that each query “votes” for which positions to attend to in a way that sums to 1.
A = softmax(QK^T / √d_k)
The softmax is applied along the last axis (the key axis), so each row of A is a probability distribution. Entry A[i, j] is “the fraction of position i’s attention that goes to position j.”
Two consequences of softmax that come up later:
- Each row sums to 1. The total amount of attention is conserved. If a query attends more to one position, it must attend less to others.
- It’s smooth but peaked. Small differences in scores produce small differences in attention weights, but if one score is much larger than the others, almost all the weight goes to that one. This is what makes attention act like a soft lookup: usually it focuses on a few positions, but the focus is differentiable.
6.7 Multi-head attention
Now we have one learned attention operation. We could just use it. But the original transformer paper observed that letting the model run several independent attention operations in parallel and concatenating their outputs is strictly more expressive at the same parameter count.
The construction:
- Pick a number of heads
H. Typical values: 8, 16, 32, 64. - Split the model dimension
d_modelintoHchunks of sized_h = d_model / H. - Run
Hindependent attention operations, each with its ownW_Q^h,W_K^h,W_V^h, each producing an output of dimensiond_h. - Concatenate the outputs from all heads:
(s, H × d_h) = (s, d_model). - Apply one more linear projection
W_Oto mix the heads:(s, d_model) → (s, d_model).
The intuition: different heads can learn to attend to different things. Some heads pick up syntactic relationships (“this verb’s subject is over there”), some pick up coreference (“this pronoun refers to that noun”), some pick up positional patterns (“look at the previous token”), some attend uniformly. The multi-head construction gives the model the freedom to specialize.
In practice, you don’t run H separate matmuls. You compute one big matmul that produces all heads at once and reshape:
# x shape: (N, S, D)
qkv = qkv_proj(x) # (N, S, 3 * D)
qkv = qkv.view(N, S, 3, H, D_h) # (N, S, 3, H, D_h)
q, k, v = qkv.unbind(dim=2) # each (N, S, H, D_h)
q = q.transpose(1, 2) # (N, H, S, D_h)
k = k.transpose(1, 2) # (N, H, S, D_h)
v = v.transpose(1, 2) # (N, H, S, D_h)
This is where the canonical attention shape (N, H, S, D_h) from Chapter 1 comes from. The H dimension exists so that all heads run as one batched operation.
6.8 Causal masking — autoregressive attention
For an autoregressive language model — a model that generates one token at a time, conditioned on all previous tokens — there’s a constraint: position i can only attend to positions j ≤ i. It can’t see the future, because at inference time the future doesn’t exist yet, and at training time we’re trying to teach the model to predict the future from the past.
We enforce this with a causal mask: before the softmax, we set every entry of the score matrix where j > i to -∞. After softmax, those entries become exactly 0, so position i puts zero attention weight on any position to its right.
mask = torch.triu(torch.ones(S, S), diagonal=1).bool() # upper triangle, excluding diagonal
scores = scores.masked_fill(mask, float('-inf'))
attn = scores.softmax(dim=-1)
The mask is the same for every batch element and every head, so it’s broadcast across the leading dims. For very long sequences this is the dominant memory cost, which is one of the reasons FlashAttention (Chapter 25) doesn’t materialize the full attention matrix at all.
A subtle but important point: the causal mask is what makes prefix caching possible. Because each token only attends to its leftward context, the K and V vectors for past tokens never need to be recomputed when a new token is added. This is the foundation of the KV cache, which is the foundation of efficient autoregressive serving. We’ll see this in Chapter 22.
6.9 Padding masks
When you batch multiple sequences of different lengths together, you pad the shorter ones to the length of the longest. You don’t want the model to attend to those padding positions — they have no meaning. The fix is a padding mask: a per-sequence boolean mask that marks the padding positions, and the same masked_fill to -∞ trick before softmax.
# pad_mask shape: (N, S), True where padding
# scores shape: (N, H, S, S)
scores = scores.masked_fill(pad_mask[:, None, None, :], float('-inf'))
The mask is applied along the key axis (the last S). Padding positions can’t be attended to. Whether a query at a padding position “produces” attention is moot — the loss is masked at those positions too, so the gradients don’t care.
In practice, padding masks and causal masks are combined into a single mask before being added to the scores. Modern attention implementations (FlashAttention, PyTorch’s scaled_dot_product_attention) take both as separate arguments and fuse them in the kernel.
6.10 The complexity story — O(s²) and where it bites
The cost of attention is dominated by the attention score matmul:
QK^T: (s, d_k) @ (d_k, s) → (s, s) cost: O(s² · d_k) per head
softmax: (s, s) cost: O(s²)
attention @ V: (s, s) @ (s, d_v) → (s, d_v) cost: O(s² · d_v) per head
Total compute per head: O(s² · d_k + s² · d_v) = O(s² · d_k) (since d_k ≈ d_v).
Total memory per head: O(s² + s · d_k). The dominant term is the s × s score matrix.
This is the famous O(s²) complexity of attention. Both compute and memory grow quadratically in sequence length. It’s why long contexts are expensive, why frontier model context windows have only crept up over time, and why every research direction in efficient attention is some attempt to break this s².
For comparison, the linear-in-s parts of a transformer block (the projections, the MLP) have cost O(s · d²). So at small s, the linear parts dominate, and the model is “compute-bound” in the matmul sense. At large s, the attention part dominates, and the model is “attention-bound.” The crossover happens around s ≈ d_model, which for a typical d_model = 4096 is around 4k tokens. Above that, attention starts to eat the budget.
This is why so many papers chase sub-quadratic attention alternatives — sliding window attention (longformer, big-bird), linear attention, ring attention, state-space models (Mamba, Chapter 41). None of them have completely replaced full softmax attention for the highest-quality models, but they all have niches.
The other reason O(s²) matters: memory. The attention matrix is the largest single tensor in a transformer forward pass at long sequence lengths. For N=1, H=32, S=8192, the score matrix in fp16 is 1 × 32 × 8192 × 8192 × 2 bytes ≈ 4.3 GB. For S=32768 it’s 70 GB. Materializing this tensor on every layer is what makes long-context inference expensive — and what FlashAttention solved by not materializing it at all.
6.11 Naive attention in 20 lines
Putting it all together, here is scaled dot-product multi-head attention in PyTorch, with no optimizations:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class NaiveAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_h = d_model // num_heads
self.h = num_heads
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, causal=True):
# x: (N, S, D)
N, S, D = x.shape
H, D_h = self.h, self.d_h
qkv = self.qkv_proj(x) # (N, S, 3D)
qkv = qkv.view(N, S, 3, H, D_h).transpose(2, 0) # (3, N, S, H, D_h)
q, k, v = qkv[0], qkv[1], qkv[2] # each (N, S, H, D_h)
q = q.transpose(1, 2) # (N, H, S, D_h)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
scores = q @ k.transpose(-2, -1) / math.sqrt(D_h) # (N, H, S, S)
if causal:
mask = torch.triu(torch.ones(S, S, device=x.device), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
attn = scores.softmax(dim=-1) # (N, H, S, S)
out = attn @ v # (N, H, S, D_h)
out = out.transpose(1, 2).contiguous().view(N, S, D) # (N, S, D)
return self.out_proj(out)
Read this code line by line. Every modern attention implementation, no matter how heavily optimized, is doing the same thing as these 20 lines. FlashAttention does it without materializing the (N, H, S, S) tensor. PagedAttention does it with a memory layout that supports sharing prefixes. GQA does it with fewer key/value heads than query heads. MLA does it with K and V compressed into a low-rank latent. They are all variations of this snippet.
If you can write this snippet from memory, with the √d_h divisor in the right place, you can pass any “explain attention” interview question.
6.12 Forward pointers
This is the spine of the rest of the book. Every later chapter is some refinement of the operation we just built:
graph TD
Naive["Naive attention<br/>(§6.11)"]
Naive --> KV["Chapter 22<br/>KV cache"]
Naive --> Flash["Chapter 25<br/>FlashAttention<br/>no s×s matrix"]
Naive --> GQA["Chapter 33<br/>GQA / MQA / MLA<br/>fewer K,V heads"]
Naive --> SSM["Chapter 41<br/>State-space models<br/>rejects O(s²)"]
KV --> Paged["Chapter 24<br/>PagedAttention<br/>virtual memory"]
Paged --> Prefix["Chapter 29<br/>Prefix caching"]
KV --> Disagg["Chapter 36<br/>Disaggregated<br/>prefill/decode"]
style Naive fill:#fdf5ef,stroke:#d14f1a,stroke-width:2px
style KV fill:#f4f2ed
style Flash fill:#f4f2ed
style GQA fill:#f4f2ed
style SSM fill:#f4f2ed
style Paged fill:#f4f2ed
style Prefix fill:#f4f2ed
style Disagg fill:#f4f2ed
Every later optimization is the same naive operation with one specific cost removed.
- Chapter 7 wraps the attention block in residuals, normalization, and an FFN to make a full transformer.
- Chapter 22 introduces the KV cache — the realization that during autoregressive decoding, the K and V for past tokens never change, so we can store them and only compute one new K/V per step. This is the foundation of efficient inference.
- Chapter 24 (PagedAttention) stores the KV cache in fixed-size blocks like a virtual memory system, enabling prefix sharing and efficient batching.
- Chapter 25 (FlashAttention) rewrites the attention kernel to fuse
QK^T → softmax → attention @ Vinto a single tile-based operation that never materializes thes × smatrix. This single optimization is the most important kernel-level improvement in LLM inference. - Chapter 33 introduces GQA, MQA, and MLA — three ways to compress the K and V sides so the KV cache is smaller.
- Chapter 36 introduces disaggregated prefill/decode, the realization that the prefill and decode phases of attention have such different computational profiles that they should run on different GPU pools.
- Chapter 41 (state-space models) rejects attention’s
O(s²)entirely and replaces it with a different sequence operator.
Every one of these is “the same operation as in §6.11, but with one specific cost optimized away.” Hold the naive implementation in your head as the baseline.
6.13 The mental model
Eight points to take into Chapter 7:
- Attention is a soft, differentiable, learnable lookup table.
- Three projections — query, key, value — give the model three views of the same input.
- Dot products measure similarity; the
√d_kdivisor keeps softmax in a healthy regime. - Softmax over the sequence axis turns scores into a probability distribution.
- Multi-head lets the model learn multiple attention patterns in parallel at no extra parameter cost.
- Causal masking is what makes autoregressive generation work and what makes the KV cache possible.
O(s²)in both compute and memory is the cost. Every later optimization is trying to dodge it.- Naive attention is 20 lines. Every “fancy” implementation is the same operation with one cost removed.
In Chapter 7 we wrap this into a full transformer.
Read it yourself
- The original paper: Vaswani et al., Attention Is All You Need (2017). Read it cover to cover. It’s only ten pages and it’s the most important paper of the last decade in ML.
- Jay Alammar, The Illustrated Transformer — the visual companion to the original paper. The diagrams alone are worth your time.
- Lilian Weng’s blog post Attention? Attention! — a long, rigorous walk through every attention variant pre-2018.
- Andrej Karpathy’s Let’s build GPT YouTube video — the “build attention from scratch” version, with code.
- The PyTorch source for
torch.nn.functional.scaled_dot_product_attention— read the docstring, then the implementation inaten/src/ATen/native/transformers/.
Practice
- Write scaled dot-product attention in PyTorch from memory. Compare to §6.11. Don’t peek.
- Why does softmax saturate when its inputs are large? Compute
softmax([10, 11])andsoftmax([100, 101])— they should be the same in theory but different in practice. Why? - Derive the gradient of
softmax(z)_iwith respect toz_j. (Answer:softmax(z)_i (δ_ij - softmax(z)_j).) You will find this useful when reading FlashAttention later. - For
d_model = 4096andS = 8192andN = 1andH = 32, compute the size of theQ K^Tattention score tensor in fp16. (Answer: ~4.3 GB.) - Why does causal masking enable a KV cache during decoding? Walk through a step of generation in your head and identify exactly what doesn’t change.
- The naive attention in §6.11 has an
if causal:branch that allocates the mask every forward pass. Why is this inefficient, and how would you fix it? - Stretch: Implement multi-head attention from scratch in NumPy (no PyTorch). Run it on a small toy input and verify the output shape and the row-sums-to-1 property.
Concept check
4 questions. Click a choice to check. Your score is saved locally.
- 1. Why do we divide the attention scores by √d_k before the softmax?
- 2. What is the memory complexity of a naive attention implementation in the sequence length s?
- 3. In multi-head attention with model dimension d_model and H heads, what is the per-head dimension d_h?
- 4. Why does the causal mask only matter for decoder-style (autoregressive) attention, not encoder attention?