Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 46 additions & 11 deletions vllm_mlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
cleanup_startup_cancellation,
run_blocking_startup_work,
)
from ..mlx_streams import MLXWorkerThread

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -216,6 +217,7 @@ def __init__(
self._tokenizer = None # For LLM
self._engine = None # AsyncEngineCore for LLM
self._mllm_scheduler = None # MLLMScheduler for MLLM
self._mlx_worker = None # Shared MLXWorkerThread for MLLM load + inference
self._mllm_instance = None # MLXMultimodalLM instance
self._loaded = False

Expand All @@ -237,12 +239,18 @@ def tokenizer(self) -> Any:
return self._tokenizer

def prepare_for_start(self) -> None:
"""Load heavyweight model state off the serving event loop."""
"""Load heavyweight model state off the serving event loop.

For MLLM models this is a no-op: the async _prepare_mllm_model()
requires the MLXWorkerThread and event loop, so model loading is
deferred to _start_mllm() which is properly awaited.
"""
if self._model is not None:
return

if self._is_mllm:
self._prepare_mllm_model()
# MLLM loading is async and handled by _start_mllm().
return
else:
self._prepare_llm_model()

Expand Down Expand Up @@ -282,17 +290,43 @@ def _uses_default_prepare_for_start(self) -> bool:
method = getattr(self.prepare_for_start, "__func__", None)
return method is BatchedEngine.prepare_for_start

def _prepare_mllm_model(self) -> None:
"""Load the MLLM model before scheduler startup."""
async def _prepare_mllm_model(self) -> None:
"""Load the MLLM model on a dedicated MLXWorkerThread.

MLX >= 0.31.2 makes Metal command encoders thread-local. Model
loading creates arrays and streams that are bound to the thread
that executes them. If the model is loaded on the event-loop
thread but inference runs on a different thread, ``mx.eval()``
raises ``RuntimeError: There is no Stream(gpu, N) in current
thread``.

By loading the model on the same ``MLXWorkerThread`` that will
later run ``step()`` calls, we guarantee all MLX state lives on
a single thread.
"""
import mlx.core as mx

from ..models.mllm import MLXMultimodalLM

self._mlx_worker = MLXWorkerThread(name="mllm-worker")
loop = asyncio.get_event_loop()

max_kv_size = getattr(self._scheduler_config, "max_kv_size", 0)
self._mllm_instance = MLXMultimodalLM(
self._model_name,
trust_remote_code=self._trust_remote_code,
max_kv_size=max_kv_size,
)
self._mllm_instance.load()
model_name = self._model_name
trust_remote_code = self._trust_remote_code

def _load_on_worker():
mx.new_thread_local_stream(mx.gpu)
inst = MLXMultimodalLM(
model_name,
trust_remote_code=trust_remote_code,
max_kv_size=max_kv_size,
)
inst.load()
logger.info("MLLM model loaded on MLXWorkerThread")
return inst

self._mllm_instance = await self._mlx_worker.submit(loop, _load_on_worker)
self._model = self._mllm_instance.model
self._processor = self._mllm_instance.processor

Expand Down Expand Up @@ -329,7 +363,7 @@ async def _start_mllm(self) -> None:
from ..mllm_scheduler import MLLMScheduler, MLLMSchedulerConfig

if self._model is None or self._processor is None:
self._prepare_mllm_model()
await self._prepare_mllm_model()

# Create MLLM scheduler config with batch generator support
if self._scheduler_config and hasattr(self._scheduler_config, "max_num_seqs"):
Expand Down Expand Up @@ -403,6 +437,7 @@ async def _start_mllm(self) -> None:
model=self._model,
processor=self._processor,
config=mllm_config,
mlx_worker=self._mlx_worker,
)
await self._mllm_scheduler.start()

Expand Down
Loading