直觉:一次可微的"软查询"

把注意力想成一次数据库检索:你拿着一个查询(Query),去和一堆条目的键(Key)比对相似度,相似度高的条目,其值(Value)就被多取一点。区别在于,传统检索是硬选一条记录,而注意力是按相似度加权所有值——它是一次"软"的、处处可微的查表。正因为可微,整套机制能塞进反向传播里端到端训练。

这套 Query/Key/Value(QKV)抽象,是理解 Transformer 的钥匙。下面我们把它一步步拆到矩阵运算和数值细节。

机制:从单个 token 到矩阵形式

设输入是 nn 个 token 的表示 XRn×dmodelX \in \mathbb{R}^{n \times d_{model}}。注意力不直接用 XX,而是先用三个可学习的投影矩阵把它变换成 Q、K、V:

Q=XWQ,K=XWK,V=XWVQ = XW^Q,\quad K = XW^K,\quad V = XW^V

其中 WQ,WKRdmodel×dkW^Q, W^K \in \mathbb{R}^{d_{model}\times d_k}WVRdmodel×dvW^V \in \mathbb{R}^{d_{model}\times d_v}。然后核心公式——缩放点积注意力(scaled dot-product attention)

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q,K,V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V

逐步拆:

  1. QKRn×nQK^\top \in \mathbb{R}^{n\times n}:第 ii 行第 jj 列是 query ii 和 key jj 的内积,即"token ii 应该多关注 token jj"的原始打分(logits)。
  2. 除以 dk\sqrt{d_k}:缩放(下一节专门讲为什么)。
  3. 对每一行做 softmax:把打分变成和为 1 的权重分布 ARn×nA \in \mathbb{R}^{n\times n}AijA_{ij} 是 token ii 分给 token jj 的注意力权重。
  4. AVA V:用权重对所有 value 加权求和,得到每个 token 的新表示。

一句话:每个 token 的输出 = 所有 token 的 value 的加权平均,权重由 query-key 相似度经 softmax 归一化决定。

公式:为什么要除以 dk\sqrt{d_k}

这是面试高频题,也是真正的数学细节。假设 qqkk 的每个分量都是独立、均值 0、方差 1 的随机变量。它们的点积是:

qk=i=1dkqikiq^\top k = \sum_{i=1}^{d_k} q_i k_i

每一项 qikiq_i k_i 均值为 0、方差为 1(独立项乘积方差 = 各自方差乘积)。dkd_k 个独立项相加,方差线性叠加:

Var(qk)=i=1dkVar(qiki)=dk\mathrm{Var}(q^\top k) = \sum_{i=1}^{d_k}\mathrm{Var}(q_i k_i) = d_k

所以点积的标准差是 dk\sqrt{d_k}。当 dkd_k 较大(比如 64、128),未缩放的 logits 量级会很大。把这种大值喂进 softmax,会让分布极度尖锐(几乎是 one-hot),而 softmax 在饱和区的梯度趋近于 0——梯度消失,训练停滞。除以 dk\sqrt{d_k} 正好把 logits 的方差拉回 O(1)O(1),让 softmax 工作在梯度健康的区间。

1
2
未缩放: logits ~ N(0, d_k)   ->  softmax 尖锐  ->  梯度近 0
缩放后: logits ~ N(0, 1) -> softmax 平滑 -> 梯度健康

多头:在多个子空间并行关注

单个注意力只能学一种"关注模式"。**多头注意力(multi-head)**把 dmodeld_{model} 切成 hh 个头,每个头在 dk=dmodel/hd_k = d_{model}/h 维的子空间独立做注意力,再拼接、过一个输出投影:

\text{MultiHead}(X) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O,\quad \text{head}_i = \text{Attention}(XW_i^Q, XW_i^K, XW_i^V)

直觉上,不同头可以分工——有的关注相邻词,有的关注长距离依赖,有的关注语法、有的关注指代。切成多头几乎不增加总计算量(因为每头维度变小了),却让模型在多个表示子空间里并行建模关系。

掩码:因果与 padding

两种常见掩码:

  • 因果掩码(causal mask):自回归生成时,token ii 不能看到未来的 j>ij > i。做法是在 softmax 之前,把上三角(j>ij>i)的 logits 设为 -\infty,softmax 后这些位置权重为 0。
  • padding 掩码:batch 内序列长度不一,补齐的 padding 位置同样置 -\infty,避免模型关注无意义的填充。

最小实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import numpy as np

def softmax(x, axis=-1):
x = x - x.max(axis=axis, keepdims=True) # 数值稳定:减最大值防溢出
e = np.exp(x)
return e / e.sum(axis=axis, keepdims=True)

def attention(Q, K, V, mask=None):
d_k = Q.shape[-1]
scores = Q @ K.transpose(0, 2, 1) / np.sqrt(d_k) # (b, n, n)
if mask is not None:
scores = np.where(mask, scores, -1e9) # 屏蔽位置置 -inf
A = softmax(scores, axis=-1) # 注意力权重
return A @ V, A

b, n, d_k, d_v = 2, 5, 16, 16
Q, K, V = (np.random.randn(b, n, d_k) for _ in range(3))
out, A = attention(Q, K, V)
print(out.shape, A.sum(-1).round(3)) # 每行权重和为 1

注意 softmax 里"减最大值"的技巧——直接 exp 大数会数值溢出,这是必备的稳定化处理。

工程权衡与边界

  • 复杂度是 O(n2d)O(n^2 d) 那个 QKQK^\topn×nn\times n 矩阵,序列长度翻倍,计算和显存都翻四倍。长上下文的根本瓶颈就在这里。各种线性注意力、稀疏注意力、以及 IO 优化(把注意力分块、避免把完整 n×nn\times n 矩阵写回显存的 FlashAttention 类方法)都是为了缓解它。
  • 显存大头是注意力矩阵。 长序列训练时,存储 n×nn\times n 的注意力权重(及其反传所需中间量)往往是显存峰值来源,而非参数本身。
  • 常见误区一:以为 Q、K、V 来自不同输入。 在自注意力里它们都来自同一个 XX,只是过了不同投影;在交叉注意力(如编码器-解码器)里,Q 来自解码器、K/V 来自编码器。
  • 常见误区二:忘了缩放或缩放错维度。 缩放因子是 dk\sqrt{d_k}(每个头的维度),不是 dmodel\sqrt{d_{model}}
  • 常见误区三:注意力本身不含位置信息。 公式对 token 顺序是置换等变的——打乱输入顺序,输出只是相应打乱。位置信息必须靠位置编码额外注入。

小结

注意力是一次可微的软查表:Q 去匹配 K、用相似度加权 V。缩放点积的精髓在 dk\sqrt{d_k}——它把点积方差从 dkd_k 拉回 O(1)O(1),避免 softmax 饱和导致梯度消失。多头让模型在多个子空间并行建模不同关系,掩码控制可见范围。代价是 O(n2)O(n^2) 的复杂度与显存,这既是 Transformer 强表达力的来源,也是长上下文优化永恒的战场。