12 Entrenar un transformer
Dónde estamos. Hemos montado el modelo entero (Cap. 1–10), pero nace vacío: sus millones de pesos empiezan al azar y no sabe nada. Este capítulo explica cómo aprende —el objetivo, cómo se ajustan los pesos, y por qué hacerlo más grande con más datos funciona de forma sorprendentemente predecible—. Sin fórmulas pesadas: la intuición de qué pasa cuando un modelo “estudia”.
12.1 La idea en una frase
Entrenar es jugar a “rellena la siguiente palabra” miles de millones de veces, ajustando cada peso un poquitín después de cada intento fallido.
🧩 Analogía. Imagina a alguien leyendo todo internet con un juego: tapa la siguiente palabra, la adivina, la compara con la real y corrige sus hábitos un pelín. Repetido billones de veces, esos ajustes minúsculos acumulados se vuelven gramática, conocimiento y razonamiento.
12.2 Conceptos clave y su papel en el transformer
Antes de entrar en detalle, definimos los términos de este capítulo y para qué sirve cada uno dentro de un transformer:
- Predicción del siguiente token (autosupervisión). Definición: el objetivo de preentrenamiento; el propio texto es su etiqueta. En el transformer: lo que aprende a hacer, sin necesidad de anotaciones humanas.
- Entropía cruzada. Definición: la pérdida
−log p(token correcto). En el transformer: mide “cuánto se equivocó” en cada predicción; entrenar es bajarla. - Perplejidad. Definición:
e^pérdida, la versión legible de la pérdida. En el transformer: “entre cuántos tokens duda efectivamente”; menos es mejor. - Gradiente (retropropagación). Definición: para cada peso, la dirección en que moverlo baja la pérdida. En el transformer: la brújula del aprendizaje; se calcula hacia atrás a través de todas las capas.
- Optimizador (Adam/AdamW). Definición: decide cuánto mover cada peso, con paso adaptativo por parámetro. En el transformer: clave porque las escalas varían muchísimo; AdamW añade weight decay para generalizar mejor.
- Warmup + decay. Definición: arrancar el learning rate bajo, subirlo y luego bajarlo. En el transformer: estabiliza los primeros pasos, cuando la red es frágil y las estadísticas de Adam aún no son fiables.
- Leyes de escala. Definición: la pérdida baja como una potencia predecible del tamaño, los datos y el cómputo. En el transformer: permiten predecir lo bueno que será un modelo antes de entrenarlo.
- Chinchilla (cómputo óptimo). Definición: a igual cómputo, modelo y datos deben crecer a la par (~20 tokens/parámetro). En el transformer: corrige el “más grande = mejor”; los gigantes previos estaban infraentrenados.
Con esto en mano, el resto del capítulo es ver cómo encajan en un único bucle de entrenamiento.
12.3 El objetivo: predecir el siguiente token
¿De qué aprende exactamente? De predecir el siguiente token. Le damos un texto, el modelo predice qué viene después, y se mide cuánto se equivocó. Lo bonito: es autosupervisado —el propio texto es su etiqueta—. No hace falta que nadie anote nada: la “respuesta correcta” es, sencillamente, la palabra que de verdad seguía.
(Los modelos de comprensión tipo BERT juegan a otra variante: tapar palabras al azar y reconstruirlas —masked language modeling—.)
12.4 Medir el error: entropía cruzada y perplejidad
¿Cómo se mide “cuánto se equivocó”? Con la entropía cruzada:
\[ \text{pérdida} = -\log(\,p_{\text{modelo}}(\text{token correcto})\,) \]
Qué dice: tomamos la probabilidad que el modelo le dio al token correcto y le aplicamos −log. Si le dio alta probabilidad (acertó con confianza), la pérdida es baja; si le dio poca, la pérdida es alta. Entrenar = bajar esa pérdida.
Una forma más intuitiva de leerla es la perplejidad = \(e^{\text{pérdida}}\): “entre cuántos tokens está dudando efectivamente el modelo”. Perplejidad 10 ≈ duda como si eligiera entre ~10 palabras; perplejidad 2 ≈ casi lo tiene. Menos es mejor.
12.5 Ajustar los pesos: gradiente y Adam
Sabido el error, ¿cómo se corrige? Con dos ideas:
- El gradiente (vía retropropagación): para cada peso, una flecha que dice “muévelo un poquito en esta dirección y la pérdida baja”. Es la brújula del aprendizaje.
- El optimizador (Adam/AdamW): decide cuánto mover cada peso. Adam le da a cada parámetro su propio tamaño de paso, adaptado según cómo se ha venido comportando su gradiente —clave en transformers, donde las escalas varían muchísimo—. AdamW añade weight decay (encoger un poco los pesos para que el modelo generalice mejor) y es el estándar hoy.
import torch.nn.functional as F
for lote in datos: # lotes de tokens
logits = modelo(lote[:, :-1]) # predice el siguiente token
loss = F.cross_entropy(logits.flatten(0, 1),
lote[:, 1:].flatten()) # cuánto se equivocó
loss.backward() # gradientes (retropropagación)
optim.step(); optim.zero_grad() # Adam ajusta los pesos12.6 Un detalle que importa: el warmup
El ritmo de aprendizaje (learning rate) no es constante: se arranca bajo, se sube poco a poco (warmup) durante los primeros miles de pasos, y luego se baja (decaimiento coseno) hacia el final. ¿Para qué? Al principio la red es inestable y las estadísticas de Adam aún no son fiables; un paso grande podría romperla. El warmup la deja calentar. (Importa sobre todo con Post-LN, Cap. 7.)
Dos apuntes prácticos más: se entrena en precisión mixta bf16 (16 bits, rápido y con buen rango, más estable que fp16) y se usa gradient clipping (recortar gradientes enormes) para que nada explote.
12.7 Lo grande: las leyes de escala
Aquí está uno de los hallazgos más influyentes de la última década. La pérdida no baja de forma caótica al hacer el modelo más grande: baja siguiendo una ley de potencia predecible con el tamaño del modelo, los datos y el cómputo —a lo largo de más de 7 órdenes de magnitud (Kaplan et al. 2020)—. Es decir: puedes predecir cómo de bueno será un modelo antes de entrenarlo.
Durante años se asumió “modelo más grande = mejor”. En 2022, Chinchilla (Hoffmann et al. 2022) lo corrigió: a igualdad de cómputo, el tamaño del modelo y la cantidad de datos deben crecer a la par —unos ~20 tokens por parámetro—. Resultó que los gigantes de la época (GPT-3, Gopher) estaban infraentrenados: demasiados parámetros, pocos datos. Chinchilla-70B, entrenado con 1,4 billones de tokens, superó a Gopher-280B (¡4× más pequeño!). Moraleja: a veces el modelo más listo es uno más pequeño pero mejor alimentado.
12.8 Cómo se ve el aprendizaje
La curva de pérdida cae rápido al principio (el modelo pilla enseguida las frecuencias y la gramática) y luego despacio (va refinando estructura más sutil: referencias, hechos, razonamiento). Hay un fenómeno curioso, el grokking —una generalización tardía y repentina tras una fase de aparente memorización—, que veremos en la Parte III (y que conecta con nuestro propio trabajo).
12.9 Resumen
- Entrenar = predecir el siguiente token billones de veces y ajustar los pesos un poco tras cada error. Es autosupervisado (el texto es su etiqueta).
- El error se mide con entropía cruzada (
−log p(correcto)); la perplejidad (eᵖérdida) es su versión legible (“entre cuántos tokens duda”). - El gradiente dice hacia dónde mover cada peso; Adam/AdamW decide cuánto, con paso adaptativo + weight decay.
- Warmup + decay del learning rate estabilizan el entrenamiento; bf16 y gradient clipping lo hacen viable y seguro.
- Leyes de escala: la pérdida mejora de forma predecible con tamaño/datos/ cómputo (Kaplan et al. 2020); Chinchilla mostró que hay que escalar datos y parámetros a la par (~20 tokens/parámetro) — los modelos previos estaban infraentrenados.
Siguiente (Capítulo 12): ya tenemos un modelo entrenado. ¿Cómo genera texto a partir de él? Decodificación, muestreo, temperatura y el KV-cache.
12.10 Ejercicios
- Autosupervisión. Explica por qué entrenar con “predecir el siguiente token” no necesita etiquetas humanas. ¿De dónde sale la “respuesta correcta”?
- Perplejidad. Si la perplejidad de un modelo es 1, ¿qué significa? ¿Y si es igual al tamaño del vocabulario?
- Chinchilla. Tienes un presupuesto de cómputo fijo y un modelo enorme que rinde regular. Según Chinchilla, ¿qué podrías estar haciendo mal?
- Warmup. ¿Por qué arrancar con un learning rate alto desde el paso 1 puede “romper” el entrenamiento?