Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/guides/continuous-batching.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,30 @@ vllm-mlx serve mlx-community/Qwen3-0.6B-8bit --continuous-batching --use-paged-c
- Better throughput for concurrent users
- Small overhead per request

### MLLM MTP and Prefill Notes

For multimodal models served through the batched MLLM scheduler, MTP is
currently a conservative greedy-only optimization. The MLLM MTP verifier is
used only when the active batch has one request, `temperature=0`, `top_p=1`,
`top_k=0`, `min_p=0`, and no request-local logits processors. Requests outside
that envelope fall back to the normal scheduler path instead of using MTP. This
keeps sampling correctness ahead of throughput until the MLLM verifier is
sampler-aware.

Thinking/logits processors stay active by default for the whole request. The
experimental retirement-to-MTP handoff is opt-in via
`VLLM_MLX_ENABLE_THINKING_RETIREMENT_RESUME=1`; leave it unset unless you have
validated that the processor advertises a safe `is_retired` transition.

MLLM prefill uses the regular scheduler `prefill_step_size` unless a future
MLLM-specific override is provided. This value controls the language-model
prefill chunk size; image/video preprocessing remains per request.

MLX generation streams are thread-local. The runtime rebinds mlx-lm/mlx-vlm
generation streams at worker-entry boundaries so generation does not reuse a
stream created on a different thread. This is a correctness guard for worker
ownership, not a throughput feature.

### Paged Cache
- KV cache stored in fixed-size blocks
- Shared system prompts use same blocks
Expand Down
22 changes: 13 additions & 9 deletions tests/test_mllm_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,18 @@ def __init__(self):
language_model.assert_not_called()
language_model.mtp_forward.assert_not_called()

def test_install_mtp_mllm_disables_mtp_for_non_greedy_sampling(self):
@pytest.mark.parametrize(
"sampling_kwargs",
[
{"temperature": 0.6, "top_p": 1.0, "top_k": 0, "min_p": 0.0},
{"temperature": 0.0, "top_p": 0.95, "top_k": 0, "min_p": 0.0},
{"temperature": 0.0, "top_p": 1.0, "top_k": 20, "min_p": 0.0},
{"temperature": 0.0, "top_p": 1.0, "top_k": 0, "min_p": 0.05},
],
)
def test_install_mtp_mllm_disables_mtp_for_non_greedy_sampling(
self, sampling_kwargs
):
from vllm_mlx.mllm_batch_generator import install_mtp_mllm

expected_tokens = mx.array([11])
Expand All @@ -1035,14 +1046,7 @@ def __init__(self):
self._next = MagicMock(return_value=[])
self.active_batch = MagicMock()
self.active_batch.__len__.return_value = 1
self.active_batch.requests = [
MagicMock(
temperature=0.6,
top_p=0.95,
top_k=20,
min_p=0.0,
)
]
self.active_batch.requests = [MagicMock(**sampling_kwargs)]
self.sampler = MagicMock()

batch_gen = FakeBatchGen()
Expand Down
5 changes: 5 additions & 0 deletions vllm_mlx/mlx_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def bind_generation_streams(
MLX streams are thread-local. If a model is loaded on one thread and
generation runs on another, module-level generation streams created during
import can point at a stream that does not exist in the worker thread.

This intentionally creates a fresh stream for the current worker call and
replaces module-level generation_stream handles under a process-local lock.
It is an admission/ownership fix, not a batching optimization; callers
should invoke it at worker-entry boundaries rather than inside token loops.
"""
with _STREAM_REBIND_LOCK:
default_stream = mx.new_stream(mx.default_device())
Expand Down
Loading