12  Training a transformer

Where we are. We’ve assembled the entire model (Ch. 1–10), but it’s born empty: its millions of weights start out random and it knows nothing. This chapter explains how it learns —the objective, how the weights are adjusted, and why making it bigger with more data works in a surprisingly predictable way—. No heavy formulas: just the intuition for what happens when a model “studies”.

12.1 The idea in one sentence

Training is playing “fill in the next word” billions of times, nudging each weight a tiny bit after every failed attempt.

🧩 Analogy. Imagine someone reading all of the internet as a game: cover the next word, guess it, compare it with the real one, and adjust your habits a smidgen. Repeated trillions of times, those accumulated minuscule adjustments turn into grammar, knowledge and reasoning.

12.2 Key concepts and their role in the transformer

Before we dig in, let’s define this chapter’s terms and what each one is for inside a transformer:

  • Next-token prediction (self-supervision). Definition: the pretraining objective; the text itself is its label. In the transformer: what it learns to do, with no need for human annotations.
  • Cross-entropy. Definition: the loss −log p(correct token). In the transformer: it measures “how wrong it was” on each prediction; training is lowering it.
  • Perplexity. Definition: e^loss, the readable version of the loss. In the transformer: “among how many tokens it’s effectively hesitating”; less is better.
  • Gradient (backpropagation). Definition: for each weight, the direction to move it in that lowers the loss. In the transformer: the compass of learning; it’s computed backward through all the layers.
  • Optimizer (Adam/AdamW). Definition: decides how much to move each weight, with an adaptive step per parameter. In the transformer: crucial because the scales vary enormously; AdamW adds weight decay to generalize better.
  • Warmup + decay. Definition: start the learning rate low, raise it and then lower it. In the transformer: it stabilizes the first steps, when the network is fragile and Adam’s statistics aren’t reliable yet.
  • Scaling laws. Definition: the loss drops as a predictable power of size, data and compute. In the transformer: they let you predict how good a model will be before training it.
  • Chinchilla (compute-optimal). Definition: at equal compute, model and data should grow together (~20 tokens/parameter). In the transformer: it corrects “bigger = better”; the earlier giants were undertrained.

With this in hand, the rest of the chapter is seeing how they fit into a single training loop.

12.3 The objective: predict the next token

What exactly does it learn from? From predicting the next token. We give it a text, the model predicts what comes after, and we measure how wrong it was. The beauty: it’s self-supervised —the text itself is its label—. Nobody needs to annotate anything: the “correct answer” is, quite simply, the word that actually came next.

(Comprehension models like BERT play a different variant: covering random words and reconstructing them —masked language modeling—.)

12.4 Measuring the error: cross-entropy and perplexity

How do we measure “how wrong it was”? With cross-entropy:

\[ \text{loss} = -\log(\,p_{\text{model}}(\text{correct token})\,) \]

What it says: we take the probability the model assigned to the correct token and apply −log. If it assigned high probability (it guessed confidently), the loss is low; if it assigned little, the loss is high. Training = lowering that loss.

A more intuitive way to read it is perplexity = \(e^{\text{loss}}\): “among how many tokens the model is effectively hesitating”. Perplexity 10 ≈ hesitating as if choosing among ~10 words; perplexity 2 ≈ it almost has it. Less is better.

12.5 Adjusting the weights: gradient and Adam

Once the error is known, how is it corrected? With two ideas:

  • The gradient (via backpropagation): for each weight, an arrow that says “move it a little in this direction and the loss goes down”. It’s the compass of learning.
  • The optimizer (Adam/AdamW): decides how much to move each weight. Adam gives each parameter its own step size, adapted to how its gradient has been behaving —crucial in transformers, where the scales vary enormously—. AdamW adds weight decay (shrinking the weights a bit so the model generalizes better) and is the standard today.
import torch.nn.functional as F
for lote in datos:                                  # batches of tokens
    logits = modelo(lote[:, :-1])                   # predict the next token
    loss = F.cross_entropy(logits.flatten(0, 1),
                           lote[:, 1:].flatten())    # how wrong it was
    loss.backward()                                 # gradients (backpropagation)
    optim.step(); optim.zero_grad()                 # Adam adjusts the weights

12.6 A detail that matters: warmup

The learning rate isn’t constant: it’s started low, raised gradually (warmup) over the first few thousand steps, and then lowered (cosine decay) toward the end. What for? At the start the network is unstable and Adam’s statistics aren’t reliable yet; a large step could break it. Warmup lets it warm up. (It matters mostly with Post-LN, Ch. 7.)

Two more practical notes: training is done in bf16 mixed precision (16 bits, fast and with good range, more stable than fp16) and gradient clipping is used (clipping enormous gradients) so nothing blows up.

12.7 The big picture: scaling laws

Here’s one of the most influential findings of the last decade. The loss doesn’t drop chaotically as you make the model bigger: it drops following a predictable power law with model size, data and compute —across more than 7 orders of magnitude (Kaplan et al. 2020)—. In other words: you can predict how good a model will be before training it.

Note🧠 Curiosity — bigger isn’t always better (Chinchilla)

For years “bigger model = better” was assumed. In 2022, Chinchilla (Hoffmann et al. 2022) corrected it: at equal compute, the model’s size and the amount of data should grow together —about ~20 tokens per parameter—. It turned out that the giants of the time (GPT-3, Gopher) were undertrained: too many parameters, too little data. Chinchilla-70B, trained on 1.4 trillion tokens, beat Gopher-280B (4× smaller!). Moral: sometimes the smartest model is a smaller but better-fed one.

12.8 What learning looks like

The loss curve falls fast at the start (the model quickly picks up frequencies and grammar) and then slowly (it keeps refining subtler structure: references, facts, reasoning). There’s a curious phenomenon, grokking —a late, sudden generalization after a phase of apparent memorization—, which we’ll see in Part III (and which connects to our own work).

12.9 Summary

  • Training = predicting the next token trillions of times and adjusting the weights a little after each error. It’s self-supervised (the text is its label).
  • The error is measured with cross-entropy (−log p(correct)); perplexity (eᵉˡᵒˢˢ) is its readable version (“among how many tokens it hesitates”).
  • The gradient says which way to move each weight; Adam/AdamW decides how much, with an adaptive step + weight decay.
  • Warmup + decay of the learning rate stabilize training; bf16 and gradient clipping make it viable and safe.
  • Scaling laws: the loss improves predictably with size/data/compute (Kaplan et al. 2020); Chinchilla showed you must scale data and parameters together (~20 tokens/parameter) — earlier models were undertrained.

Next (Chapter 12): we now have a trained model. How does it generate text from it? Decoding, sampling, temperature and the KV-cache.

12.10 Exercises

  1. Self-supervision. Explain why training with “predict the next token” needs no human labels. Where does the “correct answer” come from?
  2. Perplexity. If a model’s perplexity is 1, what does that mean? And if it equals the vocabulary size?
  3. Chinchilla. You have a fixed compute budget and a huge model that performs only so-so. According to Chinchilla, what might you be doing wrong?
  4. Warmup. Why can starting with a high learning rate from step 1 “break” training?

References

Hoffmann, Jordan et al. 2022. Training Compute-Optimal Large Language Models (Chinchilla). https://arxiv.org/abs/2203.15556.
Kaplan, Jared et al. 2020. Scaling Laws for Neural Language Models. https://arxiv.org/abs/2001.08361.