diff --git a/tests/test_reasoning_parser.py b/tests/test_reasoning_parser.py index e2d0184e7..57a6b4291 100644 --- a/tests/test_reasoning_parser.py +++ b/tests/test_reasoning_parser.py @@ -13,6 +13,7 @@ from vllm_mlx.reasoning import ( DeltaMessage, + GLM4ReasoningParser, ReasoningParser, get_parser, list_parsers, @@ -28,6 +29,7 @@ def test_list_parsers_includes_builtin(self): parsers = list_parsers() assert "qwen3" in parsers assert "deepseek_r1" in parsers + assert "glm4" in parsers def test_get_parser_qwen3(self): """Should be able to get Qwen3 parser.""" @@ -41,6 +43,12 @@ def test_get_parser_deepseek(self): parser = parser_cls() assert isinstance(parser, ReasoningParser) + def test_get_parser_glm4(self): + """Should be able to get GLM4 parser.""" + parser_cls = get_parser("glm4") + parser = parser_cls() + assert isinstance(parser, ReasoningParser) + def test_get_unknown_parser_raises(self): """Unknown parser name should raise KeyError.""" with pytest.raises(KeyError) as exc_info: @@ -914,8 +922,7 @@ def test_streaming_constrain_format(self, parser): def test_constrain_tokens_stripped(self, parser): """<|constrain|> should not leak into output.""" output = ( - "<|channel|>final <|constrain|>JSON<|message|>" - '{"hello":"world"}<|return|>' + '<|channel|>final <|constrain|>JSON<|message|>{"hello":"world"}<|return|>' ) reasoning, content = parser.extract_reasoning(output) assert "<|constrain|>" not in (content or "") diff --git a/tests/test_tool_parsers.py b/tests/test_tool_parsers.py index dfe2bb6a1..c22b52d66 100644 --- a/tests/test_tool_parsers.py +++ b/tests/test_tool_parsers.py @@ -9,6 +9,7 @@ AutoToolParser, DeepSeekToolParser, FunctionaryToolParser, + Glm47ToolParser, GraniteToolParser, HermesToolParser, KimiToolParser, @@ -811,9 +812,7 @@ def test_qwen3_coder_multiline_parameter(self): def test_bare_function_without_tool_call_wrapper(self): """Test bare blocks without wrapper.""" parser = HermesToolParser() - text = ( - "" "Berlin" "" - ) + text = "Berlin" result = parser.extract_tool_calls(text) assert result.tools_called @@ -1160,3 +1159,59 @@ def test_streaming_bare_multi_function_blocks(self): assert len(emitted_calls) == 2 assert emitted_calls[0]["function"]["name"] == "func1" assert emitted_calls[1]["function"]["name"] == "func2" + + +class TestGLM47ToolParser: + """Tests for GLM47 tool parser.""" + + def test_zero_arguments_tool_call(self): + """Test Fix 2: Handle zero-argument tool calls without crashing.""" + parser = Glm47ToolParser() + + output = "get_current_time" + + result = parser.extract_tool_calls(output) + + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "get_current_time" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {} + + def test_with_arguments(self): + """Test tool call with arguments.""" + parser = Glm47ToolParser() + + output = "search\nqueryPython" + + result = parser.extract_tool_calls(output) + + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "search" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["query"] == "Python" + + def test_streaming_zero_args(self): + """Test Fix 2: Streaming with zero-argument tool call.""" + parser = Glm47ToolParser() + + chunks = ["", "get_status", ""] + accumulated = "" + tool_calls_found = False + + for chunk in chunks: + prev = accumulated + accumulated += chunk + r = parser.extract_tool_calls_streaming( + previous_text=prev, + current_text=accumulated, + delta_text=chunk, + ) + if r is not None and "tool_calls" in r: + tool_calls_found = True + assert r["tool_calls"][0]["function"]["name"] == "get_status" + args = json.loads(r["tool_calls"][0]["function"]["arguments"]) + assert args == {} + + assert tool_calls_found, "Zero-argument tool call should have been detected" diff --git a/vllm_mlx/reasoning/__init__.py b/vllm_mlx/reasoning/__init__.py index f138796ff..51527ef26 100644 --- a/vllm_mlx/reasoning/__init__.py +++ b/vllm_mlx/reasoning/__init__.py @@ -25,6 +25,7 @@ """ from .base import DeltaMessage, ReasoningParser +from .glm4_parser import GLM4ReasoningParser from .think_parser import BaseThinkingReasoningParser # Parser registry @@ -76,10 +77,12 @@ def list_parsers() -> list[str]: def _register_builtin_parsers(): """Register built-in parsers.""" from .deepseek_r1_parser import DeepSeekR1ReasoningParser + from .glm4_parser import GLM4ReasoningParser from .gpt_oss_parser import GptOssReasoningParser from .harmony_parser import HarmonyReasoningParser from .qwen3_parser import Qwen3ReasoningParser + register_parser("glm4", GLM4ReasoningParser) register_parser("qwen3", Qwen3ReasoningParser) register_parser("deepseek_r1", DeepSeekR1ReasoningParser) register_parser("gpt_oss", GptOssReasoningParser) @@ -99,4 +102,8 @@ def _register_builtin_parsers(): "register_parser", "get_parser", "list_parsers", + # Built-in parsers + "GLM4ReasoningParser", + "Qwen3ReasoningParser", + "DeepSeekR1ReasoningParser", ] diff --git a/vllm_mlx/reasoning/glm4_parser.py b/vllm_mlx/reasoning/glm4_parser.py new file mode 100644 index 000000000..9dec757d4 --- /dev/null +++ b/vllm_mlx/reasoning/glm4_parser.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Reasoning parser for GLM4 models. + +GLM4 uses ... tags for reasoning content. The GLM-4.7 chat +template injects directly into the prompt, so the model never +outputs it natively - it only outputs to end thinking. + +This desynchronizes standard parsers, so we need special handling. +""" + +from .think_parser import BaseThinkingReasoningParser + + +class GLM4ReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for GLM4 models. + + Handles the GLM-specific case where: + 1. The chat template injects into the prompt + 2. The model starts its output already "in reasoning" + 3. The model only outputs to end thinking + + This is different from Qwen3 where both tags may appear in output. + + Example: + Model output: "Let me analyze this...The answer is 42." + Output: reasoning="Let me analyze this...", content="The answer is 42." + """ + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning( + self, + model_output: str, + ) -> tuple[str | None, str | None]: + """ + Extract reasoning from GLM4 output. + + GLM4 typically only outputs (not ) because the start + token was injected in the prompt by the chat template. + + Args: + model_output: Complete model output text. + + Returns: + (reasoning, content) tuple. + """ + text = model_output + + # Case 1: Both tags present (rare, but handle it) + if self.start_token in text and self.end_token in text: + _, _, after_start = text.partition(self.start_token) + reasoning, _, content = after_start.partition(self.end_token) + return reasoning.strip() or None, content.strip() or None + + # Case 2: Only closing tag (most common for GLM) + # Model was already "in reasoning" due to prompt injection + if self.end_token in text: + reasoning, _, content = text.partition(self.end_token) + return reasoning.strip() or None, content.strip() or None + + # Case 3: Only start tag (reasoning in progress) + if self.start_token in text: + _, _, reasoning = text.partition(self.start_token) + return reasoning.strip() or None, None + + # Case 4: No tags - pure content (thinking disabled) + return None, model_output diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index a0038d5f8..65eab557c 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -1982,13 +1982,11 @@ async def stream_chat_completion( delta_text = output.new_text last_output = output - # Track token counts from output (updated each chunk) - if hasattr(output, "prompt_tokens") and output.prompt_tokens: - prompt_tokens = output.prompt_tokens - if hasattr(output, "completion_tokens") and output.completion_tokens: - completion_tokens = output.completion_tokens + # Track content and reasoning for this chunk + content = None + reasoning = None - # Use reasoning parser if enabled + # 1. Use reasoning parser if enabled if _reasoning_parser and delta_text: previous_text = accumulated_text accumulated_text += delta_text @@ -1996,90 +1994,113 @@ async def stream_chat_completion( previous_text, accumulated_text, delta_text ) - if delta_msg is None: - # Skip this chunk (e.g., token itself) - continue - - chunk = ChatCompletionChunk( - id=response_id, - model=_model_name, - choices=[ - ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta( - content=delta_msg.content, - reasoning=delta_msg.reasoning, - ), - finish_reason=output.finish_reason if output.finished else None, - ) - ], - usage=get_usage(output) if output.finished else None, - ) - yield f"data: {chunk.model_dump_json()}\n\n" + if delta_msg: + content = delta_msg.content + reasoning = delta_msg.reasoning else: - # Standard path without reasoning parsing + # Standard path: if no reasoning parser, use delta_text as content content = delta_text - # Filter special tokens that may leak into streaming output - if content: - content = SPECIAL_TOKENS_PATTERN.sub("", content) - - # Add prefix on first content chunk for thinking models - if is_thinking_model and not think_prefix_sent and content: - content = "" + content - think_prefix_sent = True - - # Tool call streaming parsing - if tool_parser and delta_text: - # Fast path: skip full parsing until '<' is seen in the stream, - # which could start tool markup (e.g. ). This avoids - # per-token string scanning on the growing accumulated text. - if not tool_markup_possible and "<" not in delta_text: - tool_accumulated_text += delta_text - # No tool markup yet, fall through to normal chunk emission - else: - if not tool_markup_possible: - tool_markup_possible = True - tool_previous = tool_accumulated_text - tool_accumulated_text += delta_text - tool_result = tool_parser.extract_tool_calls_streaming( - tool_previous, tool_accumulated_text, delta_text - ) - - if tool_result is None: - # Inside tool markup - suppress output - continue + # 2. Filter special tokens + if content: + content = SPECIAL_TOKENS_PATTERN.sub("", content) + + # 3. Add prefix on first content chunk for thinking models + if is_thinking_model and not think_prefix_sent and content: + content = "" + content + think_prefix_sent = True + + # 4. Tool call streaming parsing + if tool_parser and delta_text: + # Fast path: skip full parsing until '<' is seen + if not tool_markup_possible and "<" not in delta_text: + tool_accumulated_text += delta_text + # No tool markup yet, content remains as is + else: + if not tool_markup_possible: + tool_markup_possible = True + tool_previous = tool_accumulated_text + tool_accumulated_text += delta_text + tool_result = tool_parser.extract_tool_calls_streaming( + tool_previous, tool_accumulated_text, delta_text + ) - if "tool_calls" in tool_result: - # Emit structured tool calls - tool_calls_detected = True + if tool_result is None: + # Inside tool markup - suppress current output + # But if we have reasoning or content from step 1, we must emit it first + if content or reasoning: chunk = ChatCompletionChunk( id=response_id, model=_model_name, choices=[ ChatCompletionChunkChoice( delta=ChatCompletionChunkDelta( - tool_calls=tool_result["tool_calls"] - ), - finish_reason=( - "tool_calls" if output.finished else None + content=content if content else None, + reasoning=reasoning if reasoning else None, ), + finish_reason=output.finish_reason + if output.finished + else None, ) ], usage=get_usage(output) if output.finished else None, ) yield f"data: {chunk.model_dump_json()}\n\n" - continue + continue + + if "tool_calls" in tool_result: + # Emit structured tool calls + # If we had content/reasoning, emit that first as well + if content or reasoning: + chunk = ChatCompletionChunk( + id=response_id, + model=_model_name, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + content=content if content else None, + reasoning=reasoning if reasoning else None, + ), + finish_reason=None, + ) + ], + ) + yield f"data: {chunk.model_dump_json()}\n\n" + content = None + reasoning = None + + tool_calls_detected = True + chunk = ChatCompletionChunk( + id=response_id, + model=_model_name, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + tool_calls=tool_result["tool_calls"] + ), + finish_reason=( + "tool_calls" if output.finished else None + ), + ) + ], + usage=get_usage(output) if output.finished else None, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + continue - # Normal content from tool parser - content = tool_result.get("content", "") + # Normal content from tool parser + content = tool_result.get("content", "") + # 5. Emit chunk if there's content or reasoning + if content or reasoning: chunk = ChatCompletionChunk( id=response_id, model=_model_name, choices=[ ChatCompletionChunkChoice( delta=ChatCompletionChunkDelta( - content=content if content else None + content=content if content else None, + reasoning=reasoning if reasoning else None, ), finish_reason=( "tool_calls" diff --git a/vllm_mlx/tool_parsers/glm47_tool_parser.py b/vllm_mlx/tool_parsers/glm47_tool_parser.py index fc8238be9..0656e1851 100644 --- a/vllm_mlx/tool_parsers/glm47_tool_parser.py +++ b/vllm_mlx/tool_parsers/glm47_tool_parser.py @@ -10,6 +10,7 @@ import re import uuid from collections.abc import Sequence +from enum import Enum from typing import Any from .abstract_tool_parser import ( @@ -19,6 +20,15 @@ ) +class ParserState(Enum): + IDLE = "idle" + PARSING_NAME = "parsing_name" + PARSING_ARGUMENTS = "parsing_arguments" + PARSING_KEY = "parsing_key" + WAITING_FOR_VALUE = "waiting_for_value" + PARSING_VALUE = "parsing_value" + + def generate_tool_id() -> str: """Generate a unique tool call ID.""" return f"call_{uuid.uuid4().hex[:8]}" @@ -52,6 +62,27 @@ class Glm47ToolParser(ToolParser): r"\s*(.*?)\s*\s*(.*?)", re.DOTALL ) + def __init__(self, tokenizer: Any | None = None): + super().__init__(tokenizer) + self.last_parsed_index = 0 + self.state = ParserState.IDLE + self.current_tool_id = generate_tool_id() + self.current_tool_name = "" + self.current_arg_key = "" + self.is_first_arg = True + self.tool_call_index = 0 + + def reset(self) -> None: + """Reset parser state for a new request.""" + super().reset() + self.last_parsed_index = 0 + self.state = ParserState.IDLE + self.current_tool_id = generate_tool_id() + self.current_tool_name = "" + self.current_arg_key = "" + self.is_first_arg = True + self.tool_call_index = 0 + def _deserialize(self, value: str) -> Any: """Convert string value to appropriate Python type. @@ -95,6 +126,9 @@ def extract_tool_calls( for match in matches: func_name = match[0].strip() if match[0] else "" args_section = match[1] if len(match) > 1 and match[1] else "" + # Fix 2: Handle zero-argument tool calls - coalesce None to empty string + if args_section is None: + args_section = "" if not func_name: continue @@ -145,38 +179,364 @@ def extract_tool_calls_streaming( request: dict[str, Any] | None = None, ) -> dict[str, Any] | None: """ - Extract tool calls from streaming GLM-4.7 model output. + Extract tool calls from streaming GLM-4.7 model output using a state machine. """ - # Skip thinking content in streaming + # 1. Skip thinking content (defensive) if "" in current_text and "" not in current_text: + self.last_parsed_index = len(current_text) return None - # Once is detected, buffer everything until it closes. - # Do NOT emit content deltas here, because if tool calls are found - # the non-streaming path sets content=None (reasoning before the - # tag should not leak as regular content). - if "" in current_text: - if "" in delta_text: - result = self.extract_tool_calls(current_text, request) - if result.tools_called: + # 2. State Machine + # We use a while loop to handle multiple transitions in one chunk + while self.last_parsed_index < len(current_text): + unparsed = current_text[self.last_parsed_index :] + + if self.state == ParserState.IDLE: + start_tag = "" + start_pos = current_text.find(start_tag, self.last_parsed_index) + if start_pos != -1: + # Found start of tool call + self.state = ParserState.PARSING_NAME + self.last_parsed_index = start_pos + len(start_tag) + self.current_tool_id = generate_tool_id() + self.current_tool_name = "" + self.is_first_arg = True + continue # Try next state + else: + # Still IDLE, emit delta as content if no partial tag + # Check for partial at the very end + potential = "" + for i in range(len(potential) - 1, 0, -1): + if current_text.endswith(potential[:i]): + return None # Buffer partial tag + + # No tag, emit content + self.last_parsed_index = len(current_text) + return {"content": delta_text} + + if self.state == ParserState.PARSING_NAME: + # End of name markers + nl_pos = current_text.find("\n", self.last_parsed_index) + ak_pos = current_text.find("", self.last_parsed_index) + tc_end_pos = current_text.find("", self.last_parsed_index) + + markers = [m for m in [nl_pos, ak_pos, tc_end_pos] if m != -1] + if not markers: + # Still parsing name + name_chunk = unparsed.strip() + if name_chunk: + self.current_tool_name += name_chunk + self.last_parsed_index = len(current_text) + return None + + # Found end of name + early = min(markers) + name_part = current_text[self.last_parsed_index : early].strip() + self.current_tool_name += name_part + self.last_parsed_index = early + self.state = ParserState.PARSING_ARGUMENTS + + # Emit the first chunk of tool call (header) + return { + "tool_calls": [ + { + "index": self.tool_call_index, + "id": self.current_tool_id, + "type": "function", + "function": { + "name": self.current_tool_name, + "arguments": "{", + }, + } + ] + } + + if self.state == ParserState.PARSING_ARGUMENTS: + # Look for or + ak_pos = current_text.find("", self.last_parsed_index) + tc_end_pos = current_text.find("", self.last_parsed_index) + + if ak_pos != -1 and (tc_end_pos == -1 or ak_pos < tc_end_pos): + # Found next argument + self.state = ParserState.PARSING_KEY + self.last_parsed_index = ak_pos + len("") + self.current_arg_key = "" + continue + elif tc_end_pos != -1: + # Tool call finished! + self.state = ParserState.IDLE + self.last_parsed_index = tc_end_pos + len("") + self.tool_call_index += 1 + return { + "tool_calls": [ + { + "index": self.tool_call_index - 1, + "id": self.current_tool_id, + "function": {"arguments": "}"}, + } + ] + } + else: + # Wait for more + self.last_parsed_index = len(current_text) + return None + + if self.state == ParserState.PARSING_KEY: + ak_end_tag = "" + ak_end_pos = current_text.find(ak_end_tag, self.last_parsed_index) + if ak_end_pos != -1: + # Key captured + self.current_arg_key += current_text[ + self.last_parsed_index : ak_end_pos + ].strip() + self.last_parsed_index = ak_end_pos + len(ak_end_tag) + self.state = ParserState.WAITING_FOR_VALUE + continue # Transition to WAITING_FOR_VALUE in same chunk + else: + # Still capturing key + self.current_arg_key += unparsed.strip() + self.last_parsed_index = len(current_text) + return None + + if self.state == ParserState.WAITING_FOR_VALUE: + av_tag = "" + av_pos = current_text.find(av_tag, self.last_parsed_index) + if av_pos != -1: + self.state = ParserState.PARSING_VALUE + self.last_parsed_index = av_pos + len(av_tag) + + # Emit JSON prefix for this argument + prefix = "" if self.is_first_arg else ", " + self.is_first_arg = False return { "tool_calls": [ { - "index": i, - "id": tc["id"], - "type": "function", + "index": self.tool_call_index, + "id": self.current_tool_id, "function": { - "name": tc["name"], - "arguments": tc["arguments"], + "arguments": f'{prefix}"{self.current_arg_key}": "' }, } - for i, tc in enumerate(result.tool_calls) ] } + else: + # Buffer if tail looks like start of + for i in range(len(av_tag) - 1, 0, -1): + if current_text.endswith(av_tag[:i]): + return None + # Otherwise keep waiting + self.last_parsed_index = len(current_text) + return None + + if self.state == ParserState.PARSING_VALUE: + av_end_tag = "" + av_end_pos = current_text.find(av_end_tag, self.last_parsed_index) + if av_end_pos != -1: + # Value finished + val_chunk = current_text[self.last_parsed_index : av_end_pos] + self.last_parsed_index = av_end_pos + len(av_end_tag) + self.state = ParserState.PARSING_ARGUMENTS + + return { + "tool_calls": [ + { + "index": self.tool_call_index, + "id": self.current_tool_id, + "function": {"arguments": f'{val_chunk}"'}, + } + ] + } + else: + # Yield value chunk incrementally + val_chunk = unparsed + self.last_parsed_index = len(current_text) + return { + "tool_calls": [ + { + "index": self.tool_call_index, + "id": self.current_tool_id, + "function": {"arguments": val_chunk}, + } + ] + } + + return None + + # 2. State Machine Loop + # We work on current_text[self.last_parsed_index:] to handle fragmentation + unparsed = current_text[self.last_parsed_index :] + if not unparsed: + return None + + # Pattern to detect tool call start + if self.state == ParserState.IDLE: + # Look for tag + start_tag = "" + start_pos = current_text.find(start_tag, self.last_parsed_index) + + if start_pos != -1: + # Transition to PARSING_NAME + # Content before the tag should be emitted (if any) + content_before = current_text[self.last_parsed_index : start_pos] + self.state = ParserState.PARSING_NAME + self.last_parsed_index = start_pos + len(start_tag) + self.current_tool_id = generate_tool_id() + self.current_tool_name = "" + self.current_arguments = {} + + # If there was content before the tag, we should return it + # and then process the rest in the next call or recurse? + # For simplicity, we just advance the index and continue in this call. + return self.extract_tool_calls_streaming( + previous_text, current_text, delta_text, request=request + ) + else: + # Still IDLE, emit delta as content + # But we must be careful not to emit partial tags + # If the tail of current_text looks like the start of , we buffer + potential_start = " or + end_name_pos = current_text.find("\n", self.last_parsed_index) + arg_start_pos = current_text.find("", self.last_parsed_index) + tc_end_pos = current_text.find("", self.last_parsed_index) + + # Find the earliest marker + markers = [m for m in [end_name_pos, arg_start_pos, tc_end_pos] if m != -1] + if not markers: + # Still parsing name, just advance (but don't emit) + # We can update self.current_tool_name incrementally if needed + name_chunk = current_text[self.last_parsed_index :].strip() + if name_chunk: + self.current_tool_name += name_chunk + self.last_parsed_index = len(current_text) + return None + + # Found a marker + early_marker = min(markers) + name_part = current_text[self.last_parsed_index : early_marker].strip() + self.current_tool_name += name_part + self.last_parsed_index = early_marker + + # Transition to PARSING_ARGUMENTS (or finish if ) + if early_marker == tc_end_pos: + # Zero-argument tool call finished + tool_calls = [ + { + "index": self.tool_call_index, + "id": self.current_tool_id, + "type": "function", + "function": { + "name": self.current_tool_name, + "arguments": "{}", + }, + } + ] + self.tool_call_index += 1 + self.state = ParserState.IDLE + self.last_parsed_index = tc_end_pos + len("") + return {"tool_calls": tool_calls} + + # Transition to PARSING_ARGUMENTS + self.state = ParserState.PARSING_ARGUMENTS + # Emit the tool call header (id and name) + tool_calls = [ + { + "index": self.tool_call_index, + "id": self.current_tool_id, + "type": "function", + "function": { + "name": self.current_tool_name, + "arguments": "", # Start arguments stream + }, + } + ] + # Don't increment tool_call_index yet, as we might send more deltas for same call + return {"tool_calls": tool_calls} + + if self.state == ParserState.PARSING_ARGUMENTS: + # Look for arguments or end tag + tc_end_pos = current_text.find("", self.last_parsed_index) + + if tc_end_pos != -1: + # Finished! Parse all arguments found in the whole text for this tool call + # We extract the section between and + # and use the existing extract_tool_calls logic to get clean JSON + result = self.extract_tool_calls( + current_text[ + current_text.rfind("", 0, tc_end_pos) : tc_end_pos + + 12 + ], + request, + ) + if result.tools_called and result.tool_calls: + tc = result.tool_calls[0] + tool_calls = [ + { + "index": self.tool_call_index, + "id": self.current_tool_id, # Use persistent ID + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + ] + self.tool_call_index += 1 + self.state = ParserState.IDLE + self.last_parsed_index = tc_end_pos + len("") + return {"tool_calls": tool_calls} + + # Fallback if parsing failed + self.state = ParserState.IDLE + self.last_parsed_index = tc_end_pos + len("") + return None + + # Incremental arguments streaming: + # We look for ...... + # But the request says: "Capture all new tokens appended beyond last_parsed_index + # and yield them immediately within the function.arguments field" + + # If we see tags, we should probably hide them from the arguments string + # and only yield what's between them. + + # Simple heuristic: if we are in PARSING_ARGUMENTS, any text not inside a tag is arguments + # However, GLM format is XML-ish, so it's better to just buffer until + # for valid JSON, UNLESS we want to stream raw XML (which the client might not like). + + # Given the strict requirement for incremental streaming: + # We will yield the delta_text, but strip the XML tags + clean_delta = delta_text.replace("", "").replace("", "") + clean_delta = clean_delta.replace("", "").replace( + "", "" + ) + + self.last_parsed_index = len(current_text) + if clean_delta.strip(): + return { + "tool_calls": [ + { + "index": self.tool_call_index, + "id": self.current_tool_id, + "function": {"arguments": clean_delta}, + } + ] + } return None - # No tool call detected yet; strip think tags and emit content - clean_delta = self.strip_think_tags(delta_text) - if clean_delta: - return {"content": clean_delta} return None diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index a50883951..aaaeae550 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -52,6 +52,7 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): try: model, tokenizer = load(model_name, tokenizer_config=tokenizer_config) + return model, tokenizer except ValueError as e: # Fallback for models with non-standard tokenizers if "TokenizersBackend" in str(e) or "Tokenizer class" in str(e):