6  Multi-head attention

Where we are. We now know how one attention blends information between words. But a single attention is short-sighted: it can only focus on one type of relationship at a time. This chapter explains why transformers run several attentions in parallel —the heads— and what the model gains from it. It’s a small change to the formula with an enormous consequence for what the model can understand.

6.1 The idea in one sentence

Instead of a single attention, the model runs several at once (the heads), each looking at the sentence with a different criterion, and then combines what each one found.

6.2 Key concepts and their role in the transformer

Before diving into the details, let’s define this chapter’s terms and what each one is for inside a transformer:

  • Head. Definition: a complete, independent attention (the one from Ch. 4) that works on a portion of the vector, with its own tables \(W_i^Q, W_i^K, W_i^V\). In the transformer: it’s one of the parallel attention channels; each head can track a different type of relationship at the same time.
  • d_k (per-head dimension). Definition: the size of the chunk of vector each head receives, d_k = d_model / h. In the transformer: it splits the vector among the h heads; because it’s smaller, it makes many heads cost almost the same as a single large attention.
  • Concatenation (Concat). Definition: gluing the h outputs one after another to reconstruct a vector of the original size. In the transformer: it gathers in one place the notes each head took separately.
  • \(W^O\) (final mix). Definition: a learned table that reprojects the concatenated vector. In the transformer: it combines what each head contributed into a single coherent result to return to the residual stream.
  • Emergent specialization. Definition: each head’s roles (positional, syntactic, rare-word…) aren’t programmed: they arise from training. In the transformer: it explains why heads don’t resemble one another and why they split up the work of reading the sentence.
  • Induction head. Definition: a head that matches prefixes and copies what came after them ([A][B] … [A] → [B]). In the transformer: it’s believed to be the engine behind much of in-context learning —learning from the examples in the prompt itself.
  • Retrieval head. Definition: a head specialized in finding a fact buried in a very long context. In the transformer: a tiny group (<5%) sustains the “needle in a haystack” ability; we return to it in Part II.
  • Head pruning. Definition: removing heads at inference time with little loss. In the transformer: it reveals that the heads are partially redundant and that a specialized few do the heavy lifting.
  • MQA / GQA. Definition: schemes that share keys/values across heads (a single one, or by groups). In the transformer: they cut the KV-cache memory at inference time while keeping almost all the quality.

With these in mind, let’s see why a single attention isn’t enough.

6.3 What it’s for (its role in the transformer)

Why isn’t one attention enough? Because a single one produces a single set of weights: it forces a choice of which type of relationship to highlight. If the model uses that one gaze to track grammatical agreement, it has nothing left to simultaneously track what a pronoun three sentences back refers to. The authors say so themselves: with a single head, “averaging inhibits this” —distinct relationships get blended together and lost.

Its job: to give the model several parallel attention channels, so it can attend to different things in different places at the same time —one head for syntax, another for a distant referent, another for a rare keyword. Without multi-head, the model could only track one kind of relationship per layer.

🧩 Analogy. Picture a panel of specialists reading the same sentence at once: the grammarian marks who agrees with whom, the reference expert tracks what “she” points to, the jargon watcher keeps an eye on the rare technical term. Each takes their own notes in parallel (the heads) and at the end they’re all gathered and summarized into a single report (that’s the concatenation + the final mix).

6.4 The mechanics

The trick is elegant: instead of one attention over the full vector, the vector is split into h chunks (the heads), each of size d_k = d_model / h. Each head does its own attention over its chunk; then the h outputs are concatenated (glued together) and passed through a final learned table, W^O:

\[ \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)\,W^O \]

\[ \text{head}_i = \text{Attention}(Q W_i^Q,\; K W_i^K,\; V W_i^V) \]

What each piece does:

  • \(\text{head}_i\) = a complete, independent attention (the one from Ch. 4), but over portion i of the vector. Each one learns its own tables \(W_i^Q, W_i^K, W_i^V\) → each one “seeks and advertises” different things. (Here \(Q, K, V\) are the input vectors; each head projects them with its own tables, just like the \(W_Q, W_K, W_V\) from Ch. 4.)
  • Concat(…) = gluing the h outputs one after another to reconstruct a vector of the original size.
  • \(W^O\) = a learned table that mixes and reprojects those concatenated notes into a single output vector. What it’s for: combining what each head contributed into something coherent to return to the residual stream.
Tip✓ Verified — and it’s surprising: it costs no more

You’d think h heads cost h times more. They don’t. Since each head works on a smaller chunk (d_k = d_model/h), the total compute is practically the same as a single large attention (literal quote from Vaswani et al. (2017)). That’s why multi-head is almost free: more gazes, same cost.

6.5 Typical numbers

d_k = d_model / h is not a free choice: it comes from splitting the vector among the heads. Some real cases:

Table 6.1: Multi-head configurations
Model d_model # of heads (h) d_k per head
Original Transformer (2017) 512 8 64
GPT-2 small 768 12 64
LLaMA-2-7B 4096 32 128

d_k = 64 is a very common per-head size.

6.6 The fascinating part: heads specialize on their own

Here’s the interesting bit. Nobody programs a head for a role: the roles emerge from training. Analyzing already-trained models reveals heads with surprisingly crisp jobs (Voita et al. 2019): positional heads (always looking at the previous or next token), syntactic heads (tracking grammatical relationships), and rare-word heads (watching the least frequent tokens).

Note🧠 Curiosity — “induction heads”

The most famous role is the induction head (Olsson et al. 2022). It does two things as a team: it matches prefixes (looks back for where the current token appeared before) and it copies (raises the probability of whatever came after it last time). In other words, it implements the pattern [A][B] … [A] → [B]: “last time I saw A, B followed, so I predict B.” The astonishing part: these heads appear all at once within a narrow window of training —a small bump in the loss curve marks that moment— and, right there, the model’s capability takes a leap. They’re believed to be the mechanism behind much of in-context learning —the ability of large models to learn from the examples you give them in the prompt itself, without retraining.

6.7 Do we need all the heads?

Curiously, not all of them. You can prune (remove) many heads at inference time with barely any loss: Voita et al. (2019) removed 38 of 48 heads from a translator while losing only 0.15 BLEU; the last to fall are the specialized ones. The honest reading: the heads are partially redundant, and a specialized few “do the heavy lifting.” (Careful: that they’re prunable afterward doesn’t mean they’re superfluous during training.)

Note🧠 Curiosity — “retrieval heads”

A model’s ability to find a fact buried in a very long context (the “needle in a haystack”) lives in a tiny group of heads: less than 5%. If you knock them out, the model hallucinates and can’t find the needle; if you knock out random heads, almost nothing happens. They’re universal (every long-context model has them) and they’re already in the base model. We return to this in Part II (long context). (Wu et al. 2024)

6.8 A look ahead: GQA and MQA

At inference time, storing the keys and values of all the heads fills up memory (the KV-cache). That’s why modern models share keys/values across heads: MQA (a single shared set) and GQA (groups of heads share) cut that memory while keeping almost all the quality. We’ll see this when discussing efficiency and long context (Parts II and VI).

6.9 In code

import torch
n, d_model, h = 5, 768, 12      # 5 tokens, GPT-2 small
d_k = d_model // h               # 64

x = torch.randn(n, d_model)
# split into h heads: each token, in h chunks of size d_k
heads = x.view(n, h, d_k)        # (5, 12, 64)
print(heads.shape)               # each head will do its attention over its chunk

Each of those 12 chunks goes through the Ch. 4 attention; the 12 outputs are concatenated back to size 768 and mixed with W^O.

Note🧪 Try it — tafagent

In a visualizer like BertViz you can see the attention map of each head separately. You’ll find they don’t resemble one another: some look at the previous token, others jump far away, others fixate on the first token. That diversity is the reason for multi-head.

6.10 Summary

  • A single attention is short-sighted (one type of relationship); multi-head runs several in parallel, each with its own learned criterion.
  • Job: several attention channels at once → the model attends to different things in different places simultaneously.
  • Mechanics: split the vector into h heads (d_k = d_model/h), one attention per head, concatenate, and mix with W^O. It costs almost the same as a single one.
  • The roles emerge (they aren’t designed): positional, syntactic, rare-word, and the induction heads (the engine of in-context learning).
  • Many heads are prunable → partial redundancy; a specialized few do the heavy lifting.

Next (Chapter 6): attention moves information between words; but each word also needs to process what it has gathered. That’s the job of the feed-forward network.

6.11 Exercises

  1. The split. If d_model = 1024 and you want h = 16 heads, what is d_k? And with h = 8?
  2. Cost. Explain in your own words why 12 heads don’t cost 12 times more than a single one. (Hint: how big is each head’s chunk?)
  3. Induction head. Given the sequence ... cat dog ... cat ?, what would an induction head predict, and via which two steps (match + copy)?
  4. Pruning. If you can remove 38 of 48 heads while losing almost nothing, does that mean multi-head is useless? Argue (think training vs inference).

References

Olsson, Catherine et al. 2022. “In-Context Learning and Induction Heads.” Transformer Circuits Thread (Anthropic). https://arxiv.org/abs/2209.11895.
Vaswani, Ashish, Noam Shazeer, Niki Parmar, et al. 2017. “Attention Is All You Need.” Advances in Neural Information Processing Systems (NeurIPS). https://arxiv.org/abs/1706.03762.
Voita, Elena, David Talbot, Fedor Moiseev, Rico Sennrich, and Ivan Titov. 2019. “Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned.” ACL. https://arxiv.org/abs/1905.09418.
Wu, Wenhao, Yizhong Wang, Guangxuan Xiao, Hao Peng, and Yao Fu. 2024. Retrieval Head Mechanistically Explains Long-Context Factuality. https://arxiv.org/abs/2404.15574.