From f980bed1666e63ef3af760c222b1b60a53b74cdd Mon Sep 17 00:00:00 2001 From: Thump604 Date: Thu, 23 Apr 2026 21:56:21 -0500 Subject: [PATCH 1/8] Add bench-serve workload contracts --- README.md | 3 + docs/benchmarks/README.md | 51 +++ docs/reference/cli.md | 42 +++ examples/bench_serve_workload.json | 44 +++ tests/test_bench_serve.py | 323 +++++++++++++++- vllm_mlx/bench_serve.py | 572 ++++++++++++++++++++++++++++- vllm_mlx/cli.py | 51 ++- 7 files changed, 1080 insertions(+), 6 deletions(-) create mode 100644 examples/bench_serve_workload.json diff --git a/README.md b/README.md index 1865b8e34..6ef5e94de 100644 --- a/README.md +++ b/README.md @@ -188,6 +188,9 @@ python examples/tts_multilingual.py "Hola mundo" --lang es --play ```bash vllm-mlx bench-serve --url http://localhost:8000 --concurrency 5 --prompts prompts.txt --output results.csv + +# Product-style workload with quality checks and metrics deltas +vllm-mlx bench-serve --url http://localhost:8000 --workload workload.json --output results.json ``` ### Prometheus metrics diff --git a/docs/benchmarks/README.md b/docs/benchmarks/README.md index ab3220d8b..96f56af0d 100644 --- a/docs/benchmarks/README.md +++ b/docs/benchmarks/README.md @@ -19,8 +19,59 @@ vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit # Video benchmark vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video + +# Running-server prompt sweep with Prometheus metric deltas +vllm-mlx bench-serve --url http://localhost:8000 --prompts short,long \ + --concurrency 1,4 --output bench.json --format json + +# Running-server product-style workload with quality checks +vllm-mlx bench-serve --url http://localhost:8000 \ + --workload ./workload.json --output workload-results.json +``` + +## Contract Workloads + +`vllm-mlx bench-serve --workload` runs declarative cases against an already +running OpenAI-compatible server. This is intended for model and feature-stack +qualification, where raw speed is not enough and every run needs provenance, +quality checks, Prometheus metric deltas, and policy-timeout evidence. + +Example workload: + +```json +{ + "name": "writing-contract", + "description": "Representative long-form writing requests", + "defaults": { + "max_tokens": 32768, + "enable_thinking": true, + "policy_timeout_ms": 180000, + "checks": { + "finish_reason": "stop", + "forbidden_regex": ["", "prompt leakage"], + "min_chars": 500 + } + }, + "cases": [ + { + "id": "resume-golden-1", + "messages": [ + {"role": "user", "content": "Write the requested artifact..."} + ], + "tags": ["resume", "quality-floor"] + } + ] +} ``` +`policy_timeout_ms` is recorded as comparison evidence. It is not treated as a +hardware capability claim. Use it to answer "would this run fit my product +policy?" after first measuring what the model and serving stack can actually do. + +Workload output defaults to JSON for full provenance. Use `--format csv` for +flat per-case rows or `--format sql` to emit importable SQL for a local +benchmark database. + ## Standalone Test Defaults Standalone benchmark test scripts have built-in default models, so you can run: diff --git a/docs/reference/cli.md b/docs/reference/cli.md index 90a752aef..ba87d729e 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -5,6 +5,7 @@ | Command | Description | |---------|-------------| | `vllm-mlx serve` | Start OpenAI-compatible server | +| `vllm-mlx bench-serve` | Benchmark a running server with prompt sweeps or workload contracts | | `vllm-mlx-bench` | Run performance benchmarks | | `vllm-mlx-chat` | Start Gradio chat interface | @@ -132,6 +133,47 @@ curl http://localhost:8000/v1/models \ -H "Authorization: Bearer your-secret-key" ``` +## `vllm-mlx bench-serve` + +Benchmark a running vllm-mlx server over HTTP. Prompt-sweep mode measures +TTFT, TPOT, throughput, cache deltas, and Metal memory. Workload mode adds +per-case quality checks and comparison-only product policy timeouts. + +### Usage + +```bash +vllm-mlx bench-serve --url http://localhost:8000 [options] +``` + +### Options + +| Option | Description | Default | +|--------|-------------|---------| +| `--url` | Running server base URL | `http://127.0.0.1:8080` | +| `--model` | API model id | Auto-detect | +| `--prompts` | Comma-separated prompt sets or files for sweep mode | `short,medium,long` | +| `--workload` | Declarative workload JSON for contract mode | None | +| `--concurrency` | Comma-separated concurrency levels for sweep mode | `1,4` | +| `--max-tokens` | Max tokens for sweep mode | `256` | +| `--enable-thinking` | `true`, `false`, or `true,false` sweep | None | +| `--scrape-metrics` | Scrape `/metrics` before/after runs | `true` | +| `--include-content` | Include full generated content in workload JSON | False | +| `--request-timeout-s` | Workload HTTP transport timeout, `0` disables | `300` | +| `--output` | Output file | stdout | +| `--format` | Output format: `table`, `json`, `csv`, `sql` | `table` for prompt sweeps, `json` for workloads | + +### Examples + +```bash +# Prompt sweep +vllm-mlx bench-serve --url http://localhost:8000 \ + --prompts short,long --concurrency 1,4 --format json --output bench.json + +# Contract workload with quality checks and policy-timeout evidence +vllm-mlx bench-serve --url http://localhost:8000 \ + --workload workload.json --output workload-results.json +``` + ## `vllm-mlx-bench` Run performance benchmarks. diff --git a/examples/bench_serve_workload.json b/examples/bench_serve_workload.json new file mode 100644 index 000000000..3f093cc08 --- /dev/null +++ b/examples/bench_serve_workload.json @@ -0,0 +1,44 @@ +{ + "name": "quality-contract-smoke", + "description": "Small contract-style workload demonstrating bench-serve quality checks.", + "defaults": { + "max_tokens": 256, + "enable_thinking": false, + "policy_timeout_ms": 30000, + "checks": { + "finish_reason": "stop", + "forbidden_regex": [ + "", + "I cannot" + ], + "min_chars": 40 + } + }, + "cases": [ + { + "id": "python-palindrome", + "tags": [ + "code", + "quality" + ], + "messages": [ + { + "role": "user", + "content": "Write a Python function with type hints that checks whether a string is a palindrome. Include one short example." + } + ], + "checks": { + "finish_reason": "stop", + "required_regex": [ + "def ", + "-> bool" + ], + "forbidden_regex": [ + "", + "TODO" + ], + "min_chars": 80 + } + } + ] +} diff --git a/tests/test_bench_serve.py b/tests/test_bench_serve.py index e5b835220..1a64f2c75 100644 --- a/tests/test_bench_serve.py +++ b/tests/test_bench_serve.py @@ -13,6 +13,8 @@ RESULT_COLUMNS, BenchServeResult, SweepConfig, + Workload, + WorkloadCase, compute_request_metrics, compute_summary_stats, detect_hardware_fingerprint, @@ -21,11 +23,18 @@ format_json, format_sql, format_table, + format_workload_csv, + format_workload_payload, + format_workload_sql, load_prompt_set, + load_workload, parse_health_response, parse_metrics_text, parse_sse_line, parse_status_response, + run_workload_case, + summarize_workload_results, + validate_quality_checks, validate_response, ) @@ -166,6 +175,54 @@ def test_thinking_prompts_contain_reasoning_keywords(self): assert any(kw in combined for kw in keywords) +class TestWorkloadLoading: + """Tests for declarative contract workload loading.""" + + def test_load_workload_with_defaults(self, tmp_path: Path): + workload_file = tmp_path / "workload.json" + workload_file.write_text( + json.dumps( + { + "name": "writing-contract", + "defaults": { + "max_tokens": 128, + "enable_thinking": True, + "policy_timeout_ms": 180000, + "checks": {"forbidden_regex": [""]}, + }, + "cases": [ + { + "id": "case-a", + "messages": [ + {"role": "user", "content": "Write a short note."} + ], + "tags": ["quality"], + } + ], + } + ) + ) + + workload = load_workload(workload_file) + + assert workload.name == "writing-contract" + assert len(workload.cases) == 1 + case = workload.cases[0] + assert case.case_id == "case-a" + assert case.max_tokens == 128 + assert case.enable_thinking is True + assert case.policy_timeout_ms == 180000 + assert case.checks == {"forbidden_regex": [""]} + assert case.tags == ("quality",) + + def test_load_workload_rejects_missing_messages(self, tmp_path: Path): + workload_file = tmp_path / "workload.json" + workload_file.write_text(json.dumps({"cases": [{"id": "bad"}]})) + + with pytest.raises(ValueError, match="messages"): + load_workload(workload_file) + + # --------------------------------------------------------------------------- # TestExpandSweep # --------------------------------------------------------------------------- @@ -627,6 +684,190 @@ def test_http_error(self): assert "500" in msg +class TestQualityChecks: + """Unit tests for workload quality checks.""" + + def test_required_and_forbidden_regex(self): + ok, issues = validate_quality_checks( + "stop", + "Dear team,\nThis is a clean response.\nSincerely", + { + "required_regex": ["Dear team", "Sincerely"], + "forbidden_regex": ["", "unsupported claim"], + "min_chars": 20, + "finish_reason": "stop", + }, + ) + + assert ok is True + assert issues == [] + + def test_forbidden_regex_failure(self): + ok, issues = validate_quality_checks( + "stop", + "Visible answer\nhidden plan", + {"forbidden_regex": [""]}, + ) + + assert ok is False + assert any("forbidden_regex matched" in issue for issue in issues) + + def test_json_check(self): + ok, issues = validate_quality_checks( + "stop", + '{"title": "Engineer", "priority": 1}', + {"json": True}, + ) + + assert ok is True + assert issues == [] + + def test_json_check_failure(self): + ok, issues = validate_quality_checks("stop", "not json", {"json": True}) + + assert ok is False + assert any("not valid JSON" in issue for issue in issues) + + +class TestWorkloadSummary: + """Unit tests for workload summary aggregation.""" + + def test_summarize_workload_results_tracks_quality_and_policy(self): + results = [ + { + "ok": True, + "policy": {"within_timeout": True}, + "quality": {"ok": True}, + "metrics": { + "e2e_latency_ms": 100.0, + "ttft_ms": 10.0, + "gen_tps": 20.0, + }, + }, + { + "ok": True, + "policy": {"within_timeout": False}, + "quality": {"ok": True}, + "metrics": { + "e2e_latency_ms": 200.0, + "ttft_ms": 20.0, + "gen_tps": 10.0, + }, + }, + ] + + summary = summarize_workload_results(results) + + assert summary["passed"] is True + assert summary["quality_passed"] is True + assert summary["policy_timeout_passed"] is False + assert summary["failure_rate"] == pytest.approx(0.0) + assert summary["latency_ms"]["p50"] == pytest.approx(150.0) + + +class TestWorkloadRunner: + """Unit tests for contract workload execution records.""" + + def test_run_workload_case_records_metrics_policy_and_quality(self, monkeypatch): + metrics_responses = iter( + [ + {"cache_hits": 10, "cache_misses": 3, "tokens_saved": 100}, + {"cache_hits": 12, "cache_misses": 4, "tokens_saved": 140}, + ] + ) + + async def fake_scrape_metrics(client, base_url): + return next(metrics_responses) + + async def fake_stream_chat_completion(**kwargs): + assert kwargs["max_tokens"] == 64 + assert kwargs["enable_thinking"] is True + assert kwargs["extra_body"] == {"temperature": 0.6} + return { + "ttft_ms": 25.0, + "tpot_ms": 3.0, + "e2e_latency_ms": 2500.0, + "gen_tps": 12.5, + "prompt_tps": 500.0, + "prompt_tokens": 200, + "completion_tokens": 100, + "finish_reason": "stop", + "content": "Dear team,\nA clean benchmark artifact.\nSincerely", + } + + class FakeResponse: + def raise_for_status(self): + return None + + def json(self): + return { + "cache": {"type": "paged"}, + "metal": { + "active_gb": 42.0, + "peak_gb": 45.0, + "cache_gb": 4.0, + }, + } + + class FakeClient: + async def get(self, url): + return FakeResponse() + + monkeypatch.setattr("vllm_mlx.bench_serve.scrape_metrics", fake_scrape_metrics) + monkeypatch.setattr( + "vllm_mlx.bench_serve.stream_chat_completion", + fake_stream_chat_completion, + ) + + workload = Workload( + name="writing-contract", + description="", + defaults={"max_tokens": 64}, + cases=[], + ) + case = WorkloadCase( + case_id="resume-smoke", + messages=[{"role": "user", "content": "Write the artifact."}], + max_tokens=64, + enable_thinking=True, + extra_body={"temperature": 0.6}, + policy_timeout_ms=1800, + checks={ + "finish_reason": "stop", + "required_regex": ["Dear team", "Sincerely"], + "forbidden_regex": [""], + }, + tags=("resume",), + ) + + record = asyncio.run( + run_workload_case( + FakeClient(), + "http://server", + workload=workload, + case=case, + model="test-model", + runtime={"engine_type": "mllm"}, + hardware={"chip": "test"}, + run_id="run123", + timestamp="2026-04-23T00:00:00+00:00", + scrape=True, + include_content=True, + ) + ) + + assert record["ok"] is True + assert record["quality"]["ok"] is True + assert record["quality"]["issues"] == [] + assert record["quality"]["content"].startswith("Dear team") + assert record["policy"]["within_timeout"] is False + assert record["metrics"]["cache_hits"] == 2 + assert record["metrics"]["cache_misses"] == 1 + assert record["metrics"]["tokens_saved"] == 40 + assert record["metrics"]["metal"]["metal_active_gb"] == pytest.approx(42.0) + assert record["metrics"]["metal"]["cache_type"] == "paged" + + # --------------------------------------------------------------------------- # TestSummaryStats (Task 5) # --------------------------------------------------------------------------- @@ -707,6 +948,68 @@ def _make_sample_result(**overrides) -> BenchServeResult: return BenchServeResult(**defaults) +def _make_sample_workload_payload() -> dict: + return { + "run_id": "run123", + "timestamp": "2026-04-23T00:00:00+00:00", + "summary": {"passed": True}, + "results": [ + { + "run_id": "run123", + "timestamp": "2026-04-23T00:00:00+00:00", + "workload": "writing-contract", + "case_id": "resume-smoke", + "tags": ["resume", "quality"], + "model_id": "test-model", + "runtime": { + "engine_type": "mllm", + "model_type": "mllm", + "mtp_enabled": False, + "specprefill": False, + "kv_quant": "", + "cache_type": "paged", + }, + "hardware": { + "chip": "M4 Ultra", + "memory_gb": 256.0, + "os_version": "macOS-test", + }, + "request": { + "max_tokens": 32768, + "enable_thinking": True, + "extra_body": {"temperature": 0.6}, + }, + "policy": {"timeout_ms": 180000, "within_timeout": True}, + "metrics": { + "ttft_ms": 25.0, + "tpot_ms": 3.0, + "e2e_latency_ms": 2500.0, + "gen_tps": 12.5, + "prompt_tps": 500.0, + "prompt_tokens": 200, + "completion_tokens": 100, + "cache_hits": 2, + "cache_misses": 1, + "tokens_saved": 40, + "metal": { + "metal_active_gb": 42.0, + "metal_peak_gb": 45.0, + "metal_cache_gb": 4.0, + }, + }, + "quality": { + "ok": True, + "issues": [], + "finish_reason": "stop", + "content_chars": 512, + "content_preview": "Clean artifact", + }, + "ok": True, + } + ], + } + + # --------------------------------------------------------------------------- # TestFormatters (Task 6) # --------------------------------------------------------------------------- @@ -766,6 +1069,24 @@ def test_result_columns_match_dataclass(self): field_names = {f.name for f in dataclasses.fields(BenchServeResult)} assert set(RESULT_COLUMNS) == field_names + def test_format_workload_csv_parseable(self): + output = format_workload_csv(_make_sample_workload_payload()) + rows = list(csv.DictReader(output.splitlines())) + assert len(rows) == 1 + assert rows[0]["case_id"] == "resume-smoke" + assert rows[0]["model_id"] == "test-model" + assert rows[0]["request_extra_body"] == '{"temperature": 0.6}' + + def test_format_workload_sql_valid(self): + output = format_workload_sql(_make_sample_workload_payload()) + assert "CREATE TABLE IF NOT EXISTS bench_serve_workload" in output + assert "INSERT INTO bench_serve_workload" in output + assert "resume-smoke" in output + + def test_format_workload_payload_rejects_unknown_format(self): + with pytest.raises(ValueError, match="Unsupported workload output format"): + format_workload_payload(_make_sample_workload_payload(), "xml") + # --------------------------------------------------------------------------- # TestBenchServeIntegration (Task 8) @@ -807,7 +1128,7 @@ def test_smoke_run(self): def test_sql_output_is_valid(self): """Verify SQL output contains CREATE TABLE and INSERT.""" - from vllm_mlx.bench_serve import run_bench_serve, format_sql + from vllm_mlx.bench_serve import format_sql, run_bench_serve results = asyncio.run( run_bench_serve( diff --git a/vllm_mlx/bench_serve.py b/vllm_mlx/bench_serve.py index bec5d21a5..d31f20fc7 100644 --- a/vllm_mlx/bench_serve.py +++ b/vllm_mlx/bench_serve.py @@ -32,7 +32,7 @@ from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path -from typing import Optional +from typing import Any, Optional import httpx from tabulate import tabulate as _tabulate @@ -45,6 +45,30 @@ _BUILTIN_NAMES = {"short", "medium", "long", "thinking"} +@dataclass +class WorkloadCase: + """One declarative benchmark case for contract-style serving tests.""" + + case_id: str + messages: list[dict] + max_tokens: Optional[int] = None + enable_thinking: Optional[bool] = None + extra_body: Optional[dict] = None + policy_timeout_ms: Optional[int] = None + checks: Optional[dict] = None + tags: tuple[str, ...] = () + + +@dataclass +class Workload: + """Normalized bench-serve workload manifest.""" + + name: str + description: str + defaults: dict + cases: list[WorkloadCase] + + def load_prompt_set(name_or_path: str) -> list[list[dict]]: """Load a prompt set by builtin name or file path. @@ -110,6 +134,82 @@ def load_prompt_set(name_or_path: str) -> list[list[dict]]: ) +def _require_message_list(value: Any, *, label: str) -> list[dict]: + if not isinstance(value, list) or not value: + raise ValueError(f"{label}: messages must be a non-empty list") + for idx, message in enumerate(value): + if not isinstance(message, dict): + raise ValueError(f"{label}: message {idx} must be an object") + if "role" not in message or "content" not in message: + raise ValueError(f"{label}: message {idx} must include role and content") + return value + + +def load_workload(path: str | Path) -> Workload: + """Load a declarative serving benchmark workload. + + Workloads are for product-like qualification where each case can carry + request settings, comparison-only policy timeouts, and quality checks. + Timeout fields are metadata unless the runner explicitly uses them as a + transport limit; they are not treated as hardware capability claims. + """ + workload_path = Path(path).expanduser() + with workload_path.open() as fh: + raw = json.load(fh) + + if not isinstance(raw, dict): + raise ValueError("workload root must be a JSON object") + raw_cases = raw.get("cases") + if not isinstance(raw_cases, list) or not raw_cases: + raise ValueError("workload must contain a non-empty cases list") + + defaults = raw.get("defaults") or {} + if not isinstance(defaults, dict): + raise ValueError("workload defaults must be an object") + + cases: list[WorkloadCase] = [] + for idx, item in enumerate(raw_cases): + if not isinstance(item, dict): + raise ValueError(f"case {idx}: case must be an object") + case_id = str(item.get("id") or f"case_{idx + 1}") + messages = _require_message_list(item.get("messages"), label=case_id) + extra_body = item.get("extra_body", defaults.get("extra_body")) + if extra_body is not None and not isinstance(extra_body, dict): + raise ValueError(f"{case_id}: extra_body must be an object") + checks = item.get("checks", defaults.get("checks")) + if checks is not None and not isinstance(checks, dict): + raise ValueError(f"{case_id}: checks must be an object") + tags = item.get("tags", []) + if isinstance(tags, str): + tags = [tags] + if not isinstance(tags, list): + raise ValueError(f"{case_id}: tags must be a list or string") + + cases.append( + WorkloadCase( + case_id=case_id, + messages=messages, + max_tokens=item.get("max_tokens", defaults.get("max_tokens")), + enable_thinking=item.get( + "enable_thinking", defaults.get("enable_thinking") + ), + extra_body=extra_body, + policy_timeout_ms=item.get( + "policy_timeout_ms", defaults.get("policy_timeout_ms") + ), + checks=checks, + tags=tuple(str(tag) for tag in tags), + ) + ) + + return Workload( + name=str(raw.get("name") or workload_path.stem), + description=str(raw.get("description") or ""), + defaults=defaults, + cases=cases, + ) + + # --------------------------------------------------------------------------- # Result dataclass # --------------------------------------------------------------------------- @@ -701,6 +801,69 @@ def validate_response( return (True, "") +def validate_quality_checks( + finish_reason: Optional[str], + content: str, + checks: Optional[dict], + *, + status_code: int = 200, +) -> tuple[bool, list[str]]: + """Validate content against generic workload quality checks. + + Supported checks: + - ``finish_reason``: string or list of allowed finish reasons + - ``required_regex``: list of regex patterns that must match + - ``forbidden_regex``: list of regex patterns that must not match + - ``min_chars`` / ``max_chars``: length bounds + - ``json``: when true, content must parse as JSON + """ + basic_ok, basic_issue = validate_response(finish_reason, content, status_code) + issues: list[str] = [] if basic_ok else [basic_issue] + checks = checks or {} + + allowed_finish = checks.get("finish_reason") + if allowed_finish is not None: + allowed = ( + [allowed_finish] + if isinstance(allowed_finish, str) + else list(allowed_finish) + ) + if finish_reason not in allowed: + issues.append( + f"finish_reason {finish_reason!r} not in allowed set {allowed!r}" + ) + + min_chars = checks.get("min_chars") + if min_chars is not None and len(content) < int(min_chars): + issues.append(f"content shorter than min_chars={min_chars}") + + max_chars = checks.get("max_chars") + if max_chars is not None and len(content) > int(max_chars): + issues.append(f"content longer than max_chars={max_chars}") + + for pattern in checks.get("required_regex", []) or []: + try: + if not re.search(str(pattern), content, re.MULTILINE): + issues.append(f"required_regex did not match: {pattern}") + except re.error as exc: + issues.append(f"invalid required_regex {pattern!r}: {exc}") + + for pattern in checks.get("forbidden_regex", []) or []: + try: + if re.search(str(pattern), content, re.MULTILINE): + issues.append(f"forbidden_regex matched: {pattern}") + except re.error as exc: + issues.append(f"invalid forbidden_regex {pattern!r}: {exc}") + + if checks.get("json"): + try: + json.loads(content) + except json.JSONDecodeError as exc: + issues.append(f"content is not valid JSON: {exc}") + + return (not issues, issues) + + def compute_summary_stats(values: list[float]) -> dict: """Compute summary statistics over a list of floats. @@ -812,6 +975,241 @@ async def _single(messages: list[dict]) -> dict: return list(results) +def _summary_or_empty(values: list[float]) -> dict: + return compute_summary_stats(values) if values else {} + + +async def run_workload_case( + client: httpx.AsyncClient, + base_url: str, + *, + workload: Workload, + case: WorkloadCase, + model: str, + runtime: dict, + hardware: dict, + run_id: str, + timestamp: str, + scrape: bool = True, + include_content: bool = False, +) -> dict: + """Run one workload case and return a JSON-serializable result.""" + metrics_before = await scrape_metrics(client, base_url) if scrape else {} + started_wall = datetime.now(timezone.utc).isoformat() + + try: + result = await stream_chat_completion( + client=client, + base_url=base_url, + messages=case.messages, + model=model, + max_tokens=int(case.max_tokens or workload.defaults.get("max_tokens", 256)), + enable_thinking=case.enable_thinking, + extra_body=case.extra_body, + ) + error = "" + except Exception as exc: + result = { + "ttft_ms": 0.0, + "tpot_ms": 0.0, + "e2e_latency_ms": 0.0, + "gen_tps": 0.0, + "prompt_tps": 0.0, + "prompt_tokens": 0, + "completion_tokens": 0, + "finish_reason": None, + "content": "", + } + error = str(exc) + + metrics_after = await scrape_metrics(client, base_url) if scrape else {} + status_after: dict = {} + try: + resp = await client.get(f"{base_url}/v1/status") + resp.raise_for_status() + status_after = resp.json() + except Exception: + status_after = {} + + cache_hits_delta = metrics_after.get("cache_hits", 0) - metrics_before.get( + "cache_hits", 0 + ) + cache_misses_delta = metrics_after.get("cache_misses", 0) - metrics_before.get( + "cache_misses", 0 + ) + tokens_saved_delta = metrics_after.get("tokens_saved", 0) - metrics_before.get( + "tokens_saved", 0 + ) + + content = str(result.get("content") or "") + quality_ok, quality_issues = validate_quality_checks( + result.get("finish_reason"), + content, + case.checks, + status_code=500 if error else 200, + ) + if error: + quality_issues.append(f"request error: {error}") + + if case.policy_timeout_ms is None: + within_policy_timeout = None + elif error: + within_policy_timeout = False + else: + within_policy_timeout = result["e2e_latency_ms"] <= case.policy_timeout_ms + + record = { + "run_id": run_id, + "timestamp": timestamp, + "started_at": started_wall, + "workload": workload.name, + "case_id": case.case_id, + "tags": list(case.tags), + "model_id": model, + "runtime": runtime, + "hardware": hardware, + "request": { + "max_tokens": int( + case.max_tokens or workload.defaults.get("max_tokens", 256) + ), + "enable_thinking": case.enable_thinking, + "extra_body": case.extra_body or {}, + "message_count": len(case.messages), + }, + "policy": { + "timeout_ms": case.policy_timeout_ms, + "within_timeout": within_policy_timeout, + "note": "comparison-only unless your product contract explicitly requires it", + }, + "metrics": { + "ttft_ms": result["ttft_ms"], + "tpot_ms": result["tpot_ms"], + "e2e_latency_ms": result["e2e_latency_ms"], + "gen_tps": result["gen_tps"], + "prompt_tps": result["prompt_tps"], + "prompt_tokens": result["prompt_tokens"], + "completion_tokens": result["completion_tokens"], + "cache_hits": cache_hits_delta, + "cache_misses": cache_misses_delta, + "tokens_saved": tokens_saved_delta, + "metal": parse_status_response(status_after), + }, + "quality": { + "ok": quality_ok, + "issues": quality_issues, + "finish_reason": result.get("finish_reason"), + "content_chars": len(content), + "content_preview": content[:240], + }, + "ok": quality_ok, + } + if include_content: + record["quality"]["content"] = content + return record + + +def summarize_workload_results(results: list[dict]) -> dict: + """Aggregate workload case records into stable qualification summary stats.""" + latencies = [r["metrics"]["e2e_latency_ms"] for r in results] + ttft = [r["metrics"]["ttft_ms"] for r in results] + gen_tps = [r["metrics"]["gen_tps"] for r in results] + quality_failures = [r for r in results if not r["quality"]["ok"]] + policy_trials = [ + r for r in results if r["policy"].get("within_timeout") is not None + ] + policy_failures = [ + r for r in policy_trials if r["policy"].get("within_timeout") is False + ] + failures = [r for r in results if not r["quality"]["ok"]] + return { + "case_count": len(results), + "passed": not failures, + "failure_count": len(failures), + "failure_rate": round(len(failures) / len(results), 4) if results else 0.0, + "quality_passed": not quality_failures, + "quality_failure_count": len(quality_failures), + "policy_timeout_passed": not policy_failures if policy_trials else None, + "policy_timeout_failure_count": ( + len(policy_failures) if policy_trials else None + ), + "latency_ms": _summary_or_empty(latencies), + "ttft_ms": _summary_or_empty(ttft), + "gen_tps": _summary_or_empty(gen_tps), + } + + +async def run_bench_serve_workload( + *, + url: str, + workload_path: str, + model: Optional[str] = None, + output_path: Optional[str] = None, + output_format: str = "json", + scrape: bool = True, + include_content: bool = False, + request_timeout_s: Optional[float] = 300.0, +) -> dict: + """Run a declarative workload against a running server. + + This is the contract-style counterpart to prompt sweeps: it keeps product + policy knobs in the manifest, records them as evidence, and measures what + the server actually does before anyone promotes a model or feature stack. + """ + workload = load_workload(workload_path) + run_id = str(uuid.uuid4())[:8] + timestamp = datetime.now(timezone.utc).isoformat() + timeout = httpx.Timeout(request_timeout_s) if request_timeout_s else None + + async with httpx.AsyncClient(timeout=timeout) as client: + runtime = await auto_detect_runtime(client, url) + hardware = detect_hardware_fingerprint() + model_id = model or runtime.get("model_id", "") + if not model_id: + raise ValueError("could not determine model ID; pass --model") + + records = [] + for case in workload.cases: + record = await run_workload_case( + client, + url, + workload=workload, + case=case, + model=model_id, + runtime=runtime, + hardware=hardware, + run_id=run_id, + timestamp=timestamp, + scrape=scrape, + include_content=include_content, + ) + records.append(record) + + payload = { + "run_id": run_id, + "timestamp": timestamp, + "workload": { + "name": workload.name, + "description": workload.description, + "path": str(Path(workload_path).expanduser()), + "defaults": workload.defaults, + }, + "transport": { + "request_timeout_s": request_timeout_s, + "note": "transport safety only; product policy timeouts live in workload cases", + }, + "summary": summarize_workload_results(records), + "results": records, + } + + rendered = format_workload_payload(payload, output_format) + if output_path: + Path(output_path).expanduser().write_text(rendered) + print(f"Workload results written to {output_path}") + else: + print(rendered) + return payload + + # --------------------------------------------------------------------------- # Task 6: Output formatters # --------------------------------------------------------------------------- @@ -958,6 +1356,178 @@ def format_sql(results: list[BenchServeResult]) -> str: return "\n".join(lines) +WORKLOAD_RESULT_COLUMNS = [ + "run_id", + "timestamp", + "workload", + "case_id", + "tags", + "model_id", + "chip", + "memory_gb", + "os_version", + "engine_type", + "model_type", + "mtp_enabled", + "specprefill", + "kv_quant", + "cache_type", + "request_max_tokens", + "request_enable_thinking", + "request_extra_body", + "policy_timeout_ms", + "within_policy_timeout", + "ttft_ms", + "tpot_ms", + "e2e_latency_ms", + "gen_tps", + "prompt_tps", + "prompt_tokens", + "completion_tokens", + "cache_hits", + "cache_misses", + "tokens_saved", + "metal_active_gb", + "metal_peak_gb", + "metal_cache_gb", + "quality_ok", + "quality_issues", + "finish_reason", + "content_chars", + "content_preview", +] + +_WORKLOAD_TABLE_COLUMNS = [ + "case_id", + "tags", + "quality_ok", + "within_policy_timeout", + "ttft_ms", + "gen_tps", + "e2e_latency_ms", + "cache_hits", + "tokens_saved", + "finish_reason", +] + + +def _workload_record_to_row(record: dict) -> dict: + runtime = record.get("runtime") or {} + hardware = record.get("hardware") or {} + request = record.get("request") or {} + policy = record.get("policy") or {} + metrics = record.get("metrics") or {} + metal = metrics.get("metal") or {} + quality = record.get("quality") or {} + return { + "run_id": record.get("run_id", ""), + "timestamp": record.get("timestamp", ""), + "workload": record.get("workload", ""), + "case_id": record.get("case_id", ""), + "tags": ",".join(record.get("tags") or []), + "model_id": record.get("model_id", ""), + "chip": hardware.get("chip", ""), + "memory_gb": hardware.get("memory_gb", 0.0), + "os_version": hardware.get("os_version", ""), + "engine_type": runtime.get("engine_type", ""), + "model_type": runtime.get("model_type", ""), + "mtp_enabled": runtime.get("mtp_enabled", False), + "specprefill": runtime.get("specprefill", False), + "kv_quant": runtime.get("kv_quant", ""), + "cache_type": runtime.get("cache_type", ""), + "request_max_tokens": request.get("max_tokens"), + "request_enable_thinking": request.get("enable_thinking"), + "request_extra_body": json.dumps( + request.get("extra_body") or {}, sort_keys=True + ), + "policy_timeout_ms": policy.get("timeout_ms"), + "within_policy_timeout": policy.get("within_timeout"), + "ttft_ms": metrics.get("ttft_ms", 0.0), + "tpot_ms": metrics.get("tpot_ms", 0.0), + "e2e_latency_ms": metrics.get("e2e_latency_ms", 0.0), + "gen_tps": metrics.get("gen_tps", 0.0), + "prompt_tps": metrics.get("prompt_tps", 0.0), + "prompt_tokens": metrics.get("prompt_tokens", 0), + "completion_tokens": metrics.get("completion_tokens", 0), + "cache_hits": metrics.get("cache_hits", 0), + "cache_misses": metrics.get("cache_misses", 0), + "tokens_saved": metrics.get("tokens_saved", 0), + "metal_active_gb": metal.get("metal_active_gb", 0.0), + "metal_peak_gb": metal.get("metal_peak_gb", 0.0), + "metal_cache_gb": metal.get("metal_cache_gb", 0.0), + "quality_ok": quality.get("ok", False), + "quality_issues": json.dumps(quality.get("issues") or []), + "finish_reason": quality.get("finish_reason"), + "content_chars": quality.get("content_chars", 0), + "content_preview": quality.get("content_preview", ""), + } + + +def format_workload_table(payload: dict) -> str: + rows = [] + for record in payload.get("results") or []: + row = _workload_record_to_row(record) + rows.append( + [ + round(value, 1) if isinstance(value, float) else value + for value in (row[col] for col in _WORKLOAD_TABLE_COLUMNS) + ] + ) + return _tabulate(rows, headers=_WORKLOAD_TABLE_COLUMNS, tablefmt="simple") + + +def format_workload_json(payload: dict) -> str: + return json.dumps(payload, indent=2) + + +def format_workload_csv(payload: dict) -> str: + buf = io.StringIO() + writer = csv_mod.DictWriter(buf, fieldnames=WORKLOAD_RESULT_COLUMNS) + writer.writeheader() + for record in payload.get("results") or []: + writer.writerow(_workload_record_to_row(record)) + return buf.getvalue() + + +_WORKLOAD_SQL_SCHEMA = ( + "run_id TEXT, timestamp TEXT, workload TEXT, case_id TEXT, tags TEXT, " + "model_id TEXT, chip TEXT, memory_gb REAL, os_version TEXT, " + "engine_type TEXT, model_type TEXT, mtp_enabled BOOLEAN, specprefill BOOLEAN, " + "kv_quant TEXT, cache_type TEXT, request_max_tokens INTEGER, " + "request_enable_thinking BOOLEAN, request_extra_body TEXT, " + "policy_timeout_ms INTEGER, within_policy_timeout BOOLEAN, " + "ttft_ms REAL, tpot_ms REAL, e2e_latency_ms REAL, gen_tps REAL, " + "prompt_tps REAL, prompt_tokens INTEGER, completion_tokens INTEGER, " + "cache_hits INTEGER, cache_misses INTEGER, tokens_saved INTEGER, " + "metal_active_gb REAL, metal_peak_gb REAL, metal_cache_gb REAL, " + "quality_ok BOOLEAN, quality_issues TEXT, finish_reason TEXT, " + "content_chars INTEGER, content_preview TEXT" +) + + +def format_workload_sql(payload: dict) -> str: + lines = [ + f"CREATE TABLE IF NOT EXISTS bench_serve_workload ({_WORKLOAD_SQL_SCHEMA});", + ] + for record in payload.get("results") or []: + row = _workload_record_to_row(record) + values = ", ".join(_sql_escape(row[col]) for col in WORKLOAD_RESULT_COLUMNS) + lines.append(f"INSERT INTO bench_serve_workload VALUES ({values});") + return "\n".join(lines) + + +def format_workload_payload(payload: dict, fmt: str = "json") -> str: + if fmt == "json": + return format_workload_json(payload) + if fmt == "csv": + return format_workload_csv(payload) + if fmt == "sql": + return format_workload_sql(payload) + if fmt == "table": + return format_workload_table(payload) + raise ValueError(f"Unsupported workload output format: {fmt}") + + # --------------------------------------------------------------------------- # Task 7: Main async orchestrator # --------------------------------------------------------------------------- diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index d7b6fd6e9..a2e7fd4b4 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -696,7 +696,26 @@ def bench_kv_cache_command(args): def bench_serve_command(args): """Run serving benchmark.""" import asyncio - from .bench_serve import run_bench_serve + + from .bench_serve import run_bench_serve, run_bench_serve_workload + + if args.workload: + request_timeout_s = ( + None if args.request_timeout_s <= 0 else args.request_timeout_s + ) + asyncio.run( + run_bench_serve_workload( + url=args.url, + workload_path=args.workload, + model=args.model, + output_path=args.output, + output_format=args.format or "json", + scrape=args.scrape_metrics == "true", + include_content=args.include_content, + request_timeout_s=request_timeout_s, + ) + ) + return prompt_sets = args.prompts.split(",") concurrencies = [int(c) for c in args.concurrency.split(",")] @@ -743,7 +762,7 @@ def bench_serve_command(args): thinking_values=thinking_values, extra_bodies=extra_bodies, output_path=args.output, - fmt=args.format, + fmt=args.format or "table", do_validate=args.validate == "true", scrape=args.scrape_metrics == "true", tag=args.tag, @@ -1342,6 +1361,16 @@ def create_parser() -> argparse.ArgumentParser: default=None, help="Model ID to benchmark (default: auto-detected from server)", ) + bench_serve_parser.add_argument( + "--workload", + type=str, + default=None, + help=( + "Path to a declarative workload JSON file. When set, bench-serve " + "runs contract-style cases with per-case quality checks and " + "comparison-only policy timeouts instead of the prompt sweep." + ), + ) bench_serve_parser.add_argument( "--prompts", type=str, @@ -1422,9 +1451,9 @@ def create_parser() -> argparse.ArgumentParser: bench_serve_parser.add_argument( "--format", type=str, - default="table", + default=None, choices=["table", "json", "csv", "sql"], - help="Output format (default: table)", + help="Output format (default: table for prompt sweeps, json for workloads)", ) bench_serve_parser.add_argument( "--validate", @@ -1440,6 +1469,20 @@ def create_parser() -> argparse.ArgumentParser: choices=["true", "false"], help="Scrape /metrics before and after each run (default: true)", ) + bench_serve_parser.add_argument( + "--include-content", + action="store_true", + help="Include full generated content in workload JSON output", + ) + bench_serve_parser.add_argument( + "--request-timeout-s", + type=float, + default=300.0, + help=( + "HTTP transport timeout for workload mode in seconds (default: 300). " + "Use 0 to disable; product policy timeouts belong in the workload." + ), + ) bench_serve_parser.add_argument( "--tag", type=str, From e8db1df6f73dade1d8f593e2348cce02489986d0 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Thu, 23 Apr 2026 22:05:07 -0500 Subject: [PATCH 2/8] Support workload request files --- docs/benchmarks/README.md | 23 ++++++++++++++++++ docs/reference/cli.md | 4 ++- tests/test_bench_serve.py | 45 ++++++++++++++++++++++++++++++++++ vllm_mlx/bench_serve.py | 51 ++++++++++++++++++++++++++++++++++++--- 4 files changed, 119 insertions(+), 4 deletions(-) diff --git a/docs/benchmarks/README.md b/docs/benchmarks/README.md index 96f56af0d..81e278195 100644 --- a/docs/benchmarks/README.md +++ b/docs/benchmarks/README.md @@ -64,6 +64,29 @@ Example workload: } ``` +Cases can also reference an existing OpenAI-compatible request JSON instead of +duplicating a large prompt body: + +```json +{ + "name": "writing-contract", + "cases": [ + { + "id": "resume-golden-1", + "request_path": "./fixtures/job543_resume_precise_request.json", + "checks": { + "finish_reason": "stop", + "forbidden_regex": [""] + } + } + ] +} +``` + +When `request_path` is used, `messages`, `max_tokens`, `enable_thinking`, and +extra request-body fields such as `thinking_token_budget` are read from that +file. Case-level `extra_body` values override request-file values. + `policy_timeout_ms` is recorded as comparison evidence. It is not treated as a hardware capability claim. Use it to answer "would this run fit my product policy?" after first measuring what the model and serving stack can actually do. diff --git a/docs/reference/cli.md b/docs/reference/cli.md index ba87d729e..16ea1fc81 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -137,7 +137,9 @@ curl http://localhost:8000/v1/models \ Benchmark a running vllm-mlx server over HTTP. Prompt-sweep mode measures TTFT, TPOT, throughput, cache deltas, and Metal memory. Workload mode adds -per-case quality checks and comparison-only product policy timeouts. +per-case quality checks and comparison-only product policy timeouts. Workload +cases can embed `messages` directly or point `request_path` at an existing +OpenAI-compatible request JSON. ### Usage diff --git a/tests/test_bench_serve.py b/tests/test_bench_serve.py index 1a64f2c75..d901dfb26 100644 --- a/tests/test_bench_serve.py +++ b/tests/test_bench_serve.py @@ -222,6 +222,49 @@ def test_load_workload_rejects_missing_messages(self, tmp_path: Path): with pytest.raises(ValueError, match="messages"): load_workload(workload_file) + def test_load_workload_case_from_request_path(self, tmp_path: Path): + request_file = tmp_path / "request.json" + request_file.write_text( + json.dumps( + { + "model": "ignored-by-workload-runner", + "messages": [{"role": "user", "content": "Write the artifact."}], + "stream": True, + "max_tokens": 32768, + "enable_thinking": True, + "thinking_token_budget": 8192, + "temperature": 0.6, + } + ) + ) + workload_file = tmp_path / "workload.json" + workload_file.write_text( + json.dumps( + { + "cases": [ + { + "id": "resume", + "request_path": "request.json", + "extra_body": {"top_p": 0.95}, + } + ] + } + ) + ) + + workload = load_workload(workload_file) + case = workload.cases[0] + + assert case.messages == [{"role": "user", "content": "Write the artifact."}] + assert case.request_path == "request.json" + assert case.max_tokens == 32768 + assert case.enable_thinking is True + assert case.extra_body == { + "thinking_token_budget": 8192, + "temperature": 0.6, + "top_p": 0.95, + } + # --------------------------------------------------------------------------- # TestExpandSweep @@ -828,6 +871,7 @@ async def get(self, url): case = WorkloadCase( case_id="resume-smoke", messages=[{"role": "user", "content": "Write the artifact."}], + request_path="/tmp/request.json", max_tokens=64, enable_thinking=True, extra_body={"temperature": 0.6}, @@ -860,6 +904,7 @@ async def get(self, url): assert record["quality"]["ok"] is True assert record["quality"]["issues"] == [] assert record["quality"]["content"].startswith("Dear team") + assert record["request"]["request_path"] == "/tmp/request.json" assert record["policy"]["within_timeout"] is False assert record["metrics"]["cache_hits"] == 2 assert record["metrics"]["cache_misses"] == 1 diff --git a/vllm_mlx/bench_serve.py b/vllm_mlx/bench_serve.py index d31f20fc7..56564686b 100644 --- a/vllm_mlx/bench_serve.py +++ b/vllm_mlx/bench_serve.py @@ -51,6 +51,7 @@ class WorkloadCase: case_id: str messages: list[dict] + request_path: Optional[str] = None max_tokens: Optional[int] = None enable_thinking: Optional[bool] = None extra_body: Optional[dict] = None @@ -145,6 +146,29 @@ def _require_message_list(value: Any, *, label: str) -> list[dict]: return value +def _load_case_request(path: str, *, workload_path: Path, case_id: str) -> dict: + request_path = Path(path).expanduser() + if not request_path.is_absolute(): + request_path = workload_path.parent / request_path + with request_path.open() as fh: + request = json.load(fh) + if not isinstance(request, dict): + raise ValueError(f"{case_id}: request_path must point to a JSON object") + return request + + +def _request_extra_body(request: dict) -> dict: + reserved = { + "model", + "messages", + "max_tokens", + "stream", + "stream_options", + "enable_thinking", + } + return {key: value for key, value in request.items() if key not in reserved} + + def load_workload(path: str | Path) -> Workload: """Load a declarative serving benchmark workload. @@ -172,8 +196,21 @@ def load_workload(path: str | Path) -> Workload: if not isinstance(item, dict): raise ValueError(f"case {idx}: case must be an object") case_id = str(item.get("id") or f"case_{idx + 1}") - messages = _require_message_list(item.get("messages"), label=case_id) + request_path = item.get("request_path") + request_defaults: dict = {} + if request_path is not None: + request_defaults = _load_case_request( + str(request_path), workload_path=workload_path, case_id=case_id + ) + messages = _require_message_list( + item.get("messages", request_defaults.get("messages")), + label=case_id, + ) extra_body = item.get("extra_body", defaults.get("extra_body")) + request_extra = _request_extra_body(request_defaults) + if extra_body: + request_extra.update(extra_body) + extra_body = request_extra or None if extra_body is not None and not isinstance(extra_body, dict): raise ValueError(f"{case_id}: extra_body must be an object") checks = item.get("checks", defaults.get("checks")) @@ -189,9 +226,16 @@ def load_workload(path: str | Path) -> Workload: WorkloadCase( case_id=case_id, messages=messages, - max_tokens=item.get("max_tokens", defaults.get("max_tokens")), + request_path=str(request_path) if request_path is not None else None, + max_tokens=item.get( + "max_tokens", + request_defaults.get("max_tokens", defaults.get("max_tokens")), + ), enable_thinking=item.get( - "enable_thinking", defaults.get("enable_thinking") + "enable_thinking", + request_defaults.get( + "enable_thinking", defaults.get("enable_thinking") + ), ), extra_body=extra_body, policy_timeout_ms=item.get( @@ -1072,6 +1116,7 @@ async def run_workload_case( "max_tokens": int( case.max_tokens or workload.defaults.get("max_tokens", 256) ), + "request_path": case.request_path, "enable_thinking": case.enable_thinking, "extra_body": case.extra_body or {}, "message_count": len(case.messages), From 2bb2747da9209e99e00ab5d2a8c94db698300ab9 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Thu, 23 Apr 2026 23:22:30 -0500 Subject: [PATCH 3/8] Read current status memory keys in bench serve --- tests/test_bench_serve.py | 26 ++++++++++++++++++++------ vllm_mlx/bench_serve.py | 12 +++++++++--- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/tests/test_bench_serve.py b/tests/test_bench_serve.py index d901dfb26..4d1f3bf5b 100644 --- a/tests/test_bench_serve.py +++ b/tests/test_bench_serve.py @@ -511,9 +511,9 @@ def test_parse_status_response(self): data = { "model": "mlx-community/Llama-3.2-1B-Instruct-4bit", "metal": { - "active_gb": 12.5, - "peak_gb": 14.0, - "cache_gb": 2.0, + "active_memory_gb": 12.5, + "peak_memory_gb": 14.0, + "cache_memory_gb": 2.0, }, "cache": {"type": "paged"}, } @@ -524,6 +524,20 @@ def test_parse_status_response(self): assert result["metal_cache_gb"] == pytest.approx(2.0) assert result["cache_type"] == "paged" + def test_parse_status_response_accepts_legacy_metal_keys(self): + data = { + "model": "legacy-server", + "metal": { + "active_gb": 12.5, + "peak_gb": 14.0, + "cache_gb": 2.0, + }, + } + result = parse_status_response(data) + assert result["metal_active_gb"] == pytest.approx(12.5) + assert result["metal_peak_gb"] == pytest.approx(14.0) + assert result["metal_cache_gb"] == pytest.approx(2.0) + def test_parse_status_no_metal(self): data = {"model": "some-model"} result = parse_status_response(data) @@ -846,9 +860,9 @@ def json(self): return { "cache": {"type": "paged"}, "metal": { - "active_gb": 42.0, - "peak_gb": 45.0, - "cache_gb": 4.0, + "active_memory_gb": 42.0, + "peak_memory_gb": 45.0, + "cache_memory_gb": 4.0, }, } diff --git a/vllm_mlx/bench_serve.py b/vllm_mlx/bench_serve.py index 56564686b..965123e33 100644 --- a/vllm_mlx/bench_serve.py +++ b/vllm_mlx/bench_serve.py @@ -402,9 +402,15 @@ def parse_status_response(data: dict) -> dict: cache = data.get("cache") or {} return { "model": data.get("model", ""), - "metal_active_gb": float(metal.get("active_gb") or 0.0), - "metal_peak_gb": float(metal.get("peak_gb") or 0.0), - "metal_cache_gb": float(metal.get("cache_gb") or 0.0), + "metal_active_gb": float( + metal.get("active_memory_gb") or metal.get("active_gb") or 0.0 + ), + "metal_peak_gb": float( + metal.get("peak_memory_gb") or metal.get("peak_gb") or 0.0 + ), + "metal_cache_gb": float( + metal.get("cache_memory_gb") or metal.get("cache_gb") or 0.0 + ), "cache_type": cache.get("type", "") or "", } From af365f78bba634195fefc13cb5c7fb2f71c85221 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Fri, 24 Apr 2026 03:11:13 -0500 Subject: [PATCH 4/8] Add workload variance summaries --- README.md | 2 +- docs/benchmarks/README.md | 7 +++ docs/reference/cli.md | 9 ++-- tests/test_bench_serve.py | 104 +++++++++++++++++++++++++++++++++++++- vllm_mlx/bench_serve.py | 101 +++++++++++++++++++++++++++++------- vllm_mlx/cli.py | 3 +- 6 files changed, 201 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 6ef5e94de..e52b1241e 100644 --- a/README.md +++ b/README.md @@ -190,7 +190,7 @@ python examples/tts_multilingual.py "Hola mundo" --lang es --play vllm-mlx bench-serve --url http://localhost:8000 --concurrency 5 --prompts prompts.txt --output results.csv # Product-style workload with quality checks and metrics deltas -vllm-mlx bench-serve --url http://localhost:8000 --workload workload.json --output results.json +vllm-mlx bench-serve --url http://localhost:8000 --workload workload.json --repetitions 5 --output results.json ``` ### Prometheus metrics diff --git a/docs/benchmarks/README.md b/docs/benchmarks/README.md index 81e278195..fcc46fb94 100644 --- a/docs/benchmarks/README.md +++ b/docs/benchmarks/README.md @@ -35,6 +35,8 @@ vllm-mlx bench-serve --url http://localhost:8000 \ running OpenAI-compatible server. This is intended for model and feature-stack qualification, where raw speed is not enough and every run needs provenance, quality checks, Prometheus metric deltas, and policy-timeout evidence. +Use `--repetitions` to measure variance; workload summaries report per-case +sample counts, failure rates, and min/median/max latency and throughput. Example workload: @@ -95,6 +97,11 @@ Workload output defaults to JSON for full provenance. Use `--format csv` for flat per-case rows or `--format sql` to emit importable SQL for a local benchmark database. +```bash +vllm-mlx bench-serve --url http://localhost:8000 \ + --workload ./workload.json --repetitions 5 --output workload-results.json +``` + ## Standalone Test Defaults Standalone benchmark test scripts have built-in default models, so you can run: diff --git a/docs/reference/cli.md b/docs/reference/cli.md index 16ea1fc81..c14876bef 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -137,9 +137,9 @@ curl http://localhost:8000/v1/models \ Benchmark a running vllm-mlx server over HTTP. Prompt-sweep mode measures TTFT, TPOT, throughput, cache deltas, and Metal memory. Workload mode adds -per-case quality checks and comparison-only product policy timeouts. Workload -cases can embed `messages` directly or point `request_path` at an existing -OpenAI-compatible request JSON. +per-case quality checks, repeated samples for variance, and comparison-only +product policy timeouts. Workload cases can embed `messages` directly or point +`request_path` at an existing OpenAI-compatible request JSON. ### Usage @@ -157,6 +157,7 @@ vllm-mlx bench-serve --url http://localhost:8000 [options] | `--workload` | Declarative workload JSON for contract mode | None | | `--concurrency` | Comma-separated concurrency levels for sweep mode | `1,4` | | `--max-tokens` | Max tokens for sweep mode | `256` | +| `--repetitions` | Repetitions per sweep configuration or workload case | `3` | | `--enable-thinking` | `true`, `false`, or `true,false` sweep | None | | `--scrape-metrics` | Scrape `/metrics` before/after runs | `true` | | `--include-content` | Include full generated content in workload JSON | False | @@ -173,7 +174,7 @@ vllm-mlx bench-serve --url http://localhost:8000 \ # Contract workload with quality checks and policy-timeout evidence vllm-mlx bench-serve --url http://localhost:8000 \ - --workload workload.json --output workload-results.json + --workload workload.json --repetitions 5 --output workload-results.json ``` ## `vllm-mlx-bench` diff --git a/tests/test_bench_serve.py b/tests/test_bench_serve.py index 4d1f3bf5b..e59fa7b86 100644 --- a/tests/test_bench_serve.py +++ b/tests/test_bench_serve.py @@ -32,6 +32,7 @@ parse_metrics_text, parse_sse_line, parse_status_response, + run_bench_serve_workload, run_workload_case, summarize_workload_results, validate_quality_checks, @@ -793,8 +794,10 @@ def test_summarize_workload_results_tracks_quality_and_policy(self): results = [ { "ok": True, + "case_id": "case-a", + "repetition": 0, "policy": {"within_timeout": True}, - "quality": {"ok": True}, + "quality": {"ok": True, "content_chars": 120}, "metrics": { "e2e_latency_ms": 100.0, "ttft_ms": 10.0, @@ -803,8 +806,10 @@ def test_summarize_workload_results_tracks_quality_and_policy(self): }, { "ok": True, + "case_id": "case-a", + "repetition": 1, "policy": {"within_timeout": False}, - "quality": {"ok": True}, + "quality": {"ok": True, "content_chars": 160}, "metrics": { "e2e_latency_ms": 200.0, "ttft_ms": 20.0, @@ -820,6 +825,11 @@ def test_summarize_workload_results_tracks_quality_and_policy(self): assert summary["policy_timeout_passed"] is False assert summary["failure_rate"] == pytest.approx(0.0) assert summary["latency_ms"]["p50"] == pytest.approx(150.0) + assert summary["unique_case_count"] == 1 + assert summary["repetition_count"] == 2 + assert summary["case_summaries"]["case-a"]["repetitions"] == [0, 1] + assert summary["case_summaries"]["case-a"]["latency_ms"]["min"] == 100.0 + assert summary["case_summaries"]["case-a"]["latency_ms"]["max"] == 200.0 class TestWorkloadRunner: @@ -909,6 +919,7 @@ async def get(self, url): hardware={"chip": "test"}, run_id="run123", timestamp="2026-04-23T00:00:00+00:00", + repetition=2, scrape=True, include_content=True, ) @@ -918,6 +929,7 @@ async def get(self, url): assert record["quality"]["ok"] is True assert record["quality"]["issues"] == [] assert record["quality"]["content"].startswith("Dear team") + assert record["repetition"] == 2 assert record["request"]["request_path"] == "/tmp/request.json" assert record["policy"]["within_timeout"] is False assert record["metrics"]["cache_hits"] == 2 @@ -926,6 +938,91 @@ async def get(self, url): assert record["metrics"]["metal"]["metal_active_gb"] == pytest.approx(42.0) assert record["metrics"]["metal"]["cache_type"] == "paged" + def test_run_bench_serve_workload_repeats_each_case(self, tmp_path, monkeypatch): + workload_file = tmp_path / "workload.json" + workload_file.write_text( + json.dumps( + { + "name": "repeat-contract", + "cases": [ + { + "id": "case-a", + "messages": [{"role": "user", "content": "A"}], + }, + { + "id": "case-b", + "messages": [{"role": "user", "content": "B"}], + }, + ], + } + ) + ) + observed = [] + + async def fake_auto_detect_runtime(client, url): + return {"model_id": "test-model"} + + def fake_detect_hardware_fingerprint(): + return {"chip": "test"} + + async def fake_run_workload_case(*args, **kwargs): + observed.append((kwargs["case"].case_id, kwargs["repetition"])) + return { + "run_id": kwargs["run_id"], + "timestamp": kwargs["timestamp"], + "workload": kwargs["workload"].name, + "case_id": kwargs["case"].case_id, + "repetition": kwargs["repetition"], + "tags": [], + "model_id": kwargs["model"], + "runtime": kwargs["runtime"], + "hardware": kwargs["hardware"], + "request": {}, + "policy": {"within_timeout": None}, + "metrics": { + "e2e_latency_ms": 100.0 + kwargs["repetition"], + "ttft_ms": 10.0, + "gen_tps": 20.0, + }, + "quality": {"ok": True, "content_chars": 20}, + "ok": True, + } + + monkeypatch.setattr( + "vllm_mlx.bench_serve.auto_detect_runtime", fake_auto_detect_runtime + ) + monkeypatch.setattr( + "vllm_mlx.bench_serve.detect_hardware_fingerprint", + fake_detect_hardware_fingerprint, + ) + monkeypatch.setattr( + "vllm_mlx.bench_serve.run_workload_case", fake_run_workload_case + ) + + payload = asyncio.run( + run_bench_serve_workload( + url="http://server", + workload_path=str(workload_file), + output_path=str(tmp_path / "results.json"), + output_format="json", + scrape=False, + request_timeout_s=None, + repetitions=3, + ) + ) + + assert observed == [ + ("case-a", 0), + ("case-b", 0), + ("case-a", 1), + ("case-b", 1), + ("case-a", 2), + ("case-b", 2), + ] + assert len(payload["results"]) == 6 + assert payload["workload"]["repetitions"] == 3 + assert payload["summary"]["case_summaries"]["case-a"]["sample_count"] == 3 + # --------------------------------------------------------------------------- # TestSummaryStats (Task 5) @@ -1018,6 +1115,7 @@ def _make_sample_workload_payload() -> dict: "timestamp": "2026-04-23T00:00:00+00:00", "workload": "writing-contract", "case_id": "resume-smoke", + "repetition": 0, "tags": ["resume", "quality"], "model_id": "test-model", "runtime": { @@ -1133,12 +1231,14 @@ def test_format_workload_csv_parseable(self): rows = list(csv.DictReader(output.splitlines())) assert len(rows) == 1 assert rows[0]["case_id"] == "resume-smoke" + assert rows[0]["repetition"] == "0" assert rows[0]["model_id"] == "test-model" assert rows[0]["request_extra_body"] == '{"temperature": 0.6}' def test_format_workload_sql_valid(self): output = format_workload_sql(_make_sample_workload_payload()) assert "CREATE TABLE IF NOT EXISTS bench_serve_workload" in output + assert "case_id TEXT, repetition INTEGER, tags TEXT" in output assert "INSERT INTO bench_serve_workload" in output assert "resume-smoke" in output diff --git a/vllm_mlx/bench_serve.py b/vllm_mlx/bench_serve.py index 965123e33..e26e88a93 100644 --- a/vllm_mlx/bench_serve.py +++ b/vllm_mlx/bench_serve.py @@ -1040,6 +1040,7 @@ async def run_workload_case( hardware: dict, run_id: str, timestamp: str, + repetition: int = 0, scrape: bool = True, include_content: bool = False, ) -> dict: @@ -1114,6 +1115,7 @@ async def run_workload_case( "started_at": started_wall, "workload": workload.name, "case_id": case.case_id, + "repetition": repetition, "tags": list(case.tags), "model_id": model, "runtime": runtime, @@ -1172,8 +1174,62 @@ def summarize_workload_results(results: list[dict]) -> dict: r for r in policy_trials if r["policy"].get("within_timeout") is False ] failures = [r for r in results if not r["quality"]["ok"]] + cases: dict[str, list[dict]] = {} + for result in results: + cases.setdefault(str(result.get("case_id", "")), []).append(result) + + case_summaries = {} + for case_id, case_results in sorted(cases.items()): + case_quality_failures = [r for r in case_results if not r["quality"].get("ok")] + case_policy_trials = [ + r for r in case_results if r["policy"].get("within_timeout") is not None + ] + case_policy_failures = [ + r for r in case_policy_trials if r["policy"].get("within_timeout") is False + ] + case_summaries[case_id] = { + "sample_count": len(case_results), + "repetitions": sorted( + { + int(r.get("repetition", 0)) + for r in case_results + if r.get("repetition") is not None + } + ), + "passed": not case_quality_failures, + "failure_count": len(case_quality_failures), + "failure_rate": ( + round(len(case_quality_failures) / len(case_results), 4) + if case_results + else 0.0 + ), + "policy_timeout_passed": ( + not case_policy_failures if case_policy_trials else None + ), + "policy_timeout_failure_count": ( + len(case_policy_failures) if case_policy_trials else None + ), + "latency_ms": _summary_or_empty( + [r["metrics"]["e2e_latency_ms"] for r in case_results] + ), + "ttft_ms": _summary_or_empty( + [r["metrics"]["ttft_ms"] for r in case_results] + ), + "gen_tps": _summary_or_empty( + [r["metrics"]["gen_tps"] for r in case_results] + ), + "content_chars": _summary_or_empty( + [r["quality"].get("content_chars", 0) for r in case_results] + ), + } + return { "case_count": len(results), + "unique_case_count": len(cases), + "repetition_count": max( + (len(summary["repetitions"]) for summary in case_summaries.values()), + default=0, + ), "passed": not failures, "failure_count": len(failures), "failure_rate": round(len(failures) / len(results), 4) if results else 0.0, @@ -1186,6 +1242,7 @@ def summarize_workload_results(results: list[dict]) -> dict: "latency_ms": _summary_or_empty(latencies), "ttft_ms": _summary_or_empty(ttft), "gen_tps": _summary_or_empty(gen_tps), + "case_summaries": case_summaries, } @@ -1199,6 +1256,7 @@ async def run_bench_serve_workload( scrape: bool = True, include_content: bool = False, request_timeout_s: Optional[float] = 300.0, + repetitions: int = 1, ) -> dict: """Run a declarative workload against a running server. @@ -1206,6 +1264,9 @@ async def run_bench_serve_workload( policy knobs in the manifest, records them as evidence, and measures what the server actually does before anyone promotes a model or feature stack. """ + if repetitions < 1: + raise ValueError("repetitions must be at least 1") + workload = load_workload(workload_path) run_id = str(uuid.uuid4())[:8] timestamp = datetime.now(timezone.utc).isoformat() @@ -1219,21 +1280,23 @@ async def run_bench_serve_workload( raise ValueError("could not determine model ID; pass --model") records = [] - for case in workload.cases: - record = await run_workload_case( - client, - url, - workload=workload, - case=case, - model=model_id, - runtime=runtime, - hardware=hardware, - run_id=run_id, - timestamp=timestamp, - scrape=scrape, - include_content=include_content, - ) - records.append(record) + for repetition in range(repetitions): + for case in workload.cases: + record = await run_workload_case( + client, + url, + workload=workload, + case=case, + model=model_id, + runtime=runtime, + hardware=hardware, + run_id=run_id, + timestamp=timestamp, + repetition=repetition, + scrape=scrape, + include_content=include_content, + ) + records.append(record) payload = { "run_id": run_id, @@ -1243,6 +1306,7 @@ async def run_bench_serve_workload( "description": workload.description, "path": str(Path(workload_path).expanduser()), "defaults": workload.defaults, + "repetitions": repetitions, }, "transport": { "request_timeout_s": request_timeout_s, @@ -1412,6 +1476,7 @@ def format_sql(results: list[BenchServeResult]) -> str: "timestamp", "workload", "case_id", + "repetition", "tags", "model_id", "chip", @@ -1450,6 +1515,7 @@ def format_sql(results: list[BenchServeResult]) -> str: _WORKLOAD_TABLE_COLUMNS = [ "case_id", + "repetition", "tags", "quality_ok", "within_policy_timeout", @@ -1475,6 +1541,7 @@ def _workload_record_to_row(record: dict) -> dict: "timestamp": record.get("timestamp", ""), "workload": record.get("workload", ""), "case_id": record.get("case_id", ""), + "repetition": record.get("repetition", 0), "tags": ",".join(record.get("tags") or []), "model_id": record.get("model_id", ""), "chip": hardware.get("chip", ""), @@ -1541,8 +1608,8 @@ def format_workload_csv(payload: dict) -> str: _WORKLOAD_SQL_SCHEMA = ( - "run_id TEXT, timestamp TEXT, workload TEXT, case_id TEXT, tags TEXT, " - "model_id TEXT, chip TEXT, memory_gb REAL, os_version TEXT, " + "run_id TEXT, timestamp TEXT, workload TEXT, case_id TEXT, repetition INTEGER, " + "tags TEXT, model_id TEXT, chip TEXT, memory_gb REAL, os_version TEXT, " "engine_type TEXT, model_type TEXT, mtp_enabled BOOLEAN, specprefill BOOLEAN, " "kv_quant TEXT, cache_type TEXT, request_max_tokens INTEGER, " "request_enable_thinking BOOLEAN, request_extra_body TEXT, " diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index a2e7fd4b4..02b21186c 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -713,6 +713,7 @@ def bench_serve_command(args): scrape=args.scrape_metrics == "true", include_content=args.include_content, request_timeout_s=request_timeout_s, + repetitions=args.repetitions, ) ) return @@ -1422,7 +1423,7 @@ def create_parser() -> argparse.ArgumentParser: "--repetitions", type=int, default=3, - help="Number of repetitions per sweep configuration (default: 3)", + help="Number of repetitions per sweep configuration or workload case (default: 3)", ) bench_serve_parser.add_argument( "--warmup", From fa40b8f064ff33d64ce1e40c372517b703ca5674 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Fri, 24 Apr 2026 03:23:36 -0500 Subject: [PATCH 5/8] Add bench serve SQLite output --- README.md | 3 ++ docs/benchmarks/README.md | 6 +++- docs/reference/cli.md | 6 +++- tests/test_bench_serve.py | 21 +++++++++++++ vllm_mlx/bench_serve.py | 62 ++++++++++++++++++++++++++++++++++++++- vllm_mlx/cli.py | 4 +-- 6 files changed, 97 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index e52b1241e..bac893121 100644 --- a/README.md +++ b/README.md @@ -191,6 +191,9 @@ vllm-mlx bench-serve --url http://localhost:8000 --concurrency 5 --prompts promp # Product-style workload with quality checks and metrics deltas vllm-mlx bench-serve --url http://localhost:8000 --workload workload.json --repetitions 5 --output results.json + +# Append workload rows into SQLite for longitudinal comparisons +vllm-mlx bench-serve --url http://localhost:8000 --workload workload.json --repetitions 5 --format sqlite --output bench.db ``` ### Prometheus metrics diff --git a/docs/benchmarks/README.md b/docs/benchmarks/README.md index fcc46fb94..3e23c86ce 100644 --- a/docs/benchmarks/README.md +++ b/docs/benchmarks/README.md @@ -94,12 +94,16 @@ hardware capability claim. Use it to answer "would this run fit my product policy?" after first measuring what the model and serving stack can actually do. Workload output defaults to JSON for full provenance. Use `--format csv` for -flat per-case rows or `--format sql` to emit importable SQL for a local +flat per-case rows, `--format sql` to emit importable SQL, or +`--format sqlite --output bench.db` to append rows directly into a local benchmark database. ```bash vllm-mlx bench-serve --url http://localhost:8000 \ --workload ./workload.json --repetitions 5 --output workload-results.json + +vllm-mlx bench-serve --url http://localhost:8000 \ + --workload ./workload.json --repetitions 5 --format sqlite --output bench.db ``` ## Standalone Test Defaults diff --git a/docs/reference/cli.md b/docs/reference/cli.md index c14876bef..1b79e2357 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -163,7 +163,7 @@ vllm-mlx bench-serve --url http://localhost:8000 [options] | `--include-content` | Include full generated content in workload JSON | False | | `--request-timeout-s` | Workload HTTP transport timeout, `0` disables | `300` | | `--output` | Output file | stdout | -| `--format` | Output format: `table`, `json`, `csv`, `sql` | `table` for prompt sweeps, `json` for workloads | +| `--format` | Output format: `table`, `json`, `csv`, `sql`, `sqlite` | `table` for prompt sweeps, `json` for workloads | ### Examples @@ -175,6 +175,10 @@ vllm-mlx bench-serve --url http://localhost:8000 \ # Contract workload with quality checks and policy-timeout evidence vllm-mlx bench-serve --url http://localhost:8000 \ --workload workload.json --repetitions 5 --output workload-results.json + +# Append contract rows directly into SQLite for longitudinal comparisons +vllm-mlx bench-serve --url http://localhost:8000 \ + --workload workload.json --repetitions 5 --format sqlite --output bench.db ``` ## `vllm-mlx-bench` diff --git a/tests/test_bench_serve.py b/tests/test_bench_serve.py index e59fa7b86..2523ee3b8 100644 --- a/tests/test_bench_serve.py +++ b/tests/test_bench_serve.py @@ -5,6 +5,7 @@ import csv import json import os +import sqlite3 from pathlib import Path import pytest @@ -37,6 +38,8 @@ summarize_workload_results, validate_quality_checks, validate_response, + write_sqlite, + write_workload_sqlite, ) # --------------------------------------------------------------------------- @@ -1220,6 +1223,15 @@ def test_format_sql_handles_nan_inf(self): assert "inf" not in output.lower().split("'")[-1] assert "NULL" in output + def test_write_sqlite_creates_prompt_sweep_rows(self, tmp_path): + db_path = tmp_path / "bench.db" + write_sqlite([_make_sample_result(run_id="sqlite-run")], str(db_path)) + with sqlite3.connect(db_path) as conn: + row = conn.execute( + "SELECT run_id, model_id, repetition FROM bench_serve" + ).fetchone() + assert row == ("sqlite-run", "mlx-community/gemma-3-4b-it-4bit", 0) + def test_result_columns_match_dataclass(self): import dataclasses @@ -1242,6 +1254,15 @@ def test_format_workload_sql_valid(self): assert "INSERT INTO bench_serve_workload" in output assert "resume-smoke" in output + def test_write_workload_sqlite_creates_case_rows(self, tmp_path): + db_path = tmp_path / "workload.db" + write_workload_sqlite(_make_sample_workload_payload(), str(db_path)) + with sqlite3.connect(db_path) as conn: + row = conn.execute( + "SELECT case_id, repetition, quality_ok FROM bench_serve_workload" + ).fetchone() + assert row == ("resume-smoke", 0, 1) + def test_format_workload_payload_rejects_unknown_format(self): with pytest.raises(ValueError, match="Unsupported workload output format"): format_workload_payload(_make_sample_workload_payload(), "xml") diff --git a/vllm_mlx/bench_serve.py b/vllm_mlx/bench_serve.py index e26e88a93..f93b2c6e4 100644 --- a/vllm_mlx/bench_serve.py +++ b/vllm_mlx/bench_serve.py @@ -25,6 +25,7 @@ import math import platform import re +import sqlite3 import statistics import sys import time @@ -1316,6 +1317,13 @@ async def run_bench_serve_workload( "results": records, } + if output_format == "sqlite": + if not output_path: + raise ValueError("--output is required when --format sqlite") + write_workload_sqlite(payload, output_path) + print(f"Workload SQLite results written to {output_path}") + return payload + rendered = format_workload_payload(payload, output_format) if output_path: Path(output_path).expanduser().write_text(rendered) @@ -1471,6 +1479,40 @@ def format_sql(results: list[BenchServeResult]) -> str: return "\n".join(lines) +def _write_sqlite_rows( + output_path: str, + *, + table: str, + schema: str, + columns: list[str], + rows: list[dict], +) -> None: + """Append benchmark rows to a SQLite database.""" + db_path = Path(output_path).expanduser() + placeholders = ", ".join("?" for _ in columns) + column_list = ", ".join(columns) + values = [[row.get(col) for col in columns] for row in rows] + with sqlite3.connect(db_path) as conn: + conn.execute(f"CREATE TABLE IF NOT EXISTS {table} ({schema})") + if values: + conn.executemany( + f"INSERT INTO {table} ({column_list}) VALUES ({placeholders})", + values, + ) + conn.commit() + + +def write_sqlite(results: list[BenchServeResult], output_path: str) -> None: + rows = [_result_to_dict(r) for r in results] + _write_sqlite_rows( + output_path, + table="bench_serve", + schema=_SQL_SCHEMA, + columns=RESULT_COLUMNS, + rows=rows, + ) + + WORKLOAD_RESULT_COLUMNS = [ "run_id", "timestamp", @@ -1634,6 +1676,17 @@ def format_workload_sql(payload: dict) -> str: return "\n".join(lines) +def write_workload_sqlite(payload: dict, output_path: str) -> None: + rows = [_workload_record_to_row(record) for record in payload.get("results") or []] + _write_sqlite_rows( + output_path, + table="bench_serve_workload", + schema=_WORKLOAD_SQL_SCHEMA, + columns=WORKLOAD_RESULT_COLUMNS, + rows=rows, + ) + + def format_workload_payload(payload: dict, fmt: str = "json") -> str: if fmt == "json": return format_workload_json(payload) @@ -1692,7 +1745,7 @@ async def run_bench_serve( output_path: File path to write results to. If ``None``, prints to stdout. fmt: Output format — one of ``"table"``, ``"json"``, ``"csv"``, - ``"sql"``. + ``"sql"``, or ``"sqlite"``. do_validate: Whether to validate each response. scrape: Whether to scrape ``/metrics`` before and after each run. tag: Optional tag string stored in every result row. @@ -2046,6 +2099,13 @@ def _mean(key: str) -> float: results.append(result_obj) # 12. Format output + if fmt == "sqlite": + if not output_path: + raise ValueError("--output is required when --format sqlite") + write_sqlite(results, output_path) + print(f"\nSQLite results written to {output_path}") + return results + formatters = { "table": format_table, "json": format_json, diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 02b21186c..0bf413568 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -1453,8 +1453,8 @@ def create_parser() -> argparse.ArgumentParser: "--format", type=str, default=None, - choices=["table", "json", "csv", "sql"], - help="Output format (default: table for prompt sweeps, json for workloads)", + choices=["table", "json", "csv", "sql", "sqlite"], + help="Output format (default: table for prompt sweeps, json for workloads; sqlite requires --output)", ) bench_serve_parser.add_argument( "--validate", From 36c664ef9cad505cdef473793bbc551abb2d8154 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Fri, 24 Apr 2026 08:05:22 -0500 Subject: [PATCH 6/8] Add workload cache policy controls --- tests/test_bench_serve.py | 101 ++++++++++++++++++++++++++++++++++++++ vllm_mlx/bench_serve.py | 62 +++++++++++++++++++++++ vllm_mlx/cli.py | 11 +++++ 3 files changed, 174 insertions(+) diff --git a/tests/test_bench_serve.py b/tests/test_bench_serve.py index 2523ee3b8..8428e1117 100644 --- a/tests/test_bench_serve.py +++ b/tests/test_bench_serve.py @@ -935,6 +935,7 @@ async def get(self, url): assert record["repetition"] == 2 assert record["request"]["request_path"] == "/tmp/request.json" assert record["policy"]["within_timeout"] is False + assert record["cache_reset"] == {"attempted": False} assert record["metrics"]["cache_hits"] == 2 assert record["metrics"]["cache_misses"] == 1 assert record["metrics"]["tokens_saved"] == 40 @@ -1026,6 +1027,106 @@ async def fake_run_workload_case(*args, **kwargs): assert payload["workload"]["repetitions"] == 3 assert payload["summary"]["case_summaries"]["case-a"]["sample_count"] == 3 + def test_run_bench_serve_workload_clears_cache_before_case( + self, tmp_path, monkeypatch + ): + workload_file = tmp_path / "workload.json" + workload_file.write_text( + json.dumps( + { + "name": "cache-contract", + "cases": [ + { + "id": "case-a", + "messages": [{"role": "user", "content": "A"}], + }, + { + "id": "case-b", + "messages": [{"role": "user", "content": "B"}], + }, + ], + } + ) + ) + clear_events = [] + observed_resets = [] + + async def fake_auto_detect_runtime(client, url): + return {"model_id": "test-model"} + + def fake_detect_hardware_fingerprint(): + return {"chip": "test"} + + async def fake_clear_runtime_cache(client, url): + event = { + "attempted": True, + "ok": True, + "status_code": 200, + "response": {"sequence": len(clear_events)}, + "error": "", + } + clear_events.append(event) + return event + + async def fake_run_workload_case(*args, **kwargs): + observed_resets.append(kwargs["cache_reset"]) + return { + "run_id": kwargs["run_id"], + "timestamp": kwargs["timestamp"], + "workload": kwargs["workload"].name, + "case_id": kwargs["case"].case_id, + "repetition": kwargs["repetition"], + "tags": [], + "model_id": kwargs["model"], + "runtime": kwargs["runtime"], + "hardware": kwargs["hardware"], + "request": {}, + "policy": {"within_timeout": None}, + "metrics": { + "e2e_latency_ms": 100.0, + "ttft_ms": 10.0, + "gen_tps": 20.0, + }, + "quality": {"ok": True, "content_chars": 20}, + "ok": True, + } + + monkeypatch.setattr( + "vllm_mlx.bench_serve.auto_detect_runtime", fake_auto_detect_runtime + ) + monkeypatch.setattr( + "vllm_mlx.bench_serve.detect_hardware_fingerprint", + fake_detect_hardware_fingerprint, + ) + monkeypatch.setattr( + "vllm_mlx.bench_serve.clear_runtime_cache", fake_clear_runtime_cache + ) + monkeypatch.setattr( + "vllm_mlx.bench_serve.run_workload_case", fake_run_workload_case + ) + + payload = asyncio.run( + run_bench_serve_workload( + url="http://server", + workload_path=str(workload_file), + output_format="json", + scrape=False, + request_timeout_s=None, + repetitions=2, + cache_policy="before-case", + ) + ) + + assert len(clear_events) == 4 + assert observed_resets == clear_events + assert payload["cache_policy"]["mode"] == "before-case" + assert [event["scope"] for event in payload["cache_policy"]["events"]] == [ + "before-case", + "before-case", + "before-case", + "before-case", + ] + # --------------------------------------------------------------------------- # TestSummaryStats (Task 5) diff --git a/vllm_mlx/bench_serve.py b/vllm_mlx/bench_serve.py index f93b2c6e4..e80275b45 100644 --- a/vllm_mlx/bench_serve.py +++ b/vllm_mlx/bench_serve.py @@ -581,6 +581,38 @@ async def scrape_metrics(client: httpx.AsyncClient, base_url: str) -> dict: return {} +async def clear_runtime_cache(client: httpx.AsyncClient, base_url: str) -> dict: + """Clear server-side runtime caches and return a JSON-serializable event.""" + event: dict[str, Any] = { + "attempted": True, + "ok": False, + "status_code": 0, + "response": {}, + "error": "", + } + try: + resp = await client.delete(f"{base_url}/v1/cache") + event["status_code"] = resp.status_code + try: + event["response"] = resp.json() + except ValueError: + event["response"] = {"text": resp.text} + resp.raise_for_status() + event["ok"] = True + except Exception as exc: + event["error"] = str(exc) + return event + + +def _normalize_cache_policy(value: Optional[str]) -> str: + policy = (value or "preserve").strip().lower().replace("_", "-") + if policy not in {"preserve", "before-run", "before-case"}: + raise ValueError( + "cache policy must be one of: preserve, before-run, before-case" + ) + return policy + + # --------------------------------------------------------------------------- # Task 4: SSE streaming core + token counting + request timing # --------------------------------------------------------------------------- @@ -1044,6 +1076,7 @@ async def run_workload_case( repetition: int = 0, scrape: bool = True, include_content: bool = False, + cache_reset: Optional[dict] = None, ) -> dict: """Run one workload case and return a JSON-serializable result.""" metrics_before = await scrape_metrics(client, base_url) if scrape else {} @@ -1135,6 +1168,7 @@ async def run_workload_case( "within_timeout": within_policy_timeout, "note": "comparison-only unless your product contract explicitly requires it", }, + "cache_reset": cache_reset or {"attempted": False}, "metrics": { "ttft_ms": result["ttft_ms"], "tpot_ms": result["tpot_ms"], @@ -1258,6 +1292,7 @@ async def run_bench_serve_workload( include_content: bool = False, request_timeout_s: Optional[float] = 300.0, repetitions: int = 1, + cache_policy: Optional[str] = None, ) -> dict: """Run a declarative workload against a running server. @@ -1269,6 +1304,9 @@ async def run_bench_serve_workload( raise ValueError("repetitions must be at least 1") workload = load_workload(workload_path) + resolved_cache_policy = _normalize_cache_policy( + cache_policy or workload.defaults.get("cache_policy") + ) run_id = str(uuid.uuid4())[:8] timestamp = datetime.now(timezone.utc).isoformat() timeout = httpx.Timeout(request_timeout_s) if request_timeout_s else None @@ -1281,8 +1319,27 @@ async def run_bench_serve_workload( raise ValueError("could not determine model ID; pass --model") records = [] + cache_events = [] + if resolved_cache_policy == "before-run": + cache_events.append( + { + "scope": "before-run", + "event": await clear_runtime_cache(client, url), + } + ) for repetition in range(repetitions): for case in workload.cases: + cache_reset = None + if resolved_cache_policy == "before-case": + cache_reset = await clear_runtime_cache(client, url) + cache_events.append( + { + "scope": "before-case", + "case_id": case.case_id, + "repetition": repetition, + "event": cache_reset, + } + ) record = await run_workload_case( client, url, @@ -1296,6 +1353,7 @@ async def run_bench_serve_workload( repetition=repetition, scrape=scrape, include_content=include_content, + cache_reset=cache_reset, ) records.append(record) @@ -1313,6 +1371,10 @@ async def run_bench_serve_workload( "request_timeout_s": request_timeout_s, "note": "transport safety only; product policy timeouts live in workload cases", }, + "cache_policy": { + "mode": resolved_cache_policy, + "events": cache_events, + }, "summary": summarize_workload_results(records), "results": records, } diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 0bf413568..9580021f3 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -714,6 +714,7 @@ def bench_serve_command(args): include_content=args.include_content, request_timeout_s=request_timeout_s, repetitions=args.repetitions, + cache_policy=args.cache_policy, ) ) return @@ -1484,6 +1485,16 @@ def create_parser() -> argparse.ArgumentParser: "Use 0 to disable; product policy timeouts belong in the workload." ), ) + bench_serve_parser.add_argument( + "--cache-policy", + type=str, + default=None, + choices=["preserve", "before-run", "before-case"], + help=( + "Workload cache handling (default: workload defaults or preserve). " + "Use before-case for cold, uncontaminated per-case qualification." + ), + ) bench_serve_parser.add_argument( "--tag", type=str, From 63167855b068be0efd4491a4a94b99b5b727caa9 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Fri, 24 Apr 2026 08:32:05 -0500 Subject: [PATCH 7/8] Respect workload timeout defaults --- tests/test_bench_serve.py | 23 +++++++++++++++++++++++ vllm_mlx/bench_serve.py | 29 +++++++++++++++++++---------- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/tests/test_bench_serve.py b/tests/test_bench_serve.py index 8428e1117..427505a1e 100644 --- a/tests/test_bench_serve.py +++ b/tests/test_bench_serve.py @@ -219,6 +219,29 @@ def test_load_workload_with_defaults(self, tmp_path: Path): assert case.checks == {"forbidden_regex": [""]} assert case.tags == ("quality",) + def test_load_workload_null_policy_timeout_falls_back_to_default( + self, tmp_path: Path + ): + workload_file = tmp_path / "workload.json" + workload_file.write_text( + json.dumps( + { + "defaults": {"policy_timeout_ms": 180000}, + "cases": [ + { + "id": "case-a", + "messages": [{"role": "user", "content": "A"}], + "policy_timeout_ms": None, + } + ], + } + ) + ) + + workload = load_workload(workload_file) + + assert workload.cases[0].policy_timeout_ms == 180000 + def test_load_workload_rejects_missing_messages(self, tmp_path: Path): workload_file = tmp_path / "workload.json" workload_file.write_text(json.dumps({"cases": [{"id": "bad"}]})) diff --git a/vllm_mlx/bench_serve.py b/vllm_mlx/bench_serve.py index e80275b45..d7041d5ba 100644 --- a/vllm_mlx/bench_serve.py +++ b/vllm_mlx/bench_serve.py @@ -170,6 +170,13 @@ def _request_extra_body(request: dict) -> dict: return {key: value for key, value in request.items() if key not in reserved} +def _first_not_none(*values: Any) -> Any: + for value in values: + if value is not None: + return value + return None + + def load_workload(path: str | Path) -> Workload: """Load a declarative serving benchmark workload. @@ -228,19 +235,21 @@ def load_workload(path: str | Path) -> Workload: case_id=case_id, messages=messages, request_path=str(request_path) if request_path is not None else None, - max_tokens=item.get( - "max_tokens", - request_defaults.get("max_tokens", defaults.get("max_tokens")), + max_tokens=_first_not_none( + item.get("max_tokens"), + request_defaults.get("max_tokens"), + defaults.get("max_tokens"), ), - enable_thinking=item.get( - "enable_thinking", - request_defaults.get( - "enable_thinking", defaults.get("enable_thinking") - ), + enable_thinking=_first_not_none( + item.get("enable_thinking"), + request_defaults.get("enable_thinking"), + defaults.get("enable_thinking"), ), extra_body=extra_body, - policy_timeout_ms=item.get( - "policy_timeout_ms", defaults.get("policy_timeout_ms") + policy_timeout_ms=_first_not_none( + item.get("policy_timeout_ms"), + request_defaults.get("policy_timeout_ms"), + defaults.get("policy_timeout_ms"), ), checks=checks, tags=tuple(str(tag) for tag in tags), From ab33faa4ab61e3993e05e5464f1260522989314d Mon Sep 17 00:00:00 2001 From: Thump604 Date: Fri, 24 Apr 2026 09:09:50 -0500 Subject: [PATCH 8/8] Harden workload benchmark review edges --- docs/benchmarks/README.md | 8 ++++++++ docs/reference/cli.md | 10 +++++++++- tests/test_bench_serve.py | 8 ++++++++ vllm_mlx/bench_serve.py | 22 ++++++++++++++++++---- vllm_mlx/cli.py | 18 ++++++++++++------ 5 files changed, 55 insertions(+), 11 deletions(-) diff --git a/docs/benchmarks/README.md b/docs/benchmarks/README.md index 3e23c86ce..4a2a00ac1 100644 --- a/docs/benchmarks/README.md +++ b/docs/benchmarks/README.md @@ -37,6 +37,10 @@ qualification, where raw speed is not enough and every run needs provenance, quality checks, Prometheus metric deltas, and policy-timeout evidence. Use `--repetitions` to measure variance; workload summaries report per-case sample counts, failure rates, and min/median/max latency and throughput. +`required_regex` and `forbidden_regex` entries are Python regular expressions; +plain literal strings are valid regex patterns. Workload `cache_policy` accepts +`preserve`, `before-run`, and `before-case`; JSON/YAML-style underscore +spellings such as `before_case` are normalized to the same values. Example workload: @@ -98,6 +102,10 @@ flat per-case rows, `--format sql` to emit importable SQL, or `--format sqlite --output bench.db` to append rows directly into a local benchmark database. +`--request-timeout-s` is the HTTP transport ceiling for each request in +workload mode. Product policy timeouts belong in the workload as +`policy_timeout_ms` and are recorded as comparison evidence. + ```bash vllm-mlx bench-serve --url http://localhost:8000 \ --workload ./workload.json --repetitions 5 --output workload-results.json diff --git a/docs/reference/cli.md b/docs/reference/cli.md index 1b79e2357..0336de2fb 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -162,8 +162,16 @@ vllm-mlx bench-serve --url http://localhost:8000 [options] | `--scrape-metrics` | Scrape `/metrics` before/after runs | `true` | | `--include-content` | Include full generated content in workload JSON | False | | `--request-timeout-s` | Workload HTTP transport timeout, `0` disables | `300` | +| `--cache-policy` | Workload cache handling: `preserve`, `before-run`, `before-case` | Workload default or `preserve` | | `--output` | Output file | stdout | -| `--format` | Output format: `table`, `json`, `csv`, `sql`, `sqlite` | `table` for prompt sweeps, `json` for workloads | +| `--format` | Output format: `auto`, `table`, `json`, `csv`, `sql`, `sqlite` | `auto` = `table` for prompt sweeps, `json` for workloads | + +In workload mode, `--request-timeout-s` is the HTTP transport ceiling for each +request. Product policy timeouts should live in the workload as +`policy_timeout_ms`. Workload `required_regex` and `forbidden_regex` values are +Python regex patterns, so literal strings are valid. Workload JSON may spell +cache policy values with underscores, such as `before_case`; they normalize to +the hyphenated CLI values. ### Examples diff --git a/tests/test_bench_serve.py b/tests/test_bench_serve.py index 427505a1e..04af27438 100644 --- a/tests/test_bench_serve.py +++ b/tests/test_bench_serve.py @@ -16,6 +16,7 @@ SweepConfig, Workload, WorkloadCase, + _validate_sql_identifier, compute_request_metrics, compute_summary_stats, detect_hardware_fingerprint, @@ -1387,6 +1388,13 @@ def test_write_workload_sqlite_creates_case_rows(self, tmp_path): ).fetchone() assert row == ("resume-smoke", 0, 1) + def test_sqlite_identifier_validation_rejects_unsafe_names(self): + with pytest.raises(ValueError, match="invalid SQLite table identifier"): + _validate_sql_identifier("bench; DROP TABLE bench_serve", kind="table") + + with pytest.raises(ValueError, match="invalid SQLite column identifier"): + _validate_sql_identifier("case-id", kind="column") + def test_format_workload_payload_rejects_unknown_format(self): with pytest.raises(ValueError, match="Unsupported workload output format"): format_workload_payload(_make_sample_workload_payload(), "xml") diff --git a/vllm_mlx/bench_serve.py b/vllm_mlx/bench_serve.py index d7041d5ba..a0c48ea81 100644 --- a/vllm_mlx/bench_serve.py +++ b/vllm_mlx/bench_serve.py @@ -44,6 +44,7 @@ _BUILTIN_DIR = Path(__file__).parent / "bench_serve_prompts" _BUILTIN_NAMES = {"short", "medium", "long", "thinking"} +_SQL_IDENTIFIER_RE = re.compile(r"^[a-z_][a-z0-9_]*$") @dataclass @@ -614,6 +615,11 @@ async def clear_runtime_cache(client: httpx.AsyncClient, base_url: str) -> dict: def _normalize_cache_policy(value: Optional[str]) -> str: + """Normalize cache-policy spelling from CLI or workload JSON. + + CLI choices are hyphenated, but workload JSON may use underscores when it + follows common Python/YAML identifier style. + """ policy = (value or "preserve").strip().lower().replace("_", "-") if policy not in {"preserve", "before-run", "before-case"}: raise ValueError( @@ -1210,14 +1216,13 @@ def summarize_workload_results(results: list[dict]) -> dict: latencies = [r["metrics"]["e2e_latency_ms"] for r in results] ttft = [r["metrics"]["ttft_ms"] for r in results] gen_tps = [r["metrics"]["gen_tps"] for r in results] - quality_failures = [r for r in results if not r["quality"]["ok"]] + failures = [r for r in results if not r["quality"]["ok"]] policy_trials = [ r for r in results if r["policy"].get("within_timeout") is not None ] policy_failures = [ r for r in policy_trials if r["policy"].get("within_timeout") is False ] - failures = [r for r in results if not r["quality"]["ok"]] cases: dict[str, list[dict]] = {} for result in results: cases.setdefault(str(result.get("case_id", "")), []).append(result) @@ -1277,8 +1282,8 @@ def summarize_workload_results(results: list[dict]) -> dict: "passed": not failures, "failure_count": len(failures), "failure_rate": round(len(failures) / len(results), 4) if results else 0.0, - "quality_passed": not quality_failures, - "quality_failure_count": len(quality_failures), + "quality_passed": not failures, + "quality_failure_count": len(failures), "policy_timeout_passed": not policy_failures if policy_trials else None, "policy_timeout_failure_count": ( len(policy_failures) if policy_trials else None @@ -1560,6 +1565,9 @@ def _write_sqlite_rows( ) -> None: """Append benchmark rows to a SQLite database.""" db_path = Path(output_path).expanduser() + _validate_sql_identifier(table, kind="table") + for column in columns: + _validate_sql_identifier(column, kind="column") placeholders = ", ".join("?" for _ in columns) column_list = ", ".join(columns) values = [[row.get(col) for col in columns] for row in rows] @@ -1573,6 +1581,12 @@ def _write_sqlite_rows( conn.commit() +def _validate_sql_identifier(identifier: str, *, kind: str) -> None: + """Reject unsafe SQL identifiers before string interpolation.""" + if not _SQL_IDENTIFIER_RE.fullmatch(identifier): + raise ValueError(f"invalid SQLite {kind} identifier: {identifier!r}") + + def write_sqlite(results: list[BenchServeResult], output_path: str) -> None: rows = [_result_to_dict(r) for r in results] _write_sqlite_rows( diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 9580021f3..af1a94e9a 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -703,13 +703,14 @@ def bench_serve_command(args): request_timeout_s = ( None if args.request_timeout_s <= 0 else args.request_timeout_s ) + output_format = "json" if args.format == "auto" else args.format asyncio.run( run_bench_serve_workload( url=args.url, workload_path=args.workload, model=args.model, output_path=args.output, - output_format=args.format or "json", + output_format=output_format, scrape=args.scrape_metrics == "true", include_content=args.include_content, request_timeout_s=request_timeout_s, @@ -751,6 +752,7 @@ def bench_serve_command(args): k, v = kv.split("=", 1) overrides[k] = v + output_format = "table" if args.format == "auto" else args.format asyncio.run( run_bench_serve( url=args.url, @@ -764,7 +766,7 @@ def bench_serve_command(args): thinking_values=thinking_values, extra_bodies=extra_bodies, output_path=args.output, - fmt=args.format or "table", + fmt=output_format, do_validate=args.validate == "true", scrape=args.scrape_metrics == "true", tag=args.tag, @@ -1453,9 +1455,12 @@ def create_parser() -> argparse.ArgumentParser: bench_serve_parser.add_argument( "--format", type=str, - default=None, - choices=["table", "json", "csv", "sql", "sqlite"], - help="Output format (default: table for prompt sweeps, json for workloads; sqlite requires --output)", + default="auto", + choices=["auto", "table", "json", "csv", "sql", "sqlite"], + help=( + "Output format (auto = table for prompt sweeps, json for workloads; " + "sqlite requires --output)" + ), ) bench_serve_parser.add_argument( "--validate", @@ -1492,7 +1497,8 @@ def create_parser() -> argparse.ArgumentParser: choices=["preserve", "before-run", "before-case"], help=( "Workload cache handling (default: workload defaults or preserve). " - "Use before-case for cold, uncontaminated per-case qualification." + "Use before-case for cold, uncontaminated per-case qualification. " + "Workload JSON may also spell these with underscores." ), ) bench_serve_parser.add_argument(