Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions tests/test_mllm_mtp_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,81 @@ def test_has_media_content_none_content():
from unittest.mock import MagicMock


def test_mllm_mtp_attempt_metadata_only_marks_real_attempts():
from vllm_mlx.mllm_batch_generator import (
MLLMBatchResponse,
_mark_mtp_attempts_on_primary_responses,
)

attempted = {7: 1}
attempted_response = MLLMBatchResponse(
uid=7,
request_id="req-7",
token=11,
logprobs=None,
)
skipped_response = MLLMBatchResponse(
uid=9,
request_id="req-9",
token=21,
logprobs=None,
)

_mark_mtp_attempts_on_primary_responses(
[attempted_response, skipped_response],
attempted,
)

assert attempted_response.mtp_attempted is True
assert attempted_response.mtp_attempted_count == 1
assert skipped_response.mtp_attempted is False
assert skipped_response.mtp_attempted_count == 0


def test_mllm_scheduler_exposes_mtp_attempts_and_accepts_on_outputs():
from vllm_mlx.mllm_batch_generator import MLLMBatchResponse
from vllm_mlx.mllm_scheduler import MLLMRequest, MLLMScheduler

class _Tokenizer:
clean_up_tokenization_spaces = False

def decode(self, tokens):
return "".join(str(token) for token in tokens)

scheduler = MLLMScheduler.__new__(MLLMScheduler)
scheduler.uid_to_request_id = {7: "req-7"}
scheduler.running = {"req-7": MLLMRequest(request_id="req-7", prompt="prompt")}
scheduler.processor = _Tokenizer()
scheduler._detokenizer_pool = {}
scheduler.total_completion_tokens = 0
scheduler.num_requests_processed = 0

outputs, finished = scheduler._process_batch_responses(
[
MLLMBatchResponse(
uid=7,
request_id="req-7",
token=11,
logprobs=None,
mtp_attempted=True,
mtp_attempted_count=1,
),
MLLMBatchResponse(
uid=7,
request_id="req-7",
token=12,
logprobs=None,
finish_reason="stop",
from_draft=True,
),
]
)

assert finished == {"req-7"}
assert outputs[-1].mtp_drafts == 1
assert outputs[-1].mtp_accepted == 1


def test_get_language_model():
from vllm_mlx.models.mllm import MLXMultimodalLM

Expand Down
3 changes: 3 additions & 0 deletions vllm_mlx/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class GenerationOutput:
# For streaming
new_text: str = ""
finished: bool = True
# MTP speculative decoding counters. Zero means no MTP attempt occurred.
mtp_drafts: int = 0
mtp_accepted: int = 0


@contextmanager
Expand Down
10 changes: 8 additions & 2 deletions vllm_mlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,13 +756,16 @@ async def generate(
logits_processors=kwargs.pop("logits_processors", None),
)

return GenerationOutput(
generation_output = GenerationOutput(
text=clean_output_text(output.output_text),
tokens=output.output_token_ids,
prompt_tokens=output.prompt_tokens,
completion_tokens=output.completion_tokens,
finish_reason=output.finish_reason,
mtp_drafts=output.mtp_drafts,
mtp_accepted=output.mtp_accepted,
)
return generation_output

# Use LLM engine for text-only (non-MLLM models)
from ..request import SamplingParams
Expand Down Expand Up @@ -844,14 +847,17 @@ async def stream_generate(
)

async for output in self._mllm_scheduler.stream_outputs(request_id):
yield GenerationOutput(
generation_output = GenerationOutput(
text=clean_output_text(output.output_text),
new_text=output.new_text,
prompt_tokens=output.prompt_tokens,
completion_tokens=output.completion_tokens,
finished=output.finished,
finish_reason=output.finish_reason,
mtp_drafts=output.mtp_drafts,
mtp_accepted=output.mtp_accepted,
)
yield generation_output
return

# Use LLM engine for text-only
Expand Down
25 changes: 25 additions & 0 deletions vllm_mlx/mllm_batch_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ def _processors_can_retire(processors: Optional[List[Callable]]) -> bool:
)


def _mark_mtp_attempts_on_primary_responses(
responses: List["MLLMBatchResponse"],
attempted_drafts_by_uid: Dict[int, int],
) -> None:
"""Mark only responses from steps that actually attempted MTP drafts."""
for response in responses:
draft_count = attempted_drafts_by_uid.pop(response.uid, 0)
if draft_count <= 0 or response.finish_reason is not None:
continue
response.mtp_attempted = True
response.mtp_attempted_count = draft_count


def _drop_retired_processors(
processors: Optional[List[Callable]],
) -> tuple[Optional[List[Callable]], int]:
Expand Down Expand Up @@ -161,6 +174,9 @@ class MLLMBatchResponse:
logprobs: mx.array # Log probabilities
finish_reason: Optional[str] = None # "stop", "length", or None
prompt_cache: Optional[Callable[[], List[Any]]] = None # Cache extraction function
from_draft: bool = False # True when this response is an accepted MTP draft
mtp_attempted: bool = False # True when the primary step attempted MTP
mtp_attempted_count: int = 0 # Number of draft tokens attempted


@dataclass
Expand Down Expand Up @@ -1972,6 +1988,7 @@ def install_mtp_mllm(

# Deferred drafts keyed by UID
_deferred_drafts: Dict[int, dict] = {}
_attempted_drafts_by_uid: Dict[int, int] = {}

# MTP stats
_mtp_stats = {"accepted": 0, "rejected": 0, "errors": 0}
Expand Down Expand Up @@ -2077,6 +2094,8 @@ def _mtp_step(
draft_logits, axis=-1, keepdims=True
)
draft_tokens = _draft_sampler(draft_logprobs)
for uid in current_uids:
_attempted_drafts_by_uid[uid] = 1

# Snapshot RNN state for hybrid models
_rnn_snapshots = {}
Expand Down Expand Up @@ -2210,6 +2229,7 @@ def _mtp_next() -> List[MLLMBatchResponse]:
if batch_gen.active_batch is None:
_skip_state[0] = None
_deferred_drafts.clear()
_attempted_drafts_by_uid.clear()

# Save deferred drafts from previous step
prev_deferred: Dict[int, dict] = {}
Expand All @@ -2220,6 +2240,9 @@ def _mtp_next() -> List[MLLMBatchResponse]:

responses = batch_gen._inner_next()

if responses:
_mark_mtp_attempts_on_primary_responses(responses, _attempted_drafts_by_uid)

if not prev_deferred or not responses:
return responses

Expand Down Expand Up @@ -2249,6 +2272,7 @@ def _mtp_next() -> List[MLLMBatchResponse]:
token=draft_t,
logprobs=draft_lp,
finish_reason="stop",
from_draft=True,
)
)
draft_end_uids.add(uid)
Expand All @@ -2272,6 +2296,7 @@ def _mtp_next() -> List[MLLMBatchResponse]:
token=draft_t,
logprobs=draft_lp,
finish_reason=draft_finish,
from_draft=True,
)
)

Expand Down
8 changes: 8 additions & 0 deletions vllm_mlx/mllm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class MLLMRequest:
# Token counts
num_prompt_tokens: int = 0
num_output_tokens: int = 0
mtp_drafts: int = 0
mtp_accepted: int = 0

# Timing
first_token_time: Optional[float] = None
Expand Down Expand Up @@ -618,6 +620,10 @@ def _process_batch_responses(
# Append token to request
request.output_tokens.append(response.token)
request.num_output_tokens = len(request.output_tokens)
if response.mtp_attempted:
request.mtp_drafts += response.mtp_attempted_count
if response.from_draft:
request.mtp_accepted += 1

if request.first_token_time is None and request.num_output_tokens > 0:
request.first_token_time = time.time()
Expand All @@ -642,6 +648,8 @@ def _process_batch_responses(
output_token_ids=list(request.output_tokens),
prompt_tokens=request.num_prompt_tokens,
completion_tokens=request.num_output_tokens,
mtp_drafts=request.mtp_drafts,
mtp_accepted=request.mtp_accepted,
)

# Check if finished
Expand Down
3 changes: 3 additions & 0 deletions vllm_mlx/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ class RequestOutput:
# Timing
prompt_tokens: int = 0
completion_tokens: int = 0
# MTP speculative decoding counters. Zero means no MTP attempt occurred.
mtp_drafts: int = 0
mtp_accepted: int = 0

@property
def usage(self) -> Dict[str, int]:
Expand Down
Loading