diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 342fa2076..a0b9fc072 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -27,6 +27,7 @@ cleanup_startup_cancellation, run_blocking_startup_work, ) +from ..mlx_streams import MLXWorkerThread logger = logging.getLogger(__name__) @@ -216,6 +217,7 @@ def __init__( self._tokenizer = None # For LLM self._engine = None # AsyncEngineCore for LLM self._mllm_scheduler = None # MLLMScheduler for MLLM + self._mlx_worker = None # Shared MLXWorkerThread for MLLM load + inference self._mllm_instance = None # MLXMultimodalLM instance self._loaded = False @@ -237,12 +239,18 @@ def tokenizer(self) -> Any: return self._tokenizer def prepare_for_start(self) -> None: - """Load heavyweight model state off the serving event loop.""" + """Load heavyweight model state off the serving event loop. + + For MLLM models this is a no-op: the async _prepare_mllm_model() + requires the MLXWorkerThread and event loop, so model loading is + deferred to _start_mllm() which is properly awaited. + """ if self._model is not None: return if self._is_mllm: - self._prepare_mllm_model() + # MLLM loading is async and handled by _start_mllm(). + return else: self._prepare_llm_model() @@ -282,17 +290,43 @@ def _uses_default_prepare_for_start(self) -> bool: method = getattr(self.prepare_for_start, "__func__", None) return method is BatchedEngine.prepare_for_start - def _prepare_mllm_model(self) -> None: - """Load the MLLM model before scheduler startup.""" + async def _prepare_mllm_model(self) -> None: + """Load the MLLM model on a dedicated MLXWorkerThread. + + MLX >= 0.31.2 makes Metal command encoders thread-local. Model + loading creates arrays and streams that are bound to the thread + that executes them. If the model is loaded on the event-loop + thread but inference runs on a different thread, ``mx.eval()`` + raises ``RuntimeError: There is no Stream(gpu, N) in current + thread``. + + By loading the model on the same ``MLXWorkerThread`` that will + later run ``step()`` calls, we guarantee all MLX state lives on + a single thread. + """ + import mlx.core as mx + from ..models.mllm import MLXMultimodalLM + self._mlx_worker = MLXWorkerThread(name="mllm-worker") + loop = asyncio.get_event_loop() + max_kv_size = getattr(self._scheduler_config, "max_kv_size", 0) - self._mllm_instance = MLXMultimodalLM( - self._model_name, - trust_remote_code=self._trust_remote_code, - max_kv_size=max_kv_size, - ) - self._mllm_instance.load() + model_name = self._model_name + trust_remote_code = self._trust_remote_code + + def _load_on_worker(): + mx.new_thread_local_stream(mx.gpu) + inst = MLXMultimodalLM( + model_name, + trust_remote_code=trust_remote_code, + max_kv_size=max_kv_size, + ) + inst.load() + logger.info("MLLM model loaded on MLXWorkerThread") + return inst + + self._mllm_instance = await self._mlx_worker.submit(loop, _load_on_worker) self._model = self._mllm_instance.model self._processor = self._mllm_instance.processor @@ -329,7 +363,7 @@ async def _start_mllm(self) -> None: from ..mllm_scheduler import MLLMScheduler, MLLMSchedulerConfig if self._model is None or self._processor is None: - self._prepare_mllm_model() + await self._prepare_mllm_model() # Create MLLM scheduler config with batch generator support if self._scheduler_config and hasattr(self._scheduler_config, "max_num_seqs"): @@ -403,6 +437,7 @@ async def _start_mllm(self) -> None: model=self._model, processor=self._processor, config=mllm_config, + mlx_worker=self._mlx_worker, ) await self._mllm_scheduler.start() diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index 52656aac4..8a8d788e4 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -379,9 +379,6 @@ class MLLMBatchGenerator: ... print(f"Request {resp.request_id}: token={resp.token}") """ - # Generation stream for async eval - _stream = None - def __init__( self, model: nn.Module, @@ -521,10 +518,6 @@ def __init__( # Stripping the suffix from cache keys enables clean PREFIX match. self._think_suffix_len = self._compute_think_suffix_len() - # Generation stream - if MLLMBatchGenerator._stream is None: - MLLMBatchGenerator._stream = mx.new_stream(mx.default_device()) - # Memory management self._old_wired_limit = None if mx.metal.is_available(): @@ -661,7 +654,7 @@ def _compute_think_suffix_len(self) -> int: def close(self) -> None: """Release resources and reset wired limit.""" if self._old_wired_limit is not None: - mx.synchronize(MLLMBatchGenerator._stream) + mx.synchronize() mx.set_wired_limit(self._old_wired_limit) self._old_wired_limit = None @@ -770,6 +763,80 @@ def remove(self, uids: List[int]) -> None: r for r in self.unprocessed_requests if r.uid not in uid_set ] + + def _tokenize_text_only(self, request: MLLMBatchRequest) -> None: + """CPU-only tokenization for text-only requests. + + This performs Jinja2 template rendering and tokenization WITHOUT + creating any MLX arrays. The result is stored as Python lists in + ``request._tokenized_ids`` and ``request._tokenized_mask``. + The actual ``mx.array()`` conversion happens later on the MLX + worker thread (inside ``_process_prompts`` or ``_preprocess_request``). + + Safe to call from any thread (no MLX operations). + """ + if request.input_ids is not None: + return # Already preprocessed + if request.images or request.videos or request.audio: + return # Not text-only, needs full _preprocess_request + + # Check if already tokenized (idempotent) + if getattr(request, "_tokenized_ids", None) is not None: + return + + from mlx_vlm.utils import prepare_inputs + + tokenizer = ( + self.processor.tokenizer + if hasattr(self.processor, "tokenizer") + else self.processor + ) + + # Ensure pad_token exists + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Run tokenizer — returns Python/numpy, NOT mx.array + inputs = tokenizer( + request.prompt, + add_special_tokens=False, + padding=True, + padding_side="left", + return_tensors=None, # Return Python lists, not tensors + ) + + # Store as Python lists for later mx.array() on worker + request._tokenized_ids = inputs["input_ids"] + request._tokenized_mask = inputs.get("attention_mask") + + def _materialize_tokens(self, request: MLLMBatchRequest) -> None: + """Convert pre-tokenized Python lists to mx.array on the current thread. + + Must be called on the MLX worker thread. This is a fast operation + (microseconds) compared to tokenization (milliseconds-seconds). + """ + ids = getattr(request, "_tokenized_ids", None) + if ids is None: + return # Not pre-tokenized, will use _preprocess_request + + import mlx.core as mx + + request.input_ids = mx.array(ids) if not isinstance(ids, mx.array) else ids + mask = getattr(request, "_tokenized_mask", None) + if mask is not None: + request.attention_mask = ( + mx.array(mask) if not isinstance(mask, mx.array) else mask + ) + + # Mark as text-only for prefix cache eligibility + request.is_text_only = True + request.extra_kwargs = {} + + # Clean up temporary storage + del request._tokenized_ids + if hasattr(request, "_tokenized_mask"): + del request._tokenized_mask + def _preprocess_request(self, request: MLLMBatchRequest) -> None: """ Preprocess a single MLLM request (vision encoding). @@ -1305,7 +1372,14 @@ def _sample_first_token(req: MLLMBatchRequest, logits: mx.array): # running them through the language model alone. cached_kv = None remaining_ids = None - if self.prefix_cache is not None and req.input_ids is not None: + if req.pixel_values is not None: + # Multimodal request — skip prefix cache entirely. + # The VLM forward must run with pixel_values to encode + # the vision features into the KV cache. Running the + # language model alone on image placeholder tokens + # produces garbage output. + pass + elif self.prefix_cache is not None and req.input_ids is not None: input_ids_list = req.input_ids.reshape(-1).tolist() # Strip think suffix from lookup key so stored entries # (also stripped) match as clean PREFIX. @@ -1351,62 +1425,61 @@ def _sample_first_token(req: MLLMBatchRequest, logits: mx.array): total_tokens = len(input_ids_list) remaining_count = len(remaining_ids) - with mx.stream(MLLMBatchGenerator._stream): - step = self.prefill_step_size - if remaining_count <= step: - # Short remaining — process in one shot - self._prefill_progress[req.request_id] = ( - total_tokens, - total_tokens, - ) - logits = self.language_model(remaining, cache=request_cache) - else: - # Chunked prefill on remaining tokens - self._prefill_progress[req.request_id] = ( - cached_count, - total_tokens, - ) - processed = 0 - chunk_count = 0 - while processed + step < remaining_count: - # Check for abort between chunks - if req.request_id in self._aborted_request_ids: - self._aborted_request_ids.discard(req.request_id) - logger.info( - f"[chunked_prefill] Aborted {req.request_id} " - f"at {cached_count + processed}/{total_tokens} tokens" - ) - raise PrefillAbortedError(req.request_id) - - chunk = remaining[:, processed : processed + step] - self.language_model(chunk, cache=request_cache) - # Eval ALL cache types (see _run_chunked_text_prefill) - _eval_prompt_cache(request_cache) - processed += step - chunk_count += 1 - self._prefill_progress[req.request_id] = ( - cached_count + processed, - total_tokens, + step = self.prefill_step_size + if remaining_count <= step: + # Short remaining — process in one shot + self._prefill_progress[req.request_id] = ( + total_tokens, + total_tokens, + ) + logits = self.language_model(remaining, cache=request_cache) + else: + # Chunked prefill on remaining tokens + self._prefill_progress[req.request_id] = ( + cached_count, + total_tokens, + ) + processed = 0 + chunk_count = 0 + while processed + step < remaining_count: + # Check for abort between chunks + if req.request_id in self._aborted_request_ids: + self._aborted_request_ids.discard(req.request_id) + logger.info( + f"[chunked_prefill] Aborted {req.request_id} " + f"at {cached_count + processed}/{total_tokens} tokens" ) - if chunk_count % 4 == 0: - mx.clear_cache() - # Last chunk — return logits - remaining = remaining[:, processed:] - logits = self.language_model(remaining, cache=request_cache) + raise PrefillAbortedError(req.request_id) + + chunk = remaining[:, processed : processed + step] + self.language_model(chunk, cache=request_cache) + # Eval ALL cache types (see _run_chunked_text_prefill) + _eval_prompt_cache(request_cache) + processed += step + chunk_count += 1 self._prefill_progress[req.request_id] = ( - total_tokens, + cached_count + processed, total_tokens, ) + if chunk_count % 4 == 0: + mx.clear_cache() + # Last chunk — return logits + remaining = remaining[:, processed:] + logits = self.language_model(remaining, cache=request_cache) + self._prefill_progress[req.request_id] = ( + total_tokens, + total_tokens, + ) - if hasattr(logits, "logits"): - logits = logits.logits + if hasattr(logits, "logits"): + logits = logits.logits - last_logits = logits[:, -1, :] + last_logits = logits[:, -1, :] - sampled, logprobs = _sample_first_token(req, last_logits) + sampled, logprobs = _sample_first_token(req, last_logits) - first_tokens.append(sampled.item()) - all_logprobs.append(logprobs.squeeze(0)) + first_tokens.append(sampled.item()) + all_logprobs.append(logprobs.squeeze(0)) per_request_caches.append(request_cache) req.vision_encoded = True @@ -1431,17 +1504,16 @@ def _sample_first_token(req: MLLMBatchRequest, logits: mx.array): total_tokens, ) - with mx.stream(MLLMBatchGenerator._stream): - logits = self.language_model(last_token, cache=request_cache) - if hasattr(logits, "logits"): - logits = logits.logits + logits = self.language_model(last_token, cache=request_cache) + if hasattr(logits, "logits"): + logits = logits.logits - last_logits = logits[:, -1, :] + last_logits = logits[:, -1, :] - sampled, logprobs = _sample_first_token(req, last_logits) + sampled, logprobs = _sample_first_token(req, last_logits) - first_tokens.append(sampled.item()) - all_logprobs.append(logprobs.squeeze(0)) + first_tokens.append(sampled.item()) + all_logprobs.append(logprobs.squeeze(0)) per_request_caches.append(request_cache) req.vision_encoded = True @@ -1457,23 +1529,22 @@ def _sample_first_token(req: MLLMBatchRequest, logits: mx.array): max_kv_size=self.max_kv_size or None, ) - with mx.stream(MLLMBatchGenerator._stream): - # Text-only: chunked prefill with real progress tracking - # Multimodal: atomic VLM forward (vision encoder needs full input) - if req.is_text_only: - logits = self._run_chunked_text_prefill( - req, cache=request_cache - ) - else: - logits = self._run_vision_encoding(req, cache=request_cache) + # Text-only: chunked prefill with real progress tracking + # Multimodal: atomic VLM forward (vision encoder needs full input) + if req.is_text_only: + logits = self._run_chunked_text_prefill( + req, cache=request_cache + ) + else: + logits = self._run_vision_encoding(req, cache=request_cache) - # Extract last token logits - last_logits = logits[:, -1, :] + # Extract last token logits + last_logits = logits[:, -1, :] - sampled, logprobs = _sample_first_token(req, last_logits) + sampled, logprobs = _sample_first_token(req, last_logits) - first_tokens.append(sampled.item()) - all_logprobs.append(logprobs.squeeze(0)) + first_tokens.append(sampled.item()) + all_logprobs.append(logprobs.squeeze(0)) per_request_caches.append(request_cache) @@ -1871,8 +1942,7 @@ def next(self) -> List[MLLMBatchResponse]: Returns: List of MLLMBatchResponse, one per active request """ - with mx.stream(MLLMBatchGenerator._stream): - return self._next() + return self._next() def stats(self) -> MLLMBatchStats: """ diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index e812e3394..e38019ba1 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -35,7 +35,7 @@ MLLMBatchRequest, MLLMBatchResponse, ) -from .mlx_streams import bind_generation_streams +from .mlx_streams import MLXWorkerThread, bind_generation_streams from .multimodal_processor import MultimodalProcessor from .request import RequestOutput, RequestStatus, SamplingParams @@ -178,6 +178,7 @@ def __init__( model: Any, processor: Any, config: Optional[MLLMSchedulerConfig] = None, + mlx_worker: Optional[MLXWorkerThread] = None, ): """ Initialize MLLM scheduler. @@ -186,10 +187,14 @@ def __init__( model: The VLM model processor: The VLM processor config: Scheduler configuration + mlx_worker: MLXWorkerThread that owns the model. When provided, + ``step()`` runs on this thread so model load and inference + share the same thread-local Metal stream context. """ self.model = model self.processor = processor self.config = config or MLLMSchedulerConfig() + self._mlx_worker = mlx_worker # Get model config self.model_config = getattr(model, "config", None) @@ -814,16 +819,82 @@ async def stop(self) -> None: async def _process_loop(self) -> None: """Main async processing loop. - MLLM models are loaded on the server/event-loop thread, so their MLX - arrays and cache state must be consumed on that same thread. Unlike - the text-only EngineCore path, moving MLLM prefill to a worker crosses - MLX stream ownership and can fail with "no Stream in current thread". + When an ``MLXWorkerThread`` is provided (the default for + BatchedEngine), ``step()`` is submitted to that thread. This + is the same thread that loaded the model, so all MLX arrays, + caches and Metal streams are consistent — no cross-thread + stream errors. - Text-only preprocessing (Jinja2 template rendering + tokenization) is - run BEFORE ``step()`` with ``await asyncio.sleep(0)`` yields between - each request. This prevents long preprocessing (10-30+ s for 40K+ - token conversations) from blocking health checks and new connections. + Text-only preprocessing (Jinja2 template rendering + tokenization) + is offloaded to the default executor so the event loop stays + responsive for health checks and new connections. + + Without an ``MLXWorkerThread`` the loop falls back to running + ``step()`` directly on the event loop thread with + ``bind_generation_streams()`` (legacy behaviour). """ + loop = asyncio.get_running_loop() + + # ── Worker-thread path (preferred) ─────────────────────────── + if self._mlx_worker is not None: + while self._running: + try: + bg = self.batch_generator + if bg is not None: + # Phase 1: CPU-only tokenization on thread pool (parallel). + # This handles Jinja2 template rendering + tokenization + # without creating MLX arrays, so it can safely run on + # any thread. For 40K+ token prompts this takes 10-30s, + # so offloading prevents the worker from blocking. + for req in list(getattr(bg, "unprocessed_requests", ())): + if ( + req.input_ids is None + and not req.images + and not req.videos + and not req.audio + and getattr(req, "_tokenized_ids", None) is None + ): + try: + await loop.run_in_executor( + None, bg._tokenize_text_only, req + ) + except Exception as e: + logger.error( + f"Tokenization failed for " + f"{req.request_id}: {e}" + ) + + # Phase 2: Fast mx.array() conversion on worker thread. + # Converts pre-tokenized Python lists to MLX arrays. + # This is microseconds per request, not a bottleneck. + for req in list(getattr(bg, "unprocessed_requests", ())): + if getattr(req, "_tokenized_ids", None) is not None: + try: + await self._mlx_worker.submit( + loop, bg._materialize_tokens, req + ) + except Exception as e: + logger.error( + f"Token materialization failed for " + f"{req.request_id}: {e}" + ) + + if self.has_requests(): + await self._mlx_worker.submit(loop, self.step) + # Yield event-loop cycles for HTTP health checks + for _ in range(5): + await asyncio.sleep(0) + else: + await asyncio.sleep(0.01) + + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"Error in MLLM process loop: {e}", exc_info=True) + await asyncio.sleep(0.1) + return + + # ── Legacy event-loop path (no worker thread) ──────────────── streams_bound = False def _ensure_streams_bound() -> None: @@ -832,17 +903,8 @@ def _ensure_streams_bound() -> None: bind_generation_streams() streams_bound = True - loop = asyncio.get_running_loop() - while self._running: try: - # --- Early preprocessing phase --- - # Run text-only preprocessing (Jinja2 template rendering + - # tokenization) in a thread-pool executor so the event loop - # stays responsive for health checks, new connections, and - # active streaming requests. Preprocessing is CPU-bound - # (no MLX GPU work) and HuggingFace tokenizers are - # thread-safe, so this is safe to offload. bg = self.batch_generator if bg is not None: for req in list(getattr(bg, "unprocessed_requests", ())): @@ -874,7 +936,6 @@ def _ensure_streams_bound() -> None: f"{req.request_id}: {e}" ) - # --- Step phase --- if self.has_requests(): _ensure_streams_bound() tic = time.perf_counter() @@ -886,20 +947,10 @@ def _ensure_streams_bound() -> None: f"(waiting={len(self.waiting)}, " f"running={len(self.running)})" ) - # Yield multiple event-loop cycles so that pending - # HTTP health checks can complete. A single - # asyncio.sleep() gives only ONE _run_once() cycle, - # but an HTTP request needs ~3 cycles minimum: - # 1. accept TCP connection - # 2. read HTTP request / parse headers - # 3. run handler / write response - # Using repeated asyncio.sleep(0) gives many cycles - # with negligible wall-clock overhead (<1ms total). n_yields = 10 if elapsed > 1.0 else 5 for _ in range(n_yields): await asyncio.sleep(0) else: - # No work, wait a bit await asyncio.sleep(0.01) except asyncio.CancelledError: diff --git a/vllm_mlx/mlx_streams.py b/vllm_mlx/mlx_streams.py index d7ac5fb35..3d792407e 100644 --- a/vllm_mlx/mlx_streams.py +++ b/vllm_mlx/mlx_streams.py @@ -1,12 +1,38 @@ # SPDX-License-Identifier: Apache-2.0 -"""Helpers for binding MLX generation streams to worker threads.""" +"""Helpers for running MLX operations on a dedicated worker thread. +MLX >= 0.31.2 (PR #3348) makes Metal CommandEncoders thread-local. Arrays +created on one thread carry a stream index that does not exist in other threads' +TLS. This means a model loaded on thread A cannot be used for inference on +thread B — the runtime raises: + + RuntimeError: There is no Stream(gpu, N) in current thread + +The fundamental fix is to ensure model loading AND all inference happen on the +**same** persistent thread. ``MLXWorkerThread`` provides this guarantee by +running a dedicated daemon thread with a simple task queue. + +``bind_generation_streams`` is a legacy helper kept for the SimpleEngine path +where a single persistent ``_MLXWorkerThread`` already owns model load + +inference but still needs to rebind module-level stream references. +""" + +from __future__ import annotations + +import asyncio import importlib +import logging +import queue import threading -from collections.abc import Iterable +from collections.abc import Awaitable, Iterable +from typing import Any, Callable, TypeVar import mlx.core as mx +logger = logging.getLogger(__name__) + +T = TypeVar("T") + # Serialize stream rebinding so module-level generation_stream references are # updated atomically across concurrent engine threads. _STREAM_REBIND_LOCK = threading.Lock() @@ -20,6 +46,10 @@ def bind_generation_streams( MLX streams are thread-local. If a model is loaded on one thread and generation runs on another, module-level generation streams created during import can point at a stream that does not exist in the worker thread. + + .. deprecated:: + Prefer ``MLXWorkerThread`` which keeps model load and inference on the + same persistent thread, avoiding the need to rebind streams entirely. """ with _STREAM_REBIND_LOCK: default_stream = mx.new_stream(mx.default_device()) @@ -32,3 +62,79 @@ def bind_generation_streams( if hasattr(module, "generation_stream"): setattr(module, "generation_stream", default_stream) return default_stream + + +class MLXWorkerThread: + """Persistent single-threaded executor for MLX/Metal operations. + + All submitted callables run on the **same** OS thread for the lifetime of + the process. This guarantees that: + + 1. ``mx.new_stream()`` is called exactly once (during thread init). + 2. Model loading and inference share the same thread-local stream context. + 3. No ``RuntimeError: There is no Stream(gpu, N)`` can occur. + + Usage:: + + worker = MLXWorkerThread() + + # From an async context: + loop = asyncio.get_event_loop() + result = await worker.submit(loop, heavy_mlx_function, arg1, arg2) + + # Shutdown (optional, thread is daemon): + worker.shutdown() + """ + + def __init__(self, name: str = "mlx-worker") -> None: + self._task_queue: queue.SimpleQueue = queue.SimpleQueue() + self._thread = threading.Thread( + target=self._run, name=name, daemon=True + ) + self._thread.start() + logger.debug("MLXWorkerThread '%s' started (tid=%d)", name, self._thread.ident) + + def _run(self) -> None: + """Worker loop: pull (fn, args, kwargs, future) tuples and execute.""" + while True: + item = self._task_queue.get() + if item is None: + break + fn, args, kwargs, fut = item + try: + result = fn(*args, **kwargs) + fut.get_loop().call_soon_threadsafe(fut.set_result, result) + except BaseException as exc: + fut.get_loop().call_soon_threadsafe(fut.set_exception, exc) + + def submit( + self, + loop: asyncio.AbstractEventLoop, + fn: Callable[..., T], + *args: Any, + **kwargs: Any, + ) -> Awaitable[T]: + """Submit a callable to the worker thread, returning an awaitable Future. + + Args: + loop: The asyncio event loop to create the Future on. + fn: The callable to execute on the worker thread. + *args: Positional arguments for ``fn``. + **kwargs: Keyword arguments for ``fn``. + + Returns: + An ``asyncio.Future`` that resolves with ``fn``'s return value or + raises its exception. + """ + fut = loop.create_future() + self._task_queue.put((fn, args, kwargs, fut)) + return fut + + def shutdown(self) -> None: + """Signal the worker thread to exit (best-effort, non-blocking).""" + self._task_queue.put(None) + + @property + def is_alive(self) -> bool: + """Whether the worker thread is running.""" + return self._thread.is_alive() diff --git a/vllm_mlx/tool_parsers/gemma4_tool_parser.py b/vllm_mlx/tool_parsers/gemma4_tool_parser.py index af1aeca6d..160e2af1b 100644 --- a/vllm_mlx/tool_parsers/gemma4_tool_parser.py +++ b/vllm_mlx/tool_parsers/gemma4_tool_parser.py @@ -54,6 +54,10 @@ _BARE_VALUE = re.compile(r"(?<=[:\[,])(\s*)([A-Za-z_][\w\-]*)(?=\s*[,}\]])") _JSON_LITERALS = frozenset({"true", "false", "null"}) +# Pattern to convert single-quoted strings to double-quoted. +# Quantized models sometimes emit 'value' instead of <|"|>value<|"|>. +_SINGLE_QUOTED = re.compile(r"'([^']*)'") + # Max arg block length to prevent runaway parsing on malformed input (1 MB) _MAX_ARG_BLOCK_LEN = 1_048_576 @@ -123,6 +127,10 @@ def _capture(m: re.Match) -> str: strings.append(m.group(1)) return f"\x00{len(strings) - 1}\x00" + # Step 0.5: Convert single-quoted strings to <|"|>-delimited. + # Quantized models may emit 'value' instead of <|"|>value<|"|>. + text = _SINGLE_QUOTED.sub(lambda m: '<|"|>' + m.group(1) + '<|"|>', text) + # Step 1: Extract <|"|>-delimited strings text = _STRING_DELIM_RE.sub(_capture, text)