AI Architecture

Flash Attention

Flash Attention is an I/O-aware, exact attention algorithm that fundamentally solves the memory wall in Transformer models. By fusing operations and minimizing costly reads/writes between the GPU's High Bandwidth Memory (HBM) and SRAM, it significantly speeds up processing and enables massive context windows.

Flash Attention is a revolutionary, highly-optimized algorithm designed to compute exact attention in Transformer models significantly faster and with far less memory overhead than standard implementations. It achieves this by being I/O-aware—specifically optimizing the transfer of data between the GPU’s large, relatively slow High Bandwidth Memory (HBM) and its small, ultra-fast on-chip Static Random Access Memory (SRAM).

By fusing mathematical operations and minimizing memory reads and writes, Flash Attention addresses the notorious “memory wall” bottleneck in deep learning hardware, essentially single-handedly enabling the training of large language models (LLMs) with massive, real-world context windows (upwards of 100K to 1M+ tokens).

The Bottleneck: Standard Attention is Memory-Bound

In standard self-attention, the computation involves three main steps to calculate the output matrix O:

  1. Compute the attention scores S = QKT.
  2. Compute the attention probabilities P = softmax(S).
  3. Compute the final output O = PV.

The primary bottleneck in standard implementations is not the mathematical compute (FLOPs), but the latency of moving data. The standard approach requires writing the massive intermediate matrices S and P—which grow quadratically with the sequence length (O(N2))—to the GPU’s slow HBM, only to immediately read them back into the fast SRAM for the next operational step. For long token sequences, these continuous memory transfers become prohibitively slow and consume vast amounts of VRAM, severely restricting maximum context lengths.

The Flash Attention Solution: Tiling and Recomputation

Flash Attention redesigns the algorithm to compute the exact same mathematical result without ever materializing the large S and P matrices in HBM. It relies on two fundamental hardware-aware techniques:

  1. Tiling (Blocking): It divides the Query (Q), Key (K), and Value (V) matrices into smaller blocks (tiles) that comfortably fit within the fast SRAM. The attention is then computed block-by-block.
  2. Recomputation: Instead of storing the massive intermediate attention matrix P in HBM during the forward pass to be used later during the backward pass (for gradient calculation), Flash Attention swiftly recomputes the blocks of P on-the-fly during the backward pass. Recomputing locally on SRAM is paradoxically much faster than reading a massive, stored matrix from HBM.

Architectural Comparison: Standard vs. Flash Attention

Standard Attention Implementation Memory (HBM) Compute Load Q, K Write S Load S Write P Load P, V Write O S = QKT P = softmax(S) O = PV Flash Attention Memory (HBM) Compute Load Kj Vj Load Qi Oi li mi Write Oi li vi Kernel operations fused together, reducing reads & writes Sij = QiKjT m = rowmax of S P = exp(S - m) l = rowsum of P m = max(mij, m) calculate O from l & m Initialize O, l and m matrices with zeroes. m and l are used to calculate cumulative softmax. Divide Q, K, V into blocks (due to SRAM's memory limits) and iterate over them, for i is row & j is column.

Deep Dive: Safe Softmax and Online Tracking

To make the tiled block-by-block approach work, Flash Attention had to solve the problem of calculating the softmax function. Standard softmax requires knowing the entire row to compute the denominator (the sum of exponentials). Flash Attention elegantly circumvents this by calculating a “cumulative softmax” using running statistics.

As defined in the core algorithm:

  1. Initialization: The output matrix O, a running maximum vector m, and a running sum vector l are initialized with zeroes.
  2. Blocking: Divide Q, K, V into blocks based on SRAM’s memory limits, and iterate over them (where i is the row block, j is the column block).
  3. Local Computations (Inside SRAM):
    • Compute block scores: Sij = QiKjT
    • Find the local maximum: ij = rowmax(Sij)
    • Compute local exponentials: ij = exp(Sij - m̃ij)
    • Compute local sum of exponentials: ij = rowsum(P̃ij)
  4. Global Updates:
    • Find the new global maximum: minew = max(mi, m̃ij)
    • Rescale previous outputs based on the shifted maximum.
    • Update the overall Oi utilizing the updated sum l and max m.

The Impact and Evolution

The introduction of Flash Attention fundamentally altered AI scaling laws. It yielded practical speedups of 2x to 4x for attention operations and successfully slashed attention memory consumption from O(N2) to O(N) with respect to sequence length.

  • FlashAttention-2: Introduced further hardware optimizations by partitioning work across different GPU thread blocks to reduce non-matmul operations (FLOPs) and massively increase physical hardware utilization.
  • FlashAttention-3: Designed specifically for the Hopper architecture (like NVIDIA H100s), utilizing asynchronous memory transfers (TMA) and overlapping compute with memory I/O to push the utilization of FP8 tensor cores near theoretical limits.

Today, Flash Attention is a standard, ubiquitous optimization. It is integrated by default into major frameworks like PyTorch, Hugging Face Transformers, xFormers, and LLM inference engines like vLLM and TGI.

Ready to build?

Leverage AI technologies to build your product stack

Superteams can help you build, deploy and launch AI application stacks using open source technologies — from architecture through to production.

Talk to Superteams