Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
323 changes: 323 additions & 0 deletions tests/test_bench_serve_workload_hardening.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
# SPDX-License-Identifier: Apache-2.0
"""
Targeted regression tests for the bench-serve workload runner.

The existing ``test_bench_serve.py`` already covers happy paths for workload
loading, sweep expansion, formatters, runner end-to-end, and basic quality
checks. Issue #499 highlighted that the load/validate/run/report code paths
in ``vllm_mlx/bench_serve.py`` are dense and worth pinning further. This
module fills the remaining corners around three areas the issue calls out:

- ``load_workload`` validation errors — every guard clause that rejects a
malformed workload JSON, plus tag-string normalisation.
- Streaming tool-call accumulation — ``accumulate_tool_calls`` /
``finalize_tool_calls`` chunk-boundary behaviour and ordering.
- ``validate_quality_checks`` diagnostics — tool-call argument validation
edge cases (invalid JSON, non-object, missing keys, count mismatch) and
the combination of ``no_tool_calls`` with content-length checks.
"""

from __future__ import annotations

import json
from pathlib import Path

import pytest

from vllm_mlx.bench_serve import (
accumulate_tool_calls,
finalize_tool_calls,
load_workload,
validate_quality_checks,
)

# ---------------------------------------------------------------------------
# load_workload — validation guard clauses
# ---------------------------------------------------------------------------


class TestLoadWorkloadValidation:
def _write(self, tmp_path: Path, payload) -> Path:
f = tmp_path / "workload.json"
f.write_text(json.dumps(payload))
return f

def test_root_must_be_object(self, tmp_path: Path):
f = tmp_path / "workload.json"
f.write_text(json.dumps(["not", "an", "object"]))
with pytest.raises(ValueError, match="root must be a JSON object"):
load_workload(f)

def test_empty_cases_list_rejected(self, tmp_path: Path):
f = self._write(tmp_path, {"cases": []})
with pytest.raises(ValueError, match="non-empty cases list"):
load_workload(f)

def test_missing_cases_key_rejected(self, tmp_path: Path):
f = self._write(tmp_path, {"defaults": {}})
with pytest.raises(ValueError, match="non-empty cases list"):
load_workload(f)

def test_defaults_must_be_object(self, tmp_path: Path):
f = self._write(
tmp_path,
{
"defaults": "not-an-object",
"cases": [{"id": "a", "messages": [{"role": "user", "content": "hi"}]}],
},
)
with pytest.raises(ValueError, match="defaults must be an object"):
load_workload(f)

def test_case_must_be_object(self, tmp_path: Path):
f = self._write(tmp_path, {"cases": ["not-an-object"]})
with pytest.raises(ValueError, match="case must be an object"):
load_workload(f)

def test_extra_body_invalid_type_rejected(self, tmp_path: Path):
f = self._write(
tmp_path,
{
"cases": [
{
"id": "a",
"messages": [{"role": "user", "content": "hi"}],
"extra_body": "not-an-object",
}
]
},
)
# The case extra_body must be a dict; the loader raises on this case.
with pytest.raises((ValueError, TypeError)):
load_workload(f)

def test_tags_string_is_normalised_to_list(self, tmp_path: Path):
f = self._write(
tmp_path,
{
"cases": [
{
"id": "a",
"messages": [{"role": "user", "content": "hi"}],
"tags": "single-tag",
}
]
},
)
workload = load_workload(f)
assert workload.cases[0].tags == ("single-tag",)

def test_tags_invalid_type_rejected(self, tmp_path: Path):
f = self._write(
tmp_path,
{
"cases": [
{
"id": "a",
"messages": [{"role": "user", "content": "hi"}],
"tags": 42,
}
]
},
)
with pytest.raises(ValueError, match="tags must be"):
load_workload(f)


# ---------------------------------------------------------------------------
# Streaming tool-call accumulation
# ---------------------------------------------------------------------------


class TestStreamingToolCallAccumulation:
def test_concatenates_name_and_arguments_across_deltas(self):
acc: dict[int, dict] = {}
# First delta: id + name fragment.
accumulate_tool_calls(
acc,
[
{
"index": 0,
"id": "call_1",
"type": "function",
"function": {"name": "get_", "arguments": '{"city":'},
}
],
)
# Second delta: rest of name + rest of arguments.
accumulate_tool_calls(
acc,
[
{
"index": 0,
"function": {"name": "weather", "arguments": '"Tokyo"}'},
}
],
)
finalised = finalize_tool_calls(acc)
assert len(finalised) == 1
tc = finalised[0]
assert tc["id"] == "call_1"
assert tc["function"]["name"] == "get_weather"
assert json.loads(tc["function"]["arguments"]) == {"city": "Tokyo"}

def test_finalize_returns_index_sorted_even_when_inserted_out_of_order(self):
acc: dict[int, dict] = {}
# Indices arrive out of order: 2, 0, 1.
accumulate_tool_calls(
acc,
[{"index": 2, "id": "c", "function": {"name": "third"}}],
)
accumulate_tool_calls(
acc,
[{"index": 0, "id": "a", "function": {"name": "first"}}],
)
accumulate_tool_calls(
acc,
[{"index": 1, "id": "b", "function": {"name": "second"}}],
)
finalised = finalize_tool_calls(acc)
assert [tc["function"]["name"] for tc in finalised] == [
"first",
"second",
"third",
]

def test_id_set_on_first_delta_is_preserved_when_later_delta_omits_id(self):
acc: dict[int, dict] = {}
accumulate_tool_calls(
acc,
[{"index": 0, "id": "call_X", "function": {"name": "f"}}],
)
# Later delta omits the id (only sends argument fragment).
accumulate_tool_calls(
acc,
[{"index": 0, "function": {"arguments": "{}"}}],
)
assert acc[0]["id"] == "call_X"

def test_default_index_is_zero_when_omitted(self):
"""OpenAI's spec says ``index`` is required, but mlx-lm style streams
sometimes omit it on the first delta. Accumulator must default to 0
rather than raise."""
acc: dict[int, dict] = {}
accumulate_tool_calls(
acc,
[{"id": "call_1", "function": {"name": "f", "arguments": "{}"}}],
)
assert 0 in acc
assert acc[0]["function"]["name"] == "f"


# ---------------------------------------------------------------------------
# validate_quality_checks — diagnostics for tool-call argument checks
# ---------------------------------------------------------------------------


class TestQualityCheckDiagnostics:
def test_tool_call_count_mismatch_reports_actual_count(self):
ok, issues = validate_quality_checks(
finish_reason="stop",
content="ignored",
checks={"tool_call_count": 2},
tool_calls=[{"function": {"name": "f", "arguments": "{}"}}],
)
assert not ok
assert any("expected 2" in issue and "got 1" in issue for issue in issues)

def test_tool_call_args_invalid_json_reports_issue(self):
ok, issues = validate_quality_checks(
finish_reason="stop",
content="ignored",
checks={"tool_call_args_required_keys": {"f": ["x"]}},
tool_calls=[{"function": {"name": "f", "arguments": "{not-json"}}],
)
assert not ok
assert any("invalid JSON" in issue for issue in issues)

def test_tool_call_args_non_object_reports_issue(self):
ok, issues = validate_quality_checks(
finish_reason="stop",
content="ignored",
checks={"tool_call_args_required_keys": {"f": ["x"]}},
tool_calls=[{"function": {"name": "f", "arguments": "[1, 2, 3]"}}],
)
assert not ok
assert any("not an object" in issue for issue in issues)

def test_tool_call_args_missing_keys_lists_what_is_missing(self):
ok, issues = validate_quality_checks(
finish_reason="stop",
content="ignored",
checks={"tool_call_args_required_keys": {"f": ["a", "b", "c"]}},
tool_calls=[{"function": {"name": "f", "arguments": '{"a": 1, "x": 2}'}}],
)
assert not ok
# The diagnostic must name the missing keys so an operator can fix
# the prompt or the check, not just say "something is missing".
joined = " ".join(issues)
assert "missing" in joined
assert "b" in joined and "c" in joined
assert "a" not in joined.split("missing")[1].split("]")[0]

def test_tool_call_args_missing_function_reports_issue(self):
ok, issues = validate_quality_checks(
finish_reason="stop",
content="ignored",
checks={"tool_call_args_required_keys": {"missing_fn": ["x"]}},
tool_calls=[{"function": {"name": "other_fn", "arguments": "{}"}}],
)
assert not ok
assert any("no tool call named" in issue for issue in issues)

def test_no_tool_calls_combines_with_other_checks(self):
ok, issues = validate_quality_checks(
finish_reason="stop",
content="hi",
checks={"no_tool_calls": True, "min_chars": 100},
tool_calls=[{"function": {"name": "f", "arguments": "{}"}}],
)
assert not ok
# Both checks fail; both must surface so the operator sees the full
# picture rather than chasing them one at a time.
joined = " ".join(issues)
assert "no_tool_calls" in joined
assert "min_chars" in joined

def test_invalid_required_regex_does_not_crash_runner(self):
ok, issues = validate_quality_checks(
finish_reason="stop",
content="hi",
checks={"required_regex": ["[unclosed"]},
tool_calls=None,
)
assert not ok
assert any("invalid required_regex" in issue for issue in issues)

def test_finish_reason_list_accepts_any_member(self):
# finish_reason="length" is treated as truncation by the basic check
# in validate_response, so we exercise a non-truncation alternative
# (tool_calls) that the test should accept when present in the list.
ok, issues = validate_quality_checks(
finish_reason="tool_calls",
content="",
checks={"finish_reason": ["stop", "tool_calls"]},
tool_calls=[{"function": {"name": "f", "arguments": "{}"}}],
)
assert ok, issues

def test_finish_reason_string_form_rejects_others(self):
# Single-string form: only "stop" is allowed; finish_reason="tool_calls"
# must surface as an explicit issue, not a generic basic-check failure.
ok, issues = validate_quality_checks(
finish_reason="tool_calls",
content="",
checks={"finish_reason": "stop"},
tool_calls=[{"function": {"name": "f", "arguments": "{}"}}],
)
assert not ok
assert any("finish_reason" in issue and "not in" in issue for issue in issues)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading