01 · First principlesWhat actually fills GPU memory
Before fixing a memory problem, do the accounting. Training state per parameter, under the standard recipe of mixed precision with Adam, looks like this:
Tensor
Precision
Bytes / param
Parameters (working copy)
bf16 / fp16
2
Gradients
bf16 / fp16
2
Master weights
fp32
4
Adam first moment (m)
fp32
4
Adam second moment (v)
fp32
4
Total model state
16
So a 7B model carries about 112 GB of state, a 70B model about 1.1 TB — before a single activation is stored. Note where the weight is: the fp32 optimizer triplet (master weights, m, v) is 12 of the 16 bytes. The optimizer, not the model, owns most of the memory.
02 · Failure firstDDP replicates all of it
DDP keeps a full copy of those 16 bytes per parameter on every rank. On a 64-GPU cluster that is 64 identical copies of the optimizer state, of the master weights, of everything. The redundancy buys nothing: at any moment, rank k only ever needs the optimizer state for the update it is about to apply, and every rank applies the same update anyway.
The observation behind ZeRO: across N data-parallel ranks, the training state is N-fold redundant. Eliminating that redundancy divides state memory by N without changing the math at all.
03 · The mechanismShard everything, gather just in time
FSDP (PyTorch's implementation of ZeRO stage 3) gives each rank a 1/N slice of every parameter, every gradient, and every optimizer state. Nothing exists in full anywhere — except briefly, one layer at a time:
Forward, per layer:allgather that layer's parameter shards into a full layer, run the forward, free the full parameters immediately.
Backward, per layer: allgather the parameters again, compute gradients, then reduce-scatter the gradients — each rank receives only the (already summed) 1/N slice it owns. Free the rest.
Optimizer step: each rank updates only its own slice, using only its own slice of m, v, and master weights. No communication at all.
The reason this costs so little extra communication is an identity from the primitives note: allreduce is reduce-scatter followed by allgather. DDP was already paying for both halves inside its allreduce; FSDP simply pulls the halves apart and does the optimizer step in between, on sharded data. The gradient synchronization is, in volume terms, nearly free relative to DDP. The genuinely new cost is the parameter allgather in forward and again in backward.
Each ZeRO stage shards one more component. Stage 3 divides all 16 bytes/param by N; activations remain and are handled by checkpointing.
04 · The mapZeRO stages 1 / 2 / 3
Stage
What is sharded
State bytes / param / rank
Extra comms vs DDP
ZeRO-1
Optimizer states
4 + 12/N
~none
ZeRO-2
+ gradients
2 + 14/N
~none (reduce-scatter replaces allreduce)
ZeRO-3 / FSDP
+ parameters
16/N
+1× parameter allgather per forward and per backward (roughly 1.5× DDP volume)
Stages 1 and 2 are close to free and are almost always worth taking. Stage 3 is the one with a real bill attached.
05 · The costWhat you pay for stage 3
More communication volume, roughly 1.5× DDP per step, and it sits on the critical path of forward as well as backward. FSDP prefetches the next layer's allgather under the current layer's compute, but on slow interconnects the gathers stop hiding.
Latency sensitivity. DDP issues a few large bucket allreduces; FSDP issues an allgather per wrapped block, many small latency-bound collectives. Wrapping granularity (one FSDP unit per transformer block is the usual choice) is a real tuning knob: coarse units gather more than needed at once, fine units multiply latency.
Peak memory is per-layer, not zero. One block's full parameters plus its activations must still fit; a single enormous layer can break stage 3, at which point you need tensor parallelism inside the layer.
Rule of thumb: FSDP keeps DDP's programming model (same loss, same optimizer semantics) and trades bandwidth for memory. Use it when state memory is the binding constraint and the interconnect is decent; use TP/PP when the interconnect, or a single layer, is the constraint.
Mental Model
Training state ≈ 16 bytes/param under Adam mixed precision, and 12 of those bytes belong to the optimizer.
DDP stores N identical copies of that state; ZeRO deletes the redundancy, dividing it by N.
Gradient sync is free to shard because allreduce = reduce-scatter + allgather; ZeRO just stops reassembling.
Stage 3's new cost is gathering parameters just-in-time, every forward and every backward — bandwidth and latency bought memory.
FSDP shards state across time; if one layer's computation is the problem, that is tensor parallelism's job.