Skip to content

fix: run BatchedEngine MLLM on dedicated MLXWorkerThread to prevent cross-thread stream errors#479

Open
xykong wants to merge 4 commits intowaybarrios:mainfrom
xykong:fix/batched-engine-stream-thread
Open

fix: run BatchedEngine MLLM on dedicated MLXWorkerThread to prevent cross-thread stream errors#479
xykong wants to merge 4 commits intowaybarrios:mainfrom
xykong:fix/batched-engine-stream-thread

Conversation

@xykong
Copy link
Copy Markdown

@xykong xykong commented May 1, 2026

Summary

Fixes RuntimeError: There is no Stream(gpu, N) in current thread when running BatchedEngine (--continuous-batching) with MLX >= 0.31.2 on Apple Silicon.

Problem

MLX 0.31.2 made Metal CommandEncoders thread-local. When BatchedEngine._prepare_mllm_model() loads the model on the event-loop thread but MLLMScheduler._process_loop() runs step() on a different thread, the runtime raises:

RuntimeError: There is no Stream(gpu, N) in current thread

This happens because MLX arrays created during model loading are tagged with the loading thread's stream, which doesn't exist in the inference thread's TLS.

Fix

  1. mlx_streams.py: Add MLXWorkerThread — a persistent single-thread executor that guarantees model load + inference share the same thread-local stream context.

  2. engine/batched.py: Make _prepare_mllm_model() async and run model loading on MLXWorkerThread. Pass the worker to MLLMScheduler.

  3. mllm_scheduler.py: Accept mlx_worker parameter. When provided, submit step() and preprocessing to the worker thread. Falls back to the legacy event-loop path (with bind_generation_streams()) when no worker is provided.

  4. mllm_batch_generator.py: Remove MLLMBatchGenerator._stream (class-level mx.new_stream()) in favor of the worker thread's default stream, eliminating the last source of cross-thread stream references.

Testing

Tested on M4 Max 128GB with mlx-community/gemma-4-26b-a4b-it-4bit:

Metric Result
Single request (short) 93 tok/s
Single request (long) 82 tok/s
4x concurrent aggregate 179 tok/s
TTFT (short) 35ms

Previously this configuration crashed immediately on the first request.

Backward Compatibility

  • The legacy event-loop path (no MLXWorkerThread) is preserved as fallback when mlx_worker=None
  • SimpleEngine path is unaffected
  • LLM-only BatchedEngine path is unaffected

MLX >= 0.31.2 makes Metal command encoders thread-local. When
BatchedEngine loads the model on the event-loop thread but runs
inference on a worker thread, mx.eval() raises:

  RuntimeError: There is no Stream(gpu, N) in current thread

Fix by:
1. Adding MLXWorkerThread to mlx_streams.py — a persistent single-thread
   executor that guarantees model load + inference share the same
   thread-local stream context.
2. Moving _prepare_mllm_model() to run on MLXWorkerThread via async submit.
3. Passing the worker to MLLMScheduler so _process_loop submits step()
   to the same thread.
4. Removing MLLMBatchGenerator._stream (class-level mx.new_stream) in
   favor of the worker thread default stream, eliminating the last
   source of cross-thread stream references.

The legacy event-loop path (no worker thread) is preserved as fallback
with bind_generation_streams() for backward compatibility.

Tested on M4 Max 128GB with gemma-4-26b-a4b-it-4bit:
- Single request: 82-93 tok/s generation
- 4x concurrent: 179 tok/s aggregate throughput
Copy link
Copy Markdown
Collaborator

@janhilgard janhilgard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall approach is solid — model load + inference on the same persistent MLXWorkerThread is the right fix for MLX 0.31.2 thread-local CommandEncoders. The legacy fallback path is a nice touch for backward compatibility. No streaming regression here (unlike #478 for SimpleEngine).

A few issues to address:


1. if True: replaces with mx.stream(...) — should dedent instead

if True:  # use thread-default stream
    logits = self.language_model(last_token, cache=request_cache)

This appears 4 times. The if True: block is dead code that exists only to preserve indentation. Just remove the with mx.stream(...) wrapper and dedent the body. A # noqa or a comment is fine, but a no-op if True: is a code smell that will confuse future readers.

2. Preprocessing unnecessarily serialized on the MLX worker thread

The old code ran text-only preprocessing in the default thread-pool executor (loop.run_in_executor(None, ...)):

await loop.run_in_executor(None, bg._preprocess_request, req)

The new worker-thread path submits preprocessing to the same single MLX worker:

await self._mlx_worker.submit(loop, bg._preprocess_request, req)

The comment says "Preprocessing creates MLX arrays (input_ids) so it must run on the same worker thread" — but _preprocess_request does Jinja2 template rendering + tokenizer.encode(), which produces Python lists, not mx.array. The mx.array() conversion happens inside _process_prompts() called from step().

Running preprocessing on the single MLX worker thread serializes it with step(), which means a 30-second preprocessing of a 40K+ token conversation blocks ALL inference for 30 seconds. The old approach of offloading to the thread pool kept inference unblocked during preprocessing. Consider keeping run_in_executor for preprocessing.

3. Lost slow-preprocessing logging and dynamic n_yields

The old code logged slow preprocessing:

elapsed = time.perf_counter() - tic
if elapsed > 1.0:
    logger.info(f"Preprocessing {req.request_id[:12]}: {n_tok} tokens in {elapsed:.2f}s")

And had dynamic yield counts:

n_yields = 10 if elapsed > 1.0 else 5

The worker-thread path drops both — no timing, always 5 yields. Please preserve the slow-preprocessing logging (it's very useful for debugging production issues) and the dynamic yield count.

4. mx.new_thread_local_stream(mx.gpu) should live in MLXWorkerThread._run(), not in _load_on_worker

Currently the thread-local stream is initialized inside _load_on_worker():

def _load_on_worker():
    mx.new_thread_local_stream(mx.gpu)
    inst = MLXMultimodalLM(...)

But the MLXWorkerThread docstring promises "mx.new_stream() is called exactly once (during thread init)." If someone submits a task before model loading, there's no stream. Move mx.new_thread_local_stream(mx.gpu) into MLXWorkerThread._run() at the top, before the task loop.

5. Overlap with PR #478 on mlx_streams.py

Both #478 and #479 add MLXWorkerThread to mlx_streams.py. If both land, there will be a merge conflict. Consider coordinating — either land #478 first and rebase #479 on top, or extract MLXWorkerThread into a shared PR that both depend on.

6. Comments stripped from legacy path

Several explanatory comments were removed from the legacy path (early preprocessing phase rationale, health-check yield explanation). These comments document why the code works the way it does and are valuable for maintainability. Please keep them.


Summary

The MLXWorkerThread + BatchedEngine integration is architecturally sound. Main blockers: (1) preprocessing should stay on the thread pool to avoid blocking inference, (2) if True: should be proper dedent, (3) preserve the slow-step diagnostics logging. The rest are minor cleanups.

xykong added 3 commits May 1, 2026 23:00
…aterialization

- Add _tokenize_text_only(): CPU-only tokenization (Jinja2 template +
  tokenizer) that produces a plain Python list. Safe to run on any thread
  via the default ThreadPoolExecutor, unblocking the MLX worker for
  generation during long prompt tokenization (10-30s for 40K+ tokens).

- Add _materialize_tokens(): fast mx.array() conversion on the MLX worker
  thread. Microseconds per request, ensures arrays are on the correct
  stream.

- Update _process_loop in scheduler to use the two-phase approach:
  Phase 1: CPU tokenization on thread pool (parallel, non-blocking)
  Phase 2: mx.array() on worker thread (serial, fast)

- Remove dead "if True:  # use thread-default stream" blocks and dedent
  their bodies (leftover from mx.stream removal).

Benchmark: aggregate throughput peaks at ~197 tok/s (16 concurrent),
sweet spot 4 concurrent at 177 tok/s with 44.7 tok/s per-request.
Two fixes for multimodal (vision) support:

1. prepare_for_start(): Make MLLM path a no-op. The async
   _prepare_mllm_model() requires MLXWorkerThread and event loop,
   so model loading must be deferred to _start_mllm() which is
   properly awaited. Calling it synchronously caused a
   RuntimeWarning (coroutine never awaited) and broken model state.

2. _process_prompts(): Skip prefix cache lookup when the request
   has pixel_values. The VLM forward pass must run with pixel_values
   to encode vision features into the KV cache. Previously, a prefix
   cache hit would route multimodal requests through the language-model-only
   path, which cannot process image placeholder tokens — producing
   garbage output that ignored the image content entirely.

   The existing image_token_index guard did not catch this because
   some models (e.g. Gemma 4) do not set config.image_token_index.
Quantized Gemma 4 models (e.g. 4-bit) sometimes emit tool call
arguments with single quotes instead of the expected <|"|"> delimiter
tokens. For example: {location:'Tokyo'} instead of
{location:<|"|">Tokyo<|"|">}.

Add a pre-processing step (Step 0.5) in _gemma4_args_to_json that
converts single-quoted strings to the canonical <|"|">-delimited
format before the existing parsing pipeline runs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants