diff --git a/docs/guides/continuous-batching.md b/docs/guides/continuous-batching.md index b4ef42cf..f0f56ef2 100644 --- a/docs/guides/continuous-batching.md +++ b/docs/guides/continuous-batching.md @@ -28,6 +28,30 @@ vllm-mlx serve mlx-community/Qwen3-0.6B-8bit --continuous-batching --use-paged-c - Better throughput for concurrent users - Small overhead per request +### MLLM MTP and Prefill Notes + +For multimodal models served through the batched MLLM scheduler, MTP is +currently a conservative greedy-only optimization. The MLLM MTP verifier is +used only when the active batch has one request, `temperature=0`, `top_p=1`, +`top_k=0`, `min_p=0`, and no request-local logits processors. Requests outside +that envelope fall back to the normal scheduler path instead of using MTP. This +keeps sampling correctness ahead of throughput until the MLLM verifier is +sampler-aware. + +Thinking/logits processors stay active by default for the whole request. The +experimental retirement-to-MTP handoff is opt-in via +`VLLM_MLX_ENABLE_THINKING_RETIREMENT_RESUME=1`; leave it unset unless you have +validated that the processor advertises a safe `is_retired` transition. + +MLLM prefill uses the regular scheduler `prefill_step_size` unless a future +MLLM-specific override is provided. This value controls the language-model +prefill chunk size; image/video preprocessing remains per request. + +MLX generation streams are thread-local. The runtime rebinds mlx-lm/mlx-vlm +generation streams at worker-entry boundaries so generation does not reuse a +stream created on a different thread. This is a correctness guard for worker +ownership, not a throughput feature. + ### Paged Cache - KV cache stored in fixed-size blocks - Shared system prompts use same blocks diff --git a/tests/test_mllm_continuous_batching.py b/tests/test_mllm_continuous_batching.py index 4a9838de..7ceb572c 100644 --- a/tests/test_mllm_continuous_batching.py +++ b/tests/test_mllm_continuous_batching.py @@ -1022,7 +1022,18 @@ def __init__(self): language_model.assert_not_called() language_model.mtp_forward.assert_not_called() - def test_install_mtp_mllm_disables_mtp_for_non_greedy_sampling(self): + @pytest.mark.parametrize( + "sampling_kwargs", + [ + {"temperature": 0.6, "top_p": 1.0, "top_k": 0, "min_p": 0.0}, + {"temperature": 0.0, "top_p": 0.95, "top_k": 0, "min_p": 0.0}, + {"temperature": 0.0, "top_p": 1.0, "top_k": 20, "min_p": 0.0}, + {"temperature": 0.0, "top_p": 1.0, "top_k": 0, "min_p": 0.05}, + ], + ) + def test_install_mtp_mllm_disables_mtp_for_non_greedy_sampling( + self, sampling_kwargs + ): from vllm_mlx.mllm_batch_generator import install_mtp_mllm expected_tokens = mx.array([11]) @@ -1035,14 +1046,7 @@ def __init__(self): self._next = MagicMock(return_value=[]) self.active_batch = MagicMock() self.active_batch.__len__.return_value = 1 - self.active_batch.requests = [ - MagicMock( - temperature=0.6, - top_p=0.95, - top_k=20, - min_p=0.0, - ) - ] + self.active_batch.requests = [MagicMock(**sampling_kwargs)] self.sampler = MagicMock() batch_gen = FakeBatchGen() diff --git a/vllm_mlx/mlx_streams.py b/vllm_mlx/mlx_streams.py index d7ac5fb3..6597faa5 100644 --- a/vllm_mlx/mlx_streams.py +++ b/vllm_mlx/mlx_streams.py @@ -20,6 +20,11 @@ def bind_generation_streams( MLX streams are thread-local. If a model is loaded on one thread and generation runs on another, module-level generation streams created during import can point at a stream that does not exist in the worker thread. + + This intentionally creates a fresh stream for the current worker call and + replaces module-level generation_stream handles under a process-local lock. + It is an admission/ownership fix, not a batching optimization; callers + should invoke it at worker-entry boundaries rather than inside token loops. """ with _STREAM_REBIND_LOCK: default_stream = mx.new_stream(mx.default_device())