Part I · ML Foundations
Chapter 3 Core ~23 min read

The backward pass: autograd, gradients, the chain rule made mechanical

"Backprop is just the chain rule applied with sufficient bookkeeping"

In Chapter 2 we built a neural network as a pure function y = f(x; θ). In this chapter we make that function learnable. The mechanism is gradient descent: we compute the partial derivative of a scalar loss with respect to every parameter, and we nudge each parameter in the direction that decreases the loss. The non-trivial part is computing those derivatives efficiently — and that is what backpropagation, and its general framework automatic differentiation (autograd), does.

This chapter is the second piece of the foundation that everything else rests on. You will not implement autograd in your day job; the framework does it for you. But you will encounter mixed-precision underflow, gradient checkpointing decisions, “why does my activation memory grow with batch size,” tensor.detach() vs with torch.no_grad():, vanishing-gradient bugs, and a hundred other situations where understanding what autograd is actually doing is the difference between solving the bug in five minutes and staring at the loss curve for two days.

Outline:

  1. Why training is the same as gradient computation.
  2. The chain rule, mechanically.
  3. Computation graphs as DAGs.
  4. Reverse-mode autodiff and why it dominates ML.
  5. What autograd actually stores (the tape).
  6. The vector-Jacobian product and why “gradient” is slightly the wrong word.
  7. Common gradients you should know cold.
  8. The memory cost of the backward pass — and why activation memory grows with batch size.
  9. detach(), no_grad(), requires_grad, retain_graph — the four toggles you must understand.
  10. Why mixed precision needs loss scaling.
  11. Gradient checkpointing — trading compute for memory.

3.1 Training is gradient computation

Training a neural network is the search for parameters θ that minimize a scalar loss L(θ). The loss summarizes how wrong the model is on a batch of training data. Every parameter in the model is a coordinate in a high-dimensional space, and the loss is a function on that space whose minimum we want to find.

The standard tool for finding minima of differentiable functions is gradient descent. The gradient of L with respect to θ — written ∇_θ L — is a vector that points in the direction of steepest increase. If you take a small step in the opposite direction, you decrease the loss:

θ ← θ - η ∇_θ L

where η is the learning rate. Repeat this update billions of times and you have trained a neural network. Every optimizer in Chapter 4 is a variant of this update rule with bells and whistles (momentum, per-parameter learning rates, weight decay), but the core remains: compute the gradient, take a step against it.

The hard part is computing ∇_θ L. A modern LLM has tens or hundreds of billions of parameters. You need a partial derivative of the scalar loss with respect to every single one. Doing this by hand is inconceivable. Doing it numerically (finite differences) would require a forward pass per parameter — billions of passes per training step. Doing it symbolically (by manipulating expressions) explodes in expression size. The only practical answer is automatic differentiation, which we’ll get to in §3.4.

But first: the chain rule.

3.2 The chain rule, mechanically

The chain rule of calculus says that if y = f(g(x)), then

dy/dx = (df/dg)(g(x)) · (dg/dx)(x)

In words: the derivative of a composition is the product of the derivatives of the pieces, evaluated at the appropriate intermediate points. This is the entire mathematical content of backpropagation. Everything else is bookkeeping.

For a scalar function with vector inputs, the chain rule generalizes to:

∂L/∂x_i = Σ_j (∂L/∂y_j) · (∂y_j/∂x_i)

where y = f(x) is some intermediate vector and L is the final scalar. The right-hand side is a matrix-vector product: the Jacobian ∂y/∂x (matrix) times the gradient of the loss ∂L/∂y (vector). We will come back to this in §3.6 — it’s the reason backprop is fast.

For a deeper composition, you just apply the chain rule repeatedly. If L = f_n(f_{n-1}(... f_1(x) ...)), then

dL/dx = (df_n/df_{n-1}) · (df_{n-1}/df_{n-2}) · ... · (df_2/df_1) · (df_1/dx)

This is a chain of matrix multiplications. The order matters for efficiency (more on that in §3.4), and the output of each matrix multiplication is the gradient with respect to that intermediate quantity.

Worked example

Let’s compute dL/dx for the function L = (W_2 · σ(W_1 · x))^2 where σ is some scalar nonlinearity, W_1 is (2,1), W_2 is (1,2), and we want a single scalar gradient.

Define intermediates:

a = W_1 · x         # shape (2,)
h = σ(a)            # shape (2,)
y = W_2 · h         # shape (1,)
L = y^2             # shape () scalar

Now compute partial derivatives, working backward from L:

dL/dy = 2y                           # easy
dL/dh = (dL/dy) · (dy/dh) = 2y · W_2  # W_2 is the Jacobian of y w.r.t. h
dL/da = (dL/dh) · (dh/da) = 2y W_2 ⊙ σ'(a)   # ⊙ = element-wise; σ' is element-wise
dL/dx = (dL/da) · (da/dx) = (2y W_2 ⊙ σ'(a)) · W_1

That’s it. Every gradient was computed by multiplying the upstream gradient (the one we’d just calculated) by the local Jacobian of the next operation. The local Jacobians are determined entirely by the operation itself — every framework knows the Jacobian rule for linear, softmax, relu, matmul, etc., baked in.

The pattern is universal: walk backward through the graph, multiplying each step’s gradient by the local Jacobian of that step. This is exactly what autograd does, with the bookkeeping handled for you.

3.3 Computation graphs as DAGs

When you write a forward pass like

z = W1 @ x + b1
h = z.relu()
y = W2 @ h + b2
loss = (y - target).pow(2).mean()

…PyTorch builds a computation graph: a directed acyclic graph (DAG) where the nodes are tensors and operations, and the edges describe data flow. The graph for the snippet above has nodes for x, W1, b1, z, h, W2, b2, y, target, and loss, with edges showing which tensors feed which operations.

The graph is built dynamically as the forward pass executes. PyTorch’s autograd is “define-by-run” — there’s no static graph compiled ahead of time; the graph is constructed by the act of evaluating tensors. This is why PyTorch supports arbitrary Python control flow inside forward(): you can use if, for, while, list comprehensions, whatever — and the graph will reflect whichever code path actually ran.

(JAX takes a different position: it traces your function symbolically and produces a static graph that can be compiled and reused. Both styles work; PyTorch’s flexibility comes at the cost of some optimization opportunity.)

When you call loss.backward(), autograd walks the graph backward from the loss node, applying the chain rule at each step. This is reverse-mode automatic differentiation.

Computation graph DAG for z=W1@x+b1, h=relu(z), loss=(W2@h-target)^2; forward edges gray, backward gradient flow shown with accent arrow. x W1, b1 z=W1@x+b1 h=relu(z) y=W2@h loss=(y-t)² backward: dL/dW1 accumulated at leaf W1
PyTorch builds this DAG dynamically during the forward pass; backward() walks it right-to-left, multiplying each node's local Jacobian by the upstream gradient until every leaf parameter has a populated .grad.

3.4 Reverse-mode autodiff

There are two flavors of autodiff: forward mode and reverse mode. The difference is the order in which you compose the chain-rule matrix multiplies.

Suppose your function is L = f_n(f_{n-1}(... f_1(x) ...)), and you want to compute dL/dx. The chain rule gives you:

dL/dx = J_n · J_{n-1} · ... · J_2 · J_1

where each J_i is the Jacobian of the i-th operation. This product can be computed in two orders:

Forward mode: start from the right (J_1), multiply by J_2, then J_3, and so on. At each step, you’ve computed dy_i/dx, the derivative of the current intermediate with respect to the original input. Useful when there are few inputs and many outputs.

Reverse mode: start from the left (J_n), multiply by J_{n-1}, then J_{n-2}, and so on. At each step, you’ve computed dL/dy_i, the derivative of the loss with respect to the current intermediate. Useful when there are many inputs and one output.

For a typical neural network, there are billions of parameters (inputs) and one scalar loss (output). Reverse mode wins by a factor of billions. That’s why every ML framework uses reverse mode by default.

The cost of reverse mode is that you need to remember every intermediate value from the forward pass, because the backward pass needs them to compute Jacobians. (The gradient of relu(x) requires knowing the sign of x; the gradient of softmax(z) requires knowing the output softmax(z); etc.) This is the source of the “activation memory” cost we’ll discuss in §3.8.

Why reverse mode is sometimes called “backpropagation”

The two terms — backpropagation and reverse-mode autodiff — refer to the same algorithm. “Backpropagation” is the older term, popularized by Rumelhart, Hinton & Williams in 1986 in the context of training neural networks specifically. “Reverse-mode autodiff” is the more general term from the autodiff literature, which dates back to the 1960s and 70s. Today they are interchangeable.

3.5 What autograd actually stores

When you do a forward pass with requires_grad=True parameters, PyTorch builds the computation graph and stores, for every operation, enough information to compute its backward pass. Specifically:

  1. The input tensors of each operation (or just the ones that are needed for the gradient — many operations don’t need all of their inputs).
  2. The function reference to the backward function (MulBackward, AddmmBackward, ReluBackward, etc.).
  3. A handle to the output tensor’s grad_fn, which connects this node to the rest of the graph.

This is sometimes called “the tape” — a recording of every differentiable operation that happened during the forward pass, in order, ready to be played back in reverse. When you call loss.backward():

  1. PyTorch walks the tape backward, starting from loss.
  2. At each node, it calls the backward function with the upstream gradient (dL/d(this_node_output)) and computes the gradient with respect to each input.
  3. For the leaves of the graph that are parameters (requires_grad=True tensors with no grad_fn), it accumulates the result into the .grad attribute of that parameter.

After loss.backward() returns, every nn.Parameter in the model has a .grad attribute populated with the gradient of the loss with respect to that parameter, ready for the optimizer to apply.

import torch.nn as nn

model = nn.Linear(4, 1)
x = torch.randn(8, 4)
y = torch.randn(8, 1)

loss = ((model(x) - y) ** 2).mean()
loss.backward()

model.weight.grad   # populated, shape (1, 4)
model.bias.grad     # populated, shape (1,)

Two important details:

  • Gradients accumulate. If you call backward() twice without zeroing gradients in between, the gradients add. This is why training loops always call optimizer.zero_grad() (or model.zero_grad()) before each forward pass. Forgetting it gives you a model that trains on the sum of the gradients across the last N steps, which is (a) a bug, and (b) sometimes used deliberately as “gradient accumulation” to simulate a larger batch size on a memory-constrained GPU.
  • The graph is freed after backward. By default, after loss.backward() returns, PyTorch deletes the saved intermediates to free memory. If you need to call backward() again on the same graph, pass retain_graph=True. If you don’t, you’ll get a clear runtime error.

3.6 The vector-Jacobian product, and why “gradient” is slightly the wrong word

Here’s a precise statement that will save you confusion later. When PyTorch’s backward() walks the graph, what it actually computes at each node is a vector-Jacobian product (VJP):

out_grad = upstream_grad @ J   # shape: (1, in_dims) = (1, out_dims) @ (out_dims, in_dims)

The “upstream gradient” is a row vector (the gradient of the loss with respect to the output of this node), and J is the local Jacobian of this node’s operation. The product is a row vector that becomes the gradient with respect to this node’s input — and the upstream gradient for the next step backward.

Crucially, the framework never materializes the full Jacobian matrix. For a linear layer y = W x with W of shape (out, in), the Jacobian dy/dx is just W itself, but the Jacobian dy/dW is a 4-tensor of shape (out, out, in) — gigantic. Materializing it would be a memory disaster. Instead, the framework knows the closed-form VJP for matmul:

upstream_grad shape = (out,)
dL/dx = W^T · upstream_grad        # shape (in,)
dL/dW = upstream_grad ⊗ x          # outer product, shape (out, in)

These are computed directly with matmul kernels — no Jacobian materialized. Every operation in PyTorch has a hand-coded VJP that does the equivalent trick. This is why backprop is roughly the same cost as the forward pass: the backward kernel for matmul is just two more matmuls, one for the input gradient and one for the weight gradient.

The minor terminological wrinkle: what tensor.grad stores after backward() is the gradient of the loss with respect to the parameter, which is the result of all the VJPs accumulating along every path from that parameter to the loss. People call this “the gradient” and it’s fine. Just know that under the hood, no Jacobian was ever assembled.

3.7 Common gradients to know cold

You don’t have to memorize gradient rules — the framework knows them. But there are a handful that come up often enough that knowing them is useful.

Linear / matmul. For y = W x:

dL/dx = W^T (dL/dy)
dL/dW = (dL/dy) x^T
dL/db = dL/dy

ReLU. For y = max(0, x):

dy/dx = 1 if x > 0 else 0
dL/dx = dL/dy ⊙ (x > 0)   # element-wise mask

This is why dead ReLUs don’t recover: when x ≤ 0, the gradient through the ReLU is exactly zero, so no gradient ever reaches the upstream weights to nudge them toward making x positive.

Sigmoid. For y = σ(x) = 1/(1 + e^{-x}):

dy/dx = y(1 - y)

A delightful result: the derivative of sigmoid is computable from its output, no need to remember x. Saturates to ≈ 0 for large |x| — this is the vanishing gradient that killed sigmoid as a hidden activation.

Softmax. For y = softmax(z):

dy_i/dz_j = y_i (δ_ij - y_j)        # where δ_ij is 1 if i==j else 0

In matrix form: J = diag(y) - y y^T. For the special (but common) case where the loss is cross-entropy of the softmax against a one-hot target t, the entire gradient through both the loss and the softmax simplifies to:

dL/dz = y - t

Beautifully simple. This is why “softmax + cross-entropy” is always implemented as a single fused op (F.cross_entropy in PyTorch, tf.nn.sparse_softmax_cross_entropy_with_logits in TF) — the analytic simplification gives you a faster and more numerically stable backward pass than the two ops computed separately.

LayerNorm. Painful to derive. The framework handles it. You only need to know: the gradient through LayerNorm depends on the mean and variance computed in the forward pass, so those have to be saved as activations. This is one of the contributors to the activation memory cost of transformers.

3.8 The memory cost of backward — and why it scales with batch size

Reverse-mode autodiff has a fundamental cost: it requires you to remember the forward activations so the backward pass can use them.

Walk through a transformer block in your head. The forward pass computes a sequence of intermediate tensors:

Activation memory grows linearly with batch, sequence, depth; the attention score tensor (N,H,S,S) is the quadratic term that dominates long contexts. Memory per transformer block (relative, S=2048, N=4, D=768, H=12) norm ~N·S·D QKV proj ~3·N·S·D attn scores N·H·S·S quadratic! FFN act ~4·N·S·D
The attention score tensor (N·H·S·S) grows quadratically with sequence length and is the dominant activation memory cost — the entire motivation for FlashAttention's recompute-on-the-fly strategy.
x_in        (N, S, D)
x_norm = norm(x_in)                  (N, S, D)    # save mean, var
qkv = qkv_proj(x_norm)               (N, S, 3D)
q, k, v = split(qkv)                 each (N, H, S, D_h)
attn_scores = q @ k^T / sqrt(D_h)    (N, H, S, S)
attn_weights = softmax(attn_scores)  (N, H, S, S)  # save for backward
attn_out = attn_weights @ v          (N, H, S, D_h)
attn_proj = out_proj(attn_out)       (N, S, D)
x_after_attn = x_in + attn_proj      (N, S, D)
... and so on for the FFN

Every one of those intermediate tensors might be needed for the backward pass. In the worst case, you have to store all of them. The total activation memory for the forward pass of a transformer is proportional to:

activation_memory ≈ N × S × D × num_layers × constant

The constant depends on which intermediates you save. The dominant term in modern transformers is often the (N, H, S, S) attention scores tensor, which is quadratic in sequence length. This is the reason long-context training is hard, and the reason FlashAttention (Chapter 25) exists — it eliminates this term entirely by recomputing attention on the fly during the backward pass.

The practical implication: activation memory scales linearly with batch size, sequence length, hidden dim, and depth. A 70B model with S=4096, N=4 has activation memory in the hundreds of GBs. The model weights are not the largest thing in your training memory budget — the activations are.

This is why gradient checkpointing exists. Instead of storing every activation, you store only some of them (typically one per transformer block), and during the backward pass you recompute the missing intermediates from the saved checkpoints. This trades roughly 33% more compute for 5–10× less activation memory. We will encounter it again in Chapter 12 (distributed training).

3.9 The four toggles you must understand

PyTorch gives you four levers for controlling autograd. You will see all four in real code.

requires_grad is a per-tensor boolean. If True, every operation involving this tensor builds graph nodes; if False, it doesn’t. By default, parameters in nn.Module have requires_grad=True; tensors you create yourself with torch.zeros(...) have requires_grad=False. You can set it manually:

x = torch.randn(10, requires_grad=True)

.detach() returns a new tensor that shares memory with the original but has requires_grad=False and no grad_fn. It “cuts the graph” at this point. Used when you want to use a tensor’s value as a constant in further computation, without backprop flowing back through it. Example: passing model output as a “fixed target” for a teacher-student loss.

fixed = teacher_model(x).detach()   # don't update the teacher
loss = (student_model(x) - fixed).pow(2).mean()

with torch.no_grad(): is a context manager that disables graph construction for everything inside it. Equivalent to requires_grad=False for all newly created tensors. Used during inference and validation to save memory and speed:

with torch.no_grad():
    val_loss = compute_loss(model, val_data)

There’s also @torch.no_grad() as a decorator, and torch.inference_mode() which is even stricter (and slightly faster).

retain_graph=True is a flag to backward() that prevents the saved intermediates from being freed after the backward pass completes, so you can call backward() again. Almost always a sign of something subtle: most training loops do exactly one backward per forward, and don’t need this. The two cases where you do need it: (1) computing gradients with respect to multiple losses on the same graph, (2) higher-order derivatives.

A common bug pattern: writing a loop where you accumulate a Python loss variable across iterations, then call backward() once. If each iteration was a separate forward pass that built its own piece of the graph, the second backward() will try to traverse already-freed nodes and crash. The fix is either to call backward() per iteration with retain_graph=False (the default), or to accumulate into the gradient by other means.

3.10 Why mixed precision needs loss scaling

You’ll learn this in Chapter 13 in full, but the connection to backprop is too clean not to introduce here.

In mixed-precision training, you store activations in fp16 or bf16 to halve memory and double Tensor Core throughput. This works well for forward, but it causes a problem for the backward pass: gradients in fp16 can underflow to zero. The smallest representable positive fp16 number is about 6 × 10^-8. Many gradients in a deep network are smaller than this. They underflow, become zero, and the corresponding parameter never updates.

The fix is loss scaling: multiply the loss by a large constant (typically 2^16 = 65536) before calling backward(). By the chain rule, this multiplies every gradient in the network by the same constant — pushing the small gradients above the underflow threshold. Just before applying the optimizer step, you divide every gradient back down by the same constant. The math is identical; the precision is preserved.

scale = 65536.0
loss_scaled = loss * scale
loss_scaled.backward()
for p in model.parameters():
    p.grad /= scale
optimizer.step()

torch.cuda.amp.GradScaler automates this with a dynamic scaling factor that increases when gradients aren’t overflowing and halves when they are.

Bf16 doesn’t have this problem. Because bf16 has the same exponent range as fp32 (just less mantissa precision), bf16 gradients don’t underflow. This is why bf16 has overtaken fp16 as the default training dtype on hardware that supports it: you don’t need loss scaling, and the code is simpler.

3.11 Gradient checkpointing — trading compute for memory

Gradient checkpointing (a.k.a. activation recomputation) is the most important memory optimization in training. The idea: don’t save every intermediate from the forward pass. Save only some of them (typically one per transformer block, or one per N layers). During the backward pass, when you need an unsaved intermediate, recompute it on the fly by re-running the forward pass for that segment.

Gradient checkpointing trades 33% extra compute for 5-10x less activation memory by recomputing intermediates from checkpointed block inputs during backward. Without checkpointing: save ALL intermediates (high memory) block 0 m₀ block 1 m₁ block 2 m₂ … all saved → large peak memory With checkpointing: save only block inputs, recompute on backward block 0 ckpt x₁ block 1 ckpt x₂ block 2 intermediates discarded recomputed on demand → +33% compute, −80% activation mem
Checkpointing keeps only block boundary tensors (highlighted); discarded intermediates are recomputed during backward — the +33% compute cost is almost always worth the 5–10x memory savings.

The cost is roughly +33% more compute (one extra forward pass for the recomputed segments). The benefit is 5–10× less activation memory. This is almost always a worthwhile trade in training, and almost every large-scale training run uses it.

In PyTorch:

from torch.utils.checkpoint import checkpoint

def forward(self, x):
    for block in self.blocks:
        x = checkpoint(block, x, use_reentrant=False)
    return self.head(x)

The checkpoint wrapper tells autograd: don’t save the intermediate activations inside block. Instead, save only the input x, and during the backward pass, re-run block(x) to recompute the intermediates.

This is the technique that makes training large models on commodity GPUs possible. It is also the reason you’ll sometimes see “training step time = 1.3× inference step time” rather than 2× — the recomputation is cheap compared to the rest of the backward pass.

3.12 The mental model

Eight points to take into Chapter 4:

  1. Training is gradient descent. Compute ∇_θ L, take a step against it, repeat.
  2. Backprop is the chain rule with bookkeeping. Walk the computation graph backward, multiplying upstream gradients by local Jacobians.
  3. Reverse-mode autodiff wins because there are billions of parameters and one scalar loss.
  4. Autograd builds a tape during the forward pass and plays it back during backward(). The tape stores enough intermediate state to compute every VJP.
  5. VJP, not Jacobian. No framework ever materializes the full Jacobian — the closed-form VJP for each op is hand-coded into the framework.
  6. Activation memory is the dominant cost in training, scales with N × S × D × layers, and is what gradient checkpointing trades against compute.
  7. Four toggles control autograd behavior: requires_grad, .detach(), with torch.no_grad():, retain_graph=True.
  8. Mixed precision in fp16 needs loss scaling; in bf16 it doesn’t. This is why bf16 is the modern default.

In Chapter 4 we use the gradient to actually optimize, and meet the optimizer family.


Read it yourself

  • The PyTorch docs page on Autograd mechanics — read it twice. Maybe three times.
  • Christopher Olah, Calculus on Computational Graphs: Backpropagation — the cleanest visual explanation of reverse-mode autodiff that exists.
  • Andrej Karpathy’s micrograd (GitHub: karpathy/micrograd) — a 100-line scalar-level autograd in Python. Reading it cover-to-cover will cement everything in this chapter.
  • Deep Learning, Goodfellow et al., chapter 6 (sections 6.5–6.6) — the formal derivation.
  • The gradient checkpointing paper: Chen et al., Training Deep Nets with Sublinear Memory Cost (2016).

Practice

  1. Compute by hand: dL/dW for L = (W x - y)^2, where W is (1, 3), x is (3,), y is (1,). Then verify in PyTorch.
  2. Why does torch.softmax(x).sum().backward() produce zero gradients for x? (Hint: think about the constraint that softmax outputs sum to 1.)
  3. Build a tiny PyTorch script that uses loss.backward() twice without optimizer.zero_grad() in between, and confirm that the second call accumulates into .grad rather than replacing.
  4. Read karpathy/micrograd. Add support for the tanh function (forward and backward). Run a tiny MLP on a toy classification problem.
  5. Estimate the activation memory for a forward pass of a 12-layer transformer with S=2048, D=768, H=12, N=8, dtype=fp16. Use the rough formula activation ≈ N × S × D × layers × ~20. (Answer: ~3 GB. The “20” constant is empirical and accounts for the various intermediates per block.)
  6. Why is the gradient of softmax-followed-by-cross-entropy just y - t, but the gradient of softmax alone is a full Jacobian? Derive both and convince yourself the simplification only happens when you compose the two ops.
  7. Stretch: Implement a from-scratch reverse-mode autodiff engine in 200 lines of pure Python that supports +, *, relu, matmul, and mean. Use it to train a 2-layer MLP on a toy regression problem. The point is not to use it; it’s to internalize the algorithm.

Concept check

4 questions. Click a choice to check. Your score is saved locally.

Score
0 / 4
  1. 1. Why does activation memory grow with batch size during the backward pass?
  2. 2. What is the key reason reverse-mode autodiff (backprop) dominates ML rather than forward-mode autodiff?
  3. 3. What does tensor.detach() do that torch.no_grad() does not?
  4. 4. Gradient checkpointing trades compute for memory by
Related chapters