diff --git a/tests/test_mlx_worker_thread.py b/tests/test_mlx_worker_thread.py new file mode 100644 index 00000000..4be1e4be --- /dev/null +++ b/tests/test_mlx_worker_thread.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for MLXWorkerThread persistent worker thread. + +Verifies that model loading and inference on the same persistent thread avoids +the ``RuntimeError: There is no Stream(gpu, N) in current thread`` crash that +occurs when MLX >= 0.31.2 thread-local CommandEncoders are accessed from +transient worker threads (e.g. asyncio.to_thread). +""" + +import asyncio +import threading + +import pytest + + +def _mlx_available() -> bool: + try: + import mlx.core as mx + + return mx.metal.is_available() + except (ImportError, AttributeError): + return False + + +@pytest.mark.anyio +async def test_mlx_worker_thread_runs_on_persistent_thread(): + """All submissions execute on the same OS thread.""" + from vllm_mlx.mlx_streams import MLXWorkerThread + + worker = MLXWorkerThread(name="test-worker") + loop = asyncio.get_event_loop() + + thread_ids = [] + for _ in range(5): + tid = await worker.submit(loop, threading.get_ident) + thread_ids.append(tid) + + assert len(set(thread_ids)) == 1, "Worker must use a single persistent thread" + assert thread_ids[0] != threading.get_ident(), ( + "Worker thread must differ from event loop thread" + ) + worker.shutdown() + + +@pytest.mark.anyio +async def test_mlx_worker_thread_preserves_exception(): + """Exceptions from submitted callables propagate correctly.""" + from vllm_mlx.mlx_streams import MLXWorkerThread + + worker = MLXWorkerThread(name="test-exc") + loop = asyncio.get_event_loop() + + def raise_value_error(): + raise ValueError("test error") + + with pytest.raises(ValueError, match="test error"): + await worker.submit(loop, raise_value_error) + + worker.shutdown() + + +@pytest.mark.anyio +async def test_mlx_worker_thread_sequential_execution(): + """Tasks execute sequentially (FIFO) on the single worker thread.""" + from vllm_mlx.mlx_streams import MLXWorkerThread + + worker = MLXWorkerThread(name="test-seq") + loop = asyncio.get_event_loop() + + results = [] + + def append_value(val): + results.append(val) + return val + + futs = [worker.submit(loop, append_value, i) for i in range(10)] + await asyncio.gather(*futs) + + assert results == list(range(10)), "Tasks must execute in submission order" + worker.shutdown() + + +@pytest.mark.anyio +@pytest.mark.skipif( + not _mlx_available(), + reason="MLX not available", +) +async def test_mlx_ops_on_worker_thread_no_stream_error(): + """MLX array operations on worker thread do not raise stream errors.""" + from vllm_mlx.mlx_streams import MLXWorkerThread + + worker = MLXWorkerThread(name="test-mlx") + loop = asyncio.get_event_loop() + + def mlx_matmul(): + import mlx.core as mx + + a = mx.ones((32, 32)) + b = mx.ones((32, 32)) + c = a @ b + mx.eval(c) + return c[0, 0].item() + + result = await worker.submit(loop, mlx_matmul) + assert result == 32.0 + worker.shutdown() diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 47343892..944104b5 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -22,16 +22,12 @@ BaseEngine, GenerationOutput, cleanup_startup_cancellation, - run_blocking_startup_work, ) -from ..mlx_streams import bind_generation_streams +from ..mlx_streams import MLXWorkerThread logger = logging.getLogger(__name__) - -def _bind_worker_generation_streams() -> None: - """Rebind mlx generation streams inside the current worker thread.""" - bind_generation_streams() +_mlx_worker = MLXWorkerThread() def _seed_logits_processors( @@ -224,26 +220,14 @@ def prepare_for_start(self) -> None: self._model.load() - def _uses_default_prepare_for_start(self) -> bool: - """Return True when prepare_for_start is the class implementation.""" - method = getattr(self.prepare_for_start, "__func__", None) - return method is SimpleEngine.prepare_for_start - async def start(self) -> None: """Start the engine (load model if not loaded).""" if self._loaded: return try: if self._model is None: - if self._uses_default_prepare_for_start(): - # MLX generation streams are thread-local. Keep model load on - # the event-loop thread so default LLM stream_generate() runs - # on the same thread that owns model-associated streams. - self.prepare_for_start() - else: - # Test doubles and custom overrides may block; preserve the - # cancellation-safe threaded startup helper for those cases. - await run_blocking_startup_work(self.prepare_for_start) + loop = asyncio.get_event_loop() + await _mlx_worker.submit(loop, self.prepare_for_start) self._loaded = True if self._mtp and self._mtp_num_draft_tokens != 1: @@ -348,12 +332,10 @@ async def _run_blocking_serialized(self, func, /, *args, on_cancel=None, **kwarg corrupt the command-buffer state. """ async with self._generation_lock: - - def run_bound(): - _bind_worker_generation_streams() - return func(*args, **kwargs) - - task = asyncio.create_task(asyncio.to_thread(run_bound)) + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + _mlx_worker.submit(loop, func, *args, **kwargs) + ) try: return await asyncio.shield(task) except asyncio.CancelledError: @@ -511,17 +493,8 @@ async def stream_generate( yield output return - async with self._generation_lock: - # Non-stream chat runs in a worker thread and rebinds generation - # streams there. Rebind again on the current thread before - # stream_generate so nonstream->stream mode switches remain valid. - _bind_worker_generation_streams() - - accumulated_text = "" - prompt_tokens = 0 - completion_tokens = 0 - finished = False - + def _run_stream_generate(): + results = [] for chunk in self._model.stream_generate( prompt=prompt, max_tokens=max_tokens, @@ -530,45 +503,56 @@ async def stream_generate( stop=stop, **kwargs, ): - prompt_tokens = ( - chunk.prompt_tokens - if hasattr(chunk, "prompt_tokens") and chunk.prompt_tokens - else prompt_tokens - ) - completion_tokens += 1 - new_text = chunk.text if hasattr(chunk, "text") else str(chunk) - accumulated_text += new_text + results.append(chunk) + return results - finished = ( - getattr(chunk, "finished", False) or completion_tokens >= max_tokens - ) - finish_reason = None - if finished: - finish_reason = getattr(chunk, "finish_reason", "stop") + chunks = await self._run_blocking_serialized(_run_stream_generate) - yield GenerationOutput( - text=accumulated_text, - new_text=new_text, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - finished=finished, - finish_reason=finish_reason, - ) + accumulated_text = "" + prompt_tokens = 0 + completion_tokens = 0 + finished = False - if finished: - break + for chunk in chunks: + prompt_tokens = ( + chunk.prompt_tokens + if hasattr(chunk, "prompt_tokens") and chunk.prompt_tokens + else prompt_tokens + ) + completion_tokens += 1 + new_text = chunk.text if hasattr(chunk, "text") else str(chunk) + accumulated_text += new_text - if not finished: - if prompt_tokens == 0: - prompt_tokens = len(self._model.tokenizer.encode(prompt)) - yield GenerationOutput( - text=accumulated_text, - new_text="", - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - finished=True, - finish_reason=None, - ) + finished = ( + getattr(chunk, "finished", False) or completion_tokens >= max_tokens + ) + finish_reason = None + if finished: + finish_reason = getattr(chunk, "finish_reason", "stop") + + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + finished=finished, + finish_reason=finish_reason, + ) + + if finished: + break + + if not finished: + if prompt_tokens == 0: + prompt_tokens = len(self._model.tokenizer.encode(prompt)) + yield GenerationOutput( + text=accumulated_text, + new_text="", + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + finished=True, + finish_reason=None, + ) async def chat( self, @@ -758,40 +742,45 @@ async def stream_chat( accumulated_text = "" token_count = 0 - # Text-only fallback when no TextModel exists: keep execution on the - # current thread. Routing through to_thread can break mlx_vlm stream - # ownership on some models (Stream(gpu, N) mismatch). + # Text-only fallback when no TextModel exists: route through the + # persistent worker thread to maintain stream ownership. if self._text_model is None and not has_media_content(messages): local_kwargs = dict(kwargs) if chat_template_kwargs: local_kwargs["chat_template_kwargs"] = chat_template_kwargs - async with self._generation_lock: - _bind_worker_generation_streams() - for chunk in self._model.stream_chat( - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - tools=template_tools, - **local_kwargs, - ): - token_count += 1 - new_text = chunk.text if hasattr(chunk, "text") else str(chunk) - accumulated_text += new_text - - finished = chunk.finish_reason is not None - - yield GenerationOutput( - text=accumulated_text, - new_text=new_text, - prompt_tokens=getattr(chunk, "prompt_tokens", 0), - completion_tokens=token_count, - finished=finished, - finish_reason=chunk.finish_reason if finished else None, + def _run_text_fallback_stream(): + return list( + self._model.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + tools=template_tools, + **local_kwargs, ) + ) + + chunks = await self._run_blocking_serialized( + _run_text_fallback_stream + ) + for chunk in chunks: + token_count += 1 + new_text = chunk.text if hasattr(chunk, "text") else str(chunk) + accumulated_text += new_text + + finished = chunk.finish_reason is not None + + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=getattr(chunk, "prompt_tokens", 0), + completion_tokens=token_count, + finished=finished, + finish_reason=chunk.finish_reason if finished else None, + ) - if finished: - break + if finished: + break return # Run stream_chat in thread pool since it's synchronous diff --git a/vllm_mlx/mlx_streams.py b/vllm_mlx/mlx_streams.py index d7ac5fb3..ead9795b 100644 --- a/vllm_mlx/mlx_streams.py +++ b/vllm_mlx/mlx_streams.py @@ -1,12 +1,34 @@ # SPDX-License-Identifier: Apache-2.0 -"""Helpers for binding MLX generation streams to worker threads.""" +"""Helpers for binding MLX generation streams to worker threads. +MLX >= 0.31.2 (PR #3348) makes Metal CommandEncoders thread-local. Arrays +created on one thread carry a stream index that does not exist in other threads' +TLS. This means a model loaded on thread A cannot be used for inference on +thread B — the runtime raises: + + RuntimeError: There is no Stream(gpu, N) in current thread + +The fundamental fix is to ensure model loading AND all inference happen on the +**same** persistent thread. ``MLXWorkerThread`` provides this guarantee by +running a dedicated daemon thread with a simple task queue. +""" + +from __future__ import annotations + +import asyncio import importlib +import logging +import queue import threading -from collections.abc import Iterable +from collections.abc import Awaitable, Iterable +from typing import Any, Callable, TypeVar import mlx.core as mx +logger = logging.getLogger(__name__) + +T = TypeVar("T") + # Serialize stream rebinding so module-level generation_stream references are # updated atomically across concurrent engine threads. _STREAM_REBIND_LOCK = threading.Lock() @@ -20,6 +42,10 @@ def bind_generation_streams( MLX streams are thread-local. If a model is loaded on one thread and generation runs on another, module-level generation streams created during import can point at a stream that does not exist in the worker thread. + + .. deprecated:: + Prefer ``MLXWorkerThread`` which keeps model load and inference on the + same persistent thread, avoiding the need to rebind streams entirely. """ with _STREAM_REBIND_LOCK: default_stream = mx.new_stream(mx.default_device()) @@ -32,3 +58,80 @@ def bind_generation_streams( if hasattr(module, "generation_stream"): setattr(module, "generation_stream", default_stream) return default_stream + + +class MLXWorkerThread: + """Persistent single-threaded executor for MLX/Metal operations. + + All submitted callables run on the **same** OS thread for the lifetime of + the process. This guarantees that: + + 1. ``mx.new_stream()`` is called exactly once (during thread init). + 2. Model loading and inference share the same thread-local stream context. + 3. No ``RuntimeError: There is no Stream(gpu, N)`` can occur. + + Usage:: + + worker = MLXWorkerThread() + + # From an async context: + loop = asyncio.get_event_loop() + result = await worker.submit(loop, heavy_mlx_function, arg1, arg2) + + # Shutdown (optional, thread is daemon): + worker.shutdown() + """ + + def __init__(self, name: str = "mlx-worker") -> None: + self._task_queue: queue.SimpleQueue = queue.SimpleQueue() + self._thread = threading.Thread( + target=self._run, name=name, daemon=True + ) + self._thread.start() + logger.debug("MLXWorkerThread '%s' started (tid=%d)", name, self._thread.ident) + + def _run(self) -> None: + """Worker loop: pull (fn, args, kwargs, future) tuples and execute.""" + while True: + item = self._task_queue.get() + if item is None: + # Shutdown sentinel + break + fn, args, kwargs, fut = item + try: + result = fn(*args, **kwargs) + fut.get_loop().call_soon_threadsafe(fut.set_result, result) + except BaseException as exc: + fut.get_loop().call_soon_threadsafe(fut.set_exception, exc) + + def submit( + self, + loop: asyncio.AbstractEventLoop, + fn: Callable[..., T], + *args: Any, + **kwargs: Any, + ) -> Awaitable[T]: + """Submit a callable to the worker thread, returning an awaitable Future. + + Args: + loop: The asyncio event loop to create the Future on. + fn: The callable to execute on the worker thread. + *args: Positional arguments for ``fn``. + **kwargs: Keyword arguments for ``fn``. + + Returns: + An ``asyncio.Future`` that resolves with ``fn``'s return value or + raises its exception. + """ + fut = loop.create_future() + self._task_queue.put((fn, args, kwargs, fut)) + return fut + + def shutdown(self) -> None: + """Signal the worker thread to exit (best-effort, non-blocking).""" + self._task_queue.put(None) + + @property + def is_alive(self) -> bool: + """Whether the worker thread is running.""" + return self._thread.is_alive()