Sharding divides model state by N, but there is a second memory consumer it never sees: activations. Backward needs the inputs of every layer to compute that layer's weight gradients, so the default autograd contract is brutal — everything the forward pass produced stays resident until backward consumes it.
Activation memory scales as layers × batch × sequence length × hidden width. For a transformer it is roughly tens of bytes per token per layer even in bf16; at long sequence lengths or deep stacks it routinely exceeds the 16 bytes/param of model state. Model state is fixed at startup; activations are what actually produce the mid-training OOM when someone doubles the context length.
02 · Failure firstThe naive alternatives both lose
Store everything (default)
Memory grows as O(L) in depth. At L = 100 layers and 8k context, activations alone can run to hundreds of GB per device. Forward is paid once; memory is the casualty.
Store nothing, recompute from input
O(1) memory, but reaching layer k's activations during backward means rerunning layers 1..k. Summed over all layers that is O(L²) compute — a 100-layer model pays roughly 50 extra forwards.
One end of the spectrum is unaffordable in memory, the other in FLOPs. Checkpointing is the observation that the spectrum has a usable middle.
03 · The mechanismCheckpoints and segments
Forward: keep activations only at chosen checkpoint layers; everything between two checkpoints is computed and immediately discarded.
Backward, per segment (last to first): rerun the forward from the segment's checkpoint to repopulate its activations, then run backward through the segment and free them again.
Peak activation memory = the checkpoints + one segment's worth of live activations.
With L layers split into segments of length s, memory is roughly L/s checkpoints plus s live activations. Minimising L/s + s gives s = √L:
memory ∝ L/s + s ⟶ min at s = √L ⟹ O(√L) memory, one extra forward
The compute overhead is exactly one additional forward pass regardless of s, because each layer is forward-computed twice in total. A forward is roughly one third of a training step's FLOPs (backward costs about two forwards), so checkpointing everything costs about 33% more compute — in practice often 20–30% more wall time, since recompute overlaps other work imperfectly.
Only √L checkpoints survive forward. Backward revives one segment at a time; peak memory is checkpoints + one live segment.
04 · PracticeWhat actually gets checkpointed
Nobody solves the optimisation per-model; the working convention is one checkpoint per transformer block (torch.utils.checkpoint.checkpoint around each block, or activation_checkpointing policies in FSDP). Two practical wrinkles deserve attention:
Selective checkpointing beats uniform: recompute is cheap for matmul outputs but attention's softmax intermediates are large and cheap to recompute, so policies that recompute only the expensive-to-store, cheap-to-redo ops (PyTorch's selective AC) recover most memory for well under 33% overhead. FlashAttention is the limiting case — it never materialises the attention matrix at all.
Randomness must replay. Dropout inside a checkpointed segment must produce identical masks on recompute, or backward differentiates a different function than forward ran. Frameworks stash and restore RNG state; custom kernels that do not are a classic source of silent wrongness.
Bigger batches can repay the recompute with interest
Long-context training
Usually yes
Activations scale with sequence length; state does not
Compute-bound and memory comfortable
No
You would pay 33% FLOPs for memory you do not need
Order of operations: checkpointing trades the cheap resource (FLOPs are growing faster than HBM capacity) for the scarce one. It composes with everything — FSDP shards state, checkpointing caps activations, accumulation shrinks the live micro-batch. Large-model recipes typically use all three.
Mental Model
Backward needs forward's intermediates; by default they all stay alive, O(L) in depth and linear in sequence length.
Store-all and store-nothing are both losing endpoints; checkpoints every √L layers give O(√L) memory.
The compute bill is flat: one extra forward, ~33% of a step's FLOPs, however you segment.
Selective recompute (and FlashAttention as its limit) gets most of the memory for less than the full bill.
Take the trade when memory is the binding constraint or when the freed memory buys utilisation; otherwise decline it.