35  Atención eficiente

Dónde estamos. Abre la Parte VI (eficiencia y despliegue). En el Cap. 18 hicimos el mapa de las variantes de atención (FlashAttention, lineal, GQA, Mamba… y qué ganó). Aquí no repetimos el catálogo: abrimos el capó y explicamos la mecánica —qué hace cara a la atención, por qué el cuello de botella real es mover datos y no calcularlos, y cómo, exactamente, FlashAttention y la atención lineal lo resuelven—.

35.1 La idea en una frase

La atención cuesta O(n²) y, en una GPU moderna, su cuello de botella no son las operaciones, sino mover la matriz n×n entre memorias; las técnicas eficientes ganan sobre todo recortando ese tráfico (FlashAttention) o evitando construir la matriz (atención lineal).

35.2 Conceptos clave y su papel en el transformer

Antes de entrar en detalle, definimos los términos de este capítulo y para qué sirve cada uno dentro de un transformer. Es el “mapa de conceptos” para no perderte:

  • Coste O(n²) (“cuadrático”). Definición: una forma de crecer en la que doblar la entrada cuadruplica el coste. En el transformer: es el coste de la atención, que compara cada token con todos; marca el límite de contexto —por qué los modelos “sufren” con secuencias muy largas—.
  • Intensidad aritmética. Definición: cuántas operaciones haces por cada byte que traes de memoria. En el transformer: la atención tiene intensidad baja (mueve mucho dato para poco cálculo), y por eso su cuello de botella es la memoria, no el cálculo.
  • Limitado por memoria vs por cómputo (memory-/compute-bound). Definición: si el tiempo lo marca mover datos o hacer operaciones. En el transformer: la atención está limitada por memoria → optimizarla es mover menos, no calcular más rápido.
  • HBM y SRAM. Definición: los dos niveles de memoria de una GPU —HBM grande y lenta, SRAM minúscula y ~10× más rápida—. En el transformer: el modelo vive en HBM y se calcula en SRAM; el coste real es el tráfico entre ambas.
  • FlashAttention. Definición: un algoritmo que calcula la atención exacta sin escribir nunca la matriz n×n en la memoria lenta. En el transformer: es hoy el kernel de atención por defecto en entrenamiento e inferencia —misma calidad, mucho más rápido—.
  • Online softmax. Definición: calcular el softmax por bloques, con un máximo y una suma “móviles”, sin necesitar la fila entera. En el transformer: es lo que hace posible trocear la atención sin alterar el resultado.
  • Atención lineal. Definición: sustituir el softmax por un núcleo que permite reordenar el cálculo y bajar el coste a O(n). En el transformer: cambia la atención por un estado de tamaño fijo (tipo RNN) → barata en contexto larguísimo, a costa de recuerdo.
  • Caché KV. Definición: la memoria donde se guardan las claves y valores ya calculados para no recomputarlos al generar (Cap. 20). En el transformer: es el coste de memoria de la inferencia; muchas técnicas (GQA, PagedAttention) atacan justo esto.

Con estos conceptos en la mano, vamos a desarrollarlos uno a uno.

35.3 Anatomía del coste: de dónde sale el O(n²)

Antes de “arreglar” la atención conviene entender exactamente qué es lo caro. La atención, recordemos (Cap. 4), hace dos productos de matrices encadenados:

  1. Las puntuaciones \(QK^\top\): cada uno de los n tokens se compara con los otros n, con vectores de dimensión d. El resultado es una matriz n×n.
  2. La mezcla \(AV\): esa matriz n×n de pesos multiplica a los n valores de dimensión d, dando la salida.

Cada producto cuesta del orden de n²·d operaciones, así que el cómputo es O(n²·d). Pero hay un segundo coste, y es el que de verdad explota: para aplicar el softmax hay que tener escrita la matriz n×n de puntuaciones. Eso es O(n²) en memoria, y —dato importante— no depende de d: aunque las cabezas sean pequeñas, la matriz crece con el cuadrado de la longitud.

Tip✓ Qué significa “O(n²)”

La notación O(·) describe cómo crece un coste con el tamaño de la entrada, ignorando constantes. O(n²) (“cuadrático”) significa que doblar la entrada cuadruplica el coste (2² = 4). Por eso, en atención, pasar de 2.000 a 4.000 tokens no duplica el trabajo: lo cuadruplica —y multiplica por 4 la matriz que hay que guardar—.

El matiz que casi nadie explica: el resto del transformer es solo lineal en n. Las proyecciones de Q, K, V y la FFN (Cap. 6) procesan cada token por separado, así que su coste crece proporcional a n (no a n²). La consecuencia práctica es nítida:

  • Con secuencias cortas, el término lineal (la FFN) domina el tiempo; la atención apenas se nota.
  • El término cuadrático de la atención solo adelanta a los demás cuando n se hace grande frente a d.

Por eso la “atención eficiente” es un problema de contexto largo, no de secuencias cortas. Cuando lees que un modelo “sufre” con 100.000 tokens, el culpable es este O(n²).

35.4 La clave que lo cambia todo: la atención está limitada por la memoria

Aquí está el giro conceptual del capítulo. Uno supondría que, siendo O(n²·d) operaciones, el límite es la velocidad de cálculo de la GPU. No lo es. El límite es mover la matriz n×n entre los distintos tipos de memoria. La distinción clave es entre estar limitado por cómputo (el tiempo lo marcan las operaciones, la GPU está a tope calculando) y estar limitado por memoria (el tiempo lo marca el trasiego de datos, la GPU está parada esperándolos). ¿Cuál de los dos? Lo decide la intensidad aritmética —operaciones por byte movido—: la atención hace pocas operaciones por cada byte de la enorme matriz n×n, así que está limitada por memoria.

Y “la memoria” de una GPU no es una sola; es una jerarquía con un compromiso brutal entre tamaño y velocidad (cifras del paper de FlashAttention, GPU A100):

Tabla 35.1: La jerarquía de memoria de la GPU
Memoria Tamaño Velocidad Papel
HBM (alto ancho de banda) ~40-80 GB ~1,5-2,0 TB/s grande pero “lenta”; vive el modelo
SRAM (en chip) ~20 MB ~19 TB/s (~10× más) minúscula pero rapidísima; donde se calcula

El problema de la atención estándar es que trata la matriz n×n como cualquier dato: la escribe entera en la HBM (lenta), la relee para el softmax, escribe las probabilidades, y las relee otra vez para el producto \(AV\). Son idas y venidas repetidas de un objeto O(n²) a la memoria lenta, y eso —no el cálculo— es lo que consume el reloj. De aquí sale la premisa que lo cambió todo: si recortas el tráfico a la HBM, recortas el tiempo —aunque hagas más operaciones—.

🧩 Analogía — el chef que espera ingredientes. Imagina un cocinero rapidísimo (las unidades de cálculo) que se pasa el rato parado porque cada ingrediente hay que traerlo de un almacén lejano (la HBM) en vez de cogerlo del mostrador de al lado (la SRAM). El cuello de botella es el acarreo, no el cortar. Acelera el acarreo —trae menos veces y en bloques— y la comida sale antes, aunque el cocinero “trabaje” un poco más.

35.5 FlashAttention: calcular lo mismo moviendo mucho menos

FlashAttention (Dao et al. 2022) calcula la atención exacta (idéntica bit a bit a la normal) pero sin escribir nunca la matriz n×n completa en la memoria lenta: no cambia qué se calcula, sino cómo se mueven los datos. Descansa en tres ideas que conviene entender una a una.

1. Tiling (trocear en bloques). En vez de operar sobre las matrices enteras, parte Q, K y V en bloques pequeños que caben en la SRAM. Carga un bloque de Q, hace pasar por la SRAM los bloques de K y V, calcula la atención de ese trozo dentro del chip y va acumulando la salida. ¿La ganancia? La matriz n×n completa nunca se forma en la HBM: allí solo viven las entradas/salidas (de tamaño O(n·d)) y unos pocos estadísticos. El “acarreo” caro desaparece.

2. Online softmax (el truco crucial). Aquí surge un problema: el softmax de una fila necesita, en principio, la fila entera —su valor máximo (para no desbordar al exponenciar) y la suma de todas las exponenciales (para normalizar)—. Pero si troceamos, nunca tenemos la fila entera a la vez. La solución (Milakov y Gimelshein 2018) es calcularlo en una sola pasada por bloques, manteniendo dos números que se actualizan sobre la marcha:

  • un máximo móvil \(m\) (el mayor valor visto hasta ahora), y
  • una suma móvil \(\ell\) (la suma de exponenciales acumulada).

Cuando llega un bloque nuevo con un valor mayor que cualquiera visto, se reescala lo ya acumulado por un factor de corrección \(\exp(m_{\text{viejo}} - m_{\text{nuevo}})\) antes de sumar la contribución nueva. La intuición: cada vez que descubres un número más grande, encoges retroactivamente el peso que habías dado a los anteriores, de modo que la normalización final sale exacta, como si hubieras tenido la fila entera desde el principio.

🧩 Analogía — sumar por páginas. Es como sumar una columna enorme de cifras una página a la vez, anotando en un post-it solo un total móvil y un máximo móvil, en vez de necesitar toda la columna extendida sobre la mesa. Si una página trae un número mayor que todos los anteriores, ajustas el total en proporción (la corrección de reescalado). Al acabar, el total es el mismo que si lo hubieras tenido todo a la vista.

3. Recomputación en la marcha atrás. Para entrenar hace falta la marcha atrás (backward, Cap. 11), que normalmente necesitaría la matriz n×n de nuevo. En vez de guardarla (carísimo en memoria), FlashAttention guarda solo los estadísticos del softmax (\(m\) y \(\ell\) por fila) y, cuando los necesita, recalcula el bloque de atención en el chip. Es el mismo intercambio del gradient checkpointing (Cap. 25): pagar un poco más de cómputo para ahorrar mucha memoria.

🧩 Analogía — rehacer el borrador. En vez de guardar todas las hojas de cálculo intermedias por si acaso, las tiras y las rehaces a partir de dos números apuntados cuando de verdad las necesitas. Recalcular sale más barato que almacenar.

Tip✓ El resultado, en cifras

FlashAttention da atención exacta (no aproximada), con memoria O(n) en vez de O(n²), y más rápida por hacer menos viajes a la HBM: +15 % en BERT-large, 3× en GPT-2 (1K tokens), 2,4× en Long-Range Arena —y permitió entrenar contextos antes imposibles—. FlashAttention-2 (Dao 2023) añade ~ (mejor reparto del trabajo y menos operaciones que no son multiplicaciones de matrices); FlashAttention-3 (Shah et al. 2024) exprime la asincronía y la precisión FP8 de las GPU Hopper (~1,5-2× sobre FA2).

35.6 Atención lineal: no mover la matriz más rápido, sino no construirla

FlashAttention ataca cómo se mueve la matriz. La otra gran familia ataca algo más radical: no formar nunca la matriz n×n —bajando el coste a O(n) al sustituir el softmax por una función que permite reordenar las multiplicaciones—. Para entender cómo, hay que ver primero por qué la atención normal está obligada a formar esa matriz.

El softmax calcula \(\exp(q_i\cdot k_j)\) para cada par (i, j) → te obliga a construir la tabla n×n antes de poder multiplicar por V. La idea de la atención lineal es reemplazar ese \(\exp(q_i\cdot k_j)\) por un núcleo que factoriza como \(\varphi(q_i)\cdot \varphi(k_j)\) —donde \(\varphi\) (“fi”) es una función de características aplicada a cada vector—. Si la similitud se factoriza así, la asociatividad del producto de matrices permite reordenar el cálculo:

\[ \big(\varphi(Q)\,\varphi(K)^\top\big)\,V \;=\; \varphi(Q)\,\big(\varphi(K)^\top V\big) \]

Vamos al porqué término a término:

  • A la izquierda, primero formas \(\varphi(Q)\varphi(K)^\top\), que es la matriz n×n —y vuelves a O(n²)—.
  • A la derecha, primero calculas \(\varphi(K)^\top V\): el producto de una matriz de tamaño d×n por una n×d da una matriz pequeña d×d, en coste O(n·d²). Luego la multiplicas por \(\varphi(Q)\), otra vez O(n·d²).
  • Ambos lados dan el mismo resultado (es la misma cuenta reordenada), pero el de la derecha nunca construye el objeto n×n. Esa reasociación es todo el truco, y vuelve el coste lineal en n.

🧩 Analogía — cambiar el orden de multiplicar. Es elegir entre (A·B)·C y A·(B·C): el resultado es idéntico, pero un orden construye una tabla intermedia gigante y el otro solo tablas pequeñas. La atención lineal siempre elige el segundo orden.

Esto tiene una lectura preciosa: en modo causal (autoregresivo), \(\varphi(K)^\top V\) se puede mantener como un estado de tamaño fijo \(S=\sum_j \varphi(k_j)v_j^\top\) que se actualiza token a token —es decir, la atención lineal se comporta como un RNN (una red recurrente) con un estado constante, lo que da memoria O(1) por paso al generar (Linear Transformers (Katharopoulos et al. 2020), hasta 4000× más rápido en secuencias muy largas). Otra rama, Performer (Choromanski et al. 2021), no cambia el softmax por otro núcleo, sino que aproxima el propio softmax con características aleatorias.

Advertencia⚠ Honesto — el estado fijo se paga

Aquí está el coste oculto: un estado de tamaño fijo tiene que comprimir todo el pasado en d×d números, mientras que la atención completa conserva cada token en la caché. Por eso la atención lineal sufre una brecha de calidad medible en tareas de recuerdo (copiar un dato exacto, encontrar la aguja en el pajar, recuperación en contexto). Es justo el veredicto del Cap. 18: no destronó a la atención completa. Tiene sentido solo cuando el O(n²) es realmente inviable y la tarea tolera ese recorte.

35.7 Otros ahorros, en una pincelada mecánica

  • Memoria O(1) (Rabe y Staats 2021): el precursor de FlashAttention. Usa la misma idea de online softmax, pero la presenta como un resultado de memoria (mostró que la atención no necesita O(n²) memoria; 59× menos a 16K tokens) más que de velocidad. FlashAttention le añadió la conciencia del tráfico HBM que lo convirtió en aceleración.
  • PagedAttention (Kwon et al. 2023): ataca el desperdicio de la caché KV (Cap. 20) guardándola en bloques no contiguos mapeados como la memoria virtual de un sistema operativo → casi cero desperdicio y compartición entre peticiones (2-4× de throughput). Lo veremos al servir (Cap. 36).
  • MQA/GQA (Cap. 18): reducen el número de cabezas KV, recortando memoria y ancho de banda al decodificar. Es una palanca de ancho de banda, no de cómputo.

35.8 Cuándo usar qué: el roofline

Tip✓ Qué es el “roofline”

Un modelo mental para razonar sobre el rendimiento: dibuja un “techo” (roof) formado por dos límites —el de cómputo (FLOPs/s de la GPU) y el de memoria (ancho de banda)—. Un cálculo “choca” contra uno u otro techo. Saber contra cuál chocas te dice qué optimización servirá: si estás contra el techo de memoria, acelerar el cálculo no ayuda; hay que mover menos datos.

Con eso, cada método ataca un eje distinto:

  • FlashAttention: úsalo siempre. Es exacto y sirve en entrenamiento e inferencia, a cualquier longitud; solo cambia el patrón de IO. Es el kernel por defecto, y gana más cuanto mayor es n. Ataca la constante (el tráfico) de la cuadrática, no su exponente.
  • Atención lineal: solo en contexto larguísimo, donde el O(n²) es fatal y la tarea aguanta el recorte de recuerdo. Es lo único que cambia la clase de complejidad (el exponente).
  • Reducciones de KV (GQA/MLA, PagedAttention): memoria y ancho de banda de inferencia. No tocan los FLOPs de entrenamiento; hacen que decodificar quepa y vaya rápido (el paso de decodificación está limitado por el ancho de banda de la caché).

En una frase: FlashAttention ataca la constante (IO); la atención lineal, el exponente; los métodos de KV, la caché de inferencia. Tres ejes distintos del mismo problema.

35.9 Puente con nuestro tema (atención a lo largo de la distancia)

La atención eficiente es la respuesta de ingeniería al mismo coste que estudiamos en física —el de atender a lo largo de la distancia—. Y hay una distinción honesta que merece la pena: FlashAttention es agnóstica al contenido, es decir, calcula todas las interacciones por igual, solo que moviendo los bytes de forma óptima. Nuestra ventana D_f derivada de γ (Cap. 20) es el complemento consciente del contenido: dice qué entradas KV lejanas tienen masa de atención despreciable y se pueden tirar —cambia qué calculas, no solo cómo lo mueves—. Son ortogonales y componibles: FlashAttention abarata la ventana que conservas; D_f decide cuán ancha debe ser esa ventana. (Honesto: es un puente conceptual, no una afirmación de que D_f sea un método de atención eficiente publicado.)

Nota🧪 Pruébalo — tafagent

tafagent calcula el presupuesto de KV a partir de γ (Cap. 20): cuánta caché necesitas de verdad a la longitud objetivo. Combínalo con la lógica de este capítulo: FlashAttention hace barata la atención exacta, y γ/D_f te dice cuánto contexto distante aporta poco y podrías comprimir —el “cuánto” que el kernel exacto no decide por ti—.

35.10 Resumen

  • Coste: la atención es O(n²·d) en cómputo y O(n²) en memoria (la matriz n×n); el resto del modelo es lineal en n → la atención solo manda en contexto largo.
  • Memory-bound: el límite real es mover la matriz n×n entre HBM (grande, lenta) y SRAM (mini, ~10× rápida) —baja intensidad aritmética—, no los FLOPs.
  • FlashAttention (Dao et al. 2022): tiling + online softmax (máximo/suma móviles + reescalado) + recomputación → atención exacta, O(n) memoria, más rápida (FA2/FA3).
  • Atención lineal: quitar el softmax + reasociar \(\varphi(Q)(\varphi(K)^\top V)\)O(n); forma RNN de estado fijo. Honesto: brecha de calidad en recuerdo (Cap. 18).
  • Otros: memoria O(1) (precursor), PagedAttention (caché KV paginada → Cap. 36), GQA/MLA (ancho de banda).
  • Roofline: FA ataca la constante; la lineal, el exponente; KV, la caché.
  • Puente: FA es agnóstica al contenido; nuestra D_f (γ) es el complemento consciente del contenido.

Siguiente (Capítulo 35): otra vía de eficiencia —hacer el modelo más pequeño sin perder (mucha) calidad: cuantización, destilación y poda.

35.11 Ejercicios

  1. Dos costes. Distingue el coste en cómputo y en memoria de la atención. ¿Cuál es el que “explota” y por qué no depende de d?
  2. Memory-bound. Define “limitado por memoria” frente a “limitado por cómputo”. ¿Por qué la atención cae en el primer caso? ¿Qué papel juegan HBM y SRAM?
  3. Online softmax. ¿Por qué el softmax normal necesita la fila entera, y cómo lo evita el cálculo por bloques con máximo y suma móviles?
  4. Exacta vs aproximada. ¿Por qué FlashAttention no es una aproximación, a diferencia de la atención lineal?
  5. Reasociación. En \(\varphi(Q)(\varphi(K)^\top V)\), ¿qué se calcula primero y por qué eso vuelve el coste lineal en n?
  6. Roofline. Asocia cada método (FlashAttention / lineal / GQA) con qué ataca: constante de IO, exponente, o caché de inferencia.

Referencias

Choromanski, Krzysztof et al. 2021. «Rethinking Attention with Performers». ICLR. https://arxiv.org/abs/2009.14794.
Dao, Tri. 2023. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. https://arxiv.org/abs/2307.08691.
Dao, Tri, Daniel Y. Fu, Stefano Ermon, Atri Rudra, y Christopher Ré. 2022. «FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness». NeurIPS. https://arxiv.org/abs/2205.14135.
Katharopoulos, Angelos, Apoorv Vyas, Nikolaos Pappas, y François Fleuret. 2020. «Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention». ICML. https://arxiv.org/abs/2006.16236.
Kwon, Woosuk, Zhuohan Li, Siyuan Zhuang, et al. 2023. «Efficient Memory Management for Large Language Model Serving with PagedAttention». SOSP. https://arxiv.org/abs/2309.06180.
Milakov, Maxim, y Natalia Gimelshein. 2018. Online Normalizer Calculation for Softmax. https://arxiv.org/abs/1805.02867.
Rabe, Markus N., y Charles Staats. 2021. Self-attention Does Not Need O(n^2) Memory. https://arxiv.org/abs/2112.05682.
Shah, Jay et al. 2024. FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision. https://arxiv.org/abs/2407.08608.