diff --git a/tests/test_qwen3_xml_parser.py b/tests/test_qwen3_xml_parser.py index c17a48afc..7fb0bf713 100644 --- a/tests/test_qwen3_xml_parser.py +++ b/tests/test_qwen3_xml_parser.py @@ -2,6 +2,7 @@ """Functional tests for Qwen3XMLToolParser parsing logic.""" import json +import logging import pytest @@ -134,6 +135,43 @@ def test_no_tool_calls(self, parser): result = parser.extract_tool_calls(text) assert not result.tools_called + @pytest.mark.parametrize( + "text", + ["", "", ""], + ) + def test_empty_wrapper_is_not_synthesized(self, parser, caplog, text): + request = { + "messages": [ + { + "role": "user", + "content": "list files in /tmp/projects", + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "bash", + "description": "Run a shell command", + "parameters": { + "type": "object", + "properties": {"command": {"type": "string"}}, + "required": ["command"], + }, + }, + }, + ], + } + + with caplog.at_level(logging.WARNING): + result = parser.extract_tool_calls(text, request=request) + + assert not result.tools_called + assert result.tool_calls == [] + assert result.content == "" + assert "empty tool_call wrapper" in caplog.text + assert "ls -la" not in caplog.text + def test_multiline_parameter(self, parser): text = ( "\n" @@ -187,6 +225,42 @@ def test_streaming_produces_deltas(self, parser): deltas.append(result) assert len(deltas) > 0 + @pytest.mark.parametrize( + "text", + ["", "", ""], + ) + def test_streaming_empty_wrapper_does_not_emit_tool_call( + self, parser, caplog, text + ): + request = { + "messages": [{"role": "user", "content": "list files in /tmp/projects"}], + "tools": [ + { + "type": "function", + "function": { + "name": "bash", + "parameters": { + "type": "object", + "properties": {"command": {"type": "string"}}, + "required": ["command"], + }, + }, + } + ], + } + + with caplog.at_level(logging.WARNING): + result = parser.extract_tool_calls_streaming( + "", + text, + text, + request=request, + ) + + assert result is None or not result.get("tool_calls") + assert "empty tool_call wrapper" in caplog.text + assert "ls -la" not in caplog.text + class TestMalformedXML: """Edge cases with malformed model output.""" diff --git a/tests/test_server.py b/tests/test_server.py index 80d42da6d..67cc632fa 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1819,6 +1819,280 @@ async def stream_chat(self, messages, **kwargs): assert payloads[2]["choices"][0]["delta"]["content"] == "world" assert payloads[2]["choices"][0]["finish_reason"] == "stop" + @pytest.mark.anyio + async def test_stream_after_tool_message_uses_cumulative_text_deltas( + self, monkeypatch + ): + """Post-tool streams should stay incremental without leaking bad deltas.""" + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + stream_chat_completion, + ) + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-engine" + + async def stream_chat(self, messages, **kwargs): + chunks = [ + GenerationOutput( + text="Here", + new_text="Here~/", + finished=False, + ), + GenerationOutput( + text="Here are the files:", + new_text=" |bad", + finished=False, + ), + GenerationOutput( + text="Here are the files:\n- alpha\n- beta", + new_text=" `bad", + finished=True, + finish_reason="stop", + prompt_tokens=8, + completion_tokens=9, + ), + ] + for chunk in chunks: + yield chunk + + monkeypatch.setattr(server, "_model_name", "served-model") + monkeypatch.setattr(server, "_reasoning_parser", None) + monkeypatch.setattr(server, "_enable_auto_tool_choice", False) + monkeypatch.setattr(server, "_tool_call_parser", None) + monkeypatch.setattr(server, "_tool_parser_instance", None) + + request = ChatCompletionRequest( + model="served-model", + messages=[ + Message(role="user", content="list files"), + Message( + role="assistant", + content=None, + tool_calls=[ + { + "id": "call_123", + "type": "function", + "function": { + "name": "bash", + "arguments": '{"command": "find . -maxdepth 1"}', + }, + } + ], + ), + Message( + role="tool", + content="./alpha\n./beta", + tool_call_id="call_123", + ), + ], + stream=True, + ) + + chunks = [ + chunk + async for chunk in stream_chat_completion( + FakeEngine(), request.messages, request + ) + ] + payloads = [ + json.loads(chunk.removeprefix("data: ").strip()) + for chunk in chunks + if chunk != "data: [DONE]\n\n" + ] + content_payloads = [ + payload + for payload in payloads + if payload["choices"] and payload["choices"][0]["delta"].get("content") + ] + + assert [ + payload["choices"][0]["delta"]["content"] for payload in content_payloads + ] == ["Here", " are the files:", "\n- alpha\n- beta"] + assert "Here~/" not in json.dumps(payloads) + assert "bad" not in json.dumps(payloads) + assert content_payloads[-1]["choices"][0]["finish_reason"] == "stop" + assert content_payloads[-1]["usage"] == { + "prompt_tokens": 8, + "completion_tokens": 9, + "total_tokens": 17, + } + + @pytest.mark.anyio + async def test_stream_after_tool_message_emits_divergent_cumulative_text( + self, monkeypatch + ): + """If cumulative text diverges, emit it instead of stale new_text.""" + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + stream_chat_completion, + ) + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-engine" + + async def stream_chat(self, messages, **kwargs): + chunks = [ + GenerationOutput( + text="Draft answer", + new_text="Draft answer", + finished=False, + ), + GenerationOutput( + text="Restarted clean answer", + new_text=" stale", + finished=True, + finish_reason="stop", + prompt_tokens=8, + completion_tokens=4, + ), + ] + for chunk in chunks: + yield chunk + + monkeypatch.setattr(server, "_model_name", "served-model") + monkeypatch.setattr(server, "_reasoning_parser", None) + monkeypatch.setattr(server, "_enable_auto_tool_choice", False) + monkeypatch.setattr(server, "_tool_call_parser", None) + monkeypatch.setattr(server, "_tool_parser_instance", None) + + request = ChatCompletionRequest( + model="served-model", + messages=[ + Message(role="user", content="list files"), + Message( + role="assistant", + content=None, + tool_calls=[ + { + "id": "call_123", + "type": "function", + "function": { + "name": "bash", + "arguments": '{"command": "find . -maxdepth 1"}', + }, + } + ], + ), + Message( + role="tool", + content="./alpha\n./beta", + tool_call_id="call_123", + ), + ], + stream=True, + ) + + chunks = [ + chunk + async for chunk in stream_chat_completion( + FakeEngine(), request.messages, request + ) + ] + payloads = [ + json.loads(chunk.removeprefix("data: ").strip()) + for chunk in chunks + if chunk != "data: [DONE]\n\n" + ] + content_payloads = [ + payload + for payload in payloads + if payload["choices"] and payload["choices"][0]["delta"].get("content") + ] + + assert [ + payload["choices"][0]["delta"]["content"] for payload in content_payloads + ] == ["Draft answer", "Restarted clean answer"] + assert "stale" not in json.dumps(payloads) + assert content_payloads[-1]["choices"][0]["finish_reason"] == "stop" + + @pytest.mark.anyio + async def test_stream_empty_xml_tool_wrapper_does_not_emit_tool_call( + self, monkeypatch, caplog + ): + """Malformed empty XML wrappers should not become server-made tools.""" + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + stream_chat_completion, + ) + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-engine" + + async def stream_chat(self, messages, **kwargs): + chunks = [ + GenerationOutput( + text="", + new_text="", + finished=False, + ), + GenerationOutput( + text="", + new_text="", + finished=True, + finish_reason="stop", + prompt_tokens=8, + completion_tokens=2, + ), + ] + for chunk in chunks: + yield chunk + + monkeypatch.setattr(server, "_model_name", "served-model") + monkeypatch.setattr(server, "_reasoning_parser", None) + monkeypatch.setattr(server, "_enable_auto_tool_choice", True) + monkeypatch.setattr(server, "_tool_call_parser", "qwen3_xml") + monkeypatch.setattr(server, "_tool_parser_instance", None) + + request = ChatCompletionRequest( + model="served-model", + messages=[Message(role="user", content="list files in /tmp/projects")], + tools=[ + { + "type": "function", + "function": { + "name": "bash", + "description": "Run a shell command", + "parameters": { + "type": "object", + "properties": {"command": {"type": "string"}}, + "required": ["command"], + }, + }, + } + ], + stream=True, + ) + + with caplog.at_level("WARNING"): + chunks = [ + chunk + async for chunk in stream_chat_completion( + FakeEngine(), request.messages, request + ) + ] + + payloads = [ + json.loads(chunk.removeprefix("data: ").strip()) + for chunk in chunks + if chunk != "data: [DONE]\n\n" + ] + serialized = json.dumps(payloads) + + assert chunks[-1] == "data: [DONE]\n\n" + assert "tool_calls" not in serialized + assert "ls -la" not in serialized + assert "empty tool_call wrapper" in caplog.text + @pytest.mark.anyio async def test_auto_parser_streams_bare_bracket_tool_calls(self, monkeypatch): """Bare bracket tool calls should stream as structured tool_calls.""" diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index bf9824f78..a9babeef2 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -2155,12 +2155,14 @@ def _process_batch_responses( detok = self._get_detokenizer(request_id) detok.add_token(response.token) new_text = detok.last_segment + output_text = detok.text # Create output output = RequestOutput( request_id=request_id, new_token_ids=[response.token], new_text=new_text, + output_text=output_text if response.finish_reason != "stop" else "", output_token_ids=list(request.output_token_ids), prompt_tokens=request.num_prompt_tokens, completion_tokens=request.num_output_tokens, diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 75590d9ba..c46ab5f01 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -5547,6 +5547,12 @@ async def stream_chat_completion( tool_calls_detected = False tool_markup_possible = False # Fast path: skip parsing until markers appear tool_parser = _get_streaming_tool_parser(request, engine) + request_dict = request.model_dump() if request is not None else {} + post_tool_stream = any( + isinstance(message, dict) and message.get("role") == "tool" + for message in request_dict.get("messages", []) + ) + decoded_text_seen = "" try: # Stream content @@ -5556,6 +5562,19 @@ async def stream_chat_completion( delta_text = output.new_text last_output = output + # Some post-tool engine paths expose trustworthy cumulative text + # while per-chunk new_text can contain stale decoder fragments. + # Preserve streaming by deriving only the newest suffix instead of + # waiting for the final chunk. + if post_tool_stream: + full_text = getattr(output, "text", "") or "" + if full_text.startswith(decoded_text_seen): + delta_text = full_text[len(decoded_text_seen) :] + decoded_text_seen = full_text + elif full_text: + delta_text = full_text + decoded_text_seen = full_text + # Track token counts from output (updated each chunk) if hasattr(output, "prompt_tokens") and output.prompt_tokens: prompt_tokens = output.prompt_tokens @@ -5610,9 +5629,19 @@ async def stream_chat_completion( tool_markup_possible = True tool_previous = tool_accumulated_text tool_accumulated_text += content - tool_result = tool_parser.extract_tool_calls_streaming( - tool_previous, tool_accumulated_text, content - ) + try: + tool_result = tool_parser.extract_tool_calls_streaming( + tool_previous, + tool_accumulated_text, + content, + request=request_dict, + ) + except TypeError as exc: + if "unexpected keyword argument 'request'" not in str(exc): + raise + tool_result = tool_parser.extract_tool_calls_streaming( + tool_previous, tool_accumulated_text, content + ) if tool_result is None: # Inside tool markup - suppress content output @@ -5639,7 +5668,7 @@ async def stream_chat_completion( tool_calls_detected = True # Coerce arguments against tool schemas tools = ( - request.model_dump().get("tools") + request_dict.get("tools") if request and request.tools else None ) @@ -5734,9 +5763,19 @@ async def stream_chat_completion( tool_markup_possible = True tool_previous = tool_accumulated_text tool_accumulated_text += delta_text - tool_result = tool_parser.extract_tool_calls_streaming( - tool_previous, tool_accumulated_text, delta_text - ) + try: + tool_result = tool_parser.extract_tool_calls_streaming( + tool_previous, + tool_accumulated_text, + delta_text, + request=request_dict, + ) + except TypeError as exc: + if "unexpected keyword argument 'request'" not in str(exc): + raise + tool_result = tool_parser.extract_tool_calls_streaming( + tool_previous, tool_accumulated_text, delta_text + ) if tool_result is None: # Inside tool markup - suppress output @@ -5747,7 +5786,7 @@ async def stream_chat_completion( tool_calls_detected = True # Coerce arguments against tool schemas tools = ( - request.model_dump().get("tools") + request_dict.get("tools") if request and request.tools else None ) @@ -5817,13 +5856,11 @@ async def stream_chat_completion( and not tool_calls_detected and _streaming_tool_markup_possible(tool_accumulated_text) ): - final_parse_result = tool_parser.extract_tool_calls(tool_accumulated_text) + final_parse_result = tool_parser.extract_tool_calls( + tool_accumulated_text, request_dict + ) if final_parse_result.tools_called: - tools = ( - request.model_dump().get("tools") - if request and request.tools - else None - ) + tools = request_dict.get("tools") if request and request.tools else None tool_chunk = ChatCompletionChunk( id=response_id, model=_model_name, diff --git a/vllm_mlx/tool_parsers/qwen3_xml_tool_parser.py b/vllm_mlx/tool_parsers/qwen3_xml_tool_parser.py index fc7d324a3..77cfb429a 100644 --- a/vllm_mlx/tool_parsers/qwen3_xml_tool_parser.py +++ b/vllm_mlx/tool_parsers/qwen3_xml_tool_parser.py @@ -44,6 +44,14 @@ logger = logging.getLogger(__name__) +def _is_empty_tool_wrapper(text: str) -> bool: + cleaned = text.strip() + return bool( + re.fullmatch(r"\s*", cleaned) + or re.fullmatch(r"", cleaned) + ) + + # --------------------------------------------------------------------------- # Shim types matching vLLM protocol types used by StreamingXMLToolCallParser. # These are lightweight replacements so the core parser works unchanged. @@ -1319,6 +1327,16 @@ def extract_tool_calls( # this guards against non-streaming paths or missing parser.) cleaned = self.strip_think_tags(model_output) + if _is_empty_tool_wrapper(cleaned): + logger.warning( + "Model emitted empty tool_call wrapper; treating it as content" + ) + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content="", + ) + self._xml_parser.reset_streaming_state() tools = self._wrap_tools(request) if tools: @@ -1365,6 +1383,12 @@ def extract_tool_calls_streaming( Returns dict with 'tool_calls' and/or 'content' keys, or None to suppress the chunk. """ + if _is_empty_tool_wrapper(current_text): + logger.warning( + "Model emitted empty tool_call wrapper; treating it as content" + ) + return None + if not previous_text: self._xml_parser.reset_streaming_state() tools = self._wrap_tools(request)