注意力机制:More!⚓︎
约 1342 个字 预计阅读时间 4 分钟 总阅读量 次
因为过于重要所以必须单开一节。不同于 Transformer 一节中大刀阔斧地对整个模型组建的梳理和若干知识点的回答,这一节更加专注于Attention的复杂度、变体以及分析。
自注意力机制的复杂度⚓︎
简单直接的答案是:对于一个长度为 n
、维度为 d
的输入序列,标准自注意力层的算法复杂度是
\[O(n^2 \cdot d)\]
这个复杂度主要来自于两个矩阵乘法步骤。下面我们来详细分解一下这个计算过程。
为了计算复杂度,我们首先定义几个关键变量:
n
:输入序列的长度(即 Token 的数量)。d
:模型的隐藏层维度(d_model
),也即每个 Token 的 embedding 维度。d_k
:Query 和 Key 向量的维度。d_v
:Value 向量的维度。
在经典的 Transformer 实现中,通常有 d_k = d_v = d / h
,其中 h
是注意力头的数量。为了简化分析,我们先假设只有一个头(h=1
),此时 d_k = d_v = d
。
分析⚓︎
我们来回顾一下自注意力的计算公式并分析每一步的计算量。
\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]
假设输入矩阵为 X
,其维度为 [n, d]
。
1. 生成 Q, K, V 矩阵⚓︎
- 操作: 通过将输入矩阵
X
与三个独立的权重矩阵W_Q
,W_K
,W_V
相乘来生成Q
,K
,V
。Q = X @ W_Q
K = X @ W_K
V = X @ W_V
- 维度分析:
X
:[n, d]
W_Q
,W_K
,W_V
:[d, d_k]
(这里我们假设d_k=d
) 即[d, d]
Q
,K
,V
:[n, d]
- 复杂度: 一次
[n, d] @ [d, d]
矩阵乘法的复杂度是 \(O(n \cdot d \cdot d) = O(n \cdot d^2)\)。由于需要计算三次,总复杂度仍为 \(O(n \cdot d^2)\)。
2. 计算注意力分数 (Scores)⚓︎
- 操作: 计算
Q
和K
的转置的点积。Scores = Q @ K.T
- 维度分析:
Q
:[n, d]
K.T
:[d, n]
Scores
:[n, n]
- 复杂度:
[n, d] @ [d, n]
矩阵乘法的复杂度是 \(O(n \cdot d \cdot n) = O(n^2 \cdot d)\)。 - 注意: 后续的缩放操作(除以 \(\sqrt{d_k}\))是逐元素的,其复杂度为 \(O(n^2)\),远小于矩阵乘法,可以忽略不计。
3. Softmax 计算注意力权重⚓︎
- 操作: 对
Scores
矩阵的每一行应用 Softmax 函数。 - 维度分析:
Scores
:[n, n]
- 复杂度: 对一行
n
个元素计算 Softmax 的复杂度是 \(O(n)\)。因为有n
行,所以总复杂度是 \(O(n^2)\)。
4. 加权求和 V⚓︎
- 操作: 将 Softmax 得到的注意力权重矩阵与
V
矩阵相乘。Output = AttentionWeights @ V
- 维度分析:
AttentionWeights
:[n, n]
V
:[n, d]
(这里我们假设d_v = d
)Output
:[n, d]
- 复杂度:
[n, n] @ [n, d]
矩阵乘法的复杂度是
\[O(n^2 \cdot d)\]
抓大放小⚓︎
现在我们将所有步骤的复杂度放在一起:
- Q,K,V 生成:\(O(n \cdot d^2)\)
- 注意力分数计算:\(O(n^2 \cdot d)\)
- Softmax:\(O(n^2)\)
- 加权求和V:\(O(n^2 \cdot d)\)
在典型的 Transformer 模型中,序列长度 n
(如 512, 1024) 和维度 d
(如 768, 1024) 通常是同一数量级或者 d > n
。但是,序列长度 n
是可变的,而维度 d
是固定的超参数。 算法复杂度的分析主要关注随输入规模(即 n
)变化的趋势。
最耗时的步骤是第 2 步和第 4 步,它们的复杂度都是 \(O(n^2 \cdot d)\)。因此,整个注意力层的复杂度由这些主导项决定:
总复杂度 = \(O(n \cdot d^2 + n^2 \cdot d)\)
- 当
d > n
时(例如在处理短文本时),\(n \cdot d^2\) 项可能更大。 - 当
n > d
时(例如在处理长文档时),\(n^2 \cdot d\) 项会成为绝对的瓶颈。
然而,在算法复杂度的标准表示中,我们通常将这两个维度都视为变量,因此保留这两个项。但在讨论 Transformer 的主要瓶颈时,我们通常特指其对于序列长度 n
的二次方依赖性,即 \(O(n^2)\) 这一部分。这是因为 d
是一个固定的模型设计参数,而 n
是我们希望能够灵活处理的输入长度。
多头注意力 (Multi-Head Attention) 的影响⚓︎
在多头注意力中,维度 d
被切分成 h
个头,每个头的维度为 d_k = d/h
。
- 单个头的注意力计算复杂度变为 \(O(n^2 \cdot (d/h))\)。
- 因为
h
个头是并行计算的,所以总的计算量是 \(h \times O(n^2 \cdot d/h) = O(n^2 \cdot d)\)。
结论:多头注意力机制并不改变整体的渐进复杂度,但它将大的矩阵乘法分解为多个小的矩阵乘法,这种方式更适合现代 GPU 的并行计算,因此在实践中通常更快。
注意力机制的核心瓶颈在于它需要计算一个 [n, n]
大小的注意力分数矩阵。这意味着内存占用和计算量都与序列长度的平方成正比。如果你的输入序列长度翻倍,那么计算和存储注意力矩阵的开销将增长到原来的四倍。这正是限制标准 Transformer 模型处理超长序列(如整本书、高分辨率图像)的主要原因,并催生了如 FlashAttention、稀疏注意力等众多优化算法。