How Attention Actually Works (And Why It Costs You O(n²) Every Time)
Every time someone asks why they can't just throw a million tokens at GPT-4 and let it figure things out, the answer comes down to one equation. Understanding that equation, not just memorizing it but knowing where every term comes from and why each design choice was made, is the difference between guessing at model behavior and reasoning about it with confidence. This post walks through attention from first principles, derives the n² cost concretely, and explains what Flash Attention and Grouped Query Attention actually do about it. No hand-waving. Every number is earned.
The problem attention is solving
Before getting into the mechanics, it helps to understand what problem the designers were actually trying to solve. The predecessor to transformers was the recurrent neural network. RNNs processed sequences token by token, maintaining a hidden state that was passed from each step to the next. This worked, but it had two crippling limitations.
The first was the information bottleneck. Every fact the network learned about token 1 had to be compressed into the hidden state before it could influence token 100. Long-range dependencies like "the trophy didn't fit in the suitcase because it was too big" were notoriously hard to learn because the signal from the noun "trophy" had to survive through dozens of intermediate hidden states before it could inform the meaning of "it." The further apart two tokens were in a sentence, the more the signal between them degraded.
The second was parallelism. Because each step depended on the previous step's hidden state, you couldn't compute step 50 until you'd finished step 49. Training on modern hardware, which excels at massively parallel operations, was hobbled by this sequential dependency. Training large RNNs was slow in a way that didn't improve just by buying more GPUs.
Attention solves both problems simultaneously. Instead of routing information through a sequential chain, it gives every token direct access to every other token in a single operation. Token 100 doesn't have to wait for information about token 1 to filter through 99 intermediate steps. It reads directly from token 1's representation. And because all these reads happen at once as matrix multiplications, the whole operation is massively parallelizable on GPUs. The price you pay for this is the quadratic cost we'll derive shortly.
The three matrices that run every modern language model
Take the sentence "the cat sat down." The transformer first turns each word into a dense vector, say a 512-dimensional embedding. You now have a matrix of shape [4 × 512]: four tokens, each represented as a 512-dimensional point in a learned vector space. The geometry of this space is meaningful. Similar words cluster together, and arithmetic operations on these vectors capture semantic relationships.
Attention's job is to let each token look at every other token and decide how much information to borrow from each. It does this by projecting the input into three distinct roles: Query (what is this token looking for?), Key (what does this token advertise about itself?), and Value (what information does this token actually contribute when selected?).
The intuition is that of a soft, differentiable lookup. The Query is the question you're asking. The Key is the index you're searching through. The Value is the content you retrieve when you find a match. The "soft" part is what makes it learnable: instead of retrieving from one entry with a hard match, you retrieve a weighted blend from all entries simultaneously, where the weights are determined by how well each Key matches the Query.
Each role gets its own learned weight matrix. W_Q, W_K, and W_V each have shape [d_model × d_k]. For a transformer with d_model = 512 and d_k = 64, multiplying the input [4 × 512] by W_Q [512 × 64] gives the query matrix Q of shape [4 × 64]. The same projection applied with W_K gives K [4 × 64], and with W_V gives V [4 × 64]. These three projections happen simultaneously as three separate linear transformations. They are the learned parameters that define what "looking for" and "offering" mean in this particular model.
The attention score between token i and token j is the dot product of Q_i and K_j. A large positive dot product means token i's query aligns strongly with token j's key. Scaling this by the square root of d_k keeps the magnitudes in a range where softmax remains numerically stable. Without the scaling, dot products grow with the dimension of the vectors. The expected magnitude of a dot product between two random unit vectors in d_k dimensions is roughly sqrt(d_k), and large inputs push softmax into a saturation region where its gradient approaches zero, making learning nearly impossible.
score(i, j) = (Q_i · K_j) / sqrt(d_k)
Compute this score for every possible pair (i, j) across all four tokens and you get a [4 × 4] matrix of raw scores. Apply softmax row-wise, where each row sums to 1, and you have a matrix of attention weights. Each row tells you how much token i should attend to each other token when constructing its new representation. Multiply by V and every token is updated to a weighted blend of all other tokens' values, where the weights encode learned, contextual relevance.
The full equation is:
Attention(Q, K, V) = softmax( QKᵀ / sqrt(d_k) ) · V
Working through the shapes concretely for the 4-token example: Q has shape [4 × 64] and Kᵀ has shape [64 × 4], so their product QKᵀ is [4 × 4], the raw score matrix with one scalar for every ordered token pair. Applying softmax row-wise keeps the shape at [4 × 4] but converts raw scores to normalized weights. Multiplying by V [4 × 64] gives the output matrix [4 × 64], the same dimensionality as the input projections but now every token embedding has been enriched by contextual information from every other token in the sequence.
That [4 × 4] matrix is the entire mechanism. For 4 tokens it costs almost nothing to compute. For 4,096 tokens it is a [4,096 × 4,096] matrix containing 16.7 million entries. For 128,000 tokens it is a [128k × 128k] matrix containing 16.4 billion entries. At fp16 (2 bytes per value), storing that matrix alone requires 32.8 GB. This is before accounting for multiple layers, multiple heads, intermediate activations, or gradients during training.
Why the cost is genuinely quadratic, and what that means at scale
The quadratic scaling shows up simultaneously in both compute and memory, which is why it is so punishing.
On the compute side, the dominant operation is QKᵀ: a matrix multiplication of [n × d_k] by [d_k × n]. The number of floating-point multiply-accumulate operations is n² × d_k. For a 128k-token context with d_k = 128, that is 128,000² × 128 = approximately 2.1 × 10¹² FLOPs per attention layer per head. A production 70B model has 80 transformer layers and 64 attention heads. Even if every other operation in the model were free, the attention FLOPs for a single forward pass at 128k tokens reach into the hundreds of petaFLOPs. No single GPU handles that.
On the memory side, storing the attention weight matrix [n × n] in fp16 costs exactly 2n² bytes. At 128k tokens, that is 2 × (128,000)² = 32.8 GB for one layer and one head. Across 80 layers, even storing just one head per layer, you need 2.6 TB of memory just for the attention matrices, before the model weights, before the KV cache, before activations. The A100 80GB has 80 GB of HBM. A naive implementation of 128k-token attention is not memory-constrained in the sense that it runs slowly. It is physically impossible to execute without fundamentally rethinking how the computation is organized.
Doubling your context window doesn't double your cost. It quadruples it. Tripling it makes it nine times more expensive.
There is also a more subtle boundary that reveals itself in profiling long before you run out of VRAM. In a standard transformer, the feedforward network (FFN) layers process tokens independently of each other with two linear projections: from [d_model → d_ff] and back from [d_ff → d_model], where d_ff is typically four times d_model. This operation scales linearly with sequence length because each token's FFN computation is independent of every other token. Attention scales quadratically. For short sequences, FFN dominates your compute budget. For long sequences, attention overtakes it.
The crossover point, where the FLOPs for attention and FFN per token are roughly equal, occurs around n ≈ sqrt(d_ff × d_model) / d_k. For LLaMA-2-7B, with d_ff = 11,008, d_model = 4,096, and d_k = 128:
n_crossover ≈ sqrt(11,008 × 4,096) / 128 ≈ 1,680 tokens
For sequences under roughly 1,700 tokens, the FFN is the bottleneck for LLaMA-2-7B. Above that threshold, attention starts to dominate. For larger models with wider FFNs, specifically d_ff = 28,672 for LLaMA-2-70B, the crossover moves out to roughly 3,400 tokens. This is why optimizations like Flash Attention matter more as you push toward longer contexts, and why profiling your actual workload before optimizing is worth the 30 minutes it takes.
Multi-head attention: richer representations, same quadratic bill
Single-head attention learns a single notion of relevance between token pairs. Run attention once and you get one way of answering how much token i should borrow from token j. This is limiting. The word "bank" in "river bank" and "bank account" needs to attend to different context clues to resolve its meaning, and a single attention pattern cannot simultaneously represent syntactic dependencies, semantic similarity, and positional proximity.
Multi-head attention solves this by running h independent attention operations in parallel, each on a lower-dimensional projection of the input. With d_model = 512 and h = 8 heads, each head works with d_k = d_model / h = 64 dimensions. The eight sets of weight matrices W_Q^i, W_K^i, W_V^i each project the input into a different 64-dimensional subspace and compute attention in that subspace independently. The eight resulting output matrices, each of shape [n × 64], are concatenated to form [n × 512] and then linearly projected back to [n × d_model] with a final weight matrix W_O.
Each head can learn a different type of relationship. In a well-trained model, some heads specialize in tracking syntactic subject-verb agreement across long distances, others in attending to the most recent mention of a noun, others in picking up positional proximity. The concatenated output captures all of these patterns simultaneously, giving the model a richer picture of context than any single attention pattern could provide.
What multi-head attention does not do is reduce the quadratic cost. Each of the h heads still computes its own [n × n] attention weight matrix. The total memory for attention matrices is h × 2n² bytes per layer. With h = 8 and n = 128k, that is 8 × 32.8 GB = 262 GB per transformer layer, still wildly over budget, just with a larger constant factor. The expressiveness benefit of multiple heads doesn't come with a computational discount. The n² problem remains, and any solution has to deal with it directly.
It is worth noting that the reduction in per-head dimensionality from d_model to d_k = d_model/h means each individual head's matrix multiplications are cheaper in raw FLOPs because d_k is smaller. Total attention FLOPs are roughly the same as a single-head system at the same d_model. The parallelism is what we gain, not efficiency.
Flash Attention: the bottleneck was IO, not arithmetic
The A100 GPU has two memory tiers that matter here. HBM (high-bandwidth memory) is the large 40–80 GB pool you think of as "GPU memory," fast by DRAM standards at around 2 TB/s on the A100 SXM. SRAM is the on-chip memory, about 40 MB on the A100, that sits much closer to the compute units and achieves roughly 19 TB/s, nearly ten times faster than HBM. Compute on the GPU runs at 312 TFLOPS for fp16 matrix multiplications.
With those numbers, a standard attention implementation reveals a problem: at realistic sequence lengths, the GPU spends most of its time doing IO (reading and writing to HBM) rather than arithmetic. The arithmetic intensity of naive attention, measured as FLOPs per byte transferred, is low enough that the operation is bandwidth-bound rather than compute-bound. You have 312 TFLOPS of math sitting idle while data shuffles back and forth from HBM.
The specific culprit is the materialization of the full [n × n] attention matrix in HBM. Standard attention computes QKᵀ/sqrt(d_k), writes the result to HBM, reads it back to apply softmax, writes the softmax output to HBM, and reads it again to multiply by V. That is three HBM round-trips for the n×n matrix on every forward pass. At n = 128k, that is three trips of 32.8 GB each, totaling nearly 100 GB of HBM reads and writes just for the attention scores in one layer.
Flash Attention (Dao et al., NeurIPS 2022) eliminates these round-trips without changing the mathematical result. The key insight is that softmax is computed row-wise. Each row of the attention weight matrix is independent of the other rows, which means you don't need the entire n×n matrix in memory at once. Instead, you can tile the Q, K, and V matrices into blocks small enough to fit in SRAM, compute the attention output for each tile of Q against all tiles of K and V entirely within SRAM, and maintain a running correction for the softmax normalization that gets updated incrementally as new K/V tiles are processed.
The running softmax correction is the mathematically non-trivial part. Standard softmax requires dividing each score by the sum of all scores in the row, which means you normally need to see all scores before you can normalize any of them. Flash Attention gets around this with the online softmax algorithm: it maintains the current row maximum m and the current sum of exponentials l as it processes K/V tiles one at a time, and rescales the accumulated output O accordingly when these running statistics are updated. At the end, O contains the exact same result as if you had computed the full n×n matrix and then multiplied by V in the standard way. Not an approximation. The exact result.
The practical outcome is significant. Flash Attention has the same arithmetic FLOPs as standard attention; it does not skip any math. But it reduces HBM accesses from O(n²) to O(n). For the 128k-token case, this is the difference between transferring roughly 100 GB and 1.3 GB of attention scores to and from HBM per layer. On an A100 with its 2 TB/s HBM bandwidth, that represents roughly 50ms versus 0.65ms per layer in pure IO time. Across 80 layers, the cumulative savings dominate the forward pass runtime. Benchmarks consistently show 2–4x end-to-end wall-clock speedup at long sequence lengths, with gains scaling as sequences grow longer.
Flash Attention 2 (2023) improved the work partitioning scheme between GPU thread blocks and increased parallelism in the backward pass, yielding another 2x in practice and reaching roughly 50–73% of A100 peak throughput. Flash Attention 3 (2024) overlapped attention computation with memory transfers using asynchronous warpgroup operations and FP8 support, pushing throughput to 75% of the theoretical A100 peak for fp16.
The reason Flash Attention matters so much for the AI engineering stack is that it is a drop-in replacement: same inputs, same outputs, dramatically less memory and time. It is enabled by default in recent versions of PyTorch and is the default attention implementation in Hugging Face Transformers for supported GPUs. The upgrade from O(n²) memory to O(n) memory is what made the push from 4k to 32k to 128k context windows possible between 2022 and 2024.
Grouped Query Attention: the second bottleneck at inference time
Flash Attention largely solves the problem of computing attention during training and the prefill phase of inference. There is a second bottleneck that Flash Attention does not address: the memory cost of the KV cache during autoregressive generation.
When a language model generates text one token at a time, each new token's attention requires the K and V projections for every previous token in the sequence. You could recompute them from scratch on every step, but that means repeating the full O(n) projection computation for all previous tokens on every generation step, a cost that grows as the sequence lengthens. The standard solution is to cache the K and V matrices for all previous tokens and append the new token's K and V on each step. This is the KV cache.
The memory cost of the KV cache for LLaMA-2-70B during inference is:
KV cache = 2 (K and V) × n_layers × n_kv_heads × d_k × seq_len × batch_size × bytes_per_value
For LLaMA-2-70B (80 layers, 64 KV heads, d_k = 128, fp16 = 2 bytes):
= 2 × 80 × 64 × 128 × seq_len × batch_size × 2
= 2,621,440 × seq_len × batch_size bytes
At seq_len = 4,096 and batch_size = 16, this evaluates to approximately 40 GB of KV cache. The model weights themselves occupy approximately 140 GB in fp16 (70 billion parameters × 2 bytes). Together that is 180 GB just to load the model and maintain a 16-batch KV cache at 4k tokens, more than two A100 80GB GPUs for the static memory alone. There is nothing left for activations, no headroom to increase batch size or sequence length, and serving latency will be bottlenecked by HBM bandwidth as tokens are generated one at a time.
Multi-Query Attention (MQA) (Shazeer, 2019) was the first practical solution. Instead of h K/V head pairs, one per query head, MQA uses a single shared K/V pair that all query heads attend to. This reduces the KV cache by a factor of h, which is 64x for LLaMA-2-70B, bringing that 40 GB down to 0.6 GB. The problem is quality degradation: all 64 query heads read from identical K and V projections, eliminating the diversity of perspectives that makes multi-head attention useful. Benchmarks showed 1–3% degradation on tasks sensitive to long-range reasoning, which was acceptable for some applications but not others.
Grouped Query Attention (GQA) (Ainslie et al., 2023) lands the tradeoff more carefully. Instead of one K/V shared by all h query heads (MQA) or one K/V per query head (MHA), GQA uses g groups where each group has its own K/V pair shared by h/g query heads. LLaMA-3-70B uses g = 8 groups with h = 64 query heads: each group of 8 query heads shares one K/V pair. This reduces the KV cache by 8x relative to MHA while giving each group of queries some differentiated K/V structure. The quality impact is less than 1% degradation on standard benchmarks relative to full MHA, while the memory savings are substantial.
The practical consequence of GQA is not just a lower memory footprint. It is the ability to serve larger batches and longer sequences within the same GPU memory budget, which directly drives inference throughput. A serving system running LLaMA-3-70B with GQA on two A100s can handle roughly 8x more concurrent requests than a full MHA variant at the same sequence length, which translates directly into cost per token in production.
What actually limits context windows in practice
The architectural innovations above, Flash Attention, GQA, and rotary positional embeddings for extrapolation, have made 128k and 1M-token context windows technically possible. But "possible" and "useful" are different things, and understanding the gap matters for engineering decisions.
Positional encoding is one constraint. Original transformers used learned absolute position embeddings that could not generalize beyond their training length. Rotary Position Embeddings (RoPE), used in LLaMA and most modern open models, encode position as a rotation applied to the Q and K vectors before the dot product. This makes position differences relative rather than absolute and allows models to extrapolate to lengths longer than those seen during training, though not indefinitely. ALiBi (Attention with Linear Biases) takes a different approach, adding a position-dependent bias directly to the attention scores, which also generalizes more gracefully than absolute embeddings.
A second, harder-to-fix constraint is what researchers call the "lost-in-the-middle" problem (Liu et al., 2023). When relevant information is placed in the middle of a very long context rather than at the beginning or end, model performance drops significantly. In retrieval tasks where the correct document was placed at position 10 out of 20 retrieved passages, GPT-3.5-Turbo-16k dropped from roughly 70% accuracy (when the document was at position 1 or 20) to below 50% accuracy. The model's effective attention is not uniformly distributed across its context window. It is biased toward the start and end of the sequence. This is a property of how attention weights are distributed in practice, not a theoretical constraint, but it means that stuffing 100k tokens into a context window is not the same as giving the model coherent access to all of it.
The engineering implication is that a retrieval strategy which surfaces the top 5 relevant chunks near the beginning of a prompt often outperforms a long-context strategy that buries relevant information inside a 50k-token document dump, both in quality and cost. Understanding this is the difference between using long context intelligently and using it as an excuse not to build retrieval.
The practical heuristic: when do these optimizations actually matter?
The right way to approach this is as a decision based on your actual workload profile.
If your sequences are under 2k tokens, attention is not your bottleneck at inference time regardless of model size. The FFN layers dominate. Flash Attention gives you memory savings because it avoids n×n matrix materialization, but the wall-clock speedup will be modest. Focus optimization energy elsewhere: quantization of FFN weights, continuous batching, operator fusion.
Between 2k and 16k tokens, attention starts to become a meaningful contributor to runtime and memory. Flash Attention is important here, and GQA matters for batch sizes above 4. This is the most common regime for production LLM applications: RAG with retrieved context, chat applications with conversation history, and document summarization.
Above 32k tokens, attention is almost certainly your dominant cost. Flash Attention becomes essential rather than optional. The memory savings from Flash Attention are what make these workloads physically possible on commercial hardware. GQA is load-bearing at this scale. Without it, the KV cache alone will exhaust your GPU memory at the batch sizes needed for production throughput.
Profile before you optimize. Running torch.profiler on your serving stack for 100 representative requests will tell you in 20 minutes whether you're attention-bound or FFN-bound, whether you're memory-constrained or compute-constrained, and exactly what percentage of wall-clock time each operation consumes. The numbers above are useful for back-of-envelope estimates before you build, but they are not a substitute for measurement on your actual workload and hardware.
The deeper principle behind all of this is that attention's quadratic cost is not a flaw waiting to be fixed. It is the price of the capability. Every token attending to every other token is what gives transformers their remarkable ability to track long-range dependencies and maintain coherent context. Flash Attention and GQA don't reduce that price fundamentally. They make paying it more efficient. Understanding where the cost comes from is the first step to making intelligent decisions about when to pay it and when to find a cheaper approximation.
References
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., and Polosukhin, I. "Attention Is All You Need." NeurIPS, 2017. https://arxiv.org/abs/1706.03762
Dao, T., Fu, D. Y., Ermon, S., Rudra, A., and Ré, C. "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS, 2022. https://arxiv.org/abs/2205.14135
Dao, T. "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR, 2024. https://arxiv.org/abs/2307.08691
Shah, J., Bikshandi, G., Zhang, Y., Kirby, V., Pandey, P., and Dao, T. "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision." arXiv, 2024. https://arxiv.org/abs/2407.08608
Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F., and Sanghai, S. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP, 2023. https://arxiv.org/abs/2305.13245
Shazeer, N. "Fast Transformer Decoding: One Write-Head is All You Need." arXiv, 2019. https://arxiv.org/abs/1911.02150
Liu, N. F., Lin, K., Hewitt, J., Paranjape, A., Bevilacqua, M., Petroni, F., and Liang, P. "Lost in the Middle: How Language Models Use Long Contexts." Transactions of the ACL, 2024. https://arxiv.org/abs/2307.03172
Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B., and Liu, Y. "RoFormer: Enhanced Transformer with Rotary Position Embedding." Neurocomputing, 2024. https://arxiv.org/abs/2104.09864
Press, O., Smith, N. A., and Lewis, M. "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation." ICLR, 2022. https://arxiv.org/abs/2108.12409