diff --git a/tests/test_mllm_mtp_routing.py b/tests/test_mllm_mtp_routing.py index 9fa1e72f..337fc6bb 100644 --- a/tests/test_mllm_mtp_routing.py +++ b/tests/test_mllm_mtp_routing.py @@ -189,6 +189,81 @@ def test_has_media_content_none_content(): from unittest.mock import MagicMock +def test_mllm_mtp_attempt_metadata_only_marks_real_attempts(): + from vllm_mlx.mllm_batch_generator import ( + MLLMBatchResponse, + _mark_mtp_attempts_on_primary_responses, + ) + + attempted = {7: 1} + attempted_response = MLLMBatchResponse( + uid=7, + request_id="req-7", + token=11, + logprobs=None, + ) + skipped_response = MLLMBatchResponse( + uid=9, + request_id="req-9", + token=21, + logprobs=None, + ) + + _mark_mtp_attempts_on_primary_responses( + [attempted_response, skipped_response], + attempted, + ) + + assert attempted_response.mtp_attempted is True + assert attempted_response.mtp_attempted_count == 1 + assert skipped_response.mtp_attempted is False + assert skipped_response.mtp_attempted_count == 0 + + +def test_mllm_scheduler_exposes_mtp_attempts_and_accepts_on_outputs(): + from vllm_mlx.mllm_batch_generator import MLLMBatchResponse + from vllm_mlx.mllm_scheduler import MLLMRequest, MLLMScheduler + + class _Tokenizer: + clean_up_tokenization_spaces = False + + def decode(self, tokens): + return "".join(str(token) for token in tokens) + + scheduler = MLLMScheduler.__new__(MLLMScheduler) + scheduler.uid_to_request_id = {7: "req-7"} + scheduler.running = {"req-7": MLLMRequest(request_id="req-7", prompt="prompt")} + scheduler.processor = _Tokenizer() + scheduler._detokenizer_pool = {} + scheduler.total_completion_tokens = 0 + scheduler.num_requests_processed = 0 + + outputs, finished = scheduler._process_batch_responses( + [ + MLLMBatchResponse( + uid=7, + request_id="req-7", + token=11, + logprobs=None, + mtp_attempted=True, + mtp_attempted_count=1, + ), + MLLMBatchResponse( + uid=7, + request_id="req-7", + token=12, + logprobs=None, + finish_reason="stop", + from_draft=True, + ), + ] + ) + + assert finished == {"req-7"} + assert outputs[-1].mtp_drafts == 1 + assert outputs[-1].mtp_accepted == 1 + + def test_get_language_model(): from vllm_mlx.models.mllm import MLXMultimodalLM diff --git a/vllm_mlx/engine/base.py b/vllm_mlx/engine/base.py index 31e5b036..d5435863 100644 --- a/vllm_mlx/engine/base.py +++ b/vllm_mlx/engine/base.py @@ -30,6 +30,9 @@ class GenerationOutput: # For streaming new_text: str = "" finished: bool = True + # MTP speculative decoding counters. Zero means no MTP attempt occurred. + mtp_drafts: int = 0 + mtp_accepted: int = 0 @contextmanager diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 342fa207..58919486 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -756,13 +756,16 @@ async def generate( logits_processors=kwargs.pop("logits_processors", None), ) - return GenerationOutput( + generation_output = GenerationOutput( text=clean_output_text(output.output_text), tokens=output.output_token_ids, prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens, finish_reason=output.finish_reason, + mtp_drafts=output.mtp_drafts, + mtp_accepted=output.mtp_accepted, ) + return generation_output # Use LLM engine for text-only (non-MLLM models) from ..request import SamplingParams @@ -844,14 +847,17 @@ async def stream_generate( ) async for output in self._mllm_scheduler.stream_outputs(request_id): - yield GenerationOutput( + generation_output = GenerationOutput( text=clean_output_text(output.output_text), new_text=output.new_text, prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens, finished=output.finished, finish_reason=output.finish_reason, + mtp_drafts=output.mtp_drafts, + mtp_accepted=output.mtp_accepted, ) + yield generation_output return # Use LLM engine for text-only diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index 52656aac..a3c9c71a 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -42,6 +42,19 @@ def _processors_can_retire(processors: Optional[List[Callable]]) -> bool: ) +def _mark_mtp_attempts_on_primary_responses( + responses: List["MLLMBatchResponse"], + attempted_drafts_by_uid: Dict[int, int], +) -> None: + """Mark only responses from steps that actually attempted MTP drafts.""" + for response in responses: + draft_count = attempted_drafts_by_uid.pop(response.uid, 0) + if draft_count <= 0 or response.finish_reason is not None: + continue + response.mtp_attempted = True + response.mtp_attempted_count = draft_count + + def _drop_retired_processors( processors: Optional[List[Callable]], ) -> tuple[Optional[List[Callable]], int]: @@ -161,6 +174,9 @@ class MLLMBatchResponse: logprobs: mx.array # Log probabilities finish_reason: Optional[str] = None # "stop", "length", or None prompt_cache: Optional[Callable[[], List[Any]]] = None # Cache extraction function + from_draft: bool = False # True when this response is an accepted MTP draft + mtp_attempted: bool = False # True when the primary step attempted MTP + mtp_attempted_count: int = 0 # Number of draft tokens attempted @dataclass @@ -1972,6 +1988,7 @@ def install_mtp_mllm( # Deferred drafts keyed by UID _deferred_drafts: Dict[int, dict] = {} + _attempted_drafts_by_uid: Dict[int, int] = {} # MTP stats _mtp_stats = {"accepted": 0, "rejected": 0, "errors": 0} @@ -2077,6 +2094,8 @@ def _mtp_step( draft_logits, axis=-1, keepdims=True ) draft_tokens = _draft_sampler(draft_logprobs) + for uid in current_uids: + _attempted_drafts_by_uid[uid] = 1 # Snapshot RNN state for hybrid models _rnn_snapshots = {} @@ -2210,6 +2229,7 @@ def _mtp_next() -> List[MLLMBatchResponse]: if batch_gen.active_batch is None: _skip_state[0] = None _deferred_drafts.clear() + _attempted_drafts_by_uid.clear() # Save deferred drafts from previous step prev_deferred: Dict[int, dict] = {} @@ -2220,6 +2240,9 @@ def _mtp_next() -> List[MLLMBatchResponse]: responses = batch_gen._inner_next() + if responses: + _mark_mtp_attempts_on_primary_responses(responses, _attempted_drafts_by_uid) + if not prev_deferred or not responses: return responses @@ -2249,6 +2272,7 @@ def _mtp_next() -> List[MLLMBatchResponse]: token=draft_t, logprobs=draft_lp, finish_reason="stop", + from_draft=True, ) ) draft_end_uids.add(uid) @@ -2272,6 +2296,7 @@ def _mtp_next() -> List[MLLMBatchResponse]: token=draft_t, logprobs=draft_lp, finish_reason=draft_finish, + from_draft=True, ) ) diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index e812e339..0eb75564 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -115,6 +115,8 @@ class MLLMRequest: # Token counts num_prompt_tokens: int = 0 num_output_tokens: int = 0 + mtp_drafts: int = 0 + mtp_accepted: int = 0 # Timing first_token_time: Optional[float] = None @@ -618,6 +620,10 @@ def _process_batch_responses( # Append token to request request.output_tokens.append(response.token) request.num_output_tokens = len(request.output_tokens) + if response.mtp_attempted: + request.mtp_drafts += response.mtp_attempted_count + if response.from_draft: + request.mtp_accepted += 1 if request.first_token_time is None and request.num_output_tokens > 0: request.first_token_time = time.time() @@ -642,6 +648,8 @@ def _process_batch_responses( output_token_ids=list(request.output_tokens), prompt_tokens=request.num_prompt_tokens, completion_tokens=request.num_output_tokens, + mtp_drafts=request.mtp_drafts, + mtp_accepted=request.mtp_accepted, ) # Check if finished diff --git a/vllm_mlx/request.py b/vllm_mlx/request.py index b61a1217..ca408533 100644 --- a/vllm_mlx/request.py +++ b/vllm_mlx/request.py @@ -213,6 +213,9 @@ class RequestOutput: # Timing prompt_tokens: int = 0 completion_tokens: int = 0 + # MTP speculative decoding counters. Zero means no MTP attempt occurred. + mtp_drafts: int = 0 + mtp_accepted: int = 0 @property def usage(self) -> Dict[str, int]: