Part II · Training, Fine-Tuning, Alignment
Chapter 13 Deep dive ~21 min read

Mixed precision training: fp16, bf16, fp8, loss scaling

"In bf16 your gradients survive. In fp16 they die quietly. In fp8 they sometimes die loudly"

This is the chapter about precision: which numeric format you store your tensors in during training, why fp32 wastes silicon, why fp16 is dangerous, why bf16 became the default, and what fp8 buys you on H100. You will not implement mixed precision from scratch in your day job — torch.cuda.amp does it for you — but you will diagnose mixed-precision bugs (NaN losses, exploded gradients, mysterious “the model trained fine in fp32 but diverges in fp16”). The bug is almost always one of three things, and this chapter is about all three.

Outline:

  1. Why fp32 wastes silicon.
  2. The IEEE 754 layout: sign, exponent, mantissa.
  3. fp16 — what dies and why.
  4. Loss scaling — the fix for fp16 underflow.
  5. bf16 — same range as fp32, less mantissa, no need for loss scaling.
  6. TF32 — the in-between for matmul on A100/H100.
  7. fp8 — e4m3 vs e5m2, calibration, what H100 gives you.
  8. Master weights and the fp32 copy you still need.
  9. torch.cuda.amp.autocast — what it actually does.
  10. Where mixed precision goes wrong.
  11. The mental model.

13.1 Why fp32 wastes silicon

A 70B parameter model in fp32 is 70 × 10⁹ × 4 bytes ≈ 280 GB. In bf16 it’s 140 GB. The memory savings alone are 2×, which is large but not the main story.

The main story is throughput. Modern NVIDIA Tensor Cores (introduced in Volta, expanded in Ampere, doubled in Hopper) are matrix-multiply units that operate on lower-precision inputs and (sometimes) accumulate to a higher-precision output. The throughput numbers for an H100, dense (no sparsity), per second:

PrecisionTFLOP/s
FP64 (Tensor Core)67
TF32495
fp16 / bf16989
FP8 (e4m3 / e5m2)1979
INT81979

The pattern: every halving of precision roughly doubles the achievable FLOPs. The reason is silicon — Tensor Cores are special-purpose units and they can pack more low-precision multiply-add ops into the same area than they can pack high-precision ones. There is no software trick that closes this gap; it’s a hardware property.

H100 Tensor Core throughput in TFLOP/s by precision: fp64 67, TF32 495, fp16/bf16 989, fp8/int8 1979. TFLOP/s fp64 67 TF32 495 fp16 / bf16 989 fp8 / int8 1979 H100 SXM5 dense (no sparsity). Each halving of precision roughly doubles achievable throughput.
Training in fp32 uses the fp64 bar — roughly 1/15 of the H100's peak; switching to bf16 reclaims almost all of that gap without any precision-related risk.

Two consequences:

  1. If you train in fp32 you are running at 5% of the GPU’s peak throughput. The H100’s 67 TFLOP/s of fp32 is roughly 1/15 of its bf16 throughput. Training in fp32 is leaving an order of magnitude on the floor.
  2. The gain from going to lower precision is not just memory; it’s compute. Halving the precision halves the bytes you have to read from HBM (so you saturate memory bandwidth less easily) and doubles the FLOPs you can do per second. For matmul-bound workloads (which transformers are), going to a lower precision is the largest single throughput optimization available.

This is why every modern training pipeline uses mixed precision. The question is not “should I” but “which precision and how do I avoid the precision-related bugs.”

13.2 IEEE 754 — sign, exponent, mantissa

A floating-point number is (-1)^s × 2^e × (1 + m/2^p), where:

  • s is the sign bit (1 bit).
  • e is the exponent (with a bias subtracted to make it signed). The number of exponent bits determines the range — how big and how small numbers can get.
  • m is the mantissa (or “significand”). The number of mantissa bits determines the precision — how many distinct values can be represented within a given range.

The standard formats:

FormatTotal bitsSignExponentMantissaRange (max)Smallest normal
FP646411152~1.8 × 10³⁰⁸~2.2 × 10⁻³⁰⁸
FP32321823~3.4 × 10³⁸~1.2 × 10⁻³⁸
FP1616151065 504~6.1 × 10⁻⁵
bf1616187~3.4 × 10³⁸~1.2 × 10⁻³⁸
FP8 (e4m3)8143448~1.95 × 10⁻³
FP8 (e5m2)815257 344~6.1 × 10⁻⁵

Read this table carefully. The key relationships:

  • fp32 and bf16 have the same exponent (8 bits). They have the same range. The difference is purely precision.
  • fp16 and fp32 have very different ranges. fp16 maxes out at 65 504 and underflows below ~6 × 10⁻⁵. The fp32-vs-fp16 range gap is 10³⁸ vs 10⁵ — many orders of magnitude.
  • fp8 e4m3 has a tiny range. It only goes up to 448. The smallest positive number is ~2 × 10⁻³. This is a very tight range for storing activations and gradients of a deep network without careful scaling.
  • fp8 e5m2 has the same range as fp16. It trades 1 mantissa bit for 1 exponent bit, giving up precision in exchange for not underflowing.
Bit layout comparison of fp32, bf16, fp16, fp8 e4m3, and fp8 e5m2 showing sign, exponent, and mantissa fields. sign (1 bit) exponent mantissa fp32 8 exp 23 mantissa bf16 8 exp 7 mant ← same range as fp32, less precision fp16 5 exp 10 mant ← smaller range, underflow risk fp8 e4m3 4 3 ← 8 bits, tiny range (max 448) fp8 e5m2 5 2 ← more range, less precision (for gradients)
fp32 and bf16 share an 8-bit exponent (same range); fp16 cuts the exponent to 5 bits (smaller range, underflow risk); fp8 has only 8 total bits, making per-tensor scaling mandatory.

The format you pick determines what your gradients can and cannot represent. If you store a 1e-7 gradient in fp16, it underflows to zero — the parameter never updates. If you store 1e-7 in bf16 or fp32, it survives.

13.3 fp16 — what dies and why

When you train a deep network, the gradients are typically much smaller than the activations. The reason is the chain rule: the gradient at layer L is the product of L Jacobian factors going back to the loss, and most of those factors have norms less than 1. The gradient at layer 1 of a 32-layer transformer can easily be 1e-6 or smaller in magnitude.

1e-6 is below fp16’s underflow threshold (~6e-5). Storing the gradient in fp16 turns it into zero. The parameter receives no update. Training stalls.

This is the fp16 underflow problem, and it is the main reason “I trained the same model in fp32 and it worked, but in fp16 the loss flatlines after step 100” is a common bug. The fix is loss scaling.

13.4 Loss scaling

The trick: multiply the loss by a large constant before calling loss.backward(). By the chain rule, this multiplies every gradient in the network by the same constant, pushing the small gradients above the underflow threshold and making them representable in fp16.

scale = 65536.0
loss_scaled = loss * scale
loss_scaled.backward()             # gradients are scale × the true gradients
for p in model.parameters():
    p.grad /= scale                # unscale before applying optimizer
optimizer.step()

The math is identical — you scale up by S and scale back down by S — but the numerical representation is different. A 1e-7 true gradient becomes a 6.5e-3 representable fp16 number, then gets unscaled back to 1e-7 (in fp32, where the optimizer step happens — see §13.8). Underflow avoided.

The catch is that scaling up too much causes overflow: if a gradient was already 1e-2 and you multiply by 65536, you get 655 which overflows fp16’s max (65 504). Now the gradient is inf and the optimizer step does the wrong thing.

The fix is dynamic loss scaling: start with a large scale, watch for inf/NaN gradients, and halve the scale when overflow happens (and double the scale after a window of clean steps, to keep finding the largest scale that works).

scaler = torch.cuda.amp.GradScaler()

for batch in loader:
    with torch.cuda.amp.autocast(dtype=torch.float16):
        loss = model(batch).loss
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)         # unscale grads in-place, in fp32
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)             # skip step if any inf/NaN; halve scale
    scaler.update()                    # adjust scale for next iter

GradScaler handles the dynamic scaling, the inf/NaN detection, and the skip-if-overflow logic. You almost never write the loss-scaling code by hand — you use GradScaler.

This entire mechanism exists because of fp16’s range limitations. In bf16, you don’t need any of it.

13.5 bf16 — the modern default

Brain Float 16 (bf16) was designed by Google for TPUs and adopted everywhere once Ampere added hardware support. It has the same exponent width as fp32 (8 bits) but only 7 mantissa bits. Compared to fp16:

  • Same range as fp32 (~10³⁸). No underflow problem at training scale.
  • Less precision than fp16 (~3 decimal digits vs ~3.5). The relative error is larger, but for training neural networks this rarely matters.
  • Same memory and Tensor Core throughput as fp16. No performance penalty for the larger range.
  • No loss scaling required. This is the killer feature.

The “no loss scaling” is the reason bf16 won. Loss scaling adds complexity (the GradScaler), occasional bugs (incorrectly handled inf gradients), and the constant risk that someone will turn it off and not notice. With bf16, you just use the format directly. The training script becomes simpler and the failure modes go away.

with torch.cuda.amp.autocast(dtype=torch.bfloat16):
    loss = model(batch).loss
loss.backward()       # no GradScaler — gradients are already in fp32 master copies
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()

This is the modern training loop. As of late 2025, bf16 is the default training precision for almost every open LLM. fp16 survives in some legacy code and on hardware without bf16 support (V100 and earlier), but if you’re starting a new training run, you should use bf16.

The one place bf16 falls short is when the relative precision really matters — for example, when accumulating large numbers of gradient updates over a long horizon, or in some specialized linear algebra. For these, the standard answer is “compute in bf16, accumulate in fp32.” We’ll see this in the master-weights section (§13.8).

13.6 TF32 — the in-between for matmul

NVIDIA introduced TensorFloat-32 (TF32) with the Ampere architecture (A100). It’s a 19-bit format used internally by Tensor Cores: 8-bit exponent (same as fp32), 10-bit mantissa (same as fp16). It’s a hardware-only format — you don’t store tensors in TF32 — that gets used as the accumulator precision for matmul operations when the inputs are fp32.

The motivation: even when your model is in fp32, the matmul kernel can speed up by ~6× on A100 (vs the previous fp32 throughput) by silently using TF32 in the multiply-add. The output is still fp32; the inputs are still fp32; only the internal computation is reduced precision. This is “free” speedup with a tiny precision cost (~10⁻⁴ relative error on matmul outputs).

TF32 is enabled by default in PyTorch on Ampere+ GPUs:

torch.backends.cuda.matmul.allow_tf32 = True   # default on A100/H100
torch.backends.cudnn.allow_tf32 = True

You don’t have to think about it. It just makes fp32 training faster. If you need bit-exact fp32 (for comparison testing, mostly), you can disable it.

TF32 is a transitional format, mostly used by people who have legacy fp32 training scripts and want some speedup without converting to bf16. For new training, just use bf16.

13.7 fp8 — what H100 gives you

Hopper (H100, H200) introduced hardware fp8 support, with two variants:

  • e4m3 — 4-bit exponent, 3-bit mantissa, max value 448. Optimized for forward activations and weights.
  • e5m2 — 5-bit exponent, 2-bit mantissa, max value 57 344. Optimized for gradients during the backward pass (gradients have a wider dynamic range than activations, so they need more exponent bits).

The pitch: fp8 doubles Tensor Core throughput vs bf16. On H100, that’s ~1980 TFLOP/s for fp8 dense matmul vs 989 TFLOP/s for bf16. For very large training runs, this is a meaningful speedup — DeepSeek-V3 famously trained natively in fp8 and credited the cost savings ($6M for a 671B MoE) largely to this choice.

The catch is that fp8’s range is so tight that you need per-tensor scaling, similar to (but more aggressive than) fp16 loss scaling. Every tensor has a per-tensor “scale factor” that says “treat this fp8 representation as value × scale.” Before computing a matmul, the framework reads the scale, dequantizes, computes, and rescales. The library that handles this in PyTorch is TransformerEngine (NVIDIA’s official, made for H100 fp8 training).

import transformer_engine.pytorch as te

# A linear layer that natively uses fp8 inputs/outputs:
layer = te.Linear(d_in, d_out, bias=False)

with te.fp8_autocast(enabled=True, fp8_recipe=DelayedScaling(...)):
    out = layer(x)

Two things make fp8 training hard in practice:

  • Calibration is fragile. The per-tensor scales have to be re-estimated periodically based on the running max of the tensor. If the max changes too quickly, the scales are wrong and you get overflow or underflow. Modern recipes use a “delayed scaling” scheme where the scale is updated based on a window of past max values.
  • Not every operation supports fp8. Element-wise ops (LayerNorm, softmax, GELU) typically still happen in bf16 or fp32, with fp8 only used for the matmuls. The conversions between formats add overhead.

For most training, bf16 is still the right choice. fp8 is for the largest runs where every percent of throughput matters. As tooling matures (TransformerEngine v2, FP8 recipes in vLLM and SGLang for inference, Megatron-LM fp8 support), we’ll see more frontier runs go fp8.

For inference, fp8 is even more tractable because there’s no backward pass to worry about. We’ll see fp8 inference quantization in Chapter 26.

13.8 Master weights and the fp32 copy

One subtlety that catches everyone the first time. In a mixed-precision training loop, the optimizer always operates on fp32 weights, even when the rest of the training is in bf16. The flow:

Mixed precision training data flow: fp32 master weights cast down for compute, gradients cast up for the optimizer step. master weights fp32 (4B/param) cast ↓ bf16 weights bf16 (2B/param) fwd+bwd bf16 gradients bf16 (2B/param) cast ↑ AdamW fp32 optimizer step updates fp32 master weights; bf16 copy refreshed for next forward
The optimizer always runs in fp32 on the master weights; the bf16 working copy is used only for the compute-bound forward and backward passes — this two-copy design lets tiny per-step updates accumulate without rounding to zero.
  1. Master weights in fp32 — this is the “true” copy of the parameters. AdamW updates these.
  2. Before each forward pass, the master fp32 weights are cast down to bf16, and the bf16 copy is what the model uses for the forward and backward passes.
  3. The backward pass produces gradients in bf16 (or fp16 with loss scaling).
  4. Gradients are cast up to fp32 and applied to the master weights via the optimizer step.
  5. The bf16 copy is refreshed for the next forward.

Why? Because fp32 is the right precision for accumulating tiny optimizer updates. AdamW accumulates two moments (m and v) over time, and the per-step updates are very small (sometimes 1e-7 or smaller). In bf16, those updates would round to zero and the optimizer would silently fail to learn. In fp32, they accumulate cleanly.

The cost is that you store two copies of the weights: a bf16 copy for compute and an fp32 master copy for the optimizer. Plus the fp32 optimizer state (m, v). Plus the bf16 gradients. So a 70B model in mixed precision actually uses:

  • Master fp32 weights: 280 GB
  • bf16 working weights: 140 GB
  • bf16 gradients: 140 GB
  • fp32 AdamW m: 280 GB
  • fp32 AdamW v: 280 GB
  • Total: ~1.1 TB before activations

This is the memory total we used in Chapter 12. The “1 TB” number for training a 70B model assumes mixed precision with fp32 master weights and fp32 optimizer state. Going further — to fp8 weights, fp8 gradients, fp16 optimizer state — can shrink this further. But the master fp32 copies of m and v are usually kept; they’re the things AdamW is most sensitive to.

There are clever variants. 8-bit AdamW (from bitsandbytes) stores the optimizer state quantized to 8 bits, reducing the optimizer state by 4×. Quality is essentially preserved. This is the standard technique for fine-tuning on consumer GPUs (Chapter 15).

13.9 torch.cuda.amp.autocast — what it actually does

autocast is a Python context manager that changes the default dtype for operations inside the context. When you wrap a forward pass in autocast(dtype=torch.bfloat16), the matmuls and convolutions inside it run in bf16, but other operations (LayerNorm, softmax, loss functions) still run in fp32 because those benefit from the precision.

Specifically, autocast maintains a list of operations that are:

  • Always cast down (matmul, conv, RNN cells, attention): the speed wins are large and the precision loss is small.
  • Always kept in fp32 (LayerNorm, softmax, log, exp, sum reductions, loss functions): the precision matters for stability.
  • Promoted to the higher input dtype (most element-wise ops): ambiguous, follows the wider input.

This per-op handling is the actual content of “mixed precision.” It’s a recipe for which operations to run at what precision, designed empirically to maximize speed without breaking training. You don’t have to think about it — autocast knows.

with torch.cuda.amp.autocast(dtype=torch.bfloat16):
    output = model(input)        # matmuls in bf16, softmax in fp32, etc.
    loss = loss_fn(output, target)   # loss in fp32

A few things to know:

  • autocast is per-thread. It only affects the current thread, not background data loaders or other workers.
  • autocast does not change the storage dtype of your parameters. Your model can still hold fp32 parameters; autocast just casts them down at the matmul boundary. This is what enables the master-weights pattern in §13.8.
  • autocast is composable with FSDP and DDP. Use them together. FSDP also has its own mixed_precision configuration that handles parameter dtype, gradient dtype, and reduction dtype separately. The two are designed to work together, but the exact interactions can be subtle.

13.10 Where mixed precision goes wrong

The top mixed-precision bugs, in rough order of frequency:

(1) NaN loss after a few steps. Almost always overflow somewhere — usually a softmax with very large logits, an exp() in the loss function, or a normalization layer with a tiny epsilon producing inf. The fix is to make sure the offending op runs in fp32 (autocast usually does this; make sure you haven’t disabled it).

(2) Loss flatlines at a high value. If you’re using fp16, this is almost always loss-scaling failure: gradients are underflowing and the model isn’t learning. Switch to bf16 or check that GradScaler is enabled and not in a degenerate state.

(3) Loss looks fine but eval is much worse than expected. Sometimes mixed precision introduces small biases that don’t show up in the loss but degrade specific evals. The diagnostic is to run the same model in pure fp32 and compare. If the gap is large, something in your mixed-precision setup is broken.

(4) “Why is the model slower in mixed precision than fp32?” Usually you’re fighting autocast — some op in your model isn’t on the cast-down list, so the cast-up-cast-down dance dominates the cost. Check the trace.

(5) FSDP + autocast gradient dtype mismatch. FSDP wants to all-reduce gradients in a specific dtype. If autocast produces gradients in a different dtype, you get weird errors. The fix is to set FSDP’s mixed_precision config to match what autocast is producing.

(6) Optimizer state corruption from fp16. If you store the AdamW moments in fp16, very small updates round to zero and the optimizer slowly loses its state. The diagnostic is “loss starts going up after a few thousand steps.” The fix is to keep the optimizer state in fp32 (which is the default for torch.cuda.amp).

(7) Inf/NaN in gradient clipping. clip_grad_norm_ computes the L2 norm of all gradients, which can overflow if the gradients are very large. Always unscale gradients before clipping when using GradScaler (scaler.unscale_(optimizer) before clip_grad_norm_).

Most of these fail loudly. The dangerous ones (where the loss looks fine but quality is silently worse) are the reason production training pipelines run periodic full-fp32 control runs as a regression check.

13.11 The mental model

Eight points to take into Chapter 14:

  1. Precision is a throughput lever, not just a memory lever. Halving precision roughly doubles Tensor Core throughput.
  2. fp32 wastes 90%+ of an H100’s silicon. Modern training is mixed precision.
  3. fp16 underflows because its range is too small. Loss scaling fixes it; it’s fragile.
  4. bf16 has the same range as fp32 with less mantissa. No loss scaling needed. Default for modern training.
  5. fp8 doubles throughput again but requires per-tensor scaling and a calibration recipe. Frontier-only for now.
  6. Master weights in fp32 are kept for the optimizer’s accumulation. The bf16 copy is for compute.
  7. autocast is the per-op precision recipe — matmul down, softmax/loss/normalization up.
  8. Mixed-precision bugs come in seven flavors. The dangerous ones are silent quality regressions; control runs in fp32 catch them.

In Chapter 14 we look at one final piece of the pretraining infrastructure: how the tokenizer itself is trained.


Read it yourself

  • Micikevicius et al., Mixed Precision Training (2017) — the original NVIDIA paper introducing fp16 + loss scaling.
  • The Brain Float 16 Wikipedia page — the cleanest one-pager on the format.
  • The NVIDIA Hopper FP8 white paper — read the recipe section for the per-tensor scaling story.
  • The TransformerEngine README and the te.fp8_autocast API docs.
  • The PyTorch torch.cuda.amp documentation — the practical reference for autocast and GradScaler.

Practice

  1. Compute the smallest representable positive number for fp16, bf16, fp8 e4m3, and fp8 e5m2. Compare to the typical magnitudes of LLM gradients (~1e-6 to 1e-3). Which format will underflow which gradients?
  2. Why does loss scaling have to be applied before backward(), not after? Trace the chain rule through one step.
  3. A training script using fp16 + GradScaler is showing dynamic scale = 1.0 after 100 steps. What does this mean and how do you debug it?
  4. Why is the optimizer state kept in fp32 even when the rest of training is in bf16? Construct a one-step example where keeping the state in bf16 silently loses information.
  5. Estimate the H100-hour speedup of training a 70B model in fp8 vs bf16, assuming both achieve similar model FLOPs utilization. What’s the realistic speedup, and what’s the cost?
  6. Write a small PyTorch script that shows fp16 gradient underflow on a 50-layer linear network with a small input. Then add GradScaler and show that loss scaling fixes it.
  7. Stretch: Take a small transformer and train it for 1000 steps in (a) pure fp32, (b) fp16 with autocast and GradScaler, (c) bf16 with autocast. Plot the three loss curves on the same axes. Identify the differences and explain them.

Concept check

4 questions. Click a choice to check. Your score is saved locally.

Score
0 / 4
  1. 1. bf16 is safer than fp16 for training despite having fewer mantissa bits. The reason is
  2. 2. Loss scaling is required for fp16 training because
  3. 3. A training run uses 'master weights' — an fp32 copy of parameters kept alongside bf16 weights. What is the purpose of the fp32 copy?
  4. 4. fp8 training on H100 offers roughly 2× throughput vs bf16 but requires calibration. The key challenge calibration solves is
Related chapters