Applied ML

Mixed Precision Training

Compute in 16 bits, accumulate where it is safe

01 · First principlesWhy anyone leaves fp32

Halving the bits halves everything that hurts: half the memory for weights, gradients, and activations; half the bytes across every memory bus and every interconnect. And on modern GPUs the win is more than 2×, because tensor cores run 16-bit matrix multiplies at several times the fp32 rate (an A100 does roughly 19 TFLOPS in fp32 and roughly 312 in 16-bit on tensor cores). For the bandwidth-bound parts of a network — most of it, see profiling — halving the bytes roughly halves the time.

So the question was never whether to use 16 bits. It was which failures show up when you do, and what the patches cost.

02 · Failure firstfp16's narrow range eats gradients

fp16 spends its 16 bits as 1 sign, 5 exponent, 10 mantissa (see floating point). Five exponent bits give a smallest normal number around 6×10−5, with subnormals reaching about 6×10−8. That sounds small until you histogram real gradients: in deep networks a large fraction of gradient values sit below 10−5. Cast to fp16, they flush to zero. The layers furthest from the loss silently stop learning, and the run degrades without ever crashing.

The patch is loss scaling. Multiply the loss by a factor S (say 214) before backward; every gradient is then scaled by S, sliding the whole histogram up into representable range. Unscale before the optimizer step. Because the right S is workload-dependent and drifts, dynamic scaling is standard: grow S every few thousand steps, and on any inf/nan in the gradients, skip the step and halve S.

Cost of the patch: occasional skipped steps, one more failure mode to monitor, and a knob that can interact badly with clipping if you clip before unscaling (always unscale first).

03 · The fix that wonbf16 trades precision for range

bfloat16 makes the opposite trade: 1 sign, 8 exponent, 7 mantissa. Eight exponent bits are exactly fp32's, so bf16 represents the same dynamic range as fp32 — roughly 10−38 to 1038 — and gradients never underflow in practice. No loss scaling, no skipped steps, no knob. The price is precision: 7 mantissa bits mean a relative error around 0.4%, versus fp16's 0.05%.

SIGN EXPONENT (RANGE) MANTISSA (PRECISION) FP32 FP16 range ~6e-5 .. 65504 BF16 fp32's exponent, ~3 decimal digits

bf16 keeps fp32's exponent and pays with mantissa. Training tolerates noise far better than it tolerates zeros.

The empirical fact that decided the matter: SGD-family training is remarkably tolerant of low-precision noise in gradients (it is already a noisy estimator) but intolerant of systematic zeros. bf16 therefore won by default wherever the hardware supports it (Ampere onward, all TPUs), and fp16 with loss scaling survives mainly on older GPUs and in inference.

04 · The second failureTiny updates vanish — master weights

Range was only the first pathology. The second is at the update itself. A weight is around 10−1; a per-step update lr·g is often around 10−6 or smaller. Adding them in 16 bits requires their ratio to fit in the mantissa, and it does not:

w + lr·g  =  w   whenever   lr·g < w · 2−(mantissa bits + 1)
bf16: updates below ~w/256 are lost entirely

Thousands of real updates would round to nothing and training stalls at a mediocre loss. The fix is to keep an fp32 master copy of the weights: the update is applied in fp32, where it always registers, and a 16-bit cast of the master weights is what forward and backward actually use. The cost is 4 extra bytes per parameter — which, together with Adam's fp32 moments, is why the FSDP note counts 16 bytes of state per parameter, not 4.

05 · The mapWhat stays fp32, and why

Tensor / opPrecisionReason
Matmuls, convolutionsbf16 / fp16Tensor cores; tolerant of noise; accumulation inside the matmul is fp32 anyway
Activations stored for backwardbf16 / fp16The bulk of activation memory; noise tolerated
Master weightsfp32Updates smaller than ~w/256 vanish in 16 bits
Optimizer states (m, v)fp32Long-running accumulations of small quantities
LayerNorm / BatchNorm internalsfp32Mean/variance are reductions; cancellation in 16 bits (see precision tricks)
Softmax, loss, large reductionsfp32exp overflows fp16 above ~11; sums over long axes accumulate rounding error

This split is what "mixed" means, and it is what torch.autocast implements: a per-op policy table, not a global cast. The pattern generalises — compute the big dense math in 16 bits, and route anything that accumulates (updates, moments, statistics, reductions) through fp32.

Total cost of mixed precision: +4 bytes/param of master weights, a scaler to babysit if fp16, and a handful of ops pinned to fp32. In exchange: roughly half the memory traffic and several-fold faster matmuls. Few trades in systems are this lopsided.
Mental Model