Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/test_batched_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,39 @@ def test_wraps_non_mapping_arguments_for_template_items(self):
assert normalized[0]["tool_calls"][0]["function"]["arguments"] == {
"value": ["not", "object"]
}

def test_closes_dangling_think_before_raw_tool_call(self):
messages = [
{
"role": "assistant",
"content": (
"<think>Need the weather tool.\n"
"<tool_call>\n"
"<function=get_weather>\n"
"<parameter=city>\nParis\n</parameter>\n"
"</function>\n"
"</tool_call>"
),
}
]

normalized = _normalize_tool_call_arguments_for_template(messages)
content = normalized[0]["content"]

assert "Need the weather tool.\n</think><tool_call>" in content
assert "<tool_call>" not in content.split("</think>", 1)[0]
assert messages[0]["content"].startswith("<think>Need")

def test_normalizes_pydantic_style_messages_without_stringifying(self):
class MessageLike:
def model_dump(self, exclude_none=False):
assert exclude_none is True
return {
"role": "assistant",
"content": "plain response",
"unused": None,
}

normalized = _normalize_tool_call_arguments_for_template([MessageLike()])

assert normalized == [{"role": "assistant", "content": "plain response"}]
27 changes: 2 additions & 25 deletions vllm_mlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import asyncio
import inspect
import json
import logging
import time
from collections.abc import AsyncIterator
Expand All @@ -27,36 +26,14 @@
cleanup_startup_cancellation,
run_blocking_startup_work,
)
from .chat_template_safety import normalize_messages_for_chat_template

logger = logging.getLogger(__name__)


def _normalize_tool_call_arguments_for_template(messages: list[dict]) -> list[dict]:
"""Normalize OpenAI tool-call replay for templates expecting mappings."""
normalized = json.loads(json.dumps(messages, default=str))
for message in normalized:
if message.get("role") != "assistant":
continue
tool_calls = message.get("tool_calls")
if not isinstance(tool_calls, list):
continue
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
continue
function = tool_call.get("function")
if not isinstance(function, dict):
continue
arguments = function.get("arguments")
if not isinstance(arguments, str):
continue
try:
parsed = json.loads(arguments)
except (json.JSONDecodeError, ValueError, TypeError):
parsed = {"value": arguments}
if not isinstance(parsed, dict):
parsed = {"value": parsed}
function["arguments"] = parsed
return normalized
return normalize_messages_for_chat_template(messages)


def _extract_media_from_messages(messages: list[dict[str, Any]]) -> tuple:
Expand Down
90 changes: 90 additions & 0 deletions vllm_mlx/engine/chat_template_safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
"""Safety normalization for messages before Jinja chat-template rendering."""

import json
from typing import Any


def _close_dangling_think_before_tool_call(content: str) -> str:
"""Keep raw tool XML out of an unterminated ``<think>`` section.

Qwen 3.6 can produce assistant history where ``<think>`` is opened and a
raw ``<tool_call>`` follows before ``</think>``. Rendering that history as-is
conditions the next turn as though the tool call is still reasoning. Close
the dangling thinking span immediately before the first tool call.

This mirrors the template-side repair described by Cheuk-Yiu Chan:
https://allanchan339.github.io/bug-fixes/2026/05/02/Qwen36-27B-updated-jinja.html
"""
if "<tool_call>" not in content or "<think>" not in content:
return content

last_think = content.rfind("<think>")
last_close = content.rfind("</think>")
tool_pos = content.find("<tool_call>")
if last_close >= last_think and last_close != -1:
return content
if tool_pos > last_think:
return content[:tool_pos] + "</think>" + content[tool_pos:]
return content + "</think>"


def _message_to_dict(message: Any) -> dict[str, Any] | Any:
"""Convert OpenAI message model objects without stringifying them."""
if isinstance(message, dict):
return dict(message)
model_dump = getattr(message, "model_dump", None)
if callable(model_dump):
return {
key: value
for key, value in model_dump(exclude_none=True).items()
if value is not None
}
legacy_dict = getattr(message, "dict", None)
if callable(legacy_dict):
return {k: v for k, v in legacy_dict().items() if v is not None}
return message


def normalize_messages_for_chat_template(messages: list[Any]) -> list[dict]:
"""Return a JSON-safe copy of messages for chat-template rendering.

Normalizations:
- close dangling ``<think>`` spans before raw ``<tool_call>`` XML in
assistant content
- convert OpenAI tool-call argument JSON strings to mappings for templates
that iterate argument keys
"""
normalized = json.loads(
json.dumps([_message_to_dict(message) for message in messages], default=str)
)
for message in normalized:
if not isinstance(message, dict):
continue
if message.get("role") != "assistant":
continue

content = message.get("content")
if isinstance(content, str):
message["content"] = _close_dangling_think_before_tool_call(content)

tool_calls = message.get("tool_calls")
if not isinstance(tool_calls, list):
continue
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
continue
function = tool_call.get("function")
if not isinstance(function, dict):
continue
arguments: Any = function.get("arguments")
if not isinstance(arguments, str):
continue
try:
parsed = json.loads(arguments)
except (json.JSONDecodeError, ValueError, TypeError):
parsed = {"value": arguments}
if not isinstance(parsed, dict):
parsed = {"value": parsed}
function["arguments"] = parsed
return normalized
11 changes: 7 additions & 4 deletions vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
cleanup_startup_cancellation,
run_blocking_startup_work,
)
from .chat_template_safety import normalize_messages_for_chat_template
from ..mlx_streams import bind_generation_streams

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -847,15 +848,16 @@ def run_stream():
template_kwargs.update(chat_template_kwargs)
if template_tools:
template_kwargs["tools"] = template_tools
safe_messages = normalize_messages_for_chat_template(messages)

try:
prompt = tokenizer.apply_chat_template(messages, **template_kwargs)
prompt = tokenizer.apply_chat_template(safe_messages, **template_kwargs)
except TypeError:
# Some templates don't support all kwargs
for key in ["tools", "enable_thinking", *chat_template_kwargs.keys()]:
if key in template_kwargs:
del template_kwargs[key]
prompt = tokenizer.apply_chat_template(messages, **template_kwargs)
prompt = tokenizer.apply_chat_template(safe_messages, **template_kwargs)
else:
prompt = "\n".join(f"{m['role']}: {m['content']}" for m in messages)
prompt += "\nassistant:"
Expand Down Expand Up @@ -1120,17 +1122,18 @@ async def _stream_generate_text(
template_kwargs.update(chat_template_kwargs)
if tools:
template_kwargs["tools"] = tools
safe_messages = normalize_messages_for_chat_template(messages)

try:
full_prompt = self._text_tokenizer.apply_chat_template(
messages, **template_kwargs
safe_messages, **template_kwargs
)
except TypeError:
# Template doesn't accept tools= or enable_thinking=
template_kwargs.pop("tools", None)
template_kwargs.pop("enable_thinking", None)
full_prompt = self._text_tokenizer.apply_chat_template(
messages, **template_kwargs
safe_messages, **template_kwargs
)

sampler = make_sampler(
Expand Down
Loading