Flash Attention Is All You Need: From Custom CUDA Kernel to PyTorch Native
By allocguy · February 2026 · 12 min read
In 2022, a Stanford PhD student published a paper that quietly changed how every major AI lab trains transformers. Three years later, Flash Attention is baked into PyTorch, powers every serious LLM training pipeline, and has been cited thousands of times. If you're training models and not using it, you're leaving 10–20x memory savings and 2–3x speedups on the table.
Here's the full story: what it does, the real benchmarks at each version, and why PyTorch's SDPA means you probably don't need to think about it anymore.
The Problem: O(N²) Attention Is a Memory Wall
Standard self-attention computes a full N × N attention matrix, where N is the sequence length. For a sequence of 4,096 tokens with 32 attention heads in FP16, that attention matrix alone consumes roughly 1 GB per layer. Stack 32 layers and you're looking at 32 GB just for attention intermediates. That's before model weights, optimizer states, and gradients even enter the picture.
The memory grows quadratically. Double your sequence length from 2K to 4K, and attention memory quadruples. This is why most models trained before 2023 were stuck at 512 or 1,024 token context windows. The math worked, but the memory didn't.
Flash Attention v1 (2022): IO-Aware Tiling
Tri Dao's key insight was that the bottleneck wasn't compute. It was memory bandwidth. Standard attention reads and writes the full N × N matrix to GPU HBM (high bandwidth memory). Flash Attention never materializes it. Instead, it processes attention in tiles that fit in SRAM (the GPU's fast on-chip memory), fusing the entire attention computation into a single kernel pass.
FLASH ATTENTION V1 BENCHMARKS
All benchmarks on 8x A100 80GB GPUs. From the original paper (Dao et al., NeurIPS 2022).
The quality gains were just as significant. GPT-2 trained with 4K context (only possible with Flash Attention) ran 30% faster and achieved 0.7 lower perplexity than the same model at 1K context with Megatron-LM. It was also the first transformer to beat chance on the Path-X benchmark (sequence length 16K) at 61.4% accuracy, and Path-256 (sequence length 64K) at 63.1%.
Flash Attention v2 (2023): Better Parallelism
Version 2 didn't change the core algorithm. It reworked the parallelism and work partitioning to better saturate GPU hardware. The original Flash Attention only hit 25–40% of A100 theoretical peak FLOPS. Version 2 pushed that to 73%.
| Metric | FA v1 | FA v2 |
|---|---|---|
| Forward pass (% of A100 peak) | 25–40% | 73% |
| Backward pass (% of A100 peak) | ~30% | 63% |
| Attention kernel TFLOPS (A100) | ~100 | 230 |
| End-to-end training TFLOPS | ~170 | 225 (72% MFU) |
A100 80GB SXM4 theoretical peak: 312 TFLOPs (FP16/BF16). Source: Dao 2023, arXiv:2307.08691
Lambda Labs ran the definitive third-party benchmark comparing H100 and A100 with Flash Attention v2 on GPT3-2.7B (OpenWebText dataset). The H100 hit 22,282 tokens/sec, roughly 2.1x the A100's throughput. Both GPUs scaled to 8x with near-linear efficiency: 96% on A100, 98% on H100.
The memory savings also had a practical side effect: batch size could jump from 1 to 4 on the same hardware, because all that freed attention memory could be used for larger batches instead.
Flash Attention v3 (2024): Hopper-Native
Version 3 was purpose-built for NVIDIA's Hopper architecture (H100 and beyond). It exploits three features unique to Hopper: asynchronous Tensor Cores via warp specialization, interleaved block-wise matmul and softmax pipelining, and native FP8 hardware support with block quantization.
The result: Flash Attention v2 only hit 35% utilization on H100 (it was designed for Ampere). Version 3 reaches 75%.
FLASH ATTENTION V3 ON H100
NeurIPS 2024 Spotlight. Source: Dao et al., arXiv:2407.08608
Flash Attention v4 (2025): Blackwell-Native
Announced by Tri Dao at Hot Chips 2025, version 4 is purpose-built for NVIDIA's Blackwell architecture. Where v3 adapted to Hopper, v4 squeezes out the last bits of performance on Blackwell's fifth-generation Tensor Cores and new memory hierarchy.
Two algorithmic changes make the difference. First, a new online softmax algorithm that skips rescaling when the row maximum hasn't changed enough to affect numerical stability. In practice, this eliminates roughly 90% of output rescaling operations. Second, software emulation of the exponential function (replacing the hardware Special Function Unit) to avoid SFU throughput bottlenecks on smaller attention heads.
FLASH ATTENTION V4 ON BLACKWELL
Hot Chips 2025. Source: Tri Dao / Together AI, SemiAnalysis
tcgen05.mma instructions for Blackwell's 5th-gen Tensor Cores.The forward pass source code is already available. The backward pass is expected to follow shortly.
PyTorch SDPA: You Probably Already Have It
Starting with PyTorch 2.0 (March 2023), Flash Attention shipped as a built-in backend via torch.nn.functional.scaled_dot_product_attention. No separate install. No CUDA compilation. Just call the function and PyTorch automatically picks the fastest backend for your hardware and input shape.
| Backend | Requirements | Best For |
|---|---|---|
| FlashAttention | FP16/BF16, NVIDIA GPU | Training (Ampere/Hopper) |
| Memory-Efficient | Any dtype, any GPU | Long sequences, memory-constrained |
| CuDNN | PyTorch 2.5+, Hopper GPU | H100 (up to 75% faster than FA v2) |
| Math | Always available | Fallback / debugging |
Backend selection is automatic. Override with torch.nn.attention.sdpa_kernel() context manager.
The CuDNN backend (added in PyTorch 2.5, October 2024) is particularly notable: it implements Flash Attention v3 optimizations for Hopper GPUs, delivering up to 75% speedup over the FlashAttention-2 kernel on H100s.
Adoption: Basically Everyone
Flash Attention is no longer optional infrastructure. It's the default.
TRAINING
- HuggingFace Transformers (Llama, Mistral, Falcon, Gemma, Phi, Qwen2, 40+ architectures)
- Megatron-LM (NVIDIA)
- DeepSpeed (Microsoft)
- PyTorch Lightning / Fabric
INFERENCE
- vLLM (FlashAttention-3 in V1 engine)
- SGLang (via FlashInfer)
- TensorRT-LLM (NVIDIA)
- Stable Diffusion
In HuggingFace Transformers, enabling it is one argument:
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8B", attn_implementation="flash_attention_2")What This Means for Your GPU Bill
Flash Attention saves memory, which saves money in two ways. First, you can fit models on smaller (cheaper) GPUs. A training run that OOMs on a 40GB A100 at sequence length 4K might fit comfortably with Flash Attention enabled, avoiding the jump to an 80GB card. Second, the freed memory can be reinvested into larger batch sizes, which improves throughput without needing more hardware.
Lambda Labs demonstrated this concretely: enabling Flash Attention v2 on GPT3-2.7B allowed batch size to increase from 1 to 4 on the same H100, delivering a 3x+ throughput improvement without touching the hardware configuration.
For long-context training (8K+ tokens), the impact is even more dramatic. Standard attention OOMs at sequence lengths where Flash Attention runs comfortably. That's the difference between "we need to shard across 8 GPUs" and "this fits on one."
The Full Timeline
| Version | Year | Target GPU | Peak Utilization | Key Innovation |
|---|---|---|---|---|
| FA v1 | 2022 | A100 | 25–40% | IO-aware tiling, O(N) memory |
| FA v2 | 2023 | A100 | 50–73% | Better parallelism + work partitioning |
| SDPA | 2023 | Any | Auto-selects | PyTorch 2.0 native, zero config |
| FA v3 | 2024 | H100 | 75% | Async Tensor Cores, FP8, warp specialization |
| CuDNN SDPA | 2024 | Hopper+ | 75%+ | PyTorch 2.5 backend, up to 75% over FA v2 |
| FA v4 | 2025 | Blackwell | 22% over cuDNN | Smart softmax rescaling, SFU bypass |
Practical Takeaways
1. If you're on PyTorch 2.0+, you already have it
SDPA auto-selects Flash Attention when your inputs are FP16/BF16 on a supported GPU. No install, no flag, no config. Just make sure your model uses F.scaled_dot_product_attention instead of manual QKV matmul + softmax.
2. Upgrade to PyTorch 2.5+ for H100s
The CuDNN backend brings Flash Attention v3 optimizations to SDPA. If you're on Hopper hardware, this is a free speedup just by upgrading PyTorch.
3. Check if your VRAM bottleneck is attention
Flash Attention saves activation memory (the attention intermediates), not model weights or optimizer states. If you're OOMing because your model weights are too large, Flash Attention won't help. Run alloc ghost to see where your VRAM is actually going.
4. Longer contexts are now viable
If you've been limited to short sequences because of memory, Flash Attention changes the math. Training at 8K, 16K, or even 64K tokens is feasible on hardware where standard attention would OOM at 4K.
See Where Your VRAM Actually Goes
Flash Attention handles the attention kernel. But attention is only one piece of the VRAM puzzle. Weights, optimizer states, gradients, and activations all compete for the same memory. If you want to know whether your model fits on a given GPU before you spend money on it:
pip install alloc && alloc ghost your_model.py60 seconds, no GPU required, full VRAM breakdown.
Sources
- Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness," NeurIPS 2022 (arXiv:2205.14135)
- Dao, "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning," 2023 (arXiv:2307.08691)
- Dao et al., "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision," NeurIPS 2024 Spotlight (arXiv:2407.08608)
- Tri Dao / Together AI, "Flash Attention v4," Hot Chips 2025. Analysis: Modal, "We Reverse-Engineered Flash Attention 4"
- Lambda Labs, "How FlashAttention-2 Accelerates LLMs on NVIDIA H100 and A100 GPUs," 2023
- PyTorch 2.0 Release Blog, pytorch.org, March 2023
- PyTorch 2.5 Release Notes (CuDNN SDPA backend), October 2024