Overview
Generating text with large language models is slow because the model must compute every token one by one. Speculative decoding [1] speeds this up by letting a small draft model sprint ahead, while a larger target model checks its work in batches. The trick is that you still get exactly the same distribution as if you had sampled directly from the big model.
Notation
- $x_{0:t}$: your current prompt/prefix (already fixed).
- $K$: lookahead window (how many draft tokens you try at once).
- $T$: minimum total length you want to reach.
- $\tilde{x}_t$: a draft token proposed by $p$.
- $q(\cdot \mid \text{context})$: target’s distribution over the next token given the context.
- $p(\cdot \mid \text{context})$: draft’s distribution likewise.
- $(a)_+ = \max(a, 0)$.
Algorithm (outer loop)
1) Draft phase (cheap and sequential)
Generate a short continuation of length $K$ from the draft model:
This is standard AR sampling but using p.
2) Target scoring (expensive but batched)
In one parallel/batched forward of q, compute the next-token logits/probs for the $K+1$ contexts:
This is the key efficiency trick: you pay one “big” pass of $q$ to score all steps in the lookahead chain.
3) Verification / acceptance (exactness correction)
Walk through the draft tokens in order $t=1\ldots K$:
-
Compute the acceptance ratio for the proposed token
$ a_t = \min\!\left(1,\; \frac{q(\tilde{x}_t \mid x_{0:n}, \tilde{x}_{1:t-1})} {p(\tilde{x}_t \mid x_{0:n}, \tilde{x}_{1:t-1})} \right). $ -
Draw $r \sim \mathrm{Uniform}[0,1]$.
-
If $r < a_t$: accept. Set $x_{n+t} \leftarrow \tilde{x}_t$ and continue.
-
Else (reject): sample $x_{n+t}$ from the residual distribution
$ \propto \big(q(\cdot \mid \text{context}) - p(\cdot \mid \text{context})\big)_+, $and stop the inner loop (once you deviate, the rest of the draft is stale).
-
Each accepted token advances $n \leftarrow n+1$.
4) Bonus token if all $K$ were accepted
If all $\tilde{x}_{1:K}$ passed, you’ve “saved” $q$-compute. Spend it to sample one extra true $q$ token:
and set $n \leftarrow n+1$. Net effect: you produced $K+1$ tokens while doing one batched pass of $q$.
Repeat until you reach length $T$.
Why this works (intuition + guarantees)
-
Think of $p$ as a proposal and the accept/reject as a rejection-sampling / MH-style correction.
-
The acceptance ratio $\min(1, q/p)$ ensures unbiasedness: the final token stream is exactly distributed as if sampled from $q$.
-
The residual sampling on rejection fills in the parts where $q$ has more mass than $p$, preserving exactness.
-
Efficiency comes from:
- $p$ is cheaper and runs sequentially for $K$ steps.
- $q$ runs once per outer loop but scores $K+1$ contexts in parallel.
- If $p$ is close to $q$, acceptance rates are high (often 70–95% in practice), so you often accept most of the draft and get the “bonus” $q$ token—yielding a 2–4× wall-clock speedup is common with a well-matched $p$.
Practical notes / gotchas
- Choose $K$ to match your hardware batch limits for $q$. Too large $K$ can hurt if acceptance falls or memory blows up.
- Quality of $p$ matters: closer to $q$ → higher acceptance → bigger speedups; but $p$ must be cheaper (smaller model, lower precision).
- Stopping when you reject is crucial: once a draft token is replaced, the remaining draft no longer matches the new context.
- Numerics: compute on logits carefully; acceptance uses probabilities (softmaxed). Clip ratios to avoid NaNs.
- KV caching: you still benefit; the batch scoring of $q$ reuses the prefix KV, then branches $K$ times.
Minimal pseudocode (clear version)
while n < T:
# 1) Draft K tokens with p
tilde = []
ctx = x[:n]
for t in range(K):
tilde_t ~ p(. | ctx)
tilde.append(tilde_t)
ctx = ctx + [tilde_t]
# 2) One batched pass of q for K+1 contexts
contexts = [x[:n]] + [x[:n] + tilde[:t] for t in range(1, K+1)]
q_logits = q.forward_batch(contexts) # parallel
q_probs = softmax(q_logits) # per context
# 3) Verify
accepted_all = True
for t in range(1, K+1):
ctx_prob_q = q_probs[t-1] # prob dist at step t conditioned on tilde[:t-1]
ctx_prob_p = p.prob(. | x[:n] + tilde[:t-1])
z = tilde[t-1]
a = min(1., ctx_prob_q[z] / max(ctx_prob_p[z], eps))
if random() < a:
x.append(z); n += 1
else:
# residual draw proportional to (q - p)_+
residual = relu(ctx_prob_q - ctx_prob_p)
residual /= residual.sum()
z_star ~ residual
x.append(z_star); n += 1
accepted_all = False
break
# 4) Bonus q token if no rejection
if accepted_all:
bonus_probs = q_probs[K] # last context (after tilde[:K])
z ~ bonus_probs
x.append(z); n += 1
Recap
- Use a small $p$ to sprint $K$ steps ahead.
- Score those $K$ steps with one big batched pass of $q$.
- Accept each draft token with $\min(1, q/p)$; on first rejection, draw from the leftover $q$-mass and stop verifying.
- If everything was accepted, grab one extra $q$ token “for free.”
- Result: exact $q$-samples, but much faster when $p \approx q$.
References
- Leviathan, Y., Matias, Y., et al. (2023). Fast Inference from Transformers via Speculative Decoding.
- Chen, X., et al. (2023). Accelerating Large Language Model Decoding with Speculative Sampling.
- Cai, Z., et al. (2023). Medusa: Simple Framework for Trading Compute for Performance via Speculative Decoding.
- Fu, Y., et al. (2024). Self-Speculative Decoding.