Expose effective MLLM MTP draft stats#473
Conversation
janhilgard
left a comment
There was a problem hiding this comment.
Clean, small observability addition. A few notes:
bypass_reasons is static documentation, not runtime telemetry
bypass_reasons is a hardcoded dict of all possible bypass conditions — it's the same regardless of which ones are actually firing. This is useful as inline documentation, but for diagnosing #471-style problems it would be far more actionable as dynamic counters:
"bypass_counts": {
"prefill": 0,
"no_active_batch": 0,
"concurrent_batch": 0,
"non_greedy_sampling": 142, # <-- this tells you WHY attempted=0
"logits_processors": 0,
}Without this, an operator sees attempted=0 and still has to read source to figure out which bypass is hitting. Not a blocker, but would make this much more useful.
Note on #471 root cause
Separate from this PR, but worth mentioning since it references #471: the benchmark in that issue uses 4-bit quantized MTP weights (scripts/add_mtp_weights_qwen35.py with default quantization → 314 MB). From production experience, quantized MTP weights give 0% acceptance rate — the quantization error is baked into the draft logits, making every draft token wrong.
The fix is --no-quantize flag to extract native BF16 weights from the original HF model (e.g., Qwen/Qwen3.6-27B). See vllm-project/vllm#36331 for background. With BF16 MTP weights, acceptance rates of 78-85% are expected on Qwen3.5/3.6 models.
Minor
server.py main()also has an MTP startup log (print(f"MTP: enabled, ...")) at ~line 5412 — it wasn't updated to match thecli.pyrename fromdraft_tokenstorequested_draft_tokens. Worth keeping them consistent.
CI 9/9 green. No merge conflicts.
|
Thanks for the review. bypass_reasons as dynamic counters -- agreed, that's more useful. I'll change server.py startup log consistency -- will update the #471 quantized MTP weights -- that's a good catch. Worth commenting on #471 directly so the reporter knows BF16 MTP weights are needed for real acceptance rates. I'll add a note there. black lint failure -- |
waybarrios
left a comment
There was a problem hiding this comment.
@Thump604 thanks for the ping. Please address the points below before this is ready to merge.
In vllm_mlx/server.py the /v1/status handler now looks like this:
"cache": stats.get("memory_aware_cache")
or stats.get("paged_cache")
or stats.get("prefix_cache"),
"mtp": stats.get("mtp"),
"requests": stats.get("requests", []),When MTP isn't active, stats.get("mtp") returns None and the endpoint emits "mtp": null. Any external client that does response["mtp"]["enabled"] will eat a TypeError. Other fields like cache already use the or pattern to collapse to something non-null, this one was left dangling. Please either return a sentinel dict like {"enabled": False} or omit the key with if stats.get("mtp") is not None.
In vllm_mlx/mllm_batch_generator.py _get_mtp_stats computes the acceptance rate by reading two fields from the shared dict without a lock:
def _get_mtp_stats() -> Dict[str, Any]:
verified = _mtp_stats["accepted"] + _mtp_stats["rejected"]
acceptance_rate = _mtp_stats["accepted"] / verified if verified > 0 else 0.0The engine thread is incrementing accepted/rejected while the HTTP handler reads. The GIL makes each += atomic on its own, but reading the two fields is not an atomic snapshot, so every once in a while you'll see acceptance_rate > 1.0 or transient under-estimates. Cosmetic today, but the moment this moves to a worker in another process it stops being cosmetic. Please snapshot both reads under a lock, or pull them into local variables before the computation.
Also, bypass_reasons ended up as a static dict with hardcoded strings:
"bypass_reasons": {
"prefill": "input_tokens.shape[1] > 1",
"no_active_batch": "active_batch is None",
"concurrent_batch": "len(active_batch) > 1",
"non_greedy_sampling": "temperature/top_p/top_k/min_p not greedy",
"logits_processors": "request-local logits processors active",
},That documents the guards but doesn't stay in sync with the actual code. The day someone relaxes the concurrent_batch guard (the comment inside _mtp_step already anticipates that's possible once a sampler-aware verifier exists), this dict will silently lie. Please turn the keys into real counters (increment them each time the bypass fires) and move the descriptions to a comment next to the guards, not in the payload.
One more thing to note (probably out of scope for this PR, fine as a follow-up): observability only covers the MLLM path. The _install_mtp for the non-MLLM path in vllm_mlx/scheduler.py keeps its own _mtp_stats that is never surfaced, so running with --enable-mtp without --mllm returns "mtp": null even though MTP is doing work. Worth opening a follow-up ticket for parity.
|
Pushed 7470b12 addressing the requested changes. Covered:
Validation:
Notes:
|
janhilgard
left a comment
There was a problem hiding this comment.
Clean observability PR. The stats design is well thought out:
attemptedcounts actual MTP forward calls (not bypassed steps) — semantics are correct:attempted == accepted + rejected + errors- Bypass counts track each reason independently (a single step can trigger multiple), which is the right model for understanding frequency of each condition
acceptance_ratecorrectly excludes errors from the denominator- Thread-safe with
_mtp_stats_lockfor concurrent reads from/v1/status effective_draft_tokens: 1is honest about the current MLLM limitation- Startup message transparently tells the operator what to expect
- Default
{"enabled": False}keeps the/v1/statusshape consistent when MTP is off
Tests cover all bypass categories (prefill, no_active_batch, concurrent_batch, non_greedy_sampling, logits_processors) and the happy path (attempted/accepted/acceptance_rate). The structural bypass test exercising three reasons in sequence is particularly thorough.
Two minor observations (non-blocking):
-
Bypass counts can overlap — a step that hits both
no_active_batchandprefillincrements both counters. This is the right behavior for per-condition frequency analysis, but worth noting thatsum(bypass_counts.values())may exceed actual bypass count. A brief docstring mention could help future readers. -
3 commits — repo convention is single squash commit per PR, but that's your call as maintainer.
LGTM.
Summary
This makes the MLLM MTP path observable and less misleading for #471.
The current MLLM path can inject MTP weights and print that MTP is enabled with the requested draft count, while the actual generator path is constrained to one effective draft token per verify step and bypasses MTP for prefill, concurrent batches, non-greedy sampling, and active logits processors. That makes a
--mtp-num-draft-tokens 5run look enabled even when it cannot provide the expected speedup.Changes:
draft_tokenstorequested_draft_tokens/v1/status.mtpfor MLLM with requested/effective draft count, attempts, accepted, rejected, errors, acceptance rate, and bypass reasonsThis PR does not claim to make MLLM MTP faster. It makes the runtime honest enough to measure whether the path is doing useful speculative work before a follow-up performance change.
Refs #471.
Validation
python -m py_compile vllm_mlx/mllm_batch_generator.py vllm_mlx/mllm_scheduler.py vllm_mlx/engine/batched.py vllm_mlx/server.py vllm_mlx/cli.py tests/test_mllm_continuous_batching.pygit diff --checkI did not run the MLX-importing unit test file locally because a resident qualification model is active on this machine. CI should exercise the added test assertions.