Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions tests/test_qwen3_xml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Functional tests for Qwen3XMLToolParser parsing logic."""

import json
import logging

import pytest

Expand Down Expand Up @@ -134,6 +135,43 @@ def test_no_tool_calls(self, parser):
result = parser.extract_tool_calls(text)
assert not result.tools_called

@pytest.mark.parametrize(
"text",
["<tool_call></>", "<tool_call></tool_call>", "<tool_call/>"],
)
def test_empty_wrapper_is_not_synthesized(self, parser, caplog, text):
request = {
"messages": [
{
"role": "user",
"content": "list files in /tmp/projects",
}
],
"tools": [
{
"type": "function",
"function": {
"name": "bash",
"description": "Run a shell command",
"parameters": {
"type": "object",
"properties": {"command": {"type": "string"}},
"required": ["command"],
},
},
},
],
}

with caplog.at_level(logging.WARNING):
result = parser.extract_tool_calls(text, request=request)

assert not result.tools_called
assert result.tool_calls == []
assert result.content == ""
assert "empty tool_call wrapper" in caplog.text
assert "ls -la" not in caplog.text

def test_multiline_parameter(self, parser):
text = (
"<tool_call>\n"
Expand Down Expand Up @@ -187,6 +225,42 @@ def test_streaming_produces_deltas(self, parser):
deltas.append(result)
assert len(deltas) > 0

@pytest.mark.parametrize(
"text",
["<tool_call></>", "<tool_call></tool_call>", "<tool_call/>"],
)
def test_streaming_empty_wrapper_does_not_emit_tool_call(
self, parser, caplog, text
):
request = {
"messages": [{"role": "user", "content": "list files in /tmp/projects"}],
"tools": [
{
"type": "function",
"function": {
"name": "bash",
"parameters": {
"type": "object",
"properties": {"command": {"type": "string"}},
"required": ["command"],
},
},
}
],
}

with caplog.at_level(logging.WARNING):
result = parser.extract_tool_calls_streaming(
"",
text,
text,
request=request,
)

assert result is None or not result.get("tool_calls")
assert "empty tool_call wrapper" in caplog.text
assert "ls -la" not in caplog.text


class TestMalformedXML:
"""Edge cases with malformed model output."""
Expand Down
274 changes: 274 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,280 @@ async def stream_chat(self, messages, **kwargs):
assert payloads[2]["choices"][0]["delta"]["content"] == "world"
assert payloads[2]["choices"][0]["finish_reason"] == "stop"

@pytest.mark.anyio
async def test_stream_after_tool_message_uses_cumulative_text_deltas(
self, monkeypatch
):
"""Post-tool streams should stay incremental without leaking bad deltas."""
from vllm_mlx.engine.base import GenerationOutput
from vllm_mlx.server import (
ChatCompletionRequest,
Message,
stream_chat_completion,
)
import vllm_mlx.server as server

class FakeEngine:
model_name = "fake-engine"

async def stream_chat(self, messages, **kwargs):
chunks = [
GenerationOutput(
text="Here",
new_text="Here~/",
finished=False,
),
GenerationOutput(
text="Here are the files:",
new_text=" |bad",
finished=False,
),
GenerationOutput(
text="Here are the files:\n- alpha\n- beta",
new_text=" `bad",
finished=True,
finish_reason="stop",
prompt_tokens=8,
completion_tokens=9,
),
]
for chunk in chunks:
yield chunk

monkeypatch.setattr(server, "_model_name", "served-model")
monkeypatch.setattr(server, "_reasoning_parser", None)
monkeypatch.setattr(server, "_enable_auto_tool_choice", False)
monkeypatch.setattr(server, "_tool_call_parser", None)
monkeypatch.setattr(server, "_tool_parser_instance", None)

request = ChatCompletionRequest(
model="served-model",
messages=[
Message(role="user", content="list files"),
Message(
role="assistant",
content=None,
tool_calls=[
{
"id": "call_123",
"type": "function",
"function": {
"name": "bash",
"arguments": '{"command": "find . -maxdepth 1"}',
},
}
],
),
Message(
role="tool",
content="./alpha\n./beta",
tool_call_id="call_123",
),
],
stream=True,
)

chunks = [
chunk
async for chunk in stream_chat_completion(
FakeEngine(), request.messages, request
)
]
payloads = [
json.loads(chunk.removeprefix("data: ").strip())
for chunk in chunks
if chunk != "data: [DONE]\n\n"
]
content_payloads = [
payload
for payload in payloads
if payload["choices"] and payload["choices"][0]["delta"].get("content")
]

assert [
payload["choices"][0]["delta"]["content"] for payload in content_payloads
] == ["Here", " are the files:", "\n- alpha\n- beta"]
assert "Here~/" not in json.dumps(payloads)
assert "bad" not in json.dumps(payloads)
assert content_payloads[-1]["choices"][0]["finish_reason"] == "stop"
assert content_payloads[-1]["usage"] == {
"prompt_tokens": 8,
"completion_tokens": 9,
"total_tokens": 17,
}

@pytest.mark.anyio
async def test_stream_after_tool_message_emits_divergent_cumulative_text(
self, monkeypatch
):
"""If cumulative text diverges, emit it instead of stale new_text."""
from vllm_mlx.engine.base import GenerationOutput
from vllm_mlx.server import (
ChatCompletionRequest,
Message,
stream_chat_completion,
)
import vllm_mlx.server as server

class FakeEngine:
model_name = "fake-engine"

async def stream_chat(self, messages, **kwargs):
chunks = [
GenerationOutput(
text="Draft answer",
new_text="Draft answer",
finished=False,
),
GenerationOutput(
text="Restarted clean answer",
new_text=" stale",
finished=True,
finish_reason="stop",
prompt_tokens=8,
completion_tokens=4,
),
]
for chunk in chunks:
yield chunk

monkeypatch.setattr(server, "_model_name", "served-model")
monkeypatch.setattr(server, "_reasoning_parser", None)
monkeypatch.setattr(server, "_enable_auto_tool_choice", False)
monkeypatch.setattr(server, "_tool_call_parser", None)
monkeypatch.setattr(server, "_tool_parser_instance", None)

request = ChatCompletionRequest(
model="served-model",
messages=[
Message(role="user", content="list files"),
Message(
role="assistant",
content=None,
tool_calls=[
{
"id": "call_123",
"type": "function",
"function": {
"name": "bash",
"arguments": '{"command": "find . -maxdepth 1"}',
},
}
],
),
Message(
role="tool",
content="./alpha\n./beta",
tool_call_id="call_123",
),
],
stream=True,
)

chunks = [
chunk
async for chunk in stream_chat_completion(
FakeEngine(), request.messages, request
)
]
payloads = [
json.loads(chunk.removeprefix("data: ").strip())
for chunk in chunks
if chunk != "data: [DONE]\n\n"
]
content_payloads = [
payload
for payload in payloads
if payload["choices"] and payload["choices"][0]["delta"].get("content")
]

assert [
payload["choices"][0]["delta"]["content"] for payload in content_payloads
] == ["Draft answer", "Restarted clean answer"]
assert "stale" not in json.dumps(payloads)
assert content_payloads[-1]["choices"][0]["finish_reason"] == "stop"

@pytest.mark.anyio
async def test_stream_empty_xml_tool_wrapper_does_not_emit_tool_call(
self, monkeypatch, caplog
):
"""Malformed empty XML wrappers should not become server-made tools."""
from vllm_mlx.engine.base import GenerationOutput
from vllm_mlx.server import (
ChatCompletionRequest,
Message,
stream_chat_completion,
)
import vllm_mlx.server as server

class FakeEngine:
model_name = "fake-engine"

async def stream_chat(self, messages, **kwargs):
chunks = [
GenerationOutput(
text="<tool_call>",
new_text="<tool_call>",
finished=False,
),
GenerationOutput(
text="<tool_call></tool_call>",
new_text="</tool_call>",
finished=True,
finish_reason="stop",
prompt_tokens=8,
completion_tokens=2,
),
]
for chunk in chunks:
yield chunk

monkeypatch.setattr(server, "_model_name", "served-model")
monkeypatch.setattr(server, "_reasoning_parser", None)
monkeypatch.setattr(server, "_enable_auto_tool_choice", True)
monkeypatch.setattr(server, "_tool_call_parser", "qwen3_xml")
monkeypatch.setattr(server, "_tool_parser_instance", None)

request = ChatCompletionRequest(
model="served-model",
messages=[Message(role="user", content="list files in /tmp/projects")],
tools=[
{
"type": "function",
"function": {
"name": "bash",
"description": "Run a shell command",
"parameters": {
"type": "object",
"properties": {"command": {"type": "string"}},
"required": ["command"],
},
},
}
],
stream=True,
)

with caplog.at_level("WARNING"):
chunks = [
chunk
async for chunk in stream_chat_completion(
FakeEngine(), request.messages, request
)
]

payloads = [
json.loads(chunk.removeprefix("data: ").strip())
for chunk in chunks
if chunk != "data: [DONE]\n\n"
]
serialized = json.dumps(payloads)

assert chunks[-1] == "data: [DONE]\n\n"
assert "tool_calls" not in serialized
assert "ls -la" not in serialized
assert "empty tool_call wrapper" in caplog.text

@pytest.mark.anyio
async def test_auto_parser_streams_bare_bracket_tool_calls(self, monkeypatch):
"""Bare bracket tool calls should stream as structured tool_calls."""
Expand Down
2 changes: 2 additions & 0 deletions vllm_mlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2155,12 +2155,14 @@ def _process_batch_responses(
detok = self._get_detokenizer(request_id)
detok.add_token(response.token)
new_text = detok.last_segment
output_text = detok.text

# Create output
output = RequestOutput(
request_id=request_id,
new_token_ids=[response.token],
new_text=new_text,
output_text=output_text if response.finish_reason != "stop" else "",
output_token_ids=list(request.output_token_ids),
prompt_tokens=request.num_prompt_tokens,
completion_tokens=request.num_output_tokens,
Expand Down
Loading