Exact attention, reorganised around the memory hierarchy
Profile standard attention on a GPU and you find something odd: the arithmetic units are mostly idle. Attention at sequence length N does O(N²d) floating-point operations, and a modern GPU can do those operations far faster than it can fetch the data they need. The kernel is not compute-bound. It is waiting on memory.
So the first-principles question is not "how do we do fewer FLOPs?" — a decade of approximate-attention papers asked that — but "how do we do fewer reads and writes?" Flash Attention computes exactly the same output as standard attention. It changes nothing about the math and everything about where the intermediate values live.
The textbook implementation runs in three passes, and each pass round-trips through HBM (the GPU's large, slow main memory):
At N = 8192 in fp16, the score matrix alone is 8192² × 2 bytes ≈ 128 MB per head per batch element, shuttled across the HBM bus several times (a catastrophe). The fast on-chip SRAM that sits next to the compute units is only ~20 MB per chip, so the matrix cannot simply be kept there.
SRAM is roughly an order of magnitude faster than HBM and three orders of magnitude smaller. Flash Attention is designed around that gap.
The analogy: HBM is the warehouse, SRAM is your desk. The naive kernel assembles the entire product in the warehouse aisle, walking back and forth for every part. The fix is obvious once stated — bring small batches of parts to the desk, finish each batch completely, and only walk the finished result back.
Tiling alone is not enough, because softmax is the one step that seems to need a whole row at once: the denominator sums over all N scores. Flash Attention's enabling trick is the online softmax — compute softmax incrementally over blocks, keeping two running scalars per row: the running max m and the running denominator ℓ.
For each query block, we stream over key/value blocks. When a new block of scores arrives with block-max mnew, the old partial sums were normalised against a stale max, so we rescale them and fold the new block in:
The partial output accumulator O is rescaled by the same factor. After the last block, O/ℓ equals the exact softmax-weighted sum — bit-for-bit the same attention, never having held more than one tile of S at a time.
One query block sweeps left to right over key/value tiles. Each tile is scored, softmaxed, and discarded; only running statistics persist.
| Standard attention | Flash Attention | |
|---|---|---|
| Output | exact | exact (same values) |
| FLOPs | O(N²d) | O(N²d), slightly more (recompute) |
| Extra memory | O(N²) | O(N) |
| HBM traffic | O(N² + Nd) | O(N²d² / M), far smaller in practice |
| Wall-clock | baseline | 2–4× faster, longer N feasible |
The backward pass does not store P either: it recomputes the tiles from Q and K during backprop, using the saved (m, ℓ) statistics. That is the tradeoff in its purest form — Flash Attention deliberately spends extra FLOPs to avoid memory traffic, the exact inverse of the usual instinct, and it wins because FLOPs are the cheap currency on this hardware.
Linformer, Performer and friends attacked O(N²) FLOPs by approximating attention, and paid in quality. Flash Attention showed the quality tax was never necessary: the quadratic compute was affordable all along, once the memory traffic stopped. It is now the default kernel everywhere, and it is what makes the long contexts assumed by causal attention and the KV cache economically real. The lesson generalises: on modern accelerators, count bytes moved before you count operations.