Flash Attention is a revolutionary, highly-optimized algorithm designed to compute exact attention in Transformer models significantly faster and with far less memory overhead than standard implementations. It achieves this by being I/O-aware—specifically optimizing the transfer of data between the GPU’s large, relatively slow High Bandwidth Memory (HBM) and its small, ultra-fast on-chip Static Random Access Memory (SRAM).
By fusing mathematical operations and minimizing memory reads and writes, Flash Attention addresses the notorious “memory wall” bottleneck in deep learning hardware, essentially single-handedly enabling the training of large language models (LLMs) with massive, real-world context windows (upwards of 100K to 1M+ tokens).
The Bottleneck: Standard Attention is Memory-Bound
In standard self-attention, the computation involves three main steps to calculate the output matrix O:
- Compute the attention scores S = QKT.
- Compute the attention probabilities P = softmax(S).
- Compute the final output O = PV.
The primary bottleneck in standard implementations is not the mathematical compute (FLOPs), but the latency of moving data. The standard approach requires writing the massive intermediate matrices S and P—which grow quadratically with the sequence length (O(N2))—to the GPU’s slow HBM, only to immediately read them back into the fast SRAM for the next operational step. For long token sequences, these continuous memory transfers become prohibitively slow and consume vast amounts of VRAM, severely restricting maximum context lengths.
The Flash Attention Solution: Tiling and Recomputation
Flash Attention redesigns the algorithm to compute the exact same mathematical result without ever materializing the large S and P matrices in HBM. It relies on two fundamental hardware-aware techniques:
- Tiling (Blocking): It divides the Query (Q), Key (K), and Value (V) matrices into smaller blocks (tiles) that comfortably fit within the fast SRAM. The attention is then computed block-by-block.
- Recomputation: Instead of storing the massive intermediate attention matrix P in HBM during the forward pass to be used later during the backward pass (for gradient calculation), Flash Attention swiftly recomputes the blocks of P on-the-fly during the backward pass. Recomputing locally on SRAM is paradoxically much faster than reading a massive, stored matrix from HBM.
Architectural Comparison: Standard vs. Flash Attention
Deep Dive: Safe Softmax and Online Tracking
To make the tiled block-by-block approach work, Flash Attention had to solve the problem of calculating the softmax function. Standard softmax requires knowing the entire row to compute the denominator (the sum of exponentials). Flash Attention elegantly circumvents this by calculating a “cumulative softmax” using running statistics.
As defined in the core algorithm:
- Initialization: The output matrix O, a running maximum vector m, and a running sum vector l are initialized with zeroes.
- Blocking: Divide Q, K, V into blocks based on SRAM’s memory limits, and iterate over them (where i is the row block, j is the column block).
- Local Computations (Inside SRAM):
- Compute block scores: Sij = QiKjT
- Find the local maximum: m̃ij = rowmax(Sij)
- Compute local exponentials: P̃ij = exp(Sij - m̃ij)
- Compute local sum of exponentials: l̃ij = rowsum(P̃ij)
- Global Updates:
- Find the new global maximum: minew = max(mi, m̃ij)
- Rescale previous outputs based on the shifted maximum.
- Update the overall Oi utilizing the updated sum l and max m.
The Impact and Evolution
The introduction of Flash Attention fundamentally altered AI scaling laws. It yielded practical speedups of 2x to 4x for attention operations and successfully slashed attention memory consumption from O(N2) to O(N) with respect to sequence length.
- FlashAttention-2: Introduced further hardware optimizations by partitioning work across different GPU thread blocks to reduce non-matmul operations (FLOPs) and massively increase physical hardware utilization.
- FlashAttention-3: Designed specifically for the Hopper architecture (like NVIDIA H100s), utilizing asynchronous memory transfers (TMA) and overlapping compute with memory I/O to push the utilization of FP8 tensor cores near theoretical limits.
Today, Flash Attention is a standard, ubiquitous optimization. It is integrated by default into major frameworks like PyTorch, Hugging Face Transformers, xFormers, and LLM inference engines like vLLM and TGI.
Ready to build?
Leverage AI technologies to build your product stack
Superteams can help you build, deploy and launch AI application stacks using open source technologies — from architecture through to production.
Talk to Superteams