缩放点积注意力为何除以根号dk

Transformer中的Scaled Dot-Product Attention:为何除以 dk\sqrt{d_k}

本文由AI生成,CJL的主要工作是编写提示词,并检验内容正确性

在Transformer模型中,Scaled Dot-Product Attention 是一个核心机制,其计算公式为:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

其中,QQ 是查询向量,KK 是键向量,VV 是值向量,dkd_k 是键向量的维度。公式中一个引人注目的细节是:点积 QKTQK^T 被除以 dk\sqrt{d_k}。为什么选择 dk\sqrt{d_k} 而不是其他缩放因子,比如直接除以 dkd_k?这篇文章将从理论和实践角度分析这一设计的合理性,并深入探讨其背后的统计特性。


为什么除以 dk\sqrt{d_k}

点积规模的问题

在注意力机制中,QKTQK^T 是查询和键的点积,表示两者的相似度。假设 QQKKdkd_k 维向量,如果每个分量的值是均值为 0、方差为 1 的随机变量,那么点积的数值会随着 dkd_k 的增加而变大。具体来说,点积的方差与 dkd_k 成正比(后面会证明)。当 dkd_k 较大时(例如 64 或 512),未经缩放的 QKTQK^T 可能达到数百甚至数千。

Softmax 的敏感性

点积直接输入到 softmax\text{softmax} 函数中,而 softmax\text{softmax} 对输入的规模非常敏感。如果 QKTQK^T 的值过大,softmax\text{softmax} 的输出会退化为接近“one-hot”分布(一个值接近 1,其他接近 0)。这会导致梯度变小,减慢训练过程,甚至引发数值不稳定。反过来,如果点积过小,softmax\text{softmax} 输出会趋于均匀,削弱注意力机制区分重要性的能力。

dk\sqrt{d_k} 的作用

除以 dk\sqrt{d_k} 可以将点积的规模标准化到一个合理的范围。直观上,dk\sqrt{d_k} 与点积标准差的增长率匹配(后面证明),从而避免上述问题:

  • 它防止点积随 dkd_k 无限制增长,保持 softmax\text{softmax} 输入的稳定性。

  • 相比直接除以 dkd_k(会导致过度缩放),dk\sqrt{d_k} 保留了足够的动态范围,让注意力机制有效工作。

这种缩放方式在《Attention is All You Need》中被提出,并在实践中被广泛验证。下面我们从数学角度证明其合理性。


证明:点积的期望值和方差与 dk\sqrt{d_k} 的关系

定义

考虑两个 dkd_k 维向量 Q=[q1,q2,,qdk]Q = [q_1, q_2, \dots, q_{d_k}]K=[k1,k2,,kdk]K = [k_1, k_2, \dots, k_{d_k}],点积为:

S=QK=i=1dkqikiS = Q \cdot K = \sum_{i=1}^{d_k} q_i k_i

假设 qiq_ikik_i 是独立同分布的随机变量,均值为 0,方差为 σ2\sigma^2。我们计算 SS 的期望和方差。

期望 E[S]E[S]

E[S]=E[i=1dkqiki]=i=1dkE[qiki]E[S] = E\left[\sum_{i=1}^{d_k} q_i k_i\right] = \sum_{i=1}^{d_k} E[q_i k_i]

由于 qiq_ikik_i 独立,且 E[qi]=0E[q_i] = 0E[ki]=0E[k_i] = 0

E[qiki]=E[qi]E[ki]=0E[q_i k_i] = E[q_i] \cdot E[k_i] = 0

E[S]=dk0=0E[S] = d_k \cdot 0 = 0

期望值为 0,与 dkd_k 无关。

方差 Var(S)Var(S)

方差定义为:

Var(S)=E[(SE[S])2]Var(S) = E[(S - E[S])^2]

由于 E[S]=0E[S] = 0,这简化为:

Var(S)=E[S2]=E[(i=1dkqiki)2]Var(S) = E[S^2] = E\left[\left(\sum_{i=1}^{d_k} q_i k_i\right)^2\right]

展开平方项:

S2=(i=1dkqiki)2=i=1dkj=1dkqikiqjkjS^2 = \left(\sum_{i=1}^{d_k} q_i k_i\right)^2 = \sum_{i=1}^{d_k} \sum_{j=1}^{d_k} q_i k_i q_j k_j

所以:

E[S2]=E[i=1dkj=1dkqikiqjkj]E[S^2] = E\left[\sum_{i=1}^{d_k} \sum_{j=1}^{d_k} q_i k_i q_j k_j\right]

将求和拆分为 i=ji = jiji \neq j 两种情况:

E[S2]=i=1dkE[qikiqiki]+ijE[qikiqjkj]E[S^2] = \sum_{i=1}^{d_k} E[q_i k_i q_i k_i] + \sum_{i \neq j} E[q_i k_i q_j k_j]

i=ji = j 时:

E[qikiqiki]=E[qi2ki2]E[q_i k_i q_i k_i] = E[q_i^2 k_i^2]

由于 qiq_ikik_i 独立:

E[qi2ki2]=E[qi2]Edirectory[ki2]E[q_i^2 k_i^2] = E[q_i^2] \cdot E directory[k_i^2]

  • E[qi2]=Var(qi)+E[qi]2=σq2+0=σq2E[q_i^2] = Var(q_i) + E[q_i]^2 = \sigma_q^2 + 0 = \sigma_q^2

  • E[ki2]=Var(ki)+E[ki]2=σk2+0=σk2E[k_i^2] = Var(k_i) + E[k_i]^2 = \sigma_k^2 + 0 = \sigma_k^2

所以:

E[qi2ki2]=σq2σk2E[q_i^2 k_i^2] = \sigma_q^2 \cdot \sigma_k^2

这样的项有 dkd_k 个(因为 ii 从 1 到 dkd_k)。

iji \neq j 时:

E[qikiqjkj]E[q_i k_i q_j k_j]

因为 qiq_ikik_iqjq_jkjk_j 都是独立的(iji \neq j 时,qiq_iqjq_j 独立,kik_ikjk_j 独立,且 QQKK 之间独立):

E[qikiqjkj]=E[qi]E[ki]E[qj]E[kj]=0000=0E[q_i k_i q_j k_j] = E[q_i] \cdot E[k_i] \cdot E[q_j] \cdot E[k_j] = 0 \cdot 0 \cdot 0 \cdot 0 = 0

合并结果:

Var(S)=E[S2]=i=1dkE[qi2ki2]+ij0=i=1dkσq2σk2=dkσq2σk2Var(S) = E[S^2] = \sum_{i=1}^{d_k} E[q_i^2 k_i^2] + \sum_{i \neq j} 0 = \sum_{i=1}^{d_k} \sigma_q^2 \cdot \sigma_k^2 = d_k \cdot \sigma_q^2 \cdot \sigma_k^2

结论:点积的方差为:

Var(S)=dkσq2σk2Var(S) = d_k \cdot \sigma_q^2 \cdot \sigma_k^2

dk\sqrt{d_k} 的关系

  • 方差 Var(S)=dkσq2σk2Var(S) = d_k \cdot \sigma_q^2 \cdot \sigma_k^2 是与 dkd_k 成正比的。

  • 标准差(方差的平方根)是:

Var(S)=dkσq2σk2=σq2σk2dk\sqrt{Var(S)} = \sqrt{d_k \cdot \sigma_q^2 \cdot \sigma_k^2} = \sqrt{\sigma_q^2 \cdot \sigma_k^2} \cdot \sqrt{d_k}

这里,σq2σk2\sqrt{\sigma_q^2 \cdot \sigma_k^2} 是常数(与 dkd_k 无关),因此标准差 Var(S)\sqrt{Var(S)}dk\sqrt{d_k} 成正比。


qiq_ikik_i 独立性假设的讨论

上述证明依赖于 qiq_ikik_i 独立的假设,这在随机初始化的情况下成立。但在训练后的Transformer中,Q=XWQQ = XW_QK=XWKK = XW_K,其中 XX 是输入嵌入,WQW_QWKW_K 是学到的权重。由于 QQKK 共享相同的输入 XX,且 WQW_QWKW_K 通过训练优化,qiq_ikik_i 通常不完全独立。

不独立时的影响

  • 期望:如果 qiq_ikik_i 存在协方差,E[S]=Cov(qi,ki)E[S] = \sum Cov(q_i, k_i) 可能不为 0,但具体值取决于相关性强度,不一定与 dkd_k 成比例。

  • 方差:交叉项 E[qikiqjkj]E[q_i k_i q_j k_j] 不为 0,可能引入 dk2d_k^2 级别的贡献,使 Var(S)dk2Var(S) \propto d_k^2,标准差 dk\propto d_k。这意味着 dk\sqrt{d_k} 缩放可能不足以完全标准化。

实际意义

尽管独立性不成立,dk\sqrt{d_k} 仍被广泛使用,可能因为:

  1. 训练过程使相关性局部化,dk2d_k^2 项不主导方差。

  2. dk\sqrt{d_k} 是一个经验折中,避免过度缩放(如除以 dkd_k)的同时保持稳定性。

  3. 实践验证表明,这种缩放在多种 dkd_k 下效果良好。


总结

除以 dk\sqrt{d_k} 是Transformer设计中的一个巧妙选择。它基于点积标准差与 dk\sqrt{d_k} 成比例的统计特性,解决了数值规模问题,同时保留注意力机制的表达力。虽然 qiq_ikik_i 的独立性在训练后不完全成立,但 dk\sqrt{d_k} 的实用性已在无数实验中得到证明。对于更精确的缩放,可能需要针对具体模型和数据进行分析,但当前设计无疑是一个优雅而有效的解决方案。


缩放点积注意力为何除以根号dk
https://blog.algorithmpark.xyz/2025/04/05/Scaled-dot-product-attention-sqrt-dk/index/
作者
CJL
发布于
2025年4月5日
更新于
2025年4月5日
许可协议