Linear Attention 学习笔记0. Linear Attention 的目的与背景0.1 标准 Attention 的瓶颈在 Transformer 的标准 Self-Attention 机制中,注意力分数的计算方式如下:Attention(Q,K,V)=softmax(QKTd)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)VAttention(Q,K,V)=softmax(d​QKT​)V其中:Q,K,V∈Rn×dQ, K, V \in \mathbb{R}^{n \times d}Q,K,V∈Rn×d:分别代表查询(Query)、键(Key)、值(Value)矩阵。nnn:序列长度(Sequence Length)。ddd:隐藏层维度(Head Dimension)。复杂度分析:计算QKTQK^TQKT:矩阵形状为(n×d)×(d×n)(n \times d) \times (d \times n)(n×d)×(d×n),结果是一个n×nn \times nn×n的注意力矩阵。计算复杂度为O(n2d)O(n^2d)O(n2d)(左行右列乘一下就知道了:d次乘法 * n行 * n行)。再乘以VVV:(n×n)×(n×d)(n \times n) \times (n \times d)(n×n)×(n×d),复杂度仍受n2n^2n2主导,计算复杂度仍为O(n2d)O(n^2d)O(n2d)。其中,ddd为固定值,那么当序列长度nnn变大时(例如长文本、高分辨率图像),n2n^2n2的内存和计算开销会急剧增加,这成为了限制 Transformer 处理长序列的主要瓶颈。0.2 Linear Attention 的核心思想Linear Attention 的目标是将复杂度从O(n2d)O(n^2d)O(n2d)降低到O(nd2)O(nd^2)O(nd2)。由于通常d≪nd \ll nd≪n且ddd是固定的,这相当于实现了关于序列长度nnn的线性复杂度。实现原理:利用矩阵乘法的结合律标准 Attention 的计算顺序是先算QKTQK^TQKT(n×nn \times nn×n),再乘VVV。如果我们改变计算顺序,先算KTVK^T VKTV(d×dd \times dd×d),再让QQQ去乘这个结果,复杂度就会改变。数学变换如下(忽略 softmax 和归一化系数):Output=Q(KTV) \text{Output} = Q (K^T V)