Flash Attention Is All You Need: From Custom CUDA Kernel to PyTorch Native

By allocguy · February 2026 · 12 min read

In 2022, a Stanford PhD student published a paper that quietly changed how every major AI lab trains transformers. Three years later, Flash Attention is baked into PyTorch, powers every serious LLM training pipeline, and has been cited thousands of times. If you're training models and not using it, you're leaving 10–20x memory savings and 2–3x speedups on the table.

Here's the full story: what it does, the real benchmarks at each version, and why PyTorch's SDPA means you probably don't need to think about it anymore.

The Problem: O(N²) Attention Is a Memory Wall

Standard self-attention computes a full N × N attention matrix, where N is the sequence length. For a sequence of 4,096 tokens with 32 attention heads in FP16, that attention matrix alone consumes roughly 1 GB per layer. Stack 32 layers and you're looking at 32 GB just for attention intermediates. That's before model weights, optimizer states, and gradients even enter the picture.

The memory grows quadratically. Double your sequence length from 2K to 4K, and attention memory quadruples. This is why most models trained before 2023 were stuck at 512 or 1,024 token context windows. The math worked, but the memory didn't.

Flash Attention v1 (2022): IO-Aware Tiling

Tri Dao's key insight was that the bottleneck wasn't compute. It was memory bandwidth. Standard attention reads and writes the full N × N matrix to GPU HBM (high bandwidth memory). Flash Attention never materializes it. Instead, it processes attention in tiles that fit in SRAM (the GPU's fast on-chip memory), fusing the entire attention computation into a single kernel pass.

FLASH ATTENTION V1 BENCHMARKS

All benchmarks on 8x A100 80GB GPUs. From the original paper (Dao et al., NeurIPS 2022).

Memory complexity: O(N) instead of O(N²). 10x savings at seq_len=2K, 20x savings at seq_len=4K.
Attention kernel speedup: up to 7.6x faster (forward pass, GPT-2).
End-to-end GPT-2 training: 3x faster than HuggingFace, 1.7x faster than Megatron-LM (seq_len=1K).
BERT-large: 15% faster than the MLPerf 1.1 training speed record (seq_len=512).
Long Range Arena: 2.4x faster than baselines (seq_len 1K–4K).

The quality gains were just as significant. GPT-2 trained with 4K context (only possible with Flash Attention) ran 30% faster and achieved 0.7 lower perplexity than the same model at 1K context with Megatron-LM. It was also the first transformer to beat chance on the Path-X benchmark (sequence length 16K) at 61.4% accuracy, and Path-256 (sequence length 64K) at 63.1%.

Flash Attention v2 (2023): Better Parallelism

Version 2 didn't change the core algorithm. It reworked the parallelism and work partitioning to better saturate GPU hardware. The original Flash Attention only hit 25–40% of A100 theoretical peak FLOPS. Version 2 pushed that to 73%.

MetricFA v1FA v2
Forward pass (% of A100 peak)25–40%73%
Backward pass (% of A100 peak)~30%63%
Attention kernel TFLOPS (A100)~100230
End-to-end training TFLOPS~170225 (72% MFU)

A100 80GB SXM4 theoretical peak: 312 TFLOPs (FP16/BF16). Source: Dao 2023, arXiv:2307.08691

Lambda Labs ran the definitive third-party benchmark comparing H100 and A100 with Flash Attention v2 on GPT3-2.7B (OpenWebText dataset). The H100 hit 22,282 tokens/sec, roughly 2.1x the A100's throughput. Both GPUs scaled to 8x with near-linear efficiency: 96% on A100, 98% on H100.

The memory savings also had a practical side effect: batch size could jump from 1 to 4 on the same hardware, because all that freed attention memory could be used for larger batches instead.

Flash Attention v3 (2024): Hopper-Native

Version 3 was purpose-built for NVIDIA's Hopper architecture (H100 and beyond). It exploits three features unique to Hopper: asynchronous Tensor Cores via warp specialization, interleaved block-wise matmul and softmax pipelining, and native FP8 hardware support with block quantization.

The result: Flash Attention v2 only hit 35% utilization on H100 (it was designed for Ampere). Version 3 reaches 75%.

FLASH ATTENTION V3 ON H100

NeurIPS 2024 Spotlight. Source: Dao et al., arXiv:2407.08608

FP16 forward pass: 740 TFLOPs/s (75% of H100 peak at 989 TFLOPs).
FP8: close to 1.2 PFLOPs/s (H100 FP8 peak: 1,978 TFLOPs).
1.5–2x speedup over FA v2 in the forward pass. 1.5–1.75x in the backward pass.
FP8 with block quantization is 2.6x more accurate than standard FP8 per-tensor quantization.

Flash Attention v4 (2025): Blackwell-Native

Announced by Tri Dao at Hot Chips 2025, version 4 is purpose-built for NVIDIA's Blackwell architecture. Where v3 adapted to Hopper, v4 squeezes out the last bits of performance on Blackwell's fifth-generation Tensor Cores and new memory hierarchy.

Two algorithmic changes make the difference. First, a new online softmax algorithm that skips rescaling when the row maximum hasn't changed enough to affect numerical stability. In practice, this eliminates roughly 90% of output rescaling operations. Second, software emulation of the exponential function (replacing the hardware Special Function Unit) to avoid SFU throughput bottlenecks on smaller attention heads.

FLASH ATTENTION V4 ON BLACKWELL

Hot Chips 2025. Source: Tri Dao / Together AI, SemiAnalysis

22% faster than NVIDIA's cuDNN attention kernels (the previous state-of-the-art on Blackwell).
Warp-specialized pipeline: 1 load warp, 1 MMA warp, 8 softmax warps, 4 correction warps, 1–2 epilogue warps.
Smart scaling reduces correction operations by 10x vs. standard online softmax.
Uses tcgen05.mma instructions for Blackwell's 5th-gen Tensor Cores.

The forward pass source code is already available. The backward pass is expected to follow shortly.

PyTorch SDPA: You Probably Already Have It

Starting with PyTorch 2.0 (March 2023), Flash Attention shipped as a built-in backend via torch.nn.functional.scaled_dot_product_attention. No separate install. No CUDA compilation. Just call the function and PyTorch automatically picks the fastest backend for your hardware and input shape.

BackendRequirementsBest For
FlashAttentionFP16/BF16, NVIDIA GPUTraining (Ampere/Hopper)
Memory-EfficientAny dtype, any GPULong sequences, memory-constrained
CuDNNPyTorch 2.5+, Hopper GPUH100 (up to 75% faster than FA v2)
MathAlways availableFallback / debugging

Backend selection is automatic. Override with torch.nn.attention.sdpa_kernel() context manager.

The CuDNN backend (added in PyTorch 2.5, October 2024) is particularly notable: it implements Flash Attention v3 optimizations for Hopper GPUs, delivering up to 75% speedup over the FlashAttention-2 kernel on H100s.

Adoption: Basically Everyone

Flash Attention is no longer optional infrastructure. It's the default.

TRAINING

  • HuggingFace Transformers (Llama, Mistral, Falcon, Gemma, Phi, Qwen2, 40+ architectures)
  • Megatron-LM (NVIDIA)
  • DeepSpeed (Microsoft)
  • PyTorch Lightning / Fabric

INFERENCE

  • vLLM (FlashAttention-3 in V1 engine)
  • SGLang (via FlashInfer)
  • TensorRT-LLM (NVIDIA)
  • Stable Diffusion

In HuggingFace Transformers, enabling it is one argument:

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8B", attn_implementation="flash_attention_2")

What This Means for Your GPU Bill

Flash Attention saves memory, which saves money in two ways. First, you can fit models on smaller (cheaper) GPUs. A training run that OOMs on a 40GB A100 at sequence length 4K might fit comfortably with Flash Attention enabled, avoiding the jump to an 80GB card. Second, the freed memory can be reinvested into larger batch sizes, which improves throughput without needing more hardware.

Lambda Labs demonstrated this concretely: enabling Flash Attention v2 on GPT3-2.7B allowed batch size to increase from 1 to 4 on the same H100, delivering a 3x+ throughput improvement without touching the hardware configuration.

For long-context training (8K+ tokens), the impact is even more dramatic. Standard attention OOMs at sequence lengths where Flash Attention runs comfortably. That's the difference between "we need to shard across 8 GPUs" and "this fits on one."

The Full Timeline

VersionYearTarget GPUPeak UtilizationKey Innovation
FA v12022A10025–40%IO-aware tiling, O(N) memory
FA v22023A10050–73%Better parallelism + work partitioning
SDPA2023AnyAuto-selectsPyTorch 2.0 native, zero config
FA v32024H10075%Async Tensor Cores, FP8, warp specialization
CuDNN SDPA2024Hopper+75%+PyTorch 2.5 backend, up to 75% over FA v2
FA v42025Blackwell22% over cuDNNSmart softmax rescaling, SFU bypass

Practical Takeaways

1. If you're on PyTorch 2.0+, you already have it

SDPA auto-selects Flash Attention when your inputs are FP16/BF16 on a supported GPU. No install, no flag, no config. Just make sure your model uses F.scaled_dot_product_attention instead of manual QKV matmul + softmax.

2. Upgrade to PyTorch 2.5+ for H100s

The CuDNN backend brings Flash Attention v3 optimizations to SDPA. If you're on Hopper hardware, this is a free speedup just by upgrading PyTorch.

3. Check if your VRAM bottleneck is attention

Flash Attention saves activation memory (the attention intermediates), not model weights or optimizer states. If you're OOMing because your model weights are too large, Flash Attention won't help. Run alloc ghost to see where your VRAM is actually going.

4. Longer contexts are now viable

If you've been limited to short sequences because of memory, Flash Attention changes the math. Training at 8K, 16K, or even 64K tokens is feasible on hardware where standard attention would OOM at 4K.

See Where Your VRAM Actually Goes

Flash Attention handles the attention kernel. But attention is only one piece of the VRAM puzzle. Weights, optimizer states, gradients, and activations all compete for the same memory. If you want to know whether your model fits on a given GPU before you spend money on it:

pip install alloc && alloc ghost your_model.py

60 seconds, no GPU required, full VRAM breakdown.

Sources

  • Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness," NeurIPS 2022 (arXiv:2205.14135)
  • Dao, "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning," 2023 (arXiv:2307.08691)
  • Dao et al., "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision," NeurIPS 2024 Spotlight (arXiv:2407.08608)
  • Tri Dao / Together AI, "Flash Attention v4," Hot Chips 2025. Analysis: Modal, "We Reverse-Engineered Flash Attention 4"
  • Lambda Labs, "How FlashAttention-2 Accelerates LLMs on NVIDIA H100 and A100 GPUs," 2023
  • PyTorch 2.0 Release Blog, pytorch.org, March 2023
  • PyTorch 2.5 Release Notes (CuDNN SDPA backend), October 2024

Related Reading