Simulate the batch that does not fit
Training recipes are tuned around an effective batch size — large-model pretraining often wants somewhere in the millions of tokens per step, because the learning rate, schedule, and noise scale were all chosen against it. Activation memory, meanwhile, scales linearly with the batch actually resident on the device (see checkpointing), and the device has a fixed amount of HBM. The batch the recipe wants and the batch that fits routinely differ by an order of magnitude.
The escape hatch is an identity about sums. A gradient over a batch is the mean of per-example gradients, and a mean does not care how you chunk the sum:
So: run k micro-batches forward and backward, letting gradients accumulate in .grad (PyTorch adds by default — for once, the footgun is the feature), and call optimizer.step() once at the end. Memory holds one micro-batch of activations; the optimizer sees the gradient of the full batch. Exactly, not approximately — with one caveat in section 04.
Each micro-batch's loss is already a mean over its own examples, so summing k of their gradients gives k× the full-batch gradient. Skip the division and every step is k times too large — equivalent to silently multiplying the learning rate by k. With k = 8 the run usually diverges quickly enough to notice; with k = 2 it may just train slightly worse forever, which is the more expensive version of the bug.
# correct: scale each micro-loss so the k backwards sum to a mean
for i, micro in enumerate(micro_batches):
loss = compute_loss(model, micro) / k # ← the line people forget
loss.backward() # grads accumulate into .grad
optimizer.step()
optimizer.zero_grad() # and only now
The mirrored bug is calling zero_grad() inside the micro-batch loop, which quietly turns accumulation into "train on the last micro-batch only". Both bugs produce running, converging, wrong training — the worst kind.
DDP hooks backward() and launches a gradient allreduce as buckets fill. Combine that naively with accumulation and every micro-batch triggers a full allreduce — k synchronisations per step where one suffices, since only the accumulated total ever reaches the optimizer. The intermediate sums are computed, shipped across the wire, and never used.
for i, micro in enumerate(micro_batches):
ctx = model.no_sync() if i < k - 1 else nullcontext()
with ctx:
(compute_loss(model, micro) / k).backward() # comms only on the k-th
Wrapping the first k−1 backwards in no_sync() defers communication so the single allreduce on the final micro-batch carries the whole accumulated gradient. The saving is a factor of k in gradient traffic; on a bandwidth-limited cluster that is frequently the difference between scaling and not. FSDP exposes the same context with the same name.
| Quantity | vs true large batch |
|---|---|
| Gradient handed to optimizer | Identical (up to rounding) |
| Optimizer trajectory, LR schedule semantics | Identical |
| Peak activation memory | ÷ k |
| Wall-clock per optimizer step | × k (roughly) |
| BatchNorm statistics | Computed per micro-batch — not identical |
| DDP gradient traffic | ×1 with no_sync |