跳到主要內容

優化 Flash Attention

通過 Flash Attention 優化 Transformer 注意力機制,實現 2-4 倍的速度提升和 10-20 倍的內存減少。當訓練或運行具有長序列(>512 個 token)的 Transformer、遇到注意力機制導致的 GPU 內存問題,或需要更快的推理速度時使用。支持 PyTorch 原生 SDPA、flash-attn 庫、H100 FP8 以及滑動窗口注意力。

技能元數據

來源可選 — 使用 hermes skills install official/mlops/flash-attention 安裝
路徑optional-skills/mlops/flash-attention
版本1.0.0
作者Orchestra Research
許可證MIT
依賴項flash-attn, torch, transformers
標籤Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers

參考:完整 SKILL.md

信息

以下是 Hermes 在觸發此技能時加載的完整技能定義。這是技能激活時代理看到的指令。

Flash Attention - 快速且內存高效的注意力機制

快速開始

Flash Attention 通過 IO 感知的分塊和重計算,為 Transformer 注意力機制提供 2-4 倍的速度提升和 10-20 倍的內存減少。

PyTorch 原生(最簡單,PyTorch 2.2+)

import torch
import torch.nn.functional as F

q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)

# Automatically uses Flash Attention if available
out = F.scaled_dot_product_attention(q, k, v)

flash-attn 庫(更多功能)

pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func

# q, k, v: [batch, seqlen, nheads, headdim]
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)

常見工作流

工作流 1:在現有 PyTorch 模型中啟用

複製此檢查清單:

Flash Attention Integration:
- [ ] Step 1: Check PyTorch version (≥2.2)
- [ ] Step 2: Enable Flash Attention backend
- [ ] Step 3: Verify speedup with profiling
- [ ] Step 4: Test accuracy matches baseline

步驟 1:檢查 PyTorch 版本

python -c "import torch; print(torch.__version__)"
# Should be ≥2.2.0

如果 <2.2,請升級:

pip install --upgrade torch

步驟 2:啟用 Flash Attention 後端

替換標準注意力:

# Before (standard attention)
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
out = attn_weights @ v

# After (Flash Attention)
import torch.nn.functional as F
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)

強制使用 Flash Attention 後端:

with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v)

步驟 3:通過性能分析驗證加速效果

import torch.utils.benchmark as benchmark

def test_attention(use_flash):
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]

if use_flash:
with torch.backends.cuda.sdp_kernel(enable_flash=True):
return F.scaled_dot_product_attention(q, k, v)
else:
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
return attn @ v

# Benchmark
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())

print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")

預期:對於超過 512 個 token 的序列,速度提升 2-4 倍。

步驟 4:測試準確率是否與基線一致

# Compare outputs
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]

# Flash Attention
out_flash = F.scaled_dot_product_attention(q, k, v)

# Standard attention
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
out_standard = attn_weights @ v

# Check difference
diff = (out_flash - out_standard).abs().max()
print(f"Max difference: {diff:.6f}")
# Should be <1e-3 for float16

工作流 2:使用 flash-attn 庫獲取高級功能

適用於多查詢注意力、滑動窗口或 H100 FP8。

複製此檢查清單:

flash-attn Library Setup:
- [ ] Step 1: Install flash-attn library
- [ ] Step 2: Modify attention code
- [ ] Step 3: Enable advanced features
- [ ] Step 4: Benchmark performance

步驟 1:安裝 flash-attn 庫

# NVIDIA GPUs (CUDA 12.0+)
pip install flash-attn --no-build-isolation

# Verify installation
python -c "from flash_attn import flash_attn_func; print('Success')"

步驟 2:修改注意力代碼

from flash_attn import flash_attn_func

# Input: [batch_size, seq_len, num_heads, head_dim]
# Transpose from [batch, heads, seq, dim] if needed
q = q.transpose(1, 2) # [batch, seq, heads, dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)

out = flash_attn_func(
q, k, v,
dropout_p=0.1,
causal=True, # For autoregressive models
window_size=(-1, -1), # No sliding window
softmax_scale=None # Auto-scale
)

out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]

步驟 3:啟用高級功能

多查詢注意力(跨頭共享 K/V):

from flash_attn import flash_attn_func

# q: [batch, seq, num_q_heads, dim]
# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
out = flash_attn_func(q, k, v) # Automatically handles MQA

滑動窗口注意力(局部注意力):

# Only attend to window of 256 tokens before/after
out = flash_attn_func(
q, k, v,
window_size=(256, 256), # (left, right) window
causal=True
)

步驟 4:基準測試性能

import torch
from flash_attn import flash_attn_func
import time

q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]

# Warmup
for _ in range(10):
_ = flash_attn_func(q, k, v)

# Benchmark
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
out = flash_attn_func(q, k, v)
torch.cuda.synchronize()
end = time.time()

print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")

工作流 3:H100 FP8 優化(FlashAttention-3)

適用於在 H100 GPU 上獲得最大性能。

FP8 Setup:
- [ ] Step 1: Verify H100 GPU available
- [ ] Step 2: Install flash-attn with FP8 support
- [ ] Step 3: Convert inputs to FP8
- [ ] Step 4: Run with FP8 attention

步驟 1:驗證 H100 GPU

nvidia-smi --query-gpu=name --format=csv
# Should show "H100" or "H800"

步驟 2:安裝支持 FP8 的 flash-attn

pip install flash-attn --no-build-isolation
# FP8 support included for H100

步驟 3:將輸入轉換為 FP8

import torch

q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)

# Convert to float8_e4m3 (FP8)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)

步驟 4:運行 FP8 注意力

from flash_attn import flash_attn_func

# FlashAttention-3 automatically uses FP8 kernels on H100
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16

何時使用與替代方案對比

在以下情況使用 Flash Attention:

  • 訓練序列長度 >512 個 token 的 Transformer
  • 運行具有長上下文(>2K 個 token)的推理
  • GPU 內存受限(標準注意力導致 OOM)
  • 需要在不損失準確率的情況下獲得 2-4 倍的速度提升
  • 使用 PyTorch 2.2+ 或可以安裝 flash-attn

在以下情況使用替代方案:

  • 標準注意力:序列長度 <256 個 token(開銷不值得)
  • xFormers:需要更多注意力變體(不僅僅是速度)
  • 內存高效注意力:CPU 推理(Flash Attention 需要 GPU)

常見問題

問題:ImportError: cannot import flash_attn

使用 no-build-isolation 標誌安裝:

pip install flash-attn --no-build-isolation

或者先安裝 CUDA toolkit:

conda install cuda -c nvidia
pip install flash-attn --no-build-isolation

問題:比預期慢(無加速效果)

Flash Attention 的優勢隨序列長度增加而增加:

  • <512 個 token:最小加速(10-20%)
  • 512-2K 個 token:2-3 倍加速
  • 2K 個 token:3-4 倍加速

檢查序列長度是否足夠。

問題:RuntimeError: CUDA error

驗證 GPU 是否支持 Flash Attention:

import torch
print(torch.cuda.get_device_capability())
# Should be ≥(7, 5) for Turing+

Flash Attention 要求:

  • Ampere (A100, A10):✅ 完全支持
  • Turing (T4):✅ 支持
  • Volta (V100):❌ 不支持

問題:準確率下降

檢查 dtype 是否為 float16 或 bfloat16(而非 float32):

q = q.to(torch.float16)  # Or torch.bfloat16

Flash Attention 使用 float16/bfloat16 以提高速度。不支持 Float32。

高級主題

與 HuggingFace Transformers 集成:參見 references/transformers-integration.md 以瞭解如何在 BERT、GPT、Llama 模型中啟用 Flash Attention。

性能基準測試:參見 references/benchmarks.md 以獲取跨 GPU 和序列長度的詳細速度和內存比較。

算法細節:請參閱 references/algorithm.md 瞭解分塊策略、重計算以及 I/O 複雜度分析。

高級特性:請參閱 references/advanced-features.md 瞭解旋轉位置編碼(rotary embeddings)、ALiBi、分頁 KV 緩存和自定義注意力掩碼。

硬件要求

  • GPU:NVIDIA Ampere+(A100、A10、A30)或 AMD MI200+
  • VRAM:與標準注意力機制相同(Flash Attention 不會增加內存佔用)
  • CUDA:12.0+(最低 11.8)
  • PyTorch:2.2+ 以獲得原生支持

不支持:V100(Volta)、CPU 推理

資源