Speculative-Decoding

Speculative Decoding 知识学习

投机解码(Speculative Decoding)学习笔记

1. 大模型推理的本质:自回归

LLM 的生成过程是标准的自回归链:

因此必须:

  • 一步一步生成 token
  • 每生成一个 token 都要跑一次完整 Transformer forward
  • GPU 低利用率(序列化瓶颈)

推理慢的根本原因:

每生成 1 个 token,就需要完整跑一遍大模型 forward。


2. KV Cache:解决训练时不能解决的问题

KV Cache 用途:

缓存历史 token 的 Key / Value:

  • 避免重复计算 K/V
  • 将注意力复杂度从 O(N²) → O(N)
  • 加速 attention,但不减少 forward 次数

KV Cache 无法解决:

  • 每次生成 token 时必须计算新的 Q
  • FFN 完全无法缓存
  • 每一层仍要计算 attention(Q, K_cache, V_cache)
  • forward 次数仍然 = token 数量

训练为什么不能用 KV Cache?

训练是全序列并行计算,每个 token 的 Q/K/V 都第一次出现,不存在“历史”信息可缓存。


3. 推理加速真正要优化的是什么?

所有加速技术的目标都只有一个:

减少大模型 Full Forward 的次数。

因为 FFN 才是推理阶段的主要 FLOPs(60%~75%),attention 已经被 KV Cache 大幅加速。


4. 投机解码的核心思想

投机解码利用两点 Transformer 特性:

  1. 大模型无法并行生成多个 token(自回归约束)
  2. 但 Transformer 可以一次性并行验证一个序列的所有位置(序列前向并行)

于是我们可以:

  1. 小模型快速生成一串 token(draft)
  2. 大模型一次 forward验证整段草稿
  3. 使用拒绝采样决定哪些 token 可接受
  4. 接受的 token 直接作为真实输出,减少大模型 forward 次数

5. 为什么小模型可以“猜”大模型?

来源:

  • 蒸馏(distillation):小模型 q 学大模型 p 的分布
  • 语言分布高度集中:
    大模型和小模型在前 1–5 个候选 token 上高度重合

小模型猜得越准,接受率越高 → 加速越大。


6. 大模型为什么能一次性验证多个 token?

因为 Transformer 的结构天然支持:

  • 对整个序列做并行 forward
  • 因果 mask 保证每个位置只看左侧
  • 每个位置的 logits 可同时输出

因此:

大模型一次 forward 给出 y₁,y₂,…,yₖ 的所有概率。

但:

  • 它没有“并行预测”多个 token
  • 只是“并行验证”草稿序列

7. 拒绝采样机制:保证最终输出分布无偏

小模型生成草稿 token (y_i)。

大模型计算真实概率:

小模型概率:

接受概率 α:

采样随机数 u:

  • (u < \alpha_i) → 接受
  • (u \ge \alpha_i) → 拒绝并回退到大模型自回归

这个机制保证:

无论小模型再差,最终采样分布始终 = 大模型真实分布 p(无偏)。


8. 投机解码完整流程(图式)

假设草稿长度 k = 4

Step 1:小模型生成草稿

1
draft = [y1, y2, y3, y4]

Step 2:大模型一次 forward

输入序列 = X + draft

大模型一次性计算:

  • p(y1|X)
  • p(y2|X,y1)
  • p(y3|X,y1,y2)
  • p(y4|X,y1,y2,y3)

Step 3:拒绝采样逐个检查

若全部接受 → 一次“吞掉” k 个 token
若 y3 被拒绝 → 接受 y1,y2,y3 由大模型重新生成

Step 4:重复


9. 投机解码伪代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
X = initial_context
while not finished:
# Step 1: small model draft
draft = S.generate_k_tokens(X)

# Step 2: large model verify (one forward)
logits_L = L.forward(X + draft)

accepted = []
for i, y in enumerate(draft):
p = softmax(logits_L[i])[y]
q = S.prob(y)

alpha = p / (c * q)
u = random()

if u < alpha:
accepted.append(y)
else:
# fallback: large model generates from here
y_L = L.sample_token(X + accepted)
accepted.append(y_L)
break

X = X + accepted

10. 为什么不会破坏自回归因果性?

因为 causal mask 保证:

  • y₂ 的 hidden state 只能使用 X + y₁
  • y₃ 只能使用 X + y₁,y₂

未来 token 永远不会泄漏信息。


11. 投机解码的优缺点

优点

  • 巨大的推理加速(2×~4×,甚至更高)
  • 输出质量完全不变(无偏)
  • 适合 KV Cache + batch 的现代推理架构
  • 已被 GPT-4, Llama, Qwen, vLLM 等主流框架采用

缺点

  • 需要训练或蒸馏一个小模型
  • 小模型越准 → 加速越大
  • 太差的小模型会降低加速比
  • 不适合极少 token 输出的任务(如分类)

12. 投机解码的本质一句话总结

小模型预测草稿,大模型一次 forward 并行验证,通过拒绝采样保证分布无偏,从而减少大模型 forward 次数,实现推理加速。