Applied ML

Gradient Accumulation

Simulate the batch that does not fit

01 · First principlesThe batch the recipe wants vs the batch that fits

Training recipes are tuned around an effective batch size — large-model pretraining often wants somewhere in the millions of tokens per step, because the learning rate, schedule, and noise scale were all chosen against it. Activation memory, meanwhile, scales linearly with the batch actually resident on the device (see checkpointing), and the device has a fixed amount of HBM. The batch the recipe wants and the batch that fits routinely differ by an order of magnitude.

The escape hatch is an identity about sums. A gradient over a batch is the mean of per-example gradients, and a mean does not care how you chunk the sum:

∇Lbatch  =  (1/k) Σi=1..k ∇Lmicro-batch i

So: run k micro-batches forward and backward, letting gradients accumulate in .grad (PyTorch adds by default — for once, the footgun is the feature), and call optimizer.step() once at the end. Memory holds one micro-batch of activations; the optimizer sees the gradient of the full batch. Exactly, not approximately — with one caveat in section 04.

02 · Failure firstThe classic bug: forgetting to divide by k

Each micro-batch's loss is already a mean over its own examples, so summing k of their gradients gives k× the full-batch gradient. Skip the division and every step is k times too large — equivalent to silently multiplying the learning rate by k. With k = 8 the run usually diverges quickly enough to notice; with k = 2 it may just train slightly worse forever, which is the more expensive version of the bug.

# correct: scale each micro-loss so the k backwards sum to a mean
for i, micro in enumerate(micro_batches):
    loss = compute_loss(model, micro) / k     # ← the line people forget
    loss.backward()                            # grads accumulate into .grad
optimizer.step()
optimizer.zero_grad()                          # and only now

The mirrored bug is calling zero_grad() inside the micro-batch loop, which quietly turns accumulation into "train on the last micro-batch only". Both bugs produce running, converging, wrong training — the worst kind.

03 · The interactionWith DDP: no_sync or pay k allreduces

DDP hooks backward() and launches a gradient allreduce as buckets fill. Combine that naively with accumulation and every micro-batch triggers a full allreduce — k synchronisations per step where one suffices, since only the accumulated total ever reaches the optimizer. The intermediate sums are computed, shipped across the wire, and never used.

for i, micro in enumerate(micro_batches):
    ctx = model.no_sync() if i < k - 1 else nullcontext()
    with ctx:
        (compute_loss(model, micro) / k).backward()   # comms only on the k-th

Wrapping the first k−1 backwards in no_sync() defers communication so the single allreduce on the final micro-batch carries the whole accumulated gradient. The saving is a factor of k in gradient traffic; on a bandwidth-limited cluster that is frequently the difference between scaling and not. FSDP exposes the same context with the same name.

04 · The fine printWhere the equivalence leaks

Distinguish the neighbours: accumulation serialises a big batch on one device over time; data parallelism spreads it over devices in space. They compose multiplicatively: effective batch = micro-batch × k × world size. Keep that product fixed when changing any factor, or you have silently changed the recipe.

05 · Summary tableWhat changes, what does not

Quantityvs true large batch
Gradient handed to optimizerIdentical (up to rounding)
Optimizer trajectory, LR schedule semanticsIdentical
Peak activation memory÷ k
Wall-clock per optimizer step× k (roughly)
BatchNorm statisticsComputed per micro-batch — not identical
DDP gradient traffic×1 with no_sync
Mental Model