A circuit breaker, not a cure
Gradient descent's update is lr·g, and nothing in the algorithm bounds g. The loss landscape of a deep network contains cliffs — regions where curvature explodes and the local gradient is orders of magnitude larger than typical. A rare pathological batch (a corrupted document, a degenerate sequence), a loss cliff, or a transient numerical event can produce a gradient 100× the usual norm.
The asymmetry is what makes this dangerous: a thousand good steps build the model slowly, and one enormous step can throw the weights into a region from which the optimizer never recovers — the loss spikes, then plateaus at a worse level, or goes straight to NaN. Weeks of compute can hinge on the worst single batch in the corpus. Clipping exists because the cost of capping every step is tiny and the cost of not capping one step is occasionally everything.
Global-norm clipping treats the entire gradient — concatenated over all parameters — as one vector and rescales it only if its L2 norm exceeds a threshold c:
Two properties matter. First, below the threshold it is the identity — in a healthy run, clipping should be doing nothing most of the time. Second, above the threshold it preserves the gradient's direction exactly and shortens only its length. The step still points downhill; it is merely no longer allowed to be a leap.
torch.nn.utils.clip_grad_norm_.
Element-wise clipping is not strictly wrong — it bounds the step too — but it distorts geometry in a way that is hard to reason about, and there is rarely a reason to prefer it when norm clipping is one line.
scaler.unscale_(optimizer) first, then clip, then step.A clipped step is a biased gradient estimate — you are no longer optimising the loss exactly, you are optimising it subject to a trust region. When clips are rare this bias is negligible and the insurance is nearly free. When clips are frequent, the run is telling you something, and the threshold is muffling the message:
| If spikes are… | Likely cause | Actual fix |
|---|---|---|
| Tied to specific batches, reproducible | Data | Find and filter the offending examples; dedupe; cap document weirdness |
| Growing in frequency over the run | Learning rate / schedule | Lower peak LR, lengthen warmup; check Adam eps and β₂ at scale |
| Correlated with inf/nan, precision-dependent | Numerics | Audit fp16 range, reductions, attention logits (see precision tricks) |
| Structural — from the architecture's depth | Conditioning | Init, normalisation, residual scaling (see exploding gradients) |