一、背景标准Attention的瓶颈标准自注意力计算Attention(Q,K,V)softmax(QK⊤d)V \text{Attention}(Q,K,V)\text{softmax}\left(\frac{QK^\top}{\sqrt{d}}\right)VAttention(Q,K,V)softmax(d​QK⊤​)V输入Q,K,V∈RN×dQ,K,V\in\mathbb{R}^{N\times d}Q,K,V∈RN×dNNN序列长度ddd维度中间QK⊤∈RN×NQK^\top\in\mathbb{R}^{N\times N}QK⊤∈RN×N显存O(N2)O(N^2)O(N2)、HBM反复读写、带宽瓶颈远大于计算瓶颈。问题长序列如N32768N32768N32768时N2N^2N2矩阵达GB级训练显存爆炸、速度极慢。二、FlashAttention核心原理IO感知分块在线Softmax重计算FlashAttention2022NeurIPS是精确、无近似误差的IO感知实现核心把计算从HBM搬到片上SRAM减少IO。1GPU内存层次关键背景HBM高带宽显存容量大40–80GB、带宽低~2TB/s、延迟高标准Attention主要瓶颈。SRAM片上共享内存容量小20–40MB、带宽极高~19TB/s、延迟低FlashAttention主战场。2分块Tiling把大矩阵切成小块把Q,K,VQ,K,VQ,K,V沿序列维分块tileQ→Q1,Q2,…,QTqQ\to Q_1,Q_2,\dots,Q_{T_q}Q→Q1​,Q2​,…,QTq​​每块大小Bq×dB_q\times dBq​×dK→K1,K2,…,KTkK\to K_1,K_2,\dots,K_{T_k}K→K1​,K2​,…,KTk​​每块大小Bk×dB_k\times dBk​×dV→V1,V2,…,VTkV\to V_1,V_2,\dots,V_{T_k}V→V1​,V2​,…,VTk​​每块大小Bk×dB_k\times dBk​×d每次只加载一个QQQ块一个K/VK/VK/V块到SRAM计算局部QiKj⊤Q_iK_j^\topQi​Kj⊤​、局部Softmax、局部OijPijVjO_{ij}P_{ij}V_jOij​Pij​Vj​全程不生成完整N×NN\times NN×N矩阵。3在线SoftmaxOnline Softmax分块也能正确归一化Softmax需要全局信息行最大值、指数和分块后无法一次拿到全行核心技巧维护运行时统计量。对每一行维护mmm当前块的最大值用于数值稳定防止exp溢出ℓ\ellℓ当前块的指数和归一化分母ooo当前块的输出累加块间合并规则以处理第jjj个K/VK/VK/V块为例mnewmax⁡(mold,mj)ℓnewℓoldemold−mnewℓjemj−mnewonewooldemold−mnewojemj−mnew \begin{aligned} m_{\text{new}}\max(m_{\text{old}},m_j)\\ \ell_{\text{new}}\ell_{\text{old}}e^{m_{\text{old}}-m_{\text{new}}}\ell_j e^{m_j-m_{\text{new}}}\\ o_{\text{new}}o_{\text{old}}e^{m_{\text{old}}-m_{\text{new}}}o_j e^{m_j-m_{\text{new}}} \end{aligned}mnew​ℓnew​onew​​max(mold​,mj​)ℓold​emold​−mnew​ℓj​emj​−mnew​oold​emold​−mnew​oj​emj​−mnew​​全程无完整Softmax矩阵仅在SRAM中更新统计量数值稳定、结果与标准Attention完全一致。4算子融合Kernel Fusion一次CUDA Kernel完成所有步骤标准QK⊤QK^\topQK⊤写HBM→ softmax读HBM→PVPVPV写HBM多次HBM读写。FlashAttention分块加载→矩阵乘→在线Softmax→加权求和→结果写回单Kernel、零中间HBM读写。5反向传播重计算Recomputation换显存标准前向存Psoftmax(QK⊤)P\text{softmax}(QK^\top)Psoftmax(QK⊤)反向读PPP显存O(N2)O(N^2)O(N2)。FlashAttention不存PPP反向时重新在SRAM分块计算QK⊤QK^\topQK⊤与Softmax显存降至**O(N)O(N)O(N)代价是前向反向共2次计算**但远少于HBM读写节省的时间。三、复杂度对比必考点维度标准AttentionFlashAttention时间复杂度O(N2d)O(N^2d)O(N2d)O(N2d)O(N^2d)O(N2d)FLOPs不变IO减少显存复杂度O(N2)O(N^2)O(N2)存QK⊤/PQK^\top/PQK⊤/PO(N)O(N)O(N)仅存统计量/分块HBM访问量O(N2)O(N^2)O(N2)O(N2d/B)O(N^2d/B)O(N2d/B)BBB块大小降10–100×数值误差无无精确等价四、常见考点面试/笔试高频1基础概念QFlashAttention提出时间、作者、核心动机A2022年斯坦福Tri Dao等解决标准Attention的HBM带宽瓶颈与显存爆炸长序列训练。QFlashAttention是近似Attention吗与Linformer/Performer区别A精确、无近似Linformer/Performer是近似、降秩/核化有精度损失。2核心原理重中之重Q为什么FlashAttention能提速、降显存A分块TilingSRAM计算、不写HBM在线Softmax分块归一化、无全局矩阵算子融合单Kernel、少IO反向重计算显存O(N)O(N)O(N)。Q在线Softmax解决什么问题如何保证数值稳定A解决分块后Softmax全局信息缺失每行维护最大值mmm防exp溢出、指数和ℓ\ellℓ归一化、输出ooo块间动态合并统计量。Q分块大小Bq/BkB_q/B_kBq​/Bk​如何选择A适配SRAM容量常见64/128/256BBB太小→块多、IO增多BBB太大→SRAM放不下、溢出。3硬件与IOQGPU内存层次对FlashAttention的影响AHBM慢、SRAM快FlashAttention核心是把计算从HBM转移到SRAM减少HBM读写IO瓶颈。Q标准Attention的瓶颈是计算FLOPs还是IOAIOHBM带宽GPU算力远超HBM带宽标准Attention大部分时间在等数据而非计算。4反向传播与显存QFlashAttention反向为何用重计算优缺点A前向不存PPP反向重算以换显存优点显存O(N)O(N)O(N)可训更长序列缺点FLOPs增加~33%但IO节省远大于计算开销。5演进与应用QFlashAttention-1/2/3关键改进AFA12022基础分块在线SoftmaxFA22023更好分块、支持更宽ddd、H100优化FA32024FP8、稀疏、长序列优化。QPyTorch中如何用FlashAttentionAPyTorch2.0内置torch.nn.functional.scaled_dot_product_attention默认自动启用FlashAttention硬件支持时。五、总结一句话记牢FlashAttention IO感知 分块计算 在线Softmax 算子融合 反向重计算在零精度损失下将显存从O(N2)O(N^2)O(N2)降至O(N)O(N)O(N)、HBM访问降10–100×提速2–4×是当前LLM长序列训练的标配。