The Autoregressive Bottleneck Nobody Talks About
Generating text with large language models is painfully slow, and the reason is embarrassingly simple: they produce one token at a time. Each token requires a full forward pass through billions of parameters. For a 70B model generating 100 tokens, that’s 100 separate inference calls. The GPU sits there, massively parallel hardware reduced to sequential token generation because each prediction depends on the previous one.
I’ve spent months trying to speed up LLM inference for a production chatbot, and traditional methods like quantization and batching only get you so far. Then I stumbled onto speculative decoding—a technique that feels like cheating but actually works. The core idea: use a small, fast “draft” model to generate multiple tokens speculatively, then verify them in parallel with the large model. When it works, you get 2-3x speedup with zero quality loss. When it fails, you’ve wasted some compute.
Two recent approaches—Medusa and EAGLE—take this concept in different directions, and I’ve tested both on LLaMA 2 7B/13B variants. Here’s what I learned.
Why Speculative Decoding Works (and When It Doesn’t)
The math behind speculative decoding is deceptively simple. Standard autoregressive generation produces token as:
Each token requires waiting for the previous one. Latency scales linearly: where is sequence length.
Speculative decoding introduces a small draft model (think LLaMA 68M or 160M parameters) alongside the large target model (7B+). The draft model generates tokens speculatively:
Then the target model verifies all tokens in parallel by running a single forward pass on the extended sequence. For each speculative token , we check:
If this holds for all tokens, accept them. If any fails, reject from that point onward and resample. The acceptance rate determines speedup—80% acceptance with means you generate ~3.2 tokens per target model call instead of 1.
But here’s the catch: acceptance rate depends heavily on draft model quality and domain. On in-distribution text (similar to training data), I’ve seen 75-85% acceptance. On out-of-distribution prompts or code generation, it drops to 40-50%, and suddenly your “speedup” becomes overhead.
Medusa: Predicting Multiple Future Tokens Simultaneously
Medusa (Cai et al., 2024) takes a fundamentally different approach: instead of using a separate draft model, it adds multiple prediction heads to the target model itself. Each head predicts a different future token position.
The architecture adds lightweight heads (each a 2-layer MLP with residual connections) on top of the base model’s final hidden states. Head predicts token given hidden state :
Training is supervised: freeze the base LLaMA model, then train only the Medusa heads to predict future tokens. The loss is a weighted sum across heads:
I trained Medusa heads () on LLaMA 2 7B using 100k samples from ShareGPT (conversational data). Training took ~6 hours on a single A100 40GB with batch size 16. The heads add only ~200MB to model size.
Decoding with Medusa is wild. At each step, all heads fire simultaneously, producing a tree of candidate sequences. Head 1 might output top-3 tokens, head 2 expands each into top-2, etc. You end up with dozens of candidate paths. The target model verifies all of them in a single batched forward pass.
Here’s simplified inference code (this is my actual implementation, minus some tensor reshaping hell):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class MedusaModel:
def __init__(self, base_model, medusa_heads, num_heads=4):
self.base = base_model
self.heads = medusa_heads # List of nn.Module heads
self.num_heads = num_heads
def generate_candidates(self, hidden_states, top_k=3):
# Each head predicts next token independently
candidates = []
for i, head in enumerate(self.heads):
logits = head(hidden_states) # [batch, vocab_size]
top_tokens = torch.topk(logits, k=top_k, dim=-1).indices
candidates.append(top_tokens)
# Build tree of candidate sequences (this gets complex fast)
# For simplicity, take Cartesian product of top-1 from each head
tree = self._build_tree(candidates) # Returns tensor [num_paths, num_heads]
return tree
def verify_batch(self, input_ids, candidate_tree):
# Extend input with each candidate path, verify in parallel
batch = torch.cat([input_ids.repeat(len(candidate_tree), 1),
candidate_tree], dim=1)
with torch.no_grad():
outputs = self.base(batch)
# Check which paths match target model predictions
logits = outputs.logits[:, -self.num_heads-1:-1, :] # Get logits for candidate positions
pred_tokens = logits.argmax(dim=-1) # [num_paths, num_heads]
# Find longest matching prefix across all paths
matches = (pred_tokens == candidate_tree).all(dim=1)
if matches.any():
best_path = candidate_tree[matches][0] # Accept first full match
return best_path, len(best_path)
else:
# Partial match: find longest prefix
prefix_lens = (pred_tokens == candidate_tree).sum(dim=1)
best_idx = prefix_lens.argmax()
return candidate_tree[best_idx][:prefix_lens[best_idx]], prefix_lens[best_idx].item()
The tree construction (_build_tree) is the tricky part. Naïve Cartesian product explodes quickly— paths for 4 heads with top-3 each. Medusa uses heuristics to prune unlikely branches (e.g., drop paths where head predictions conflict with each other’s logits).
In my tests on conversational prompts (average 50 tokens), Medusa achieved:
– Acceptance rate: 68% (2.7 tokens per step on average)
– Speedup: 2.1x vs baseline autoregressive
– GPU memory: +15% for heads and candidate batch
The big win: no separate draft model to maintain. The big loss: training is dataset-specific. Heads trained on chat data performed poorly on code (acceptance dropped to 45%).
EAGLE: Learning What the Target Model Will Say Next
EAGLE (Li et al., 2024) flips the script. Instead of predicting raw future tokens, it trains a small auto-regressive draft model to predict the target model’s next hidden state, then decodes from that.
The intuition: hidden states are continuous and smoother than discrete tokens, so they’re easier to predict. EAGLE’s draft model is a 2-layer transformer (68M params for LLaMA 7B) that takes previous hidden states and predicts the next one:
Then decode the predicted hidden state to a token using the target model’s LM head:
This draft token is verified against the target model’s actual prediction, just like classic speculative decoding.
Training EAGLE is surprisingly fast. I used 50k samples from ShareGPT (same as Medusa) and trained for 3 hours on one A100. The loss is simple MSE between predicted and actual hidden states:
Implementation is cleaner than Medusa because you’re back to standard speculative decoding—no tree search, just sequential draft tokens verified in parallel:
class EAGLEDecoder:
def __init__(self, target_model, draft_model, k=4):
self.target = target_model
self.draft = draft_model
self.k = k # Speculation depth
def generate_step(self, input_ids, past_hidden):
# Draft model predicts next k hidden states
draft_hidden = past_hidden[-1] # Last target hidden state
draft_tokens = []
for _ in range(self.k):
draft_hidden = self.draft(draft_hidden.unsqueeze(0)) # [1, hidden_dim]
logits = self.target.lm_head(draft_hidden) # Decode to token
token = logits.argmax(dim=-1).item()
draft_tokens.append(token)
# Verify all draft tokens in parallel
candidate_ids = torch.cat([input_ids, torch.tensor(draft_tokens).unsqueeze(0)], dim=1)
with torch.no_grad():
outputs = self.target(candidate_ids, output_hidden_states=True)
target_tokens = outputs.logits[:, -self.k-1:-1].argmax(dim=-1).squeeze()
# Find acceptance length
matches = (target_tokens == torch.tensor(draft_tokens))
accept_len = matches.sum().item() if matches.any() else 0
if accept_len > 0:
return draft_tokens[:accept_len], outputs.hidden_states[-1][:, :len(input_ids)+accept_len]
else:
# Reject all, return single target token
return [target_tokens[0].item()], outputs.hidden_states[-1][:, :len(input_ids)+1]
EAGLE’s acceptance rate on my chatbot dataset was 72%—slightly higher than Medusa. Speedup was 2.3x, and the draft model added only 68MB (vs Medusa’s 200MB heads).
The real advantage: EAGLE generalizes better. When I switched to code generation prompts (Python function completions), acceptance rate dropped to 58% (vs Medusa’s 45%). My best guess is that hidden state prediction is less brittle than direct token prediction—the draft model learns “semantic direction” rather than exact tokens.
Where Both Approaches Hit Walls
Neither technique is a silver bullet. I hit three major problems:
1. Out-of-distribution collapse. Both Medusa and EAGLE were trained on conversational English. When I fed them LaTeX math or SQL queries, acceptance rates tanked (35-40%). The draft models confidently predict wrong continuations, and verification overhead kills any speedup.
2. Memory overhead during verification. Verifying tokens requires running the target model on a sequence tokens longer. For LLaMA 13B with , I hit OOM on a 24GB GPU at batch size 1 when the context exceeded 1800 tokens (the KV cache for the extended sequence is just too big). You can mitigate this by dynamically reducing as context grows, but it’s annoying.
3. Highly uncertain tokens break everything. When the target model’s next-token distribution is high-entropy (e.g., creative writing, open-ended questions), the draft model has no chance of guessing right. Acceptance rate correlates inversely with target model entropy:
I measured entropy on each prompt and saw speedup vanish entirely when nats (roughly uniform over 90 tokens). For these cases, you’re better off just disabling speculative decoding.
When to Use Which (and When to Skip Both)
After running both for a month in production:
Use Medusa if:
– Your prompts are in-distribution (similar to training data)
– You can’t afford a separate draft model in memory
– You’re willing to retrain heads per domain (chat vs code vs math)
Use EAGLE if:
– You need better out-of-distribution generalization
– You have memory for a small draft model (~70MB)
– You want simpler inference code (no tree search)
Skip both if:
– Prompts are highly diverse (mixing domains)
– Target model entropy is consistently high (creative tasks)
– Context lengths exceed 2000 tokens (verification memory overhead)
– You’re already using 4-bit quantization + flash attention (diminishing returns)
One more thing: I’m curious whether these techniques stack with other speedup methods. I tried combining EAGLE with 8-bit quantization (bitsandbytes) and got 3.1x total speedup (vs 2.3x from EAGLE alone), but stability was sketchy—occasional NaN logits that I never fully debugged. The interaction between quantization rounding and speculative acceptance thresholds needs more investigation.
For my chatbot, I ended up using EAGLE with dynamic adjustment (reduce speculation depth as context grows) and entropy-based fallback (disable speculation when ). That gave 1.9x average speedup across all prompts without quality loss. Not the 3x promised in papers, but enough to cut our inference bill by 40%.
Did you find this helpful?
☕ Buy me a coffee
Leave a Reply