35 Efficient attention
Where we are. This opens Part VI (efficiency and deployment). In Ch. 18 we drew the map of attention variants (FlashAttention, linear, GQA, Mamba… and what won). Here we don’t rehash the catalog: we lift the hood and explain the mechanics —what makes attention expensive, why the real bottleneck is moving data rather than computing it, and how, exactly, FlashAttention and linear attention solve it.
35.1 The idea in one sentence
Attention costs O(n²) and, on a modern GPU, its bottleneck isn’t the operations but moving the n×n matrix between memories; efficient techniques win mostly by cutting that traffic (FlashAttention) or avoiding building the matrix at all (linear attention).
35.2 Key concepts and their role in the transformer
Before we dig in, let’s define this chapter’s terms and what each one is for inside a transformer. It’s the “concept map” so you don’t get lost:
- O(n²) cost (“quadratic”). Definition: a way of growing in which doubling the input quadruples the cost. In the transformer: it’s the cost of attention, which compares every token with every other; it sets the context limit —why models “struggle” with very long sequences.
- Arithmetic intensity. Definition: how many operations you do per byte you bring from memory. In the transformer: attention has low intensity (it moves a lot of data for little compute), and that’s why its bottleneck is memory, not compute.
- Memory-bound vs compute-bound. Definition: whether the time is set by moving data or by doing operations. In the transformer: attention is memory-bound → optimizing it means moving less, not computing faster.
- HBM and SRAM. Definition: a GPU’s two memory tiers —HBM, large and slow, and SRAM, tiny and ~10× faster. In the transformer: the model lives in HBM and is computed in SRAM; the real cost is the traffic between the two.
- FlashAttention. Definition: an algorithm that computes exact attention without ever writing the n×n matrix to slow memory. In the transformer: it’s today’s default attention kernel in training and inference —same quality, much faster.
- Online softmax. Definition: computing the softmax block by block, with a running maximum and a running sum, without needing the whole row. In the transformer: it’s what makes it possible to tile attention without altering the result.
- Linear attention. Definition: replacing the softmax with a kernel that lets you reorder the computation and bring the cost down to O(n). In the transformer: it swaps attention for a fixed-size state (RNN-like) → cheap at very long context, at the cost of recall.
- KV cache. Definition: the memory where already-computed keys and values are stored so they aren’t recomputed during generation (Ch. 20). In the transformer: it’s the memory cost of inference; many techniques (GQA, PagedAttention) attack exactly this.
With these concepts in hand, let’s unpack them one by one.
35.3 Anatomy of the cost: where the O(n²) comes from
Before “fixing” attention it helps to understand exactly what’s expensive. Recall (Ch. 4) that attention does two chained matrix products:
- The scores \(QK^\top\): each of the n tokens is compared with the other n, using vectors of dimension d. The result is an n×n matrix.
- The mixing \(AV\): that n×n weight matrix multiplies the n values of dimension d, giving the output.
Each product costs on the order of n²·d operations, so the compute is O(n²·d). But there’s a second cost, and it’s the one that really blows up: to apply the softmax you have to have written the n×n matrix of scores. That’s O(n²) in memory, and —this is the key point— it doesn’t depend on d: even with small heads, the matrix grows with the square of the length.
The notation O(·) describes how a cost grows with the size of the input, ignoring constants. O(n²) (“quadratic”) means that doubling the input quadruples the cost (2² = 4). That’s why, in attention, going from 2,000 to 4,000 tokens doesn’t double the work: it quadruples it —and multiplies the matrix you have to store by 4.
The nuance almost no one explains: the rest of the transformer is only linear in n. The Q, K, V projections and the FFN (Ch. 6) process each token separately, so their cost grows proportionally to n (not n²). The practical consequence is clear-cut:
- With short sequences, the linear term (the FFN) dominates the time; attention is barely noticeable.
- The quadratic attention term only overtakes the rest when n gets large relative to d.
That’s why “efficient attention” is a long-context problem, not a short-sequence one. When you read that a model “struggles” with 100,000 tokens, the culprit is this O(n²).
35.4 The key that changes everything: attention is memory-bound
Here’s the chapter’s conceptual twist. You’d assume that, with O(n²·d) operations, the limit is the GPU’s compute speed. It isn’t. The limit is moving the n×n matrix between the different kinds of memory. The key distinction is between being compute-bound (the time is set by the operations, the GPU is maxed out computing) and being memory-bound (the time is set by data shuffling, the GPU sits idle waiting for it). Which one? It’s decided by the arithmetic intensity —operations per byte moved: attention does few operations per byte of the huge n×n matrix, so it’s memory-bound.
And a GPU’s “memory” isn’t a single thing; it’s a hierarchy with a brutal trade-off between size and speed (figures from the FlashAttention paper, A100 GPU):
| Memory | Size | Speed | Role |
|---|---|---|---|
| HBM (high-bandwidth) | ~40-80 GB | ~1.5-2.0 TB/s | large but “slow”; the model lives here |
| SRAM (on-chip) | ~20 MB | ~19 TB/s (~10× more) | tiny but blazing fast; where computing happens |
The problem with standard attention is that it treats the n×n matrix like any other data: it writes the whole thing to HBM (slow), rereads it for the softmax, writes the probabilities, and rereads them again for the \(AV\) product. These are repeated round trips of an O(n²) object to slow memory, and that —not the compute— is what eats the clock. From this comes the premise that changed everything: if you cut the HBM traffic, you cut the time —even if you do more operations.
🧩 Analogy — the chef waiting on ingredients. Picture a lightning-fast cook (the compute units) who stands idle most of the time because every ingredient has to be fetched from a distant warehouse (HBM) instead of grabbed from the counter right next door (SRAM). The bottleneck is the fetching, not the chopping. Speed up the fetching —bring fewer times, in batches— and the food comes out sooner, even if the cook “works” a little more.
35.5 FlashAttention: computing the same thing while moving far less
FlashAttention (Dao et al. 2022) computes exact attention (bit-for-bit identical to the normal one) but without ever writing the full n×n matrix to slow memory: it doesn’t change what is computed, only how the data moves. It rests on three ideas worth understanding one at a time.
1. Tiling (breaking into blocks). Instead of operating on whole matrices, it splits Q, K and V into small blocks that fit in SRAM. It loads a block of Q, streams the blocks of K and V through SRAM, computes that chunk’s attention on-chip, and accumulates the output. The payoff? The full n×n matrix is never formed in HBM: all that lives there are the inputs/outputs (of size O(n·d)) and a few statistics. The expensive “fetching” disappears.
2. Online softmax (the crucial trick). Here a problem arises: the softmax of a row needs, in principle, the whole row —its maximum value (to avoid overflow when exponentiating) and the sum of all the exponentials (to normalize). But if we tile, we never have the whole row at once. The solution (Milakov and Gimelshein 2018) is to compute it in a single pass over blocks, keeping two numbers that update on the fly:
- a running maximum \(m\) (the largest value seen so far), and
- a running sum \(\ell\) (the accumulated sum of exponentials).
When a new block arrives with a value larger than anything seen, what’s already accumulated is rescaled by a correction factor \(\exp(m_{\text{old}} - m_{\text{new}})\) before adding the new contribution. The intuition: every time you discover a bigger number, you retroactively shrink the weight you’d given the earlier ones, so that the final normalization comes out exact, as if you’d had the whole row from the start.
🧩 Analogy — adding up page by page. It’s like summing a huge column of figures one page at a time, jotting on a sticky note only a running total and a running maximum, instead of needing the whole column spread out on the table. If a page brings a number bigger than all the previous ones, you adjust the total in proportion (the rescaling correction). At the end, the total is the same as if you’d had it all in view.
3. Recomputation in the backward pass. Training needs the backward pass (backward, Ch. 11), which would normally need the n×n matrix again. Instead of storing it (hugely expensive in memory), FlashAttention keeps only the softmax statistics (\(m\) and \(\ell\) per row) and, when it needs them, recomputes the attention block on-chip. It’s the same trade-off as gradient checkpointing (Ch. 25): pay a bit more compute to save a lot of memory.
🧩 Analogy — redoing the draft. Instead of keeping all the intermediate spreadsheets just in case, you throw them out and redo them from two jotted numbers when you actually need them. Recomputing is cheaper than storing.
FlashAttention gives exact attention (not approximate), with O(n) memory instead of O(n²), and is faster by making fewer trips to HBM: +15% on BERT-large, 3× on GPT-2 (1K tokens), 2.4× on Long-Range Arena —and it made it possible to train contexts that were previously out of reach. FlashAttention-2 (Dao 2023) adds ~2× (better work partitioning and fewer non-matmul operations); FlashAttention-3 (Shah et al. 2024) squeezes out the asynchrony and FP8 precision of Hopper GPUs (~1.5-2× over FA2).
35.6 Linear attention: not moving the matrix faster, but not building it
FlashAttention attacks how the matrix moves. The other big family attacks something more radical: never forming the n×n matrix —dropping the cost to O(n) by replacing the softmax with a function that lets you reorder the multiplications. To see how, we first have to see why normal attention is forced to form that matrix.
The softmax computes \(\exp(q_i\cdot k_j)\) for every pair (i, j) → it forces you to build the n×n table before you can multiply by V. The idea of linear attention is to replace that \(\exp(q_i\cdot k_j)\) with a kernel that factorizes as \(\varphi(q_i)\cdot \varphi(k_j)\) —where \(\varphi\) (“phi”) is a feature map applied to each vector. If similarity factorizes that way, the associativity of matrix multiplication lets you reorder the computation:
\[ \big(\varphi(Q)\,\varphi(K)^\top\big)\,V \;=\; \varphi(Q)\,\big(\varphi(K)^\top V\big) \]
Let’s go through the why term by term:
- On the left, you first form \(\varphi(Q)\varphi(K)^\top\), which is the n×n matrix —and you’re back to O(n²).
- On the right, you first compute \(\varphi(K)^\top V\): the product of a d×n matrix by an n×d one gives a small d×d matrix, at cost O(n·d²). Then you multiply it by \(\varphi(Q)\), again O(n·d²).
- Both sides give the same result (it’s the same computation reordered), but the one on the right never builds the n×n object. That reassociation is the whole trick, and it makes the cost linear in n.
🧩 Analogy — changing the order of multiplication. It’s choosing between (A·B)·C and A·(B·C): the result is identical, but one order builds a giant intermediate table and the other only small tables. Linear attention always picks the second order.
This has a lovely reading: in causal (autoregressive) mode, \(\varphi(K)^\top V\) can be kept as a fixed-size state \(S=\sum_j \varphi(k_j)v_j^\top\) that updates token by token —that is, linear attention behaves like an RNN (a recurrent network) with a constant state, which gives O(1) memory per step when generating (Linear Transformers (Katharopoulos et al. 2020), up to 4000× faster on very long sequences). Another branch, Performer (Choromanski et al. 2021), doesn’t swap the softmax for another kernel but instead approximates the softmax itself with random features.
Here’s the hidden cost: a fixed-size state has to compress the entire past into d×d numbers, whereas full attention keeps every token in the cache. That’s why linear attention suffers a measurable quality gap on recall tasks (copying an exact piece of data, finding the needle in the haystack, in-context retrieval). It’s exactly the verdict of Ch. 18: it did not dethrone full attention. It makes sense only when the O(n²) is genuinely infeasible and the task tolerates that trade-off.
35.7 Other savings, in a quick mechanical sketch
- O(1) memory (Rabe and Staats 2021): the precursor to FlashAttention. It uses the same online softmax idea, but presents it as a memory result (it showed that attention doesn’t need O(n²) memory; 59× less at 16K tokens) rather than a speed one. FlashAttention added the HBM-traffic awareness that turned it into acceleration.
- PagedAttention (Kwon et al. 2023): attacks the KV cache waste (Ch. 20) by storing it in non-contiguous blocks mapped like an operating system’s virtual memory → almost zero waste and sharing across requests (2-4× throughput). We’ll see it when we get to serving (Ch. 36).
- MQA/GQA (Ch. 18): reduce the number of KV heads, cutting memory and bandwidth at decode time. It’s a bandwidth lever, not a compute one.
35.8 When to use what: the roofline
A mental model for reasoning about performance: it draws a “roof” formed by two limits —the compute one (the GPU’s FLOPs/s) and the memory one (bandwidth). A computation “hits” one roof or the other. Knowing which one you hit tells you which optimization will help: if you’re against the memory roof, speeding up the compute won’t help; you have to move less data.
With that, each method attacks a different axis:
- FlashAttention: always use it. It’s exact and works in training and inference, at any length; it only changes the IO pattern. It’s the default kernel, and it wins more the larger n is. It attacks the constant (the traffic) of the quadratic, not its exponent.
- Linear attention: only at very long context, where the O(n²) is fatal and the task can withstand the recall trade-off. It’s the only thing that changes the complexity class (the exponent).
- KV reductions (GQA/MLA, PagedAttention): inference memory and bandwidth. They don’t touch the training FLOPs; they make decoding fit and run fast (the decode step is bound by the cache’s bandwidth).
In one sentence: FlashAttention attacks the constant (IO); linear attention, the exponent; the KV methods, the inference cache. Three different axes of the same problem.
35.9 Bridge to our theme (attention across distance)
Efficient attention is the engineering answer to the same cost we studied in physics —that of attending across distance. And there’s an honest distinction worth making: FlashAttention is content-agnostic, i.e. it computes all the interactions equally, just moving the bytes optimally. Our D_f window derived from γ (Ch. 20) is the content-aware complement: it says which distant KV entries have negligible attention mass and can be dropped —it changes what you compute, not just how you move it. They are orthogonal and composable: FlashAttention makes the window you keep cheaper; D_f decides how wide that window should be. (Honest: it’s a conceptual bridge, not a claim that D_f is a published efficient-attention method.)
tafagent computes the KV budget from γ (Ch. 20): how much cache you really need at the target length. Combine it with this chapter’s logic: FlashAttention makes exact attention cheap, and γ/D_f tells you how much distant context contributes little and you could compress —the “how much” that the exact kernel doesn’t decide for you.
35.10 Summary
- Cost: attention is O(n²·d) in compute and O(n²) in memory (the n×n matrix); the rest of the model is linear in n → attention only rules at long context.
- Memory-bound: the real limit is moving the n×n matrix between HBM (large, slow) and SRAM (tiny, ~10× fast) —low arithmetic intensity— not the FLOPs.
- FlashAttention (Dao et al. 2022): tiling + online softmax (running max/sum + rescaling) + recomputation → exact attention, O(n) memory, faster (FA2/FA3).
- Linear attention: drop the softmax + reassociate \(\varphi(Q)(\varphi(K)^\top V)\) → O(n); forms a fixed-state RNN. Honest: quality gap on recall (Ch. 18).
- Others: O(1) memory (precursor), PagedAttention (paged KV cache → Ch. 36), GQA/MLA (bandwidth).
- Roofline: FA attacks the constant; linear, the exponent; KV, the cache.
- Bridge: FA is content-agnostic; our D_f (γ) is the content-aware complement.
Next (Chapter 35): another route to efficiency —making the model smaller without losing (much) quality: quantization, distillation and pruning.
35.11 Exercises
- Two costs. Distinguish the compute cost and the memory cost of attention. Which one “blows up” and why doesn’t it depend on d?
- Memory-bound. Define “memory-bound” versus “compute-bound”. Why does attention fall into the first case? What role do HBM and SRAM play?
- Online softmax. Why does the normal softmax need the whole row, and how does the block-wise computation with a running max and sum avoid it?
- Exact vs approximate. Why is FlashAttention not an approximation, unlike linear attention?
- Reassociation. In \(\varphi(Q)(\varphi(K)^\top V)\), what is computed first and why does that make the cost linear in n?
- Roofline. Match each method (FlashAttention / linear / GQA) with what it attacks: the IO constant, the exponent, or the inference cache.