Flash Attention 原理解析:IO-Aware 精确注意力计算一、问题的起点:Attention 为什么成为瓶颈?2023 年以来,LLaMA、GPT-4、Claude 等大语言模型席卷 AI 领域。这些模型的共同骨架是 Transformer,而 Transformer 的核心计算是Scaled Dot-Product Attention:Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk​​QKT​)V看起来不过三次矩阵乘法加一次 softmax,复杂度O(n2d)O(n^2d)O(n2d)——对于nnn个 token、ddd维隐藏状态。问题出在那个n2n^2n2上:当序列长度达到 8K、32K、128K 时,注意力矩阵S=QKTS = QK^TS=QKT的大小爆炸——128K 序列的注意力矩阵约 64GB(FP16),单张 H100 的 80GB 显存放不下不说,读写这个矩阵本身就消耗巨大。更关键的是,瓶颈不在计算(FLOPs),而在内存访问(I/O)。现代 GPU 计算能力远超内存带宽。以 NVIDIA H100 为例:计算能力:~1000 TFLOPS(FP16)HBM3 带宽:~3.35 TB/s一个n=16Kn=16\text{K}n=16K的注意力计算需要约 4B FLOPs,理论上 4 微秒就能算完;但在标准实现中,将QKTQK^TQKT矩阵写入 HBM 再读回需要约 30 毫秒——99.99% 的时间花在了数据搬运上。这就是 Flash Attention 要解决的问题:如何在不显式存储完整注意力矩阵的前提下,算出精确的注意力输出?二、核心洞察:GPU 内存层次与 IO-Awareness2.1 GPU 内存层次理解 Flash Attention 之前,需要先理解 GPU 的内存体系:内存层级大小(H100)带宽可编程性HBM(显存)80 GB~3.35 TB/s全局访问L2 Cache50 MB~12 TB/s自动缓存SRAM(Shared Memory)228 KB/SM~20 TB/s手动管理SRAM 速度快但极小(单 SM 仅 228KB),HBM 大但慢。标准 Attention 的做法是:在 SRAM 中计算S=QKTS = QK^TS=QKT写回 HBM(因为太大,SRAM 放不下)再从 HBM 读回做 softmax用 softmax 结果乘VVV,再写回 HBM这种「算一步、写回一步」的模式导致大量冗余 HBM 读写。Flash Attention 的洞见是:我们可以把计算切分成小块,让每个小块完全在 SRAM 内完成,无需写回中间结果到 HBM。2.2 Tiling:分块计算核心思想是将Q,K,VQ, K, VQ,K,V切分成 Block:Q=[Q1,Q2,…,QTr],K=[K1,K2,…,KTc],V=[V1,V2,…,VTc]Q = [Q_1, Q_2, \ldots, Q_{T_r}], \quad K = [K_1, K_2, \ldots, K_{T_c}], \quad V = [V_1, V_2, \ldots, V_{T_c}]Q=[Q1​,Q2​,…,QTr​​],K=[K1​,K2​,…,KTc​​],V=[V1​,V2​,…,V