Distributed training: DDP, FSDP, ZeRO, tensor/pipeline/sequence parallel
"A 70B model does not fit on one GPU. Everything else is a consequence of that fact"
This is the chapter that earns its keep in interviews. “How would you train a 70B model?” is the most asked ML systems design question after KV cache, and the answer requires knowing the four parallelism dimensions, the ZeRO stages, the FSDP API, and the communication patterns each one demands.
We are going to build the picture from first principles: why distributed training exists (the memory math from Chapter 4 makes it inevitable), what the four parallelism strategies are, how they compose, and which combinations the major frameworks (DeepSpeed, Megatron-LM, FSDP, MLX) actually use. By the end you’ll be able to size a training run on any cluster.
Outline:
- The memory wall — why one GPU is hopeless past ~7B parameters.
- Data parallelism (DDP) — the simplest answer.
- The all-reduce primitive and the communication patterns you need to know.
- ZeRO Stage 1, 2, 3 — sharding the optimizer state, then the gradients, then the parameters.
- FSDP — PyTorch’s blessing of ZeRO-3.
- Tensor parallelism — splitting individual matmuls across GPUs.
- Pipeline parallelism — splitting layers across GPUs.
- Sequence parallelism — splitting the sequence dimension across GPUs.
- 3D and 4D parallelism — composing the strategies.
- The collective communication library: NCCL, NVLink, InfiniBand.
- How the major frameworks pick combinations.
12.1 The memory wall
Recall the memory math from Chapter 4. For a 70B-parameter model trained with AdamW in mixed precision, the memory budget is roughly:
| Item | Per-parameter cost | 70B total |
|---|---|---|
| Model weights (bf16) | 2 bytes | 140 GB |
| Gradients (bf16) | 2 bytes | 140 GB |
| AdamW first moment (fp32) | 4 bytes | 280 GB |
| AdamW second moment (fp32) | 4 bytes | 280 GB |
| Master weight copy (fp32) | 4 bytes | 280 GB |
| Optimizer + grads + master | 14 bytes/param | ~1 TB |
That’s 1 TB of GPU memory before you’ve stored a single activation. Plus activations: depending on batch size, sequence length, depth, and gradient checkpointing strategy, another few hundred GB on top.
An H100 has 80 GB of HBM. A single H100 cannot hold any of the above. You cannot train a 70B model on one GPU. You cannot even hold its optimizer state on one GPU. You need to split the work across many GPUs, and the question is how.
The four answers are the four parallelism dimensions:
- Data parallelism (DP): replicate the model on every GPU, give each GPU a different chunk of the batch.
- Model parallelism along the parameter axis (TP, tensor parallelism): split each weight matrix across GPUs so each GPU holds a slice.
- Model parallelism along the layer axis (PP, pipeline parallelism): put different layers on different GPUs.
- Sequence parallelism (SP): split the sequence dimension across GPUs so each GPU sees only some positions.
The three “model parallelism” axes (TP, PP, SP) reduce per-GPU memory by splitting the model itself. Data parallelism replicates the model and only saves on activations (because each GPU only computes activations for its data chunk). In practice you combine multiple axes — “3D parallelism” (DP × TP × PP) is the canonical setup for huge models.
There is also sharded data parallelism (ZeRO / FSDP), which is in some sense the cleanest of all of these — it’s data parallelism that shards the optimizer/gradient/parameter state across the data-parallel replicas instead of replicating it. We’ll get to it after DDP.
12.2 Distributed Data Parallelism (DDP)
The simplest case. You have K GPUs. You replicate the entire model on every GPU. You take a batch of size K × B and split it: GPU 0 gets the first B examples, GPU 1 gets the next B, and so on. Every GPU computes its own forward and backward independently and ends up with its own gradient tensor.
The catch: the gradients on different GPUs are different (because they came from different data). To do an optimizer step on the full batch, you need to average the gradients across all GPUs. After averaging, every GPU has the same averaged gradient, applies the same optimizer step, and ends up with the same updated weights — so the replicas stay in sync.
The averaging operation is called an all-reduce (next section). It’s the only collective communication required for plain DDP, and it’s bidirectional: every GPU sends its gradient and every GPU receives the average.
import torch.distributed as dist
import torch.nn.parallel as par
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
model = MyModel().cuda(local_rank)
model = par.DistributedDataParallel(model, device_ids=[local_rank])
for batch in data_loader:
loss = model(batch).loss
loss.backward() # gradients are all-reduced inside this call
optimizer.step()
optimizer.zero_grad()
That’s all there is to plain DDP. The DistributedDataParallel wrapper hooks into the autograd graph and triggers an all-reduce of the gradient buckets as the backward pass completes. Every framework (PyTorch DDP, JAX pmap, TF MirroredStrategy) does the same thing under the hood.
DDP is the right answer when the model fits on a single GPU. For models up to ~7B in bf16 (with reasonable batch size and sequence length), DDP plus a sensible optimizer is enough. Above 7B, the per-GPU memory budget breaks and you need something fancier.
12.3 The all-reduce primitive (and friends)
Distributed training is built on collective communication operations. You should know these by heart.
All-reduce. Every rank contributes a tensor; every rank receives the elementwise sum (or average, or other reduction) of all contributed tensors. Used by DDP for gradient averaging.
All-gather. Every rank contributes a tensor of size s; every rank receives a tensor of size K × s containing every other rank’s contribution. Used by ZeRO/FSDP to reconstruct full parameters from sharded ones.
Reduce-scatter. The “inverse” of all-gather. Every rank contributes a tensor of size K × s; every rank receives a slice of size s containing the reduction of one chunk. The combination “reduce-scatter then all-gather” is exactly equivalent to “all-reduce.”
Broadcast. One rank sends a tensor; every other rank receives a copy. Used at initialization to make sure all ranks start with the same weights.
All-to-all. Every rank contributes K chunks; every rank receives one chunk from every other rank. Used by expert parallelism in MoE models, and by some sequence-parallelism schemes.
The cost of each is a function of the cluster topology and the message size. The two relevant numbers:
- Latency — fixed cost per message, in microseconds. Limited by the slowest link in the topology.
- Bandwidth — per-byte cost. Limited by the slowest link’s bandwidth.
For all-reduce on K GPUs with bandwidth B per link, the ring all-reduce algorithm achieves a per-GPU communication cost of 2(K-1)/K × tensor_size / B, which approaches 2 × tensor_size / B for large K. This is bandwidth-optimal — you can’t do better. The factor of 2 is because every byte goes once around the ring twice (once for the reduce-scatter half, once for the all-gather half).
What this means in practice: the communication cost of DDP is dominated by the gradient size, not the GPU count. A 70B model has ~140 GB of gradients. At 50 GB/s of NVLink bandwidth, an all-reduce takes ~5 seconds. A training step that does ~200 ms of compute spending 5 seconds on communication is a disaster — the GPUs are idle 96% of the time. This is the fundamental reason plain DDP doesn’t scale to large models.
12.4 ZeRO — sharding the optimizer state
ZeRO (Zero Redundancy Optimizer, Rajbhandari et al., 2020) is the most influential idea in distributed training of the last five years. The observation: in plain DDP, every GPU stores the same copy of the optimizer state, gradients, and parameters. This is wasteful. If you have 64 GPUs, you’re storing 64 copies of a 1 TB state, or 64 TB total. You only need one copy distributed.
ZeRO progressively shards three things across the data-parallel ranks:
ZeRO Stage 1 — shard the optimizer state
Split the AdamW first and second moments (and the master fp32 copy) into K shards, one per data-parallel rank. Each rank holds only its 1/K slice of the optimizer state.
The forward and backward pass are unchanged — every rank still computes its own forward and backward on its own data chunk, and ends up with its own gradient tensor. The change is in the optimizer step:
- Each rank does reduce-scatter on the gradients: every rank ends up with the reduced gradient for its 1/K shard of the parameters.
- Each rank applies the optimizer update to its 1/K shard of the parameters, using its 1/K shard of the optimizer state.
- Each rank does all-gather on its updated parameters so every rank has the full updated weights for the next forward pass.
The total bytes communicated per step is the same as DDP (reduce-scatter + all-gather is equivalent to all-reduce), but the memory savings are enormous: optimizer state goes from 12 × P bytes (per rank) to 12 × P / K bytes. For K=64, that’s a 64× reduction in optimizer-state memory.
This is the easiest “free” memory reduction in distributed training. ZeRO Stage 1 is on by default in modern training frameworks.
ZeRO Stage 2 — also shard the gradients
Stage 1 still has each rank holding the full gradient tensor (2 × P bytes in bf16). ZeRO Stage 2 shards the gradient too: each rank, after computing its local gradient, immediately reduce-scatters it so it only holds its 1/K shard of the reduced gradient. The full gradient is never simultaneously materialized on any rank.
This requires a slightly more complex backward pass — gradients have to be reduced as soon as they’re computed, in groups, rather than all at once at the end — but the memory savings are real: gradient memory goes from 2 × P per rank to 2 × P / K per rank.
ZeRO Stage 3 — also shard the parameters
The most aggressive version. The parameters themselves are sharded: each rank only holds 1/K of the model weights at any given time. To do a forward pass, the rank needs the full weights — so before each layer’s forward, the rank performs an all-gather to materialize the layer’s weights, runs the forward, and then immediately discards the gathered weights (keeping only its shard).
This is much more communication: one all-gather per layer per forward pass, plus the gradient communication during backward. The communication cost goes from 2 × P / K per step (Stage 1) to roughly ~3 × P per step (Stage 3 — full parameter all-gathers in both forward and backward, plus gradient reduce-scatters). On a fast interconnect (NVLink within a node, NVSwitch across nodes) the extra communication is hidden by overlap with compute. On slower interconnects (cross-datacenter, sub-optimal Ethernet) it kills throughput.
The benefit: per-rank memory drops to ~2 × P / K + activation_memory. For a 70B model with K=64, that’s ~140 GB / 64 ≈ 2 GB of parameter memory plus a similar amount of activation memory. A 70B model can be trained on 64 GPUs with ~10 GB used per GPU, leaving plenty of room for larger batches and longer sequences.
ZeRO Stage 3 is the technique that makes large-model training on commodity clusters possible.
12.5 FSDP — PyTorch’s official ZeRO-3
Fully Sharded Data Parallel (FSDP) is PyTorch’s first-class implementation of ZeRO Stage 3. It shipped in PyTorch 1.11 and is now the default approach for training models that don’t fit on one GPU using PyTorch.
The API is straightforward:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
dist.init_process_group(backend='nccl')
model = MyTransformer().cuda(local_rank)
model = FSDP(
model,
auto_wrap_policy=transformer_auto_wrap_policy,
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, ...),
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3
cpu_offload=CPUOffload(offload_params=False),
)
Three knobs that matter:
auto_wrap_policy: tells FSDP how to chunk the model into “wrap units.” The natural choice for a transformer is “wrap each transformer block individually,” so FSDP treats each block as a unit and all-gathers its parameters before forward, frees them after.mixed_precision: configures bf16/fp16 storage and fp32 computation in critical places. FSDP integrates with mixed precision rather than fighting it.sharding_strategy:FULL_SHARDis ZeRO-3,SHARD_GRAD_OPis ZeRO-2,NO_SHARDis plain DDP. There’s alsoHYBRID_SHARDwhich shards within a node and replicates across nodes (good for clusters where intra-node bandwidth is much higher than inter-node).
FSDP is the modern PyTorch default. Earlier (pre-FSDP) versions of ZeRO were implemented in DeepSpeed, which is still actively developed but is no longer the only option. For most workloads, FSDP is enough.
The thing FSDP doesn’t do alone: tensor parallelism inside a single layer. For very large layers (the FFN of a 70B model has 8192 → 28672 projections — 230M parameters per layer, hard to fit even after sharding) you may need to combine FSDP with tensor parallelism. PyTorch 2.x supports this combination through the device_mesh API.
12.6 Tensor parallelism — splitting matmuls
Tensor parallelism (TP) splits individual matrix multiplications across GPUs. It’s the most aggressive form of model parallelism and the only one that can split a single linear layer that doesn’t fit on one GPU.
The canonical example is a linear layer Y = X W, where W is (d_in, d_out):
Column parallelism. Split W along the output dimension:
W = [W_0, W_1, ..., W_{K-1}] # each shape (d_in, d_out / K)
GPU i holds W_i. Each GPU computes its own slice of the output: Y_i = X W_i, of shape (N, d_out / K). The full output is the concatenation of the slices: Y = concat(Y_0, ..., Y_{K-1}). No communication is needed for the matmul itself — every GPU has the full input X and computes its own slice independently.
Row parallelism. Split W along the input dimension:
W = [W_0, W_1, ..., W_{K-1}]^T # each shape (d_in / K, d_out)
GPU i holds W_i. The input X has to be split too: X_i is the slice of X along the input axis that pairs with W_i. Each GPU computes a partial result: Y_i_partial = X_i W_i. The full output is the sum of the partial results: Y = Σ_i Y_i_partial. This sum requires an all-reduce across the GPUs.
The Megatron-LM paper (Shoeybi et al., 2019) introduced the canonical pattern: column parallelism on the first matmul of an FFN, row parallelism on the second. This way the intermediate activation lives in column-parallel form (no communication needed), and only one all-reduce is needed at the end of the FFN. Same pattern for the attention block: column-parallel on the QKV projection, row-parallel on the output projection. One all-reduce per sub-block — one for attention, one for FFN, two per transformer block.
# Megatron-style transformer block (pseudocode)
def block(x):
# Attention sub-block — TP across H heads
qkv = column_parallel_linear(x) # no comm
q, k, v = split(qkv)
attn_out = attention(q, k, v) # local
attn_proj = row_parallel_linear(attn_out) # all-reduce
x = x + attn_proj
# FFN sub-block — TP across the FFN hidden dim
h = column_parallel_linear(x) # no comm
h = silu(h)
ffn_out = row_parallel_linear(h) # all-reduce
x = x + ffn_out
return x
Two all-reduces per block. For a 32-block model trained on 8-way TP, that’s 64 all-reduces per forward pass plus 64 more per backward. The all-reduce cost dominates if the GPUs aren’t on the same NVLink fabric — TP is the form of parallelism that most demands fast intra-node interconnect.
Practical note: TP is almost always configured at the node level, not across nodes. A typical 8-GPU node has NVLink between all 8 GPUs (~600 GB/s of bandwidth), and TP can use that fast fabric. Cross-node TP over Ethernet or InfiniBand is much slower and rarely worth it.
12.7 Pipeline parallelism — splitting layers
Pipeline parallelism (PP) splits the model along the layer axis: GPUs 0..K-1 each hold a contiguous chunk of layers. For a 32-layer model on 4 GPUs with PP=4, GPU 0 holds layers 0-7, GPU 1 holds layers 8-15, GPU 2 holds layers 16-23, GPU 3 holds layers 24-31.
The forward pass is sequential across pipeline stages: GPU 0 processes the input through its layers and sends the activations to GPU 1, which processes through its layers and sends to GPU 2, and so on. The backward pass goes in reverse: gradients flow from GPU 3 back to GPU 0.
This works, but with one batch in flight at a time, most GPUs are idle most of the time. GPU 1 is doing nothing while GPU 0 processes the first quarter of layers. The “pipeline bubble” — the idle time at the start and end of each batch — kills throughput.
The fix is micro-batching: split the batch into many micro-batches, and feed them through the pipeline in a staggered way so multiple micro-batches are in flight simultaneously. With enough micro-batches, the pipeline stays full and the bubble overhead amortizes to small.
Two scheduling variants:
- GPipe (Huang et al., 2019): all forward passes for all micro-batches happen first, then all backward passes. Simple but has a large activation-memory cost (all forward activations of all micro-batches must be retained until backward).
- 1F1B (PipeDream, then Megatron’s variant): interleave forward and backward passes so each pipeline stage works on different micro-batches simultaneously. Lower peak activation memory.
PP’s advantage over TP is that the communication is much smaller: only the activations between layer boundaries are sent across the network, not full gradient tensors. PP works well across nodes over slower interconnects.
PP’s disadvantage: the bubble. Even with micro-batching, you lose ~10–20% of throughput to idle time.
PP is the right tool when the model is so deep that even with TP and FSDP it doesn’t fit. For 405B+ models, PP is essentially mandatory.
12.8 Sequence parallelism
Sequence parallelism (SP) splits the sequence dimension across GPUs. The motivation: during training of long-context models, the dominant memory cost is the activation tensors (N, S, D), which scale linearly with sequence length. If S is 32k or 128k, those tensors are enormous.
Sequence parallelism is used in conjunction with TP. The trick: while the FFN and attention sub-blocks (when in TP) require all-reduce, the LayerNorm/RMSNorm and the dropout layers in between do not require communication. The activations between sub-blocks live in (N, S, D) form on every GPU, replicated. SP splits this S axis across the same TP ranks: each rank holds only its slice of the sequence.
The all-reduce in TP becomes a slightly different pattern (an all-to-all or a reduce-scatter + all-gather pair) that simultaneously reduces and re-shards along the S axis. The extra communication is small; the activation memory savings are large (a factor of K).
Sequence parallelism is the technique that makes 1M+ context training feasible.
A different and more recent variant: Ring Attention (Liu et al., 2023) does SP in a way that distributes the attention computation itself across GPUs. Each GPU only computes attention for its local sequence chunk, but K and V are passed around the ring so every chunk eventually attends to every other. This is the technique used to train Gemini’s million-token context.
12.9 3D parallelism
The big training runs use all of these together:
Total GPUs = DP × TP × PP
A typical 70B training run on 256 GPUs might use:
- TP = 8 (within each node, fast NVLink)
- PP = 2 (across nodes, slower interconnect)
- DP = 16 (256 / 8 / 2)
So the model is split 8 ways within a node, 2 ways across pipeline stages, and replicated 16 times for data parallelism. Each replica of the (TP × PP) shard sees a different chunk of the global batch.
The configuration is a multi-dimensional optimization problem: bigger TP reduces model size per GPU but increases communication; bigger PP reduces memory per GPU and reduces communication but creates pipeline bubbles; bigger DP reduces wall-clock time but does nothing for memory; FSDP/ZeRO can be layered on top of all of the above to further shard the optimizer state along the DP dimension.
The skill is knowing which knob to turn first when you run out of memory:
| Symptom | Reach for |
|---|---|
| Optimizer state OOM | ZeRO-1 / FSDP |
| Gradient OOM | ZeRO-2 / FSDP |
| Parameter OOM | ZeRO-3 / FSDP, then TP |
| Activation OOM (long sequences) | gradient checkpointing, then SP |
| Activation OOM (very deep model) | PP |
| All of the above | the kitchen sink (3D parallelism + FSDP + activation checkpointing) |
The frontier-scale training runs use the kitchen sink. Llama 3.1 405B used TP=8, PP=16, FSDP, and gradient checkpointing — a 4D-parallelism setup, with sequence parallelism on top.
12.10 The collective communication library
The actual all-reduces and all-gathers are implemented by NCCL (NVIDIA Collective Communications Library), which sits below your training framework and below CUDA. NCCL knows about the cluster topology and uses the fastest available paths:
- NVLink within a node (~600 GB/s for H100, ~900 GB/s for H200, ~1.8 TB/s for B200).
- NVSwitch for full-fabric NVLink connectivity across all 8 GPUs in a node.
- InfiniBand between nodes (~200 GB/s on a modern cluster, 400 GB/s on a top-tier one).
- Ethernet as a fallback (10–100 GB/s, much slower; not used in serious training clusters).
The NCCL ring all-reduce uses these in priority order. A well-tuned cluster has all GPUs in NVLink within a node and all nodes in InfiniBand across nodes, and the all-reduce time is dominated by the inter-node InfiniBand step.
The non-obvious operational fact: NCCL hangs are the most common training failure mode. A bad GPU, a bad cable, a bad switch, or a bad driver can cause an all-reduce to hang silently — every GPU waits for a partner that never responds. Modern training scripts include NCCL timeout handlers and watchdog threads to catch these. The Llama 3 paper had a section on this; it’s a real problem at scale.
12.11 How the major frameworks pick combinations
Quick reference for what each framework does well:
- PyTorch FSDP — ZeRO-3 + 2D parallelism (TP via
device_mesh). Modern default for most PyTorch training. Good for 7B–70B. - DeepSpeed — ZeRO-1/2/3, with extensive features (offload, NVMe offload, ZeRO-Infinity, Megatron integration). The original ZeRO implementation. Still widely used, especially for very large models.
- Megatron-LM — TP and PP done right, with excellent kernels. The frontier framework for the biggest models. Used in the Llama 3 405B run, the Falcon 180B run, and most >100B-parameter open releases.
- JAX with
pjit/shard_map— declarative sharding viadevice_mesh. Used by Google internally and by some open projects. Cleaner API than PyTorch’s, fewer features. - MosaicML / Composer — wraps PyTorch FSDP with operational tooling. Good for “managed” training runs.
- Lightning Fabric — wraps PyTorch FSDP and DeepSpeed with a unified API. Good for portability.
The choice depends on model size and operational maturity. For 7B models, plain FSDP is enough. For 70B, FSDP + TP via Megatron-style integration. For 405B, Megatron-LM with full 3D + sequence parallelism is the only path.
12.12 The mental model
Eight points to take into Chapter 13:
- A 70B model needs ~1 TB of GPU memory in mixed precision. This is the entire reason distributed training exists.
- Four parallelism axes: DP (replicate, shard data), TP (split matmuls), PP (split layers), SP (split sequence).
- DDP is the simple case. All-reduce gradients, every GPU has the same model.
- ZeRO shards the optimizer/gradient/parameter state across data-parallel ranks. Stage 1, 2, 3 are progressively more aggressive.
- FSDP is PyTorch’s first-class ZeRO-3. It’s the modern default for medium-large model training.
- Tensor parallelism uses column-then-row matmul splits. One all-reduce per sub-block. Demands fast intra-node interconnect.
- Pipeline parallelism splits layers and uses micro-batching to hide the bubble. Lower communication, works across nodes.
- 3D parallelism = DP × TP × PP, often with FSDP and gradient checkpointing on top. This is what frontier runs actually use.
In Chapter 13 we look at the precision side of the same problem: how mixed-precision training enables all of the above.
Read it yourself
- Rajbhandari et al., ZeRO: Memory Optimizations Toward Training Trillion Parameter Models (2020). The ZeRO paper. Read sections 4 and 5.
- Shoeybi et al., Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism (2019). The TP paper.
- Huang et al., GPipe (2018) and Narayanan et al., PipeDream (2019) for pipeline parallelism schedules.
- Liu et al., Ring Attention with Blockwise Transformers for Near-Infinite Context (2023). The sequence-parallelism story for 1M+ context.
- The PyTorch FSDP tutorial and the
torch.distributeddocumentation. Read the FSDP API page in full. - The Megatron-LM repository on GitHub. The README and the
megatron/core/parallel_state.pyfile together explain the production parallelism setup better than any paper.
Practice
- Compute the per-GPU memory of a 13B-parameter dense model with AdamW in mixed precision (a) on a single GPU, (b) on 8-way ZeRO-3, (c) on 8-way ZeRO-3 plus TP=2.
- Why does plain DDP not scale to a 70B model? Compute the all-reduce time per step for the gradients alone, on a 200 GB/s interconnect.
- The Megatron-style FFN uses column parallelism on the first matmul and row parallelism on the second. Sketch why this requires only one all-reduce per FFN, not two.
- A pipeline-parallel run with PP=4 and 32 micro-batches per global batch has what fraction of pipeline-bubble overhead? (Use the GPipe formula:
bubble fraction ≈ (PP - 1) / (PP + microbatches - 1).) - Why does ZeRO-3 demand more communication than DDP? Where does the extra all-gather come from, and how is it hidden behind compute?
- Read the FSDP
device_meshAPI. Configure (in pseudocode) a 64-GPU run with TP=4 and FSDP=16, on a cluster with 8 H100s per node. - Stretch: Take a small open transformer (e.g., GPT-2 small), wrap it with FSDP, run training on a 4-GPU box, and compare wall-clock per step vs single-GPU training. Where does the speedup go vs the comm cost?
Concept check
4 questions. Click a choice to check. Your score is saved locally.
- 1. ZeRO Stage 3 shards model parameters across all data-parallel ranks in addition to optimizer state and gradients. The main cost of this compared to Stage 1 or 2 is
- 2. Tensor parallelism splits individual matrix multiplications across GPUs. The communication primitive required between the split operations is
- 3. Pipeline parallelism assigns contiguous groups of transformer layers to different GPUs. The main efficiency problem it introduces is
- 4. For a 70B AdamW model in bf16 weights with fp32 optimizer state, what is the minimum per-GPU memory needed to hold just the optimizer state when training with ZeRO Stage 1 across 64 GPUs?