Speculative Decoding: Exact Speedups with Draft Models

3 min read

Speed up LLM sampling using a small draft model and a batched verifier while preserving the exact target distribution.

Overview

Generating text with large language models is slow because the model must compute every token one by one. Speculative decoding [1] speeds this up by letting a small draft model sprint ahead, while a larger target model checks its work in batches. The trick is that you still get exactly the same distribution as if you had sampled directly from the big model.

Notation

  • $x_{0:t}$: your current prompt/prefix (already fixed).
  • $K$: lookahead window (how many draft tokens you try at once).
  • $T$: minimum total length you want to reach.
  • $\tilde{x}_t$: a draft token proposed by $p$.
  • $q(\cdot \mid \text{context})$: target’s distribution over the next token given the context.
  • $p(\cdot \mid \text{context})$: draft’s distribution likewise.
  • $(a)_+ = \max(a, 0)$.

Algorithm (outer loop)

1) Draft phase (cheap and sequential)

Generate a short continuation of length $K$ from the draft model:

$ \tilde{x}_1, \tilde{x}_2, \dots, \tilde{x}_K \quad\text{with}\quad \tilde{x}_t \sim p(\cdot \mid x_{0:n}, \tilde{x}_{1:t-1}) $

This is standard AR sampling but using p.

2) Target scoring (expensive but batched)

In one parallel/batched forward of q, compute the next-token logits/probs for the $K+1$ contexts:

$ \begin{aligned} & q(\cdot \mid x_{0:n}) \\ & q(\cdot \mid x_{0:n}, \tilde{x}_1) \\ & \dots \\ & q(\cdot \mid x_{0:n}, \tilde{x}_{1:K}) \end{aligned} $

This is the key efficiency trick: you pay one “big” pass of $q$ to score all steps in the lookahead chain.

3) Verification / acceptance (exactness correction)

Walk through the draft tokens in order $t=1\ldots K$:

  • Compute the acceptance ratio for the proposed token

    $ a_t = \min\!\left(1,\; \frac{q(\tilde{x}_t \mid x_{0:n}, \tilde{x}_{1:t-1})} {p(\tilde{x}_t \mid x_{0:n}, \tilde{x}_{1:t-1})} \right). $
  • Draw $r \sim \mathrm{Uniform}[0,1]$.

    • If $r < a_t$: accept. Set $x_{n+t} \leftarrow \tilde{x}_t$ and continue.

    • Else (reject): sample $x_{n+t}$ from the residual distribution

      $ \propto \big(q(\cdot \mid \text{context}) - p(\cdot \mid \text{context})\big)_+, $

      and stop the inner loop (once you deviate, the rest of the draft is stale).

Each accepted token advances $n \leftarrow n+1$.

4) Bonus token if all $K$ were accepted

If all $\tilde{x}_{1:K}$ passed, you’ve “saved” $q$-compute. Spend it to sample one extra true $q$ token:

$ x_{n+1} \sim q(\cdot \mid x_{0:n}) \quad\text{(where \(n\) is now original \(n+K\))} $

and set $n \leftarrow n+1$. Net effect: you produced $K+1$ tokens while doing one batched pass of $q$.

Repeat until you reach length $T$.


Why this works (intuition + guarantees)

  • Think of $p$ as a proposal and the accept/reject as a rejection-sampling / MH-style correction.

  • The acceptance ratio $\min(1, q/p)$ ensures unbiasedness: the final token stream is exactly distributed as if sampled from $q$.

  • The residual sampling on rejection fills in the parts where $q$ has more mass than $p$, preserving exactness.

  • Efficiency comes from:

    • $p$ is cheaper and runs sequentially for $K$ steps.
    • $q$ runs once per outer loop but scores $K+1$ contexts in parallel.
    • If $p$ is close to $q$, acceptance rates are high (often 70–95% in practice), so you often accept most of the draft and get the “bonus” $q$ token—yielding a 2–4× wall-clock speedup is common with a well-matched $p$.

Practical notes / gotchas

  • Choose $K$ to match your hardware batch limits for $q$. Too large $K$ can hurt if acceptance falls or memory blows up.
  • Quality of $p$ matters: closer to $q$ → higher acceptance → bigger speedups; but $p$ must be cheaper (smaller model, lower precision).
  • Stopping when you reject is crucial: once a draft token is replaced, the remaining draft no longer matches the new context.
  • Numerics: compute on logits carefully; acceptance uses probabilities (softmaxed). Clip ratios to avoid NaNs.
  • KV caching: you still benefit; the batch scoring of $q$ reuses the prefix KV, then branches $K$ times.

Minimal pseudocode (clear version)

while n < T:
    # 1) Draft K tokens with p
    tilde = []
    ctx = x[:n]
    for t in range(K):
        tilde_t ~ p(. | ctx)
        tilde.append(tilde_t)
        ctx = ctx + [tilde_t]

    # 2) One batched pass of q for K+1 contexts
    contexts = [x[:n]] + [x[:n] + tilde[:t] for t in range(1, K+1)]
    q_logits = q.forward_batch(contexts)     # parallel
    q_probs  = softmax(q_logits)             # per context

    # 3) Verify
    accepted_all = True
    for t in range(1, K+1):
        ctx_prob_q = q_probs[t-1]  # prob dist at step t conditioned on tilde[:t-1]
        ctx_prob_p = p.prob(. | x[:n] + tilde[:t-1])
        z = tilde[t-1]
        a = min(1., ctx_prob_q[z] / max(ctx_prob_p[z], eps))
        if random() < a:
            x.append(z); n += 1
        else:
            # residual draw proportional to (q - p)_+
            residual = relu(ctx_prob_q - ctx_prob_p)
            residual /= residual.sum()
            z_star ~ residual
            x.append(z_star); n += 1
            accepted_all = False
            break

    # 4) Bonus q token if no rejection
    if accepted_all:
        bonus_probs = q_probs[K]   # last context (after tilde[:K])
        z ~ bonus_probs
        x.append(z); n += 1

Recap

  • Use a small $p$ to sprint $K$ steps ahead.
  • Score those $K$ steps with one big batched pass of $q$.
  • Accept each draft token with $\min(1, q/p)$; on first rejection, draw from the leftover $q$-mass and stop verifying.
  • If everything was accepted, grab one extra $q$ token “for free.”
  • Result: exact $q$-samples, but much faster when $p \approx q$.

References

  1. Leviathan, Y., Matias, Y., et al. (2023). Fast Inference from Transformers via Speculative Decoding.
  2. Chen, X., et al. (2023). Accelerating Large Language Model Decoding with Speculative Sampling.
  3. Cai, Z., et al. (2023). Medusa: Simple Framework for Trading Compute for Performance via Speculative Decoding.
  4. Fu, Y., et al. (2024). Self-Speculative Decoding.
Copyright 2025, Ran DingPrivacyTerms