fix: run BatchedEngine MLLM on dedicated MLXWorkerThread to prevent cross-thread stream errors#479
fix: run BatchedEngine MLLM on dedicated MLXWorkerThread to prevent cross-thread stream errors#479xykong wants to merge 4 commits intowaybarrios:mainfrom
Conversation
MLX >= 0.31.2 makes Metal command encoders thread-local. When BatchedEngine loads the model on the event-loop thread but runs inference on a worker thread, mx.eval() raises: RuntimeError: There is no Stream(gpu, N) in current thread Fix by: 1. Adding MLXWorkerThread to mlx_streams.py — a persistent single-thread executor that guarantees model load + inference share the same thread-local stream context. 2. Moving _prepare_mllm_model() to run on MLXWorkerThread via async submit. 3. Passing the worker to MLLMScheduler so _process_loop submits step() to the same thread. 4. Removing MLLMBatchGenerator._stream (class-level mx.new_stream) in favor of the worker thread default stream, eliminating the last source of cross-thread stream references. The legacy event-loop path (no worker thread) is preserved as fallback with bind_generation_streams() for backward compatibility. Tested on M4 Max 128GB with gemma-4-26b-a4b-it-4bit: - Single request: 82-93 tok/s generation - 4x concurrent: 179 tok/s aggregate throughput
janhilgard
left a comment
There was a problem hiding this comment.
The overall approach is solid — model load + inference on the same persistent MLXWorkerThread is the right fix for MLX 0.31.2 thread-local CommandEncoders. The legacy fallback path is a nice touch for backward compatibility. No streaming regression here (unlike #478 for SimpleEngine).
A few issues to address:
1. if True: replaces with mx.stream(...) — should dedent instead
if True: # use thread-default stream
logits = self.language_model(last_token, cache=request_cache)This appears 4 times. The if True: block is dead code that exists only to preserve indentation. Just remove the with mx.stream(...) wrapper and dedent the body. A # noqa or a comment is fine, but a no-op if True: is a code smell that will confuse future readers.
2. Preprocessing unnecessarily serialized on the MLX worker thread
The old code ran text-only preprocessing in the default thread-pool executor (loop.run_in_executor(None, ...)):
await loop.run_in_executor(None, bg._preprocess_request, req)The new worker-thread path submits preprocessing to the same single MLX worker:
await self._mlx_worker.submit(loop, bg._preprocess_request, req)The comment says "Preprocessing creates MLX arrays (input_ids) so it must run on the same worker thread" — but _preprocess_request does Jinja2 template rendering + tokenizer.encode(), which produces Python lists, not mx.array. The mx.array() conversion happens inside _process_prompts() called from step().
Running preprocessing on the single MLX worker thread serializes it with step(), which means a 30-second preprocessing of a 40K+ token conversation blocks ALL inference for 30 seconds. The old approach of offloading to the thread pool kept inference unblocked during preprocessing. Consider keeping run_in_executor for preprocessing.
3. Lost slow-preprocessing logging and dynamic n_yields
The old code logged slow preprocessing:
elapsed = time.perf_counter() - tic
if elapsed > 1.0:
logger.info(f"Preprocessing {req.request_id[:12]}: {n_tok} tokens in {elapsed:.2f}s")And had dynamic yield counts:
n_yields = 10 if elapsed > 1.0 else 5The worker-thread path drops both — no timing, always 5 yields. Please preserve the slow-preprocessing logging (it's very useful for debugging production issues) and the dynamic yield count.
4. mx.new_thread_local_stream(mx.gpu) should live in MLXWorkerThread._run(), not in _load_on_worker
Currently the thread-local stream is initialized inside _load_on_worker():
def _load_on_worker():
mx.new_thread_local_stream(mx.gpu)
inst = MLXMultimodalLM(...)But the MLXWorkerThread docstring promises "mx.new_stream() is called exactly once (during thread init)." If someone submits a task before model loading, there's no stream. Move mx.new_thread_local_stream(mx.gpu) into MLXWorkerThread._run() at the top, before the task loop.
5. Overlap with PR #478 on mlx_streams.py
Both #478 and #479 add MLXWorkerThread to mlx_streams.py. If both land, there will be a merge conflict. Consider coordinating — either land #478 first and rebase #479 on top, or extract MLXWorkerThread into a shared PR that both depend on.
6. Comments stripped from legacy path
Several explanatory comments were removed from the legacy path (early preprocessing phase rationale, health-check yield explanation). These comments document why the code works the way it does and are valuable for maintainability. Please keep them.
Summary
The MLXWorkerThread + BatchedEngine integration is architecturally sound. Main blockers: (1) preprocessing should stay on the thread pool to avoid blocking inference, (2) if True: should be proper dedent, (3) preserve the slow-step diagnostics logging. The rest are minor cleanups.
…aterialization - Add _tokenize_text_only(): CPU-only tokenization (Jinja2 template + tokenizer) that produces a plain Python list. Safe to run on any thread via the default ThreadPoolExecutor, unblocking the MLX worker for generation during long prompt tokenization (10-30s for 40K+ tokens). - Add _materialize_tokens(): fast mx.array() conversion on the MLX worker thread. Microseconds per request, ensures arrays are on the correct stream. - Update _process_loop in scheduler to use the two-phase approach: Phase 1: CPU tokenization on thread pool (parallel, non-blocking) Phase 2: mx.array() on worker thread (serial, fast) - Remove dead "if True: # use thread-default stream" blocks and dedent their bodies (leftover from mx.stream removal). Benchmark: aggregate throughput peaks at ~197 tok/s (16 concurrent), sweet spot 4 concurrent at 177 tok/s with 44.7 tok/s per-request.
Two fixes for multimodal (vision) support: 1. prepare_for_start(): Make MLLM path a no-op. The async _prepare_mllm_model() requires MLXWorkerThread and event loop, so model loading must be deferred to _start_mllm() which is properly awaited. Calling it synchronously caused a RuntimeWarning (coroutine never awaited) and broken model state. 2. _process_prompts(): Skip prefix cache lookup when the request has pixel_values. The VLM forward pass must run with pixel_values to encode vision features into the KV cache. Previously, a prefix cache hit would route multimodal requests through the language-model-only path, which cannot process image placeholder tokens — producing garbage output that ignored the image content entirely. The existing image_token_index guard did not catch this because some models (e.g. Gemma 4) do not set config.image_token_index.
Quantized Gemma 4 models (e.g. 4-bit) sometimes emit tool call
arguments with single quotes instead of the expected <|"|"> delimiter
tokens. For example: {location:'Tokyo'} instead of
{location:<|"|">Tokyo<|"|">}.
Add a pre-processing step (Step 0.5) in _gemma4_args_to_json that
converts single-quoted strings to the canonical <|"|">-delimited
format before the existing parsing pipeline runs.
Summary
Fixes
RuntimeError: There is no Stream(gpu, N) in current threadwhen running BatchedEngine (--continuous-batching) with MLX >= 0.31.2 on Apple Silicon.Problem
MLX 0.31.2 made Metal CommandEncoders thread-local. When
BatchedEngine._prepare_mllm_model()loads the model on the event-loop thread butMLLMScheduler._process_loop()runsstep()on a different thread, the runtime raises:This happens because MLX arrays created during model loading are tagged with the loading thread's stream, which doesn't exist in the inference thread's TLS.
Fix
mlx_streams.py: AddMLXWorkerThread— a persistent single-thread executor that guarantees model load + inference share the same thread-local stream context.engine/batched.py: Make_prepare_mllm_model()async and run model loading onMLXWorkerThread. Pass the worker toMLLMScheduler.mllm_scheduler.py: Acceptmlx_workerparameter. When provided, submitstep()and preprocessing to the worker thread. Falls back to the legacy event-loop path (withbind_generation_streams()) when no worker is provided.mllm_batch_generator.py: RemoveMLLMBatchGenerator._stream(class-levelmx.new_stream()) in favor of the worker thread's default stream, eliminating the last source of cross-thread stream references.Testing
Tested on M4 Max 128GB with
mlx-community/gemma-4-26b-a4b-it-4bit:Previously this configuration crashed immediately on the first request.
Backward Compatibility
MLXWorkerThread) is preserved as fallback whenmlx_worker=None