fix(simple): use persistent MLX worker thread to fix thread-local stream crash#478
fix(simple): use persistent MLX worker thread to fix thread-local stream crash#478xykong wants to merge 1 commit intowaybarrios:mainfrom
Conversation
…eam crash MLX >= 0.31.2 (PR #3348) made CommandEncoders thread-local. The previous approach of spawning transient threads via asyncio.to_thread() and calling bind_generation_streams() fails because: 1. Model loading creates Metal compiled kernels with a stream index (N) bound to the loading thread's TLS 2. Transient worker threads create new stream indices via mx.new_stream() which do NOT include the model's original stream N 3. RuntimeError: There is no Stream(gpu, N) in current thread Fix: Replace asyncio.to_thread with a persistent MLXWorkerThread that: - Runs a single daemon thread for the lifetime of the process - Loads the model AND runs all inference on the same thread - Guarantees thread-local stream consistency This matches the approach used by: - oMLX (jundot/omlx PR #891) - MLX maintainer recommendations (ml-explore/mlx #3216) Tested on Apple M4 Max with MLX 0.31.2, Gemma 4 26B.
janhilgard
left a comment
There was a problem hiding this comment.
Good problem analysis and clean MLXWorkerThread implementation. The root cause (MLX 0.31.2 thread-local CommandEncoders) is real, and a persistent worker thread is the right approach.
However, this PR introduces a critical streaming regression and has a few additional concerns:
1. Critical: stream_generate and stream_chat text-only fallback lost true streaming
The old code yielded chunks as they were generated (true token-by-token streaming):
async with self._generation_lock:
for chunk in self._model.stream_generate(...):
yield GenerationOutput(...) # client sees each token immediatelyThe new code collects ALL chunks into a list first, then yields them after generation is complete:
def _run_stream_generate():
results = []
for chunk in self._model.stream_generate(...):
results.append(chunk)
return results
chunks = await self._run_blocking_serialized(_run_stream_generate)
for chunk in chunks:
yield GenerationOutput(...) # client sees everything only after full generationThis means TTFT = total generation time — for a 500-token response at 100 tok/s, the client waits 5 seconds before seeing any output instead of seeing the first token in ~50ms. This is a major UX regression for any SSE/streaming client.
The fix would be to use an asyncio.Queue or similar mechanism to bridge the synchronous generator on the worker thread with the async generator on the event loop, yielding chunks as they're produced.
2. Module-level singleton starts thread at import time
_mlx_worker = MLXWorkerThread()This starts a daemon thread the moment simple.py is imported, even if SimpleEngine is never instantiated (e.g., when using BatchedEngine or MllmBatchGenerator). Consider lazy initialization or moving the worker into the SimpleEngine instance.
3. Missing generation_stream rebinding on worker thread
The old code called bind_generation_streams() before each inference to rebind mlx_lm.generate.generation_stream and mlx_vlm.generate.generation_stream to the current thread. The new code skips this entirely.
These module-level generation_stream variables are created at import time (main thread). If mlx_lm internally uses mx.stream(generation_stream) to schedule Metal operations, those operations reference a stream whose CommandEncoder lives in the main thread's TLS — which would crash on the worker thread.
This may work in your test because either (a) the default stream is special, or (b) mlx_lm doesn't reference generation_stream in the code path you tested. But it should be verified more broadly, or a single bind_generation_streams() call should be added to the worker thread's init.
4. Minor: asyncio.ensure_future on a Future is a no-op
task = asyncio.ensure_future(
_mlx_worker.submit(loop, func, *args, **kwargs)
)worker.submit() already returns an asyncio.Future. ensure_future on a Future just returns it. The variable name task is misleading since it's a Future, not a Task. asyncio.shield() works on both, so it's functionally correct but confusing.
Tests
The new tests are well-structured and cover the right cases (thread persistence, FIFO ordering, exception propagation, MLX ops). No issues there.
Summary
The MLXWorkerThread itself is solid. The main blocker is the streaming regression — stream_generate and stream_chat must preserve true token-by-token streaming, not batch-then-yield.
Replaces asyncio.to_thread() + bind_generation_streams() in SimpleEngine with a persistent MLXWorkerThread that loads the model and runs all inference on the same OS thread, fixing 'RuntimeError: There is no Stream(gpu, N) in current thread' on MLX >= 0.31.2 (which made Metal CommandEncoder storage thread-local). Known issues per reviewer feedback (waybarrios#478): - Streaming regression: stream_generate/stream_chat collect all chunks before yielding (TTFT = total generation time). Smoke tests use non-streaming so this doesn't block initial DeepSeek-V4-Flash bring-up, but must be fixed before relying on SSE clients. - Module-level singleton starts a daemon thread at import. - generation_stream rebinding on worker thread is dropped (may break paths that explicitly reference module-level generation_stream). Pulled from xykong/vllm-mlx@fix/persistent-mlx-worker-thread.
Summary
Replaces
asyncio.to_thread()+bind_generation_streams()inSimpleEnginewith a persistentMLXWorkerThreadthat loads the model and runs all inference on the same OS thread, fixingRuntimeError: There is no Stream(gpu, N) in current threadon MLX >= 0.31.2.Problem
MLX 0.31.2 (ml-explore/mlx#3348) made Metal
CommandEncoderstorage thread-local. The currentSimpleEngine._run_blocking_serialized()spawns transient threads viaasyncio.to_thread()and callsbind_generation_streams()which creates a new stream index viamx.new_stream(). However, the model's compiled Metal kernels reference the original stream index from the loading thread — that index does not exist in the transient worker thread's TLS.Reproduction: Run
vllm-mlxwithSimpleEngineon Apple Silicon with MLX >= 0.31.2 and any model that triggers Metal kernel compilation during load (most models). The server crashes on first inference request.Root Cause
mx.new_stream())asyncio.to_thread()spawns a new thread each timebind_generation_streams()callsmx.new_stream()→ creates stream M (M ≠ N)There is no Stream(gpu, N) in current threadFix
Introduce
MLXWorkerThread— a persistent single-threaded executor:This matches the proven approach in:
ThreadPoolExecutor(max_workers=1)with initializerChanges
vllm_mlx/mlx_streams.py: AddedMLXWorkerThreadclass (persistent daemon thread withqueue.SimpleQueuetask dispatch).bind_generation_streams()retained for backward compatibility but deprecated.vllm_mlx/engine/simple.py:start()loads model via_mlx_worker.submit()instead of event-loop-thread load_run_blocking_serialized()submits to_mlx_workerinstead ofasyncio.to_thread()stream_generate()andstream_chat()MLLM text-only fallback route through worker_bind_worker_generation_streams()andrun_blocking_startup_workusagetests/test_mlx_worker_thread.py: New test file verifying thread persistence, FIFO ordering, exception propagation, and MLX Metal ops.Testing
Tested on Apple M4 Max 128GB, macOS Tahoe 26.3.1, MLX 0.31.2, with
mlx-community/gemma-4-26b-a4b-it-4bit:Related Issues
RuntimeError: There is no Stream(gpu, N) in current threadwhen calling/v1/messagesor/v1/chat/completionson Qwen3.5 MLX models (MLLM path runs in worker thread) #407