Applied ML

FSDP / ZeRO

Shard the training state, not just the batch

01 · First principlesWhat actually fills GPU memory

Before fixing a memory problem, do the accounting. Training state per parameter, under the standard recipe of mixed precision with Adam, looks like this:

TensorPrecisionBytes / param
Parameters (working copy)bf16 / fp162
Gradientsbf16 / fp162
Master weightsfp324
Adam first moment (m)fp324
Adam second moment (v)fp324
Total model state16

So a 7B model carries about 112 GB of state, a 70B model about 1.1 TB — before a single activation is stored. Note where the weight is: the fp32 optimizer triplet (master weights, m, v) is 12 of the 16 bytes. The optimizer, not the model, owns most of the memory.

02 · Failure firstDDP replicates all of it

DDP keeps a full copy of those 16 bytes per parameter on every rank. On a 64-GPU cluster that is 64 identical copies of the optimizer state, of the master weights, of everything. The redundancy buys nothing: at any moment, rank k only ever needs the optimizer state for the update it is about to apply, and every rank applies the same update anyway.

The observation behind ZeRO: across N data-parallel ranks, the training state is N-fold redundant. Eliminating that redundancy divides state memory by N without changing the math at all.

03 · The mechanismShard everything, gather just in time

FSDP (PyTorch's implementation of ZeRO stage 3) gives each rank a 1/N slice of every parameter, every gradient, and every optimizer state. Nothing exists in full anywhere — except briefly, one layer at a time:

  1. Forward, per layer: allgather that layer's parameter shards into a full layer, run the forward, free the full parameters immediately.
  2. Backward, per layer: allgather the parameters again, compute gradients, then reduce-scatter the gradients — each rank receives only the (already summed) 1/N slice it owns. Free the rest.
  3. Optimizer step: each rank updates only its own slice, using only its own slice of m, v, and master weights. No communication at all.

The reason this costs so little extra communication is an identity from the primitives note: allreduce is reduce-scatter followed by allgather. DDP was already paying for both halves inside its allreduce; FSDP simply pulls the halves apart and does the optimizer step in between, on sharded data. The gradient synchronization is, in volume terms, nearly free relative to DDP. The genuinely new cost is the parameter allgather in forward and again in backward.

PER-GPU MEMORY, 7B PARAMS, N = 8 RANKS ■ params 2B ■ grads 2B ■ optimizer 12B (fp32 master + m + v) DDP 112 GB ZeRO-1 ~38 GB · shard optimizer ZeRO-2 ~26 GB · + shard gradients ZeRO-3 14 GB · + shard params (FSDP) STATE MEMORY / RANK = 16 BYTES × PARAMS / N (ZeRO-3) · ACTIVATIONS NOT SHOWN

Each ZeRO stage shards one more component. Stage 3 divides all 16 bytes/param by N; activations remain and are handled by checkpointing.

04 · The mapZeRO stages 1 / 2 / 3

StageWhat is shardedState bytes / param / rankExtra comms vs DDP
ZeRO-1Optimizer states4 + 12/N~none
ZeRO-2+ gradients2 + 14/N~none (reduce-scatter replaces allreduce)
ZeRO-3 / FSDP+ parameters16/N+1× parameter allgather per forward and per backward (roughly 1.5× DDP volume)

Stages 1 and 2 are close to free and are almost always worth taking. Stage 3 is the one with a real bill attached.

05 · The costWhat you pay for stage 3

Rule of thumb: FSDP keeps DDP's programming model (same loss, same optimizer semantics) and trades bandwidth for memory. Use it when state memory is the binding constraint and the interconnect is decent; use TP/PP when the interconnect, or a single layer, is the constraint.
Mental Model