LLMs

Causal Attention

One triangular mask, and why all of GPT falls out of it

01 · First principlesThe factorisation sets the rule

An autoregressive language model commits to one specific factorisation of the probability of a sequence:

p(x1, …, xN) = Πt p(xt | x<t)
each token conditioned only on its predecessors

This is exact (chain rule, no approximation), and it imposes one hard constraint on the network: when computing the distribution for position t, no information from positions ≥ t may leak in. Plain attention violates this constraint maximally — every query attends to every key, including future ones. Train that way and the model learns the world's most useless skill: predicting a token it can already see. Training loss collapses; the generator is garbage, because at generation time the future genuinely does not exist.

02 · The fixAdd −∞ above the diagonal

The constraint is enforced with arithmetic rather than control flow. Before the softmax, add a mask M to the score matrix: Mij = 0 where j ≤ i, and −∞ where j > i. Softmax of −∞ is exactly zero attention weight, so each position's output is a mixture of strictly past (and current) values — future tokens contribute nothing, not even gradient.

A = softmax( (QKT)/√d + M ),    Mij = 0 if j ≤ i,   −∞ if j > i
lower-triangular additive mask
KEYS j → QUERIES i → −∞ −∞ −∞ −∞ −∞ −∞ −∞ −∞ −∞ −∞ −∞ −∞ −∞ −∞ −∞ blue: j ≤ i, allowed −∞: future, zeroed by softmax row i = what token i may see: itself and everything before it

The causal mask for N = 6. Each row is one token's field of view; the staircase is the arrow of time.

# scores: (N, N) = q @ k.T / sqrt(d)
mask = torch.triu(torch.ones(N, N), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
attn = scores.softmax(dim=-1)   # future weights are exactly 0

03 · The payoffN training examples in one forward pass

Here is why this mask, and not some sequential scheme, is the foundation of the GPT era. An RNN must process tokens one after another because its state must exist before it can be updated. The masked transformer has no such dependency: all N positions are computed simultaneously, and the mask guarantees that position t's output behaves as if the future had not been computed. One forward pass over one sequence therefore yields N honest next-token predictions, each conditioned exactly as it would be at generation time — N supervised examples for the price of one parallel matmul (this is teacher forcing, made legal by the mask). Training throughput on parallel hardware is the entire ballgame at scale, and this is the trick that buys it; the Transformer vs RNN vs S4 note ranks the alternatives on exactly this axis.

04 · The consequenceThe KV cache is the mask at inference

Causality has a second gift. Because token t never looks rightward, the keys and values of positions 1…t are final the moment they are computed — no future token will ever revise them. So at generation time we store them (the KV cache) and, for each new token, compute one query against the cached past instead of re-running the whole prefix:

  1. Prefill: run the prompt once in parallel; cache every layer's K and V.
  2. Decode: for each new token, compute q, attend over the cache, append the new k, v.
  3. Cost per token: O(N) attention reads, not O(N²) recompute — but the cache grows linearly, and at long context the cache, not the weights, dominates memory.

This is why the mask is not just a training-time correctness device. Its triangular structure is the precise reason incremental decoding is possible at all — bidirectional models like BERT have no equivalent, since every new token would invalidate every old representation. The memory-bandwidth bill the cache runs up is what Flash Attention, GQA and paged caches exist to pay down, and what speculative decoding works around.

Mental Model