Loss functions and optimization, in just enough depth
"You do not rise to the level of your model. You fall to the level of your loss function"
In Chapter 3 we made the network differentiable. In this chapter we use those derivatives to actually move the parameters toward something useful. There are two questions:
- What scalar are we trying to minimize? That’s the loss function.
- How do we use the gradient to take a good step? That’s the optimizer.
We will go just deep enough that you can read any modern training script and know what every line is doing, justify every choice in code review, and answer the loss/optimizer questions that come up in interviews. We are not going to derive convergence rates or prove anything about strong convexity. That’s not the job here.
Outline:
- The minimization view of training.
- Cross-entropy from first principles (negative log-likelihood of a categorical distribution).
- Why softmax + cross-entropy is one fused op.
- Mean squared error and where it fits.
- Contrastive losses (a preview — Chapter 9 has the full story).
- The problem with vanilla SGD.
- Momentum.
- Adam — the per-parameter learning-rate adaptation.
- AdamW — decoupled weight decay, and why “Adam” usually means “AdamW” now.
- Lion — the modern challenger, with half the optimizer state.
- The optimizer-state memory math (the part interviews care about).
- Learning rate schedules — warmup and cosine decay.
- Gradient clipping.
4.1 Minimization is the framing
Every neural network training run, at any scale, is the same problem in disguise:
find θ* = argmin_θ L(θ)
…where θ are the parameters and L is a scalar loss that summarizes how bad the model is on a chunk of training data. The argmin doesn’t have a closed form for any interesting L, so we approach it iteratively: start at some random θ, compute ∇_θ L, take a step against the gradient, repeat. That’s the entire game.
The two design decisions that separate “this trains” from “this doesn’t” are the loss (what scalar we minimize) and the optimizer (how the gradient becomes a step). Get either one wrong and the other can’t save you.
4.2 Cross-entropy from first principles
For classification — and “next-token prediction” is the world’s biggest classification problem — the loss is cross-entropy.
The story goes like this. The model emits logits z of shape (num_classes,). We turn the logits into a probability distribution with softmax:
p_i = exp(z_i) / Σ_j exp(z_j)
Now p is a probability vector that sums to 1. We have a target class t (an integer in [0, num_classes)). We want p_t — the probability the model assigned to the correct class — to be as close to 1 as possible.
A natural objective is to maximize p_t. Equivalently, we can maximize log p_t (logarithm is monotonic, and the log makes the math nicer). Equivalently, we can minimize the negative log-likelihood:
L = -log p_t
That’s the loss for a single example. For a batch of N examples:
L = -(1/N) Σ_n log p_{t_n}^{(n)}
This is negative log-likelihood (NLL). It is also called cross-entropy because of an information-theoretic interpretation: it’s the cross-entropy of the one-hot target distribution against the model’s predicted distribution. The two names refer to the same loss; ML people use them interchangeably.
Why log?
Three reasons.
(1) It separates products into sums. If you have a sequence of independent observations, the joint probability is the product, and log(p_1 · p_2 · ...) = log p_1 + log p_2 + .... This is critical for autoregressive models: the log-likelihood of a whole sequence is the sum of the log-likelihoods of each token, which is what we want as a loss.
(2) It punishes confidently wrong predictions much harder than mildly wrong ones. If the model assigns probability 0.001 to the correct class, the loss is about 6.9. If it assigns 0.5, the loss is about 0.69. If it assigns 0.99, the loss is about 0.01. The asymmetry — huge penalty for “very confident and wrong” — is a feature: it tells the model “your overconfident mistakes are the ones I want you to fix first.”
(3) It connects to information theory. -log p_t is the number of bits (with log_2) or nats (with log_e) of “surprise” the model assigns to the correct outcome. Cross-entropy is the average surprise. A model with cross-entropy 2.3 nats per token is, on average, behaving as if it expected e^{-2.3} ≈ 10% of the probability mass to be in the right place per step. This mental model is the basis for the perplexity metric you’ll see everywhere in language modeling: perplexity = exp(cross_entropy_per_token).
Why “softmax + cross-entropy” is always one op
We saw in §3.7 that the gradient of softmax through cross-entropy is just y - t (where y is the softmax output and t is the one-hot target). This is the cleanest gradient in all of deep learning. It’s the result of the algebraic simplification of d/dz (-log softmax(z)_t).
Computing softmax and cross-entropy as two separate operations doesn’t get this simplification. You’d compute softmax, take its log, then negate and select. The intermediate softmax output has to be stored for the backward pass, the log is numerically unstable for small probabilities, and the chain rule is more work.
Computing them as one fused op (F.cross_entropy in PyTorch) gives you:
- Better numerical stability (log-sum-exp trick built in).
- A simpler backward pass (gradient is just
y - t). - Lower memory (no separate
softmaxoutput to save).
Always use F.cross_entropy(logits, targets). Never write (-target * log(softmax(logits))).sum() by hand. The pattern is universal: when an op composition has a clean fused form, the framework will give you that fused form, and you should use it.
4.3 Mean squared error
For regression — predicting a continuous value, like the price of a house or the depth of a pixel — the loss is mean squared error (MSE):
L = (1/N) Σ_n (y_n - y_hat_n)^2
The derivative is just 2 (y_hat - y). Same logic as cross-entropy: the model gets pushed harder when it’s more wrong, and the gradient direction is “predict closer to the target.”
MSE comes from a probabilistic interpretation too: it’s the negative log-likelihood under the assumption that the output is Gaussian-distributed around the target with constant variance. This is the only loss in this chapter that has a clean Gaussian interpretation; cross-entropy is the analogous quantity under a categorical distribution. Other distributions give you other losses (Poisson, Bernoulli, von Mises, etc.) — the framework here is called “negative log-likelihood under a chosen output distribution,” and once you see it that way, every loss feels like a special case.
Variants to know:
- L1 loss (mean absolute error):
|y - y_hat|. More robust to outliers because it doesn’t square them. Has a non-smooth gradient at zero, which can hurt convergence. - Huber loss (smooth L1): MSE near zero, L1 far from zero. The best of both. Used in object detection bounding-box regression.
For classification, you almost always want cross-entropy. For regression, you almost always start with MSE and only reach for the variants if you have outliers.
4.4 Contrastive losses — a preview
A third family of losses, central to embedding models and self-supervised learning: contrastive losses. The setup is “given an anchor, pull positive examples close in embedding space and push negatives far.” The canonical version is InfoNCE:
L = -log [ exp(sim(anchor, positive) / τ) / Σ_j exp(sim(anchor, j) / τ) ]
Look familiar? It’s softmax cross-entropy in disguise, where the “classes” are the candidate examples and the “logits” are the similarities scaled by a temperature τ. The model is being asked to assign the highest probability to the positive among a slate of candidates.
We’ll cover contrastive losses in full in Chapter 9 (embeddings and rerankers). Just lodge in your head now: most losses in modern ML are some flavor of cross-entropy. The differences are in how the “logits” are produced and how the “classes” are chosen.
4.5 The problem with vanilla SGD
The simplest optimizer is stochastic gradient descent:
θ ← θ - η · ∇_θ L
That’s it. Compute the gradient on a mini-batch, scale by the learning rate, subtract. It’s the optimizer everyone learns first. It is also bad enough at scale that nobody uses it anymore for big neural networks.
What’s wrong with it?
(1) It treats every parameter the same. A single global learning rate η is applied to every parameter, regardless of how steep or flat the loss is in that direction. Some directions are very narrow (high curvature), some are very wide (low curvature). The same step is too big in one and too small in the other.
(2) It oscillates in narrow valleys. If the loss surface is shaped like a long thin valley (which it usually is, especially in deep networks), SGD bounces back and forth across the steep sides while making slow progress along the flat axis. You can see this on toy problems and it’s not subtle.
(3) It is noisy. The gradient is computed on a random mini-batch, so it’s a noisy estimate of the true gradient over the whole dataset. SGD acts on the noisy estimate directly, which makes the trajectory jagged.
The fixes come in two flavors: momentum (smooth out the noise) and adaptive per-parameter learning rates (give each parameter its own effective η). Modern optimizers do both.
4.6 Momentum
Momentum is a tiny modification to SGD:
v ← β · v + ∇_θ L
θ ← θ - η · v
You maintain a velocity vector v that is an exponential moving average of past gradients (with decay β, typically 0.9). Instead of stepping in the current gradient’s direction, you step in the velocity’s direction.
The intuition: a heavy ball rolling down the loss surface has inertia. It doesn’t get knocked sideways every time the gradient wobbles. It builds up speed in directions where the gradient is consistently pointing. The valley-oscillation problem disappears: the across-valley component of the gradient cancels out across steps (because it flips sign), while the along-valley component reinforces.
Cost: you have to store one extra tensor per parameter (v), which doubles your optimizer-state memory. We’ll come back to this in §4.11.
Variants you’ll see:
- Heavy ball momentum — the formulation above, due to Polyak.
- Nesterov momentum — a slightly more clever version that “looks ahead” by evaluating the gradient at
θ - β vrather than atθ. Theoretically better convergence, marginally better in practice.
Both are subsumed by Adam.
4.7 Adam — the per-parameter learning rate adaptation
Adam (Kingma & Ba, 2014) was the optimizer that made training deep networks feel routine. It combines momentum with per-parameter adaptive learning rates based on the running variance of past gradients.
m ← β_1 · m + (1 - β_1) · g # 1st moment (momentum)
v ← β_2 · v + (1 - β_2) · g^2 # 2nd moment (running variance)
m_hat ← m / (1 - β_1^t) # bias correction
v_hat ← v / (1 - β_2^t) # bias correction
θ ← θ - η · m_hat / (sqrt(v_hat) + ε)
Read that update rule. The numerator m_hat is just momentum. The denominator sqrt(v_hat) + ε is the per-parameter scale: parameters whose gradients have been historically large get divided by a large number (small effective learning rate); parameters whose gradients have been historically small get divided by a small number (large effective learning rate). The scaling is automatic and per-parameter.
Default hyperparameters: β_1 = 0.9, β_2 = 0.999, ε = 1e-8. These work for almost everything; you almost never tune them.
Cost: you have to store two extra tensors per parameter (m and v), so Adam’s optimizer state is 2× the parameter count. For a 70B model in fp32 optimizer state, that’s an extra 560 GB of GPU memory just for the optimizer. This is a real number you have to plan around.
The bias correction
The (1 - β^t) divisions look weird. Here’s why: m and v are initialized to zero, so for the first few steps they’re biased toward zero (because the running average is being averaged with a bunch of starting zeros). The bias correction divides by (1 - β^t) to undo this — at t=1, you divide by (1 - β) which gives you the raw gradient, not a tiny fraction of it. After a few hundred steps, β^t is essentially zero and the correction does nothing.
When Adam is wrong
Adam is the default for a reason — it works for almost every supervised learning task. The exceptions:
- Very large-scale image classification. SGD with momentum and a good schedule can sometimes outperform Adam at the end of training on ImageNet-scale classification. This is well-documented and not understood.
- Embedding models with sparse gradients. Adam doesn’t handle sparsity well. Specialized variants like Adagrad or sparse Adam exist for this.
- Anything where you can’t afford 2× optimizer state memory. This is where Lion comes in.
4.8 AdamW — the decoupled weight decay fix
You almost never use plain Adam. You use AdamW, which fixes a subtle but important bug.
Weight decay is a regularization technique: you add a term λ ||θ||^2 to the loss, which encourages the weights to stay small. The gradient of this term is 2 λ θ, which gets added to the gradient before the optimizer step.
In plain Adam, the weight decay term goes through the same per-parameter scaling as the rest of the gradient. This means parameters with large running variance get less weight decay (because they’re divided by sqrt(v_hat)), and parameters with small running variance get more. This is not what you want — you want all parameters to be regularized at the same rate.
AdamW (Loshchilov & Hutter, 2017) decouples weight decay from the gradient: it applies the decay directly to the weights as a separate step, after the Adam update.
θ ← θ - η · (m_hat / (sqrt(v_hat) + ε) + λ · θ)
That’s the only difference. It’s a one-line change. It matters more than it should, especially for transformers, where the difference between Adam and AdamW can be the difference between converging and not.
Modern code says “Adam” and means AdamW. PyTorch’s torch.optim.AdamW is what you almost always want. The plain torch.optim.Adam is mostly there for legacy reproducibility.
4.9 Lion — half the memory, similar quality
Lion (Chen et al., 2023) is the most interesting recent challenger to Adam. The update rule:
m ← β_1 · m + (1 - β_1) · g
update ← sign(β_2 · m + (1 - β_2) · g) # sign function!
θ ← θ - η · (update + λ · θ)
m ← β_2 · m + (1 - β_2) · g
The trick is the sign(). The update direction is the sign of a momentum-smoothed gradient, with magnitude controlled entirely by the learning rate. There’s no per-parameter variance to track, so there’s only one extra tensor per parameter (m) — half the optimizer-state memory of AdamW.
Empirically, Lion matches AdamW on most workloads at large scale, and sometimes wins. The community is still cautious about adopting it because Adam’s gravity is enormous, but it has a real chance of becoming the new default. For now, the practical answer is: AdamW for almost everything, Lion if you’re memory-constrained and willing to retune.
4.10 Other optimizers worth knowing exist
A few names you should be able to recognize and place:
- Adafactor — Adam with factored second moments to save memory. Used to train T5. Lower memory, somewhat trickier to tune.
- Shampoo — second-order method, expensive but powerful. Used in some Google internal training runs.
- Sophia — diagonal approximation to a second-order method. Recent, promising for LLMs.
- AdamW8bit — quantized optimizer state from
bitsandbytes. Same algorithm as AdamW but the moments are stored in 8-bit. Cuts optimizer memory by 4×. Standard for fine-tuning on consumer GPUs.
The pattern with optimizer research: most papers claim improvements that don’t survive replication. AdamW is the floor everyone should compare against. If a new optimizer doesn’t beat AdamW on your task in your hands, it doesn’t beat AdamW.
4.11 The optimizer-state memory math (the part interviews care about)
Here is the table to memorize:
| Optimizer | State per parameter | Total state for a 70B model |
|---|---|---|
| Plain SGD | 0 | 0 GB |
| SGD + momentum | 1 tensor (v) | 280 GB (fp32) / 140 GB (bf16) |
| AdamW | 2 tensors (m, v) | 560 GB (fp32) / 280 GB (bf16) |
| Lion | 1 tensor (m) | 280 GB (fp32) / 140 GB (bf16) |
| Adafactor | ~factored, much less | ~70 GB |
This is one of the most quoted facts in distributed training interviews. Here’s the back-of-envelope math for AdamW on a 70B model in fp32 optimizer state:
70B params × 2 (m, v) × 4 bytes (fp32) = 560 GB
That’s the optimizer state alone, not the model weights, not the activations, not the gradients. And that’s why training a 70B model on a single GPU is impossible (an H100 has 80 GB of HBM), and why the entire field of distributed training (Chapter 12) exists. The memory budget for training looks like:
- Weights: 70B × 2 bytes (bf16) = 140 GB
- Gradients: 70B × 2 bytes (bf16) = 140 GB
- Optimizer state (AdamW, fp32): 560 GB
- Activations: depends heavily on batch size, sequence length, and gradient checkpointing — easily another few hundred GB
Total: roughly 1 TB of GPU memory for a 70B model, before you even pick a batch size. ZeRO and FSDP are the techniques that shard this across many GPUs so each one only holds a slice. We’ll see how in Chapter 12.
For inference, none of this applies — you only need the weights (140 GB in bf16, 35–40 GB in INT4), and there’s no optimizer or gradient state. This is why a 70B model is much easier to serve than to train. You’ll get this question in interviews: explain why a model that needs 8 GPUs to train can sometimes serve on 2.
4.12 Learning rate schedules
The learning rate η is the most important hyperparameter in training. Too high and the loss diverges; too low and training takes forever. Worse, the best learning rate changes over the course of training: high at the start (when you want to move fast across the rough loss surface) and low at the end (when you want to fine-tune into a minimum).
The two essential schedules:
Warmup
For the first few thousand steps, start at zero learning rate and linearly ramp up to your target. This is warmup, and it is critical for transformers.
Why? Because at the very start of training, the model’s predictions are random, the gradients are huge, and the Adam moments (m, v) are still at their initial zero values. Without warmup, the very first few steps take huge updates in directions that turn out to be wrong, which can knock the model into a bad region of parameter space that it can never recover from. Warmup gives the moments time to stabilize before letting the learning rate actually do anything.
Typical warmup is 1–10% of total training steps. For LLMs, “first 2000 steps” or “first 5% of total” are common values.
Cosine decay
After warmup, you decay the learning rate. The standard schedule is cosine decay: the learning rate follows a half-cosine from the warmup end value down to (typically) 10% or 0% of the peak.
η(t) = η_min + 0.5 × (η_max - η_min) × (1 + cos(π × (t - t_warmup) / (T - t_warmup)))
At t = t_warmup, the cosine is 1 and you get η_max. At t = T, the cosine is -1 and you get η_min. In between, the decay is smooth and faster in the middle than at the edges.
Why cosine and not linear or exponential? Mostly empirical — it consistently outperforms the alternatives by a hair on big models, and the convention has stuck. Some papers use linear decay; the difference is small.
Other schedules to know
- Constant after warmup: for small experiments and fine-tuning, you sometimes just hold the learning rate constant after warmup.
- Step decay: drop the learning rate by 10× at fixed milestones. Old-school, used in classical computer vision (ImageNet ResNet training).
- One-cycle (Smith): triangular schedule that goes up and back down. Some practitioners swear by it; it’s not standard for LLMs.
4.13 Gradient clipping
The other safety mechanism in modern training: gradient clipping. After computing the gradient and before applying the optimizer step, you cap the L2 norm of the gradient at some maximum value (typically 1.0 for transformers).
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
The motivation: occasionally, a very small batch of unusually-difficult examples produces a gradient that’s much bigger than typical. Without clipping, this huge step lands the model in a bad region of parameter space. Clipping makes training robust to these spikes.
For transformers, clipping at 1.0 is the universal default. You’ll see it in every modern training script. It’s almost free and only helps.
4.14 The mental model
Eight points to take into Chapter 5:
- Training is minimization. Pick a loss, take gradient steps against it.
- Cross-entropy is the loss for classification. For LLMs, “classification over the vocabulary at each position” is exactly cross-entropy.
- Softmax + cross-entropy is always one fused op. The gradient is
y - tand the simplification gives you both speed and stability. - MSE is the loss for regression, and the derivation comes from negative log-likelihood under a Gaussian.
- AdamW is the default optimizer. Plain Adam is a legacy spelling.
- The memory math: AdamW state is 2× the parameters in fp32. A 70B model has ~560 GB of optimizer state alone, which is why distributed training exists.
- Schedule = warmup + cosine. Warmup is non-negotiable for transformers; the cosine part is convention with mild empirical support.
- Clip gradient norm at 1.0. Almost free, only helps.
In Chapter 5 we look at the input side: what tokens actually are, why subword tokenizers won, and the bugs the tokenizer secretly causes.
Read it yourself
- The original Adam paper: Kingma & Ba, Adam: A Method for Stochastic Optimization (2014). Short and clear.
- The AdamW paper: Loshchilov & Hutter, Decoupled Weight Decay Regularization (2017). Equally short.
- The Lion paper: Chen et al., Symbolic Discovery of Optimization Algorithms (2023). The interesting bit is how they found Lion (search) more than the algorithm.
- Sebastian Ruder’s An Overview of Gradient Descent Optimization Algorithms (blog post, 2016) — still the best overview of the optimizer family.
- Deep Learning Tuning Playbook by Google Brain (open source on GitHub) — practical hyperparameter tuning advice for the AdamW + cosine schedule combo.
Practice
- Compute the cross-entropy loss by hand for logits
[2.0, 1.0, 0.1]and target class0. Verify withF.cross_entropy(torch.tensor([[2.0, 1.0, 0.1]]), torch.tensor([0])). - A model has 7 billion parameters. Compute the optimizer-state memory for SGD, SGD+momentum, AdamW, and Lion in both fp32 and bf16. Tabulate the answer.
- Why does plain Adam (not AdamW) couple weight decay with the per-parameter scaling? Write down the math for one parameter and show that the effective decay differs between high-variance and low-variance parameters.
- Implement the cosine schedule with warmup as a function
lr(step, warmup_steps, total_steps, peak_lr, min_lr). Plot it forpeak_lr=1e-4, warmup=2000, total=100000. - Why does training a transformer without warmup often diverge in the first 100 steps? Trace through what happens to the Adam
vmoment in step 1 vs step 1000. - Estimate the total training memory (weights + grads + optimizer state in AdamW + ~50% activation overhead) for a 13B model in bf16 with fp32 optimizer state. Will it fit on a single H100 (80 GB)? On a single B200 (192 GB)?
- Stretch: Take a small open-source LLM (e.g., GPT-2 small) and try training it on a small dataset with (a) plain SGD, (b) Adam, (c) AdamW. Compare loss curves. The point is to see the AdamW improvement empirically and to feel how brittle SGD is on transformers.
Concept check
4 questions. Click a choice to check. Your score is saved locally.
- 1. AdamW differs from Adam by decoupling weight decay. Concretely, what does this mean?
- 2. For a model trained with AdamW in full fp32, how many bytes of optimizer state are stored per parameter?
- 3. Cross-entropy loss for a softmax classifier is often called 'negative log-likelihood.' Why negative?
- 4. A learning rate warmup schedule holds lr near zero for the first N steps before rising to the target. The main reason for warmup is