“为什么模型只能记住 8K / 128K 个 token,超了就忘?” 把这个问题讲清楚,需要拆开两件事:一是自注意力的平方复杂度决定了长度的根本代价,二是 KV Cache 这个推理时的关键优化既让自回归生成变快,又成了长上下文的主要显存瓶颈。理解了它俩,你就能算出"再长一倍要多花多少显存",也能看懂业界各种长上下文方案到底在优化什么。

直觉:每个 token 都要"看一遍"所有 token

Transformer 的核心是自注意力:生成序列里的每个位置,都要和之前所有位置算一遍相关性,再加权聚合信息。这正是它强大的原因——任意两个 token 可以直接交互,没有 RNN 那种信息逐步衰减的长程依赖问题。但代价是:序列长度 LL 时,要算的 token 对数量是 O(L2)O(L^2)。长度翻倍,注意力的计算与中间激活内存就涨到四倍。这是上下文长度受限的第一性根源

机制:注意力与 O(L2)O(L^2) 从哪来

每个 token 经线性投影得到三个向量:Query、Key、Value。注意力公式是

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V

其中 Q,K,VQ, K, V 形状为 L×dkL \times d_k。关键在中间那个 QKQK^\top:它是 (L×dk)(dk×L)(L \times d_k)(d_k \times L),得到一个 L×LL \times L 的注意力分数矩阵——这就是平方项的来源。L=100KL = 100\text{K} 时,单头单层的这个矩阵就有 101010^{10} 个元素。计算量和(朴素实现下的)激活内存都随 L2L^2 增长,长度因此不能无限拉长。

(顺带一提:FlashAlttention 等优化通过分块计算避免把完整 L×LL\times L 矩阵物化到显存,把激活内存压到接近线性,但计算量本身仍是 O(L2)O(L^2)。)

KV Cache:自回归生成的命脉

推理是逐 token 自回归进行的:已经生成了 tt 个 token,要预测第 t+1t+1 个。朴素做法是把这 tt 个 token 整个重新喂进网络跑一遍——但这里有大量重复劳动。

注意一个事实:当我们生成第 t+1t+1 个 token 时,前面 tt 个 token 的 Key 和 Value 向量并不会改变(它们只依赖各自及更早的输入,与未来无关,这正是因果掩码保证的)。既然不变,就没必要每步重算。于是把每一层、每个历史 token 算出的 K、V 缓存下来,新 token 来时只算它自己的 Q、K、V,K/V 追加进缓存,Q 去和缓存里全部 K/V 做注意力即可。

1
2
3
4
5
6
7
8
# 概念示意:带 KV Cache 的单步解码
def decode_step(new_token, kv_cache, model):
q, k, v = model.qkv(new_token) # 只为新 token 计算 QKV
kv_cache.k.append(k) # K/V 追加进缓存,历史不重算
kv_cache.v.append(v)
# 新 token 的 Q 对全部历史 K/V 做注意力
attn = softmax(q @ kv_cache.k.T / sqrt(d_k)) @ kv_cache.v
return model.head(attn), kv_cache

收益巨大:没有缓存时,生成 LL 个 token 的总计算约 O(L3)O(L^3)(每步重算一遍 O(L2)O(L^2) 的前缀);有了 KV Cache,每步只处理一个新 token,整体降到 O(L2)O(L^2)KV Cache 是让长文本生成在实践中可行的核心优化。

显存账:KV Cache 才是长上下文的真瓶颈

天下没有免费的缓存。KV Cache 把"重复计算"换成了"持续占显存",而它的大小随上下文线性增长且相当可观。单条序列的 KV Cache 字节数约为:

\text{KV bytes} = 2 \times L \times n_{\text{layers}} \times n_{\text{kv\_heads}} \times d_{\text{head}} \times \text{bytes\_per\_elem}

其中开头的 2 是 K 和 V 各一份。逐项看含义:序列越长(LL)、层数越多、KV 头越多、每个头维度越大、精度越高(FP16 是 2 字节),缓存越大。代入一个量级感受一下:几十层、隐藏维数千、FP16 的模型,每个 token 的 KV 往往要几十到上百 KB。于是 128K 上下文单条序列就可能吃掉数 GB显存——而且这还要乘以并发的 batch size。

这解释了几个现实现象:

  • 显存而非算力,常常是长上下文与高并发的第一约束。推理服务的并发上限,往往卡在"KV Cache 总量塞不下显存"。
  • 首 token 慢、后续快。处理输入 prompt(prefill 阶段)要一次算完 O(L2)O(L^2) 的全注意力并填满缓存,长 prompt 的首 token 延迟明显;之后逐 token(decode 阶段)每步都很轻。

工程权衡:怎么把窗口做长、做省

围绕"O(L2)O(L^2) 算力 + 线性增长的 KV 显存"这两个约束,业界的优化基本都在做同一件事——省 KV 或省注意力

  • GQA / MQA(分组/多查询注意力):让多个 Query 头共享同一组 K/V 头,直接砍掉公式里的 n_{\text{kv\_heads}},KV Cache 成倍缩小。这是当下大模型几乎标配的省显存手段。
  • KV Cache 量化:把缓存从 FP16 降到 INT8/INT4 存储,用精度换显存,线性削减占用。
  • 稀疏 / 滑动窗口注意力:不让每个 token 看全部历史,只看局部窗口或少量全局 token,把注意力从 O(L2)O(L^2) 降到接近线性,代价是牺牲部分长程依赖。
  • PagedAttention:像操作系统管理虚拟内存那样分页管理 KV Cache,消除碎片、提升显存利用率,从而塞下更多并发序列。
  • 位置编码外推:RoPE 缩放等技巧让模型在比训练时更长的序列上仍能合理工作,但这是"能不能正确利用"的问题,与"显存够不够"正交。

常见误区

  • “上下文越长越好”。长上下文不仅烧显存、抬延迟,还有信息利用问题:模型对中段内容容易读漏(lost in the middle)。把最关键的指令放在头部或尾部,往往比无脑塞满窗口更可靠。
  • “窗口大就等于记忆力强”。窗口是每次请求的临时工作区,请求一结束就清空。模型没有跨会话的长期记忆,那需要靠外部存储 + 检索(RAG)另行搭建。
  • 混淆"算得动"和"放得下"。FlashAttention 解决的是激活内存与带宽,GQA/量化解决的是 KV 显存,稀疏注意力解决的是计算量——它们针对不同瓶颈,不能互相替代。

小结

上下文长度受限,根子在两处:自注意力天生 O(L2)O(L^2) 的计算代价,以及 KV Cache 随长度线性膨胀的显存占用。KV Cache 是让自回归生成从 O(L3)O(L^3) 降到 O(L2)O(L^2) 的命脉优化,却也成了长上下文与高并发的首要瓶颈。看懂那条 KV 显存公式,你就能估算"上下文翻倍要多花多少显存",也能一眼看穿 GQA、量化、稀疏注意力、PagedAttention 这些方案各自在对公式里的哪一项动手。