FlashAttention-4: Attention at Matmul Speed
Criado: 2026-03-27 | Tamanho: 9595 bytes
TL;DR
FlashAttention-4 (FA4) is an algorithm-kernel co-design for NVIDIA Blackwell GPUs that reaches ~1600 TFLOPs/s (71% hardware utilization) for BF16 (bfloat16 - a 16-bit floating-point format from Google Brain, widely used for LLM training because it keeps the dynamic range of 32-bit floats at half the memory cost) attention. The key insight: on modern accelerators, matmul is no longer the bottleneck. The exponential unit (forward pass) and shared memory bandwidth (backward pass) are the new walls. FA4 redesigns attention computation to maximally overlap matmul with these non-matmul bottlenecks, using software-emulated exponentials, Blackwell's new Tensor Memory, and 2-CTA MMA mode. Its techniques have already been adopted into NVIDIA's cuDNN library.
The Bottleneck Shifted and Nobody Adjusted
Every generation of NVIDIA hardware scales tensor core throughput aggressively. The B200 pushes BF16 matmul to 2.25 PFLOPs, a 2.25x jump from H100. But shared memory bandwidth and the special function unit (SFU) count stayed flat.
This creates asymmetric hardware scaling: the fast part got faster, but the slow parts didn't move. For attention specifically:
| Resource | H100 | B200 | Scaling |
|---|---|---|---|
| BF16 Tensor Core TFLOPs | 1000 | 2250 | 2.25x |
| Shared Memory BW | ~fixed | ~fixed | 1x |
| SFU (exp unit) Count | ~fixed | ~fixed | 1x |
Previous FlashAttention versions optimized for a world where matmul dominated the cost. That world no longer exists on Blackwell. FA4 is built for the new reality.
Where the Cycles Actually Go
A roofline analysis on B200 with 128x128 tiles reveals the real bottlenecks:
Forward pass: Tensor cores and the exponential unit are co-bottlenecked at ~1024 cycles each. Shared memory is not the issue here; the exp computation for softmax is.
Backward pass: Tensor cores need ~2560 cycles for 5 MMAs, but shared memory traffic needs ~3328 cycles. SMEM is the clear limiter, not compute.
This is why you can't just "optimize the matmul" anymore. FA4 attacks both bottlenecks with distinct strategies.
Forward Pass: Beating the Exponential Wall
FA4 uses three techniques to overlap matmul with softmax computation:
Ping-pong Q scheduling. Two Q tiles alternate per CTA. While one tile's MMA runs on tensor cores, the other tile's softmax runs on the exponential unit. Two dedicated softmax warpgroups are synchronized so they never compete for the MUFU (multi-function unit) simultaneously.
Software-emulated exp2. Since the hardware MUFU.EX2 instruction is a bottleneck, FA4 supplements it with a software exponential running on otherwise-idle FMA units. The approach uses Cody-Waite range reduction to split , then approximates with a degree-3 Horner polynomial. The coefficients are optimized by the Sollya package. This effectively doubles exponential throughput.
Conditional rescaling. Online softmax normally rescales every iteration. FA4 only rescales when the running max changes significantly, keeping the correction off the critical path without sacrificing numerical correctness.
Backward Pass: Taming Shared Memory Traffic
The backward pass chains 5 MMA operations and is dominated by shared memory bandwidth. FA4 attacks this with Blackwell-specific hardware features:
Tensor Memory (TMEM). Blackwell introduces 256 KB per SM of on-chip scratchpad wired directly into tensor cores. FA4 stores accumulators in TMEM instead of registers, enabling multiple MMAs in flight while CUDA cores handle element-wise softmax work. Intermediates like S and P-transposed are stored directly in TMEM in the layout consumed by the next MMA, eliminating extra shared memory round-trips.
2-CTA MMA mode. Blackwell can execute one MMA across a CTA pair spanning both CTAs' TMEM, scaling tile dimensions to 256x256x16. This roughly halves operand-B shared memory traffic. A reduction axis conflict for dQ is resolved via distributed shared memory (DSMEM) exchange between CTAs.
The combined effect: shared memory pressure drops enough that the backward pass approaches tensor-core-limited performance.
Deterministic Mode That Actually Performs
Nondeterministic reductions are a common source of flaky training runs. FA4 provides a deterministic backward mode using semaphore locks and memory fences to enforce a fixed accumulation order for dQ global reductions.
The clever part: CTA swizzling and shortest-processing-time-first ordering minimize lock contention, achieving up to 75% of nondeterministic throughput per the paper (the authors' blog posts report 85-90% in some configurations). That's a small price for reproducibility.
Implementation: Python All the Way Down
FA4 is implemented entirely in CuTe-DSL, CUTLASS's Python kernel DSL that lowers to PTX. This cuts compile times by ~20-30x compared to C++ templates while preserving access to low-level PTX instructions. The entire attention kernel, forward and backward, is Python code that generates GPU assembly.
This matters for the ecosystem. Faster iteration on kernel development means faster adoption of new hardware features across the community.
Benchmarks
On B200 with BF16:
| Comparison | Forward | Backward |
|---|---|---|
| FA4 vs cuDNN 9.13 | 1.1-1.3x faster | Consistently faster at long sequences |
| FA4 vs Triton | 2.1-2.7x faster | Significantly faster |
FA4 is Blackwell-specific - its techniques (TMEM, 2-CTA MMA, async tensor cores) depend on hardware features that don't exist on Hopper. If you're on H100, FlashAttention-3 remains the right choice.
The FA4 team collaborated with NVIDIA's cuDNN team to incorporate these techniques into cuDNN 9.13+. The latest cuDNN (9.19) now offers similar performance, meaning these optimizations are available to anyone using cuDNN, not just FA4 users directly. To use FA4 directly, pip install flash-attn with a Blackwell GPU and CUDA 12.8+.
Why This Matters Beyond GPUs
The asymmetric scaling problem isn't unique to Blackwell. Every accelerator generation will face the same pattern: some units scale, others don't. The FA4 approach (profiling the actual cycle bottlenecks and co-designing algorithms around them) is a template for the next decade of kernel engineering.
For practitioners, the takeaway is concrete: if you're training or serving models on Blackwell hardware, FlashAttention-4 (or cuDNN 9.19+) is not optional. The gap between optimized and unoptimized attention is now 2-3x, and that compounds across every layer of every forward and backward pass.
One open question: FA4 targets BF16, but FP8 attention is where the next efficiency jump lives. The same asymmetric scaling analysis applies - FP8 tensor cores are even faster relative to SFU and SMEM, which means the non-matmul bottlenecks will be even more dominant. Expect FA4's co-design approach to be the starting point for FP8 attention kernels.
The benchmarking gap between synthetic and real-world performance we see in LLM code generation has a hardware parallel: peak FLOPS on a spec sheet means nothing if your kernel can't actually use them. FA4 closes that gap for attention.
References
- FlashAttention-4 Paper (PDF) - Original paper
- FlashAttention-4 Code - CuTe-DSL implementation
- Together AI Blog Post - Together AI announcement
- Tri Dao Blog Post - Author's technical walkthrough
- Colfax Research Blog Post - Colfax Research analysis
- Your LLM Scores 88% on Code Benchmarks. In Production, It Hits 30%. - Daita blog