Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/test_lifecycle_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,39 @@ async def test_health_endpoint_handles_engine_none(self, monkeypatch):
assert result["status"] == "healthy"
assert result["model_loaded"] is False

@pytest.mark.anyio
async def test_status_endpoint_returns_disabled_mtp_object_when_absent(
self, monkeypatch
):
"""/v1/status should keep the mtp field object-shaped when MTP is off."""
import vllm_mlx.server as srv

class EngineWithoutMTPStats:
def get_stats(self):
return {
"running": True,
"uptime_seconds": 10,
"steps_executed": 1,
"num_running": 0,
"num_waiting": 0,
"num_requests_processed": 0,
"total_prompt_tokens": 0,
"total_completion_tokens": 0,
"metal_active_memory_gb": 0,
"metal_peak_memory_gb": 0,
"metal_cache_memory_gb": 0,
"requests": [],
}

monkeypatch.setattr(srv, "_engine", EngineWithoutMTPStats())
monkeypatch.setattr(srv, "_model_name", "test")
monkeypatch.setattr(srv, "_residency_manager", None)
monkeypatch.setattr(srv, "_default_model_key", None)

result = await srv.status()

assert result["mtp"] == {"enabled": False}


class TestToolParserUsesLocalEngine:
"""Tool parser should initialize from the request-local engine."""
Expand Down
88 changes: 88 additions & 0 deletions tests/test_mllm_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,19 @@ def __init__(self):

install_mtp_mllm(batch_gen, language_model, num_draft_tokens=4)

stats = batch_gen.get_mtp_stats()
assert stats["enabled"] is True
assert stats["requested_draft_tokens"] == 4
assert stats["effective_draft_tokens"] == 1
assert stats["attempted"] == 0
assert stats["bypass_counts"] == {
"prefill": 0,
"no_active_batch": 0,
"concurrent_batch": 0,
"non_greedy_sampling": 0,
"logits_processors": 0,
}

logits_processor = MagicMock()
tokens, logprobs = batch_gen._step(
mx.array([[123]]),
Expand All @@ -1021,6 +1034,9 @@ def __init__(self):
original_step.assert_called_once()
language_model.assert_not_called()
language_model.mtp_forward.assert_not_called()
stats = batch_gen.get_mtp_stats()
assert stats["attempted"] == 0
assert stats["bypass_counts"]["logits_processors"] == 1

def test_install_mtp_mllm_disables_mtp_for_non_greedy_sampling(self):
from vllm_mlx.mllm_batch_generator import install_mtp_mllm
Expand Down Expand Up @@ -1065,6 +1081,75 @@ def __init__(self):
original_step.assert_called_once()
language_model.assert_not_called()
language_model.mtp_forward.assert_not_called()
stats = batch_gen.get_mtp_stats()
assert stats["attempted"] == 0
assert stats["bypass_counts"]["non_greedy_sampling"] == 1

def test_install_mtp_mllm_counts_structural_bypasses(self):
from vllm_mlx.mllm_batch_generator import install_mtp_mllm

expected_tokens = mx.array([11])
expected_logprobs = [mx.array([0.3, 0.7])]
original_step = MagicMock(return_value=(expected_tokens, expected_logprobs))

class FakeBatchGen:
def __init__(self):
self._step = original_step
self._next = MagicMock(return_value=[])
self.active_batch = MagicMock()
self.active_batch.__len__.return_value = 1
self.active_batch.requests = [
MagicMock(
temperature=0.0,
top_p=1.0,
top_k=0,
min_p=0.0,
)
]
self.sampler = MagicMock()

batch_gen = FakeBatchGen()
language_model = MagicMock()

install_mtp_mllm(batch_gen, language_model, num_draft_tokens=4)

batch_gen._step(
mx.array([[1, 2]]),
cache=[],
logits_processors=None,
output_tokens=None,
samplers=None,
)
batch_gen.active_batch = None
batch_gen._step(
mx.array([[3]]),
cache=[],
logits_processors=None,
output_tokens=None,
samplers=None,
)
batch_gen.active_batch = MagicMock()
batch_gen.active_batch.__len__.return_value = 2
batch_gen.active_batch.requests = [
MagicMock(temperature=0.0, top_p=1.0, top_k=0, min_p=0.0),
MagicMock(temperature=0.0, top_p=1.0, top_k=0, min_p=0.0),
]
batch_gen._step(
mx.array([[4]]),
cache=[],
logits_processors=None,
output_tokens=None,
samplers=None,
)

assert original_step.call_count == 3
language_model.assert_not_called()
language_model.mtp_forward.assert_not_called()
stats = batch_gen.get_mtp_stats()
assert stats["attempted"] == 0
assert stats["bypass_counts"]["prefill"] == 1
assert stats["bypass_counts"]["no_active_batch"] == 1
assert stats["bypass_counts"]["concurrent_batch"] == 1

def test_install_mtp_mllm_accepted_drafts_bypass_request_sampler(self):
from vllm_mlx.mllm_batch_generator import MLLMBatchResponse, install_mtp_mllm
Expand Down Expand Up @@ -1130,6 +1215,9 @@ def __call__(self, verify_input, cache=None, return_hidden=False):
assert [r.token for r in responses] == [1, 2]
assert request_sampler.call_count == 1
assert batch_gen.sampler.call_count == 0
assert batch_gen.get_mtp_stats()["attempted"] == 1
assert batch_gen.get_mtp_stats()["accepted"] == 1
assert batch_gen.get_mtp_stats()["acceptance_rate"] == 1.0

def test_next_keeps_retired_processors_by_default(self, monkeypatch):
from vllm_mlx.mllm_batch_generator import (
Expand Down
7 changes: 6 additions & 1 deletion vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,12 @@ def serve_command(args):
if args.chunked_prefill_tokens > 0:
print(f"Chunked prefill: {args.chunked_prefill_tokens} tokens per step")
if args.enable_mtp:
print(f"MTP: enabled, draft_tokens={args.mtp_num_draft_tokens}")
print(f"MTP: enabled, requested_draft_tokens={args.mtp_num_draft_tokens}")
if args.mllm:
print(
"MTP: MLLM path currently uses effective_draft_tokens=1 "
"per verify step; inspect /v1/status for attempts and acceptance"
)
print(f"Stream interval: {args.stream_interval} tokens")
if args.use_paged_cache:
print(
Expand Down
1 change: 1 addition & 0 deletions vllm_mlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,7 @@ def get_stats(self) -> dict[str, Any]:
"paged_cache",
"prefix_cache",
"batch_generator",
"mtp",
"requests",
):
if key in mllm_stats:
Expand Down
90 changes: 78 additions & 12 deletions vllm_mlx/mllm_batch_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1973,8 +1973,41 @@ def install_mtp_mllm(
# Deferred drafts keyed by UID
_deferred_drafts: Dict[int, dict] = {}

# MTP stats
_mtp_stats = {"accepted": 0, "rejected": 0, "errors": 0}
# MTP stats. These are intentionally exposed through get_mtp_stats() so
# /v1/status can distinguish "weights injected" from useful draft work.
_mtp_stats_lock = threading.Lock()
_mtp_stats = {"attempted": 0, "accepted": 0, "rejected": 0, "errors": 0}
_bypass_counts = {
"prefill": 0,
"no_active_batch": 0,
"concurrent_batch": 0,
"non_greedy_sampling": 0,
"logits_processors": 0,
}

def _get_mtp_stats() -> Dict[str, Any]:
with _mtp_stats_lock:
attempted = _mtp_stats["attempted"]
accepted = _mtp_stats["accepted"]
rejected = _mtp_stats["rejected"]
errors = _mtp_stats["errors"]
bypass_counts = dict(_bypass_counts)
verified = accepted + rejected
acceptance_rate = accepted / verified if verified > 0 else 0.0
return {
"enabled": True,
"requested_draft_tokens": num_draft_tokens,
"effective_draft_tokens": 1,
"mode": "always_advance_verified",
"attempted": attempted,
"accepted": accepted,
"rejected": rejected,
"errors": errors,
"acceptance_rate": acceptance_rate,
"bypass_counts": bypass_counts,
}

batch_gen.get_mtp_stats = _get_mtp_stats

def _mtp_step(
input_tokens: mx.array,
Expand Down Expand Up @@ -2006,13 +2039,40 @@ def _mtp_step(
# from the greedy draft/argmax verify path; they do not pass through the
# request-local sampler. Non-greedy decoding needs a sampler-aware
# verifier before this guard can be safely relaxed.
prefill_bypass = input_tokens.shape[1] > 1
no_active_batch_bypass = batch_gen.active_batch is None
concurrent_batch_bypass = (
batch_gen.active_batch is not None and len(batch_gen.active_batch) > 1
)
non_greedy_bypass = has_non_greedy_sampling
logits_processors_bypass = logits_processors is not None and any(
logits_processors
)
if (
input_tokens.shape[1] > 1
or batch_gen.active_batch is None
or len(batch_gen.active_batch) > 1
or has_non_greedy_sampling
or (logits_processors is not None and any(logits_processors))
prefill_bypass
or no_active_batch_bypass
or concurrent_batch_bypass
or non_greedy_bypass
or logits_processors_bypass
):
# Keep the descriptions near the guards so operator-facing
# telemetry stays dynamic instead of duplicating code predicates:
# prefill=input_tokens.shape[1] > 1
# no_active_batch=active_batch is None
# concurrent_batch=len(active_batch) > 1
# non_greedy_sampling=request sampler is not greedy
# logits_processors=request-local processors are active
with _mtp_stats_lock:
if prefill_bypass:
_bypass_counts["prefill"] += 1
if no_active_batch_bypass:
_bypass_counts["no_active_batch"] += 1
if concurrent_batch_bypass:
_bypass_counts["concurrent_batch"] += 1
if non_greedy_bypass:
_bypass_counts["non_greedy_sampling"] += 1
if logits_processors_bypass:
_bypass_counts["logits_processors"] += 1
_skip_state[0] = None
return _orig_step(
input_tokens, cache, logits_processors, output_tokens, samplers
Expand Down Expand Up @@ -2067,6 +2127,8 @@ def _mtp_step(

# MTP draft + always-advance verify
try:
with _mtp_stats_lock:
_mtp_stats["attempted"] += 1
draft_logits = language_model.mtp_forward(
hidden_states[:, -1:, :],
primary_tokens[:, None],
Expand Down Expand Up @@ -2123,7 +2185,8 @@ def _mtp_step(
"token": draft_list[e],
"logprobs": verify_lp[e],
}
_mtp_stats["accepted"] += 1
with _mtp_stats_lock:
_mtp_stats["accepted"] += 1

else:
# REJECT
Expand Down Expand Up @@ -2181,19 +2244,22 @@ def _mtp_step(
_skip_state[0] = None
for uid in current_uids:
_deferred_drafts.pop(uid, None)
_mtp_stats["rejected"] += 1
with _mtp_stats_lock:
_mtp_stats["rejected"] += 1

except Exception as e:
logger.warning(f"[MTP-MLLM] draft/verify failed: {e}")
_skip_state[0] = None
_mtp_stats["errors"] += 1
with _mtp_stats_lock:
_mtp_stats["errors"] += 1

# Log MTP stats every 50 steps
total = _mtp_stats["accepted"] + _mtp_stats["rejected"] + _mtp_stats["errors"]
if total > 0 and total % 50 == 0:
with _mtp_stats_lock:
acc = _mtp_stats["accepted"]
rej = _mtp_stats["rejected"]
err = _mtp_stats["errors"]
total = acc + rej + err
if total > 0 and total % 50 == 0:
rate = acc / (acc + rej) * 100 if (acc + rej) > 0 else 0
logger.info(
f"[MTP-MLLM] stats: accepted={acc} rejected={rej} "
Expand Down
2 changes: 2 additions & 0 deletions vllm_mlx/mllm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,8 @@ def get_stats(self) -> Dict[str, Any]:
# Vision embedding cache stats from batch generator
vec_stats = self.batch_generator.get_vision_cache_stats()
stats["vision_embedding_cache"] = vec_stats
if hasattr(self.batch_generator, "get_mtp_stats"):
stats["mtp"] = self.batch_generator.get_mtp_stats()

# Include Metal memory stats
try:
Expand Down
1 change: 1 addition & 0 deletions vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3025,6 +3025,7 @@ async def status():
"cache": stats.get("memory_aware_cache")
or stats.get("paged_cache")
or stats.get("prefix_cache"),
"mtp": stats.get("mtp") or {"enabled": False},
"requests": stats.get("requests", []),
}

Expand Down
Loading