diff --git a/tests/test_lifecycle_server.py b/tests/test_lifecycle_server.py index 453af6a0..eb2a357e 100644 --- a/tests/test_lifecycle_server.py +++ b/tests/test_lifecycle_server.py @@ -501,6 +501,39 @@ async def test_health_endpoint_handles_engine_none(self, monkeypatch): assert result["status"] == "healthy" assert result["model_loaded"] is False + @pytest.mark.anyio + async def test_status_endpoint_returns_disabled_mtp_object_when_absent( + self, monkeypatch + ): + """/v1/status should keep the mtp field object-shaped when MTP is off.""" + import vllm_mlx.server as srv + + class EngineWithoutMTPStats: + def get_stats(self): + return { + "running": True, + "uptime_seconds": 10, + "steps_executed": 1, + "num_running": 0, + "num_waiting": 0, + "num_requests_processed": 0, + "total_prompt_tokens": 0, + "total_completion_tokens": 0, + "metal_active_memory_gb": 0, + "metal_peak_memory_gb": 0, + "metal_cache_memory_gb": 0, + "requests": [], + } + + monkeypatch.setattr(srv, "_engine", EngineWithoutMTPStats()) + monkeypatch.setattr(srv, "_model_name", "test") + monkeypatch.setattr(srv, "_residency_manager", None) + monkeypatch.setattr(srv, "_default_model_key", None) + + result = await srv.status() + + assert result["mtp"] == {"enabled": False} + class TestToolParserUsesLocalEngine: """Tool parser should initialize from the request-local engine.""" diff --git a/tests/test_mllm_continuous_batching.py b/tests/test_mllm_continuous_batching.py index 4a9838de..f31fbd39 100644 --- a/tests/test_mllm_continuous_batching.py +++ b/tests/test_mllm_continuous_batching.py @@ -1005,6 +1005,19 @@ def __init__(self): install_mtp_mllm(batch_gen, language_model, num_draft_tokens=4) + stats = batch_gen.get_mtp_stats() + assert stats["enabled"] is True + assert stats["requested_draft_tokens"] == 4 + assert stats["effective_draft_tokens"] == 1 + assert stats["attempted"] == 0 + assert stats["bypass_counts"] == { + "prefill": 0, + "no_active_batch": 0, + "concurrent_batch": 0, + "non_greedy_sampling": 0, + "logits_processors": 0, + } + logits_processor = MagicMock() tokens, logprobs = batch_gen._step( mx.array([[123]]), @@ -1021,6 +1034,9 @@ def __init__(self): original_step.assert_called_once() language_model.assert_not_called() language_model.mtp_forward.assert_not_called() + stats = batch_gen.get_mtp_stats() + assert stats["attempted"] == 0 + assert stats["bypass_counts"]["logits_processors"] == 1 def test_install_mtp_mllm_disables_mtp_for_non_greedy_sampling(self): from vllm_mlx.mllm_batch_generator import install_mtp_mllm @@ -1065,6 +1081,75 @@ def __init__(self): original_step.assert_called_once() language_model.assert_not_called() language_model.mtp_forward.assert_not_called() + stats = batch_gen.get_mtp_stats() + assert stats["attempted"] == 0 + assert stats["bypass_counts"]["non_greedy_sampling"] == 1 + + def test_install_mtp_mllm_counts_structural_bypasses(self): + from vllm_mlx.mllm_batch_generator import install_mtp_mllm + + expected_tokens = mx.array([11]) + expected_logprobs = [mx.array([0.3, 0.7])] + original_step = MagicMock(return_value=(expected_tokens, expected_logprobs)) + + class FakeBatchGen: + def __init__(self): + self._step = original_step + self._next = MagicMock(return_value=[]) + self.active_batch = MagicMock() + self.active_batch.__len__.return_value = 1 + self.active_batch.requests = [ + MagicMock( + temperature=0.0, + top_p=1.0, + top_k=0, + min_p=0.0, + ) + ] + self.sampler = MagicMock() + + batch_gen = FakeBatchGen() + language_model = MagicMock() + + install_mtp_mllm(batch_gen, language_model, num_draft_tokens=4) + + batch_gen._step( + mx.array([[1, 2]]), + cache=[], + logits_processors=None, + output_tokens=None, + samplers=None, + ) + batch_gen.active_batch = None + batch_gen._step( + mx.array([[3]]), + cache=[], + logits_processors=None, + output_tokens=None, + samplers=None, + ) + batch_gen.active_batch = MagicMock() + batch_gen.active_batch.__len__.return_value = 2 + batch_gen.active_batch.requests = [ + MagicMock(temperature=0.0, top_p=1.0, top_k=0, min_p=0.0), + MagicMock(temperature=0.0, top_p=1.0, top_k=0, min_p=0.0), + ] + batch_gen._step( + mx.array([[4]]), + cache=[], + logits_processors=None, + output_tokens=None, + samplers=None, + ) + + assert original_step.call_count == 3 + language_model.assert_not_called() + language_model.mtp_forward.assert_not_called() + stats = batch_gen.get_mtp_stats() + assert stats["attempted"] == 0 + assert stats["bypass_counts"]["prefill"] == 1 + assert stats["bypass_counts"]["no_active_batch"] == 1 + assert stats["bypass_counts"]["concurrent_batch"] == 1 def test_install_mtp_mllm_accepted_drafts_bypass_request_sampler(self): from vllm_mlx.mllm_batch_generator import MLLMBatchResponse, install_mtp_mllm @@ -1130,6 +1215,9 @@ def __call__(self, verify_input, cache=None, return_hidden=False): assert [r.token for r in responses] == [1, 2] assert request_sampler.call_count == 1 assert batch_gen.sampler.call_count == 0 + assert batch_gen.get_mtp_stats()["attempted"] == 1 + assert batch_gen.get_mtp_stats()["accepted"] == 1 + assert batch_gen.get_mtp_stats()["acceptance_rate"] == 1.0 def test_next_keeps_retired_processors_by_default(self, monkeypatch): from vllm_mlx.mllm_batch_generator import ( diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 7ce1e540..9b21c14f 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -261,7 +261,12 @@ def serve_command(args): if args.chunked_prefill_tokens > 0: print(f"Chunked prefill: {args.chunked_prefill_tokens} tokens per step") if args.enable_mtp: - print(f"MTP: enabled, draft_tokens={args.mtp_num_draft_tokens}") + print(f"MTP: enabled, requested_draft_tokens={args.mtp_num_draft_tokens}") + if args.mllm: + print( + "MTP: MLLM path currently uses effective_draft_tokens=1 " + "per verify step; inspect /v1/status for attempts and acceptance" + ) print(f"Stream interval: {args.stream_interval} tokens") if args.use_paged_cache: print( diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 415b4dfc..3c24aafc 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -1100,6 +1100,7 @@ def get_stats(self) -> dict[str, Any]: "paged_cache", "prefix_cache", "batch_generator", + "mtp", "requests", ): if key in mllm_stats: diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index 52656aac..9c03889f 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -1973,8 +1973,41 @@ def install_mtp_mllm( # Deferred drafts keyed by UID _deferred_drafts: Dict[int, dict] = {} - # MTP stats - _mtp_stats = {"accepted": 0, "rejected": 0, "errors": 0} + # MTP stats. These are intentionally exposed through get_mtp_stats() so + # /v1/status can distinguish "weights injected" from useful draft work. + _mtp_stats_lock = threading.Lock() + _mtp_stats = {"attempted": 0, "accepted": 0, "rejected": 0, "errors": 0} + _bypass_counts = { + "prefill": 0, + "no_active_batch": 0, + "concurrent_batch": 0, + "non_greedy_sampling": 0, + "logits_processors": 0, + } + + def _get_mtp_stats() -> Dict[str, Any]: + with _mtp_stats_lock: + attempted = _mtp_stats["attempted"] + accepted = _mtp_stats["accepted"] + rejected = _mtp_stats["rejected"] + errors = _mtp_stats["errors"] + bypass_counts = dict(_bypass_counts) + verified = accepted + rejected + acceptance_rate = accepted / verified if verified > 0 else 0.0 + return { + "enabled": True, + "requested_draft_tokens": num_draft_tokens, + "effective_draft_tokens": 1, + "mode": "always_advance_verified", + "attempted": attempted, + "accepted": accepted, + "rejected": rejected, + "errors": errors, + "acceptance_rate": acceptance_rate, + "bypass_counts": bypass_counts, + } + + batch_gen.get_mtp_stats = _get_mtp_stats def _mtp_step( input_tokens: mx.array, @@ -2006,13 +2039,40 @@ def _mtp_step( # from the greedy draft/argmax verify path; they do not pass through the # request-local sampler. Non-greedy decoding needs a sampler-aware # verifier before this guard can be safely relaxed. + prefill_bypass = input_tokens.shape[1] > 1 + no_active_batch_bypass = batch_gen.active_batch is None + concurrent_batch_bypass = ( + batch_gen.active_batch is not None and len(batch_gen.active_batch) > 1 + ) + non_greedy_bypass = has_non_greedy_sampling + logits_processors_bypass = logits_processors is not None and any( + logits_processors + ) if ( - input_tokens.shape[1] > 1 - or batch_gen.active_batch is None - or len(batch_gen.active_batch) > 1 - or has_non_greedy_sampling - or (logits_processors is not None and any(logits_processors)) + prefill_bypass + or no_active_batch_bypass + or concurrent_batch_bypass + or non_greedy_bypass + or logits_processors_bypass ): + # Keep the descriptions near the guards so operator-facing + # telemetry stays dynamic instead of duplicating code predicates: + # prefill=input_tokens.shape[1] > 1 + # no_active_batch=active_batch is None + # concurrent_batch=len(active_batch) > 1 + # non_greedy_sampling=request sampler is not greedy + # logits_processors=request-local processors are active + with _mtp_stats_lock: + if prefill_bypass: + _bypass_counts["prefill"] += 1 + if no_active_batch_bypass: + _bypass_counts["no_active_batch"] += 1 + if concurrent_batch_bypass: + _bypass_counts["concurrent_batch"] += 1 + if non_greedy_bypass: + _bypass_counts["non_greedy_sampling"] += 1 + if logits_processors_bypass: + _bypass_counts["logits_processors"] += 1 _skip_state[0] = None return _orig_step( input_tokens, cache, logits_processors, output_tokens, samplers @@ -2067,6 +2127,8 @@ def _mtp_step( # MTP draft + always-advance verify try: + with _mtp_stats_lock: + _mtp_stats["attempted"] += 1 draft_logits = language_model.mtp_forward( hidden_states[:, -1:, :], primary_tokens[:, None], @@ -2123,7 +2185,8 @@ def _mtp_step( "token": draft_list[e], "logprobs": verify_lp[e], } - _mtp_stats["accepted"] += 1 + with _mtp_stats_lock: + _mtp_stats["accepted"] += 1 else: # REJECT @@ -2181,19 +2244,22 @@ def _mtp_step( _skip_state[0] = None for uid in current_uids: _deferred_drafts.pop(uid, None) - _mtp_stats["rejected"] += 1 + with _mtp_stats_lock: + _mtp_stats["rejected"] += 1 except Exception as e: logger.warning(f"[MTP-MLLM] draft/verify failed: {e}") _skip_state[0] = None - _mtp_stats["errors"] += 1 + with _mtp_stats_lock: + _mtp_stats["errors"] += 1 # Log MTP stats every 50 steps - total = _mtp_stats["accepted"] + _mtp_stats["rejected"] + _mtp_stats["errors"] - if total > 0 and total % 50 == 0: + with _mtp_stats_lock: acc = _mtp_stats["accepted"] rej = _mtp_stats["rejected"] err = _mtp_stats["errors"] + total = acc + rej + err + if total > 0 and total % 50 == 0: rate = acc / (acc + rej) * 100 if (acc + rej) > 0 else 0 logger.info( f"[MTP-MLLM] stats: accepted={acc} rejected={rej} " diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index e812e339..31955037 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -1130,6 +1130,8 @@ def get_stats(self) -> Dict[str, Any]: # Vision embedding cache stats from batch generator vec_stats = self.batch_generator.get_vision_cache_stats() stats["vision_embedding_cache"] = vec_stats + if hasattr(self.batch_generator, "get_mtp_stats"): + stats["mtp"] = self.batch_generator.get_mtp_stats() # Include Metal memory stats try: diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 1a546f84..38c6b9a5 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -3025,6 +3025,7 @@ async def status(): "cache": stats.get("memory_aware_cache") or stats.get("paged_cache") or stats.get("prefix_cache"), + "mtp": stats.get("mtp") or {"enabled": False}, "requests": stats.get("requests", []), }