跳转至

注意力机制:More!⚓︎

约 1342 个字 预计阅读时间 4 分钟 总阅读量

因为过于重要所以必须单开一节。不同于 Transformer 一节中大刀阔斧地对整个模型组建的梳理和若干知识点的回答,这一节更加专注于Attention的复杂度、变体以及分析。

自注意力机制的复杂度⚓︎

简单直接的答案是:对于一个长度为 n、维度为 d 的输入序列,标准自注意力层的算法复杂度是

\[O(n^2 \cdot d)\]

这个复杂度主要来自于两个矩阵乘法步骤。下面我们来详细分解一下这个计算过程。

为了计算复杂度,我们首先定义几个关键变量:

  • n:输入序列的长度(即 Token 的数量)。
  • d:模型的隐藏层维度(d_model),也即每个 Token 的 embedding 维度。
  • d_k:Query 和 Key 向量的维度。
  • d_v:Value 向量的维度。

在经典的 Transformer 实现中,通常有 d_k = d_v = d / h,其中 h 是注意力头的数量。为了简化分析,我们先假设只有一个头(h=1),此时 d_k = d_v = d

分析⚓︎

我们来回顾一下自注意力的计算公式并分析每一步的计算量。

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

假设输入矩阵为 X,其维度为 [n, d]

1. 生成 Q, K, V 矩阵⚓︎

  • 操作: 通过将输入矩阵 X 与三个独立的权重矩阵 W_Q, W_K, W_V 相乘来生成 Q, K, V
    • Q = X @ W_Q
    • K = X @ W_K
    • V = X @ W_V
  • 维度分析:
    • X: [n, d]
    • W_Q, W_K, W_V: [d, d_k] (这里我们假设 d_k=d) 即 [d, d]
    • Q, K, V: [n, d]
  • 复杂度: 一次 [n, d] @ [d, d] 矩阵乘法的复杂度是 \(O(n \cdot d \cdot d) = O(n \cdot d^2)\)。由于需要计算三次,总复杂度仍为 \(O(n \cdot d^2)\)

2. 计算注意力分数 (Scores)⚓︎

  • 操作: 计算 QK 的转置的点积。
    • Scores = Q @ K.T
  • 维度分析:
    • Q: [n, d]
    • K.T: [d, n]
    • Scores: [n, n]
  • 复杂度: [n, d] @ [d, n] 矩阵乘法的复杂度是 \(O(n \cdot d \cdot n) = O(n^2 \cdot d)\)
  • 注意: 后续的缩放操作(除以 \(\sqrt{d_k}\))是逐元素的,其复杂度为 \(O(n^2)\),远小于矩阵乘法,可以忽略不计。

3. Softmax 计算注意力权重⚓︎

  • 操作: 对 Scores 矩阵的每一行应用 Softmax 函数。
  • 维度分析:
    • Scores: [n, n]
  • 复杂度: 对一行 n 个元素计算 Softmax 的复杂度是 \(O(n)\)。因为有 n 行,所以总复杂度是 \(O(n^2)\)

4. 加权求和 V⚓︎

  • 操作: 将 Softmax 得到的注意力权重矩阵与 V 矩阵相乘。
    • Output = AttentionWeights @ V
  • 维度分析:
    • AttentionWeights: [n, n]
    • V: [n, d] (这里我们假设 d_v = d)
    • Output: [n, d]
  • 复杂度: [n, n] @ [n, d] 矩阵乘法的复杂度是
\[O(n^2 \cdot d)\]

抓大放小⚓︎

现在我们将所有步骤的复杂度放在一起:

  1. Q,K,V 生成:\(O(n \cdot d^2)\)
  2. 注意力分数计算\(O(n^2 \cdot d)\)
  3. Softmax:\(O(n^2)\)
  4. 加权求和V\(O(n^2 \cdot d)\)

在典型的 Transformer 模型中,序列长度 n (如 512, 1024) 和维度 d (如 768, 1024) 通常是同一数量级或者 d > n。但是,序列长度 n 是可变的,而维度 d 是固定的超参数。 算法复杂度的分析主要关注随输入规模(即 n)变化的趋势。

最耗时的步骤是第 2 步和第 4 步,它们的复杂度都是 \(O(n^2 \cdot d)\)。因此,整个注意力层的复杂度由这些主导项决定:

总复杂度 = \(O(n \cdot d^2 + n^2 \cdot d)\)

  • d > n 时(例如在处理短文本时),\(n \cdot d^2\) 项可能更大。
  • n > d 时(例如在处理长文档时),\(n^2 \cdot d\) 项会成为绝对的瓶颈。

然而,在算法复杂度的标准表示中,我们通常将这两个维度都视为变量,因此保留这两个项。但在讨论 Transformer 的主要瓶颈时,我们通常特指其对于序列长度 n 的二次方依赖性,即 \(O(n^2)\) 这一部分。这是因为 d 是一个固定的模型设计参数,而 n 是我们希望能够灵活处理的输入长度。


多头注意力 (Multi-Head Attention) 的影响⚓︎

在多头注意力中,维度 d 被切分成 h 个头,每个头的维度为 d_k = d/h

  • 单个头的注意力计算复杂度变为 \(O(n^2 \cdot (d/h))\)
  • 因为 h 个头是并行计算的,所以总的计算量是 \(h \times O(n^2 \cdot d/h) = O(n^2 \cdot d)\)

结论:多头注意力机制并不改变整体的渐进复杂度,但它将大的矩阵乘法分解为多个小的矩阵乘法,这种方式更适合现代 GPU 的并行计算,因此在实践中通常更快。

注意力机制的核心瓶颈在于它需要计算一个 [n, n] 大小的注意力分数矩阵。这意味着内存占用和计算量都与序列长度的平方成正比。如果你的输入序列长度翻倍,那么计算和存储注意力矩阵的开销将增长到原来的四倍。这正是限制标准 Transformer 模型处理超长序列(如整本书、高分辨率图像)的主要原因,并催生了如 FlashAttention、稀疏注意力等众多优化算法。