Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions tests/test_mlx_worker_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for MLXWorkerThread persistent worker thread.

Verifies that model loading and inference on the same persistent thread avoids
the ``RuntimeError: There is no Stream(gpu, N) in current thread`` crash that
occurs when MLX >= 0.31.2 thread-local CommandEncoders are accessed from
transient worker threads (e.g. asyncio.to_thread).
"""

import asyncio
import threading

import pytest


def _mlx_available() -> bool:
try:
import mlx.core as mx

return mx.metal.is_available()
except (ImportError, AttributeError):
return False


@pytest.mark.anyio
async def test_mlx_worker_thread_runs_on_persistent_thread():
"""All submissions execute on the same OS thread."""
from vllm_mlx.mlx_streams import MLXWorkerThread

worker = MLXWorkerThread(name="test-worker")
loop = asyncio.get_event_loop()

thread_ids = []
for _ in range(5):
tid = await worker.submit(loop, threading.get_ident)
thread_ids.append(tid)

assert len(set(thread_ids)) == 1, "Worker must use a single persistent thread"
assert thread_ids[0] != threading.get_ident(), (
"Worker thread must differ from event loop thread"
)
worker.shutdown()


@pytest.mark.anyio
async def test_mlx_worker_thread_preserves_exception():
"""Exceptions from submitted callables propagate correctly."""
from vllm_mlx.mlx_streams import MLXWorkerThread

worker = MLXWorkerThread(name="test-exc")
loop = asyncio.get_event_loop()

def raise_value_error():
raise ValueError("test error")

with pytest.raises(ValueError, match="test error"):
await worker.submit(loop, raise_value_error)

worker.shutdown()


@pytest.mark.anyio
async def test_mlx_worker_thread_sequential_execution():
"""Tasks execute sequentially (FIFO) on the single worker thread."""
from vllm_mlx.mlx_streams import MLXWorkerThread

worker = MLXWorkerThread(name="test-seq")
loop = asyncio.get_event_loop()

results = []

def append_value(val):
results.append(val)
return val

futs = [worker.submit(loop, append_value, i) for i in range(10)]
await asyncio.gather(*futs)

assert results == list(range(10)), "Tasks must execute in submission order"
worker.shutdown()


@pytest.mark.anyio
@pytest.mark.skipif(
not _mlx_available(),
reason="MLX not available",
)
async def test_mlx_ops_on_worker_thread_no_stream_error():
"""MLX array operations on worker thread do not raise stream errors."""
from vllm_mlx.mlx_streams import MLXWorkerThread

worker = MLXWorkerThread(name="test-mlx")
loop = asyncio.get_event_loop()

def mlx_matmul():
import mlx.core as mx

a = mx.ones((32, 32))
b = mx.ones((32, 32))
c = a @ b
mx.eval(c)
return c[0, 0].item()

result = await worker.submit(loop, mlx_matmul)
assert result == 32.0
worker.shutdown()
187 changes: 88 additions & 99 deletions vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,12 @@
BaseEngine,
GenerationOutput,
cleanup_startup_cancellation,
run_blocking_startup_work,
)
from ..mlx_streams import bind_generation_streams
from ..mlx_streams import MLXWorkerThread

logger = logging.getLogger(__name__)


def _bind_worker_generation_streams() -> None:
"""Rebind mlx generation streams inside the current worker thread."""
bind_generation_streams()
_mlx_worker = MLXWorkerThread()


def _seed_logits_processors(
Expand Down Expand Up @@ -224,26 +220,14 @@ def prepare_for_start(self) -> None:

self._model.load()

def _uses_default_prepare_for_start(self) -> bool:
"""Return True when prepare_for_start is the class implementation."""
method = getattr(self.prepare_for_start, "__func__", None)
return method is SimpleEngine.prepare_for_start

async def start(self) -> None:
"""Start the engine (load model if not loaded)."""
if self._loaded:
return
try:
if self._model is None:
if self._uses_default_prepare_for_start():
# MLX generation streams are thread-local. Keep model load on
# the event-loop thread so default LLM stream_generate() runs
# on the same thread that owns model-associated streams.
self.prepare_for_start()
else:
# Test doubles and custom overrides may block; preserve the
# cancellation-safe threaded startup helper for those cases.
await run_blocking_startup_work(self.prepare_for_start)
loop = asyncio.get_event_loop()
await _mlx_worker.submit(loop, self.prepare_for_start)
self._loaded = True

if self._mtp and self._mtp_num_draft_tokens != 1:
Expand Down Expand Up @@ -348,12 +332,10 @@ async def _run_blocking_serialized(self, func, /, *args, on_cancel=None, **kwarg
corrupt the command-buffer state.
"""
async with self._generation_lock:

def run_bound():
_bind_worker_generation_streams()
return func(*args, **kwargs)

task = asyncio.create_task(asyncio.to_thread(run_bound))
loop = asyncio.get_event_loop()
task = asyncio.ensure_future(
_mlx_worker.submit(loop, func, *args, **kwargs)
)
try:
return await asyncio.shield(task)
except asyncio.CancelledError:
Expand Down Expand Up @@ -511,17 +493,8 @@ async def stream_generate(
yield output
return

async with self._generation_lock:
# Non-stream chat runs in a worker thread and rebinds generation
# streams there. Rebind again on the current thread before
# stream_generate so nonstream->stream mode switches remain valid.
_bind_worker_generation_streams()

accumulated_text = ""
prompt_tokens = 0
completion_tokens = 0
finished = False

def _run_stream_generate():
results = []
for chunk in self._model.stream_generate(
prompt=prompt,
max_tokens=max_tokens,
Expand All @@ -530,45 +503,56 @@ async def stream_generate(
stop=stop,
**kwargs,
):
prompt_tokens = (
chunk.prompt_tokens
if hasattr(chunk, "prompt_tokens") and chunk.prompt_tokens
else prompt_tokens
)
completion_tokens += 1
new_text = chunk.text if hasattr(chunk, "text") else str(chunk)
accumulated_text += new_text
results.append(chunk)
return results

finished = (
getattr(chunk, "finished", False) or completion_tokens >= max_tokens
)
finish_reason = None
if finished:
finish_reason = getattr(chunk, "finish_reason", "stop")
chunks = await self._run_blocking_serialized(_run_stream_generate)

yield GenerationOutput(
text=accumulated_text,
new_text=new_text,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
finished=finished,
finish_reason=finish_reason,
)
accumulated_text = ""
prompt_tokens = 0
completion_tokens = 0
finished = False

if finished:
break
for chunk in chunks:
prompt_tokens = (
chunk.prompt_tokens
if hasattr(chunk, "prompt_tokens") and chunk.prompt_tokens
else prompt_tokens
)
completion_tokens += 1
new_text = chunk.text if hasattr(chunk, "text") else str(chunk)
accumulated_text += new_text

if not finished:
if prompt_tokens == 0:
prompt_tokens = len(self._model.tokenizer.encode(prompt))
yield GenerationOutput(
text=accumulated_text,
new_text="",
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
finished=True,
finish_reason=None,
)
finished = (
getattr(chunk, "finished", False) or completion_tokens >= max_tokens
)
finish_reason = None
if finished:
finish_reason = getattr(chunk, "finish_reason", "stop")

yield GenerationOutput(
text=accumulated_text,
new_text=new_text,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
finished=finished,
finish_reason=finish_reason,
)

if finished:
break

if not finished:
if prompt_tokens == 0:
prompt_tokens = len(self._model.tokenizer.encode(prompt))
yield GenerationOutput(
text=accumulated_text,
new_text="",
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
finished=True,
finish_reason=None,
)

async def chat(
self,
Expand Down Expand Up @@ -758,40 +742,45 @@ async def stream_chat(
accumulated_text = ""
token_count = 0

# Text-only fallback when no TextModel exists: keep execution on the
# current thread. Routing through to_thread can break mlx_vlm stream
# ownership on some models (Stream(gpu, N) mismatch).
# Text-only fallback when no TextModel exists: route through the
# persistent worker thread to maintain stream ownership.
if self._text_model is None and not has_media_content(messages):
local_kwargs = dict(kwargs)
if chat_template_kwargs:
local_kwargs["chat_template_kwargs"] = chat_template_kwargs

async with self._generation_lock:
_bind_worker_generation_streams()
for chunk in self._model.stream_chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
tools=template_tools,
**local_kwargs,
):
token_count += 1
new_text = chunk.text if hasattr(chunk, "text") else str(chunk)
accumulated_text += new_text

finished = chunk.finish_reason is not None

yield GenerationOutput(
text=accumulated_text,
new_text=new_text,
prompt_tokens=getattr(chunk, "prompt_tokens", 0),
completion_tokens=token_count,
finished=finished,
finish_reason=chunk.finish_reason if finished else None,
def _run_text_fallback_stream():
return list(
self._model.stream_chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
tools=template_tools,
**local_kwargs,
)
)

chunks = await self._run_blocking_serialized(
_run_text_fallback_stream
)
for chunk in chunks:
token_count += 1
new_text = chunk.text if hasattr(chunk, "text") else str(chunk)
accumulated_text += new_text

finished = chunk.finish_reason is not None

yield GenerationOutput(
text=accumulated_text,
new_text=new_text,
prompt_tokens=getattr(chunk, "prompt_tokens", 0),
completion_tokens=token_count,
finished=finished,
finish_reason=chunk.finish_reason if finished else None,
)

if finished:
break
if finished:
break
return

# Run stream_chat in thread pool since it's synchronous
Expand Down
Loading
Loading