diff --git a/README.md b/README.md index 1865b8e34..bac893121 100644 --- a/README.md +++ b/README.md @@ -188,6 +188,12 @@ python examples/tts_multilingual.py "Hola mundo" --lang es --play ```bash vllm-mlx bench-serve --url http://localhost:8000 --concurrency 5 --prompts prompts.txt --output results.csv + +# Product-style workload with quality checks and metrics deltas +vllm-mlx bench-serve --url http://localhost:8000 --workload workload.json --repetitions 5 --output results.json + +# Append workload rows into SQLite for longitudinal comparisons +vllm-mlx bench-serve --url http://localhost:8000 --workload workload.json --repetitions 5 --format sqlite --output bench.db ``` ### Prometheus metrics diff --git a/docs/benchmarks/README.md b/docs/benchmarks/README.md index ab3220d8b..4a2a00ac1 100644 --- a/docs/benchmarks/README.md +++ b/docs/benchmarks/README.md @@ -19,6 +19,99 @@ vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit # Video benchmark vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video + +# Running-server prompt sweep with Prometheus metric deltas +vllm-mlx bench-serve --url http://localhost:8000 --prompts short,long \ + --concurrency 1,4 --output bench.json --format json + +# Running-server product-style workload with quality checks +vllm-mlx bench-serve --url http://localhost:8000 \ + --workload ./workload.json --output workload-results.json +``` + +## Contract Workloads + +`vllm-mlx bench-serve --workload` runs declarative cases against an already +running OpenAI-compatible server. This is intended for model and feature-stack +qualification, where raw speed is not enough and every run needs provenance, +quality checks, Prometheus metric deltas, and policy-timeout evidence. +Use `--repetitions` to measure variance; workload summaries report per-case +sample counts, failure rates, and min/median/max latency and throughput. +`required_regex` and `forbidden_regex` entries are Python regular expressions; +plain literal strings are valid regex patterns. Workload `cache_policy` accepts +`preserve`, `before-run`, and `before-case`; JSON/YAML-style underscore +spellings such as `before_case` are normalized to the same values. + +Example workload: + +```json +{ + "name": "writing-contract", + "description": "Representative long-form writing requests", + "defaults": { + "max_tokens": 32768, + "enable_thinking": true, + "policy_timeout_ms": 180000, + "checks": { + "finish_reason": "stop", + "forbidden_regex": ["", "prompt leakage"], + "min_chars": 500 + } + }, + "cases": [ + { + "id": "resume-golden-1", + "messages": [ + {"role": "user", "content": "Write the requested artifact..."} + ], + "tags": ["resume", "quality-floor"] + } + ] +} +``` + +Cases can also reference an existing OpenAI-compatible request JSON instead of +duplicating a large prompt body: + +```json +{ + "name": "writing-contract", + "cases": [ + { + "id": "resume-golden-1", + "request_path": "./fixtures/job543_resume_precise_request.json", + "checks": { + "finish_reason": "stop", + "forbidden_regex": [""] + } + } + ] +} +``` + +When `request_path` is used, `messages`, `max_tokens`, `enable_thinking`, and +extra request-body fields such as `thinking_token_budget` are read from that +file. Case-level `extra_body` values override request-file values. + +`policy_timeout_ms` is recorded as comparison evidence. It is not treated as a +hardware capability claim. Use it to answer "would this run fit my product +policy?" after first measuring what the model and serving stack can actually do. + +Workload output defaults to JSON for full provenance. Use `--format csv` for +flat per-case rows, `--format sql` to emit importable SQL, or +`--format sqlite --output bench.db` to append rows directly into a local +benchmark database. + +`--request-timeout-s` is the HTTP transport ceiling for each request in +workload mode. Product policy timeouts belong in the workload as +`policy_timeout_ms` and are recorded as comparison evidence. + +```bash +vllm-mlx bench-serve --url http://localhost:8000 \ + --workload ./workload.json --repetitions 5 --output workload-results.json + +vllm-mlx bench-serve --url http://localhost:8000 \ + --workload ./workload.json --repetitions 5 --format sqlite --output bench.db ``` ## Standalone Test Defaults diff --git a/docs/reference/cli.md b/docs/reference/cli.md index 90a752aef..0336de2fb 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -5,6 +5,7 @@ | Command | Description | |---------|-------------| | `vllm-mlx serve` | Start OpenAI-compatible server | +| `vllm-mlx bench-serve` | Benchmark a running server with prompt sweeps or workload contracts | | `vllm-mlx-bench` | Run performance benchmarks | | `vllm-mlx-chat` | Start Gradio chat interface | @@ -132,6 +133,62 @@ curl http://localhost:8000/v1/models \ -H "Authorization: Bearer your-secret-key" ``` +## `vllm-mlx bench-serve` + +Benchmark a running vllm-mlx server over HTTP. Prompt-sweep mode measures +TTFT, TPOT, throughput, cache deltas, and Metal memory. Workload mode adds +per-case quality checks, repeated samples for variance, and comparison-only +product policy timeouts. Workload cases can embed `messages` directly or point +`request_path` at an existing OpenAI-compatible request JSON. + +### Usage + +```bash +vllm-mlx bench-serve --url http://localhost:8000 [options] +``` + +### Options + +| Option | Description | Default | +|--------|-------------|---------| +| `--url` | Running server base URL | `http://127.0.0.1:8080` | +| `--model` | API model id | Auto-detect | +| `--prompts` | Comma-separated prompt sets or files for sweep mode | `short,medium,long` | +| `--workload` | Declarative workload JSON for contract mode | None | +| `--concurrency` | Comma-separated concurrency levels for sweep mode | `1,4` | +| `--max-tokens` | Max tokens for sweep mode | `256` | +| `--repetitions` | Repetitions per sweep configuration or workload case | `3` | +| `--enable-thinking` | `true`, `false`, or `true,false` sweep | None | +| `--scrape-metrics` | Scrape `/metrics` before/after runs | `true` | +| `--include-content` | Include full generated content in workload JSON | False | +| `--request-timeout-s` | Workload HTTP transport timeout, `0` disables | `300` | +| `--cache-policy` | Workload cache handling: `preserve`, `before-run`, `before-case` | Workload default or `preserve` | +| `--output` | Output file | stdout | +| `--format` | Output format: `auto`, `table`, `json`, `csv`, `sql`, `sqlite` | `auto` = `table` for prompt sweeps, `json` for workloads | + +In workload mode, `--request-timeout-s` is the HTTP transport ceiling for each +request. Product policy timeouts should live in the workload as +`policy_timeout_ms`. Workload `required_regex` and `forbidden_regex` values are +Python regex patterns, so literal strings are valid. Workload JSON may spell +cache policy values with underscores, such as `before_case`; they normalize to +the hyphenated CLI values. + +### Examples + +```bash +# Prompt sweep +vllm-mlx bench-serve --url http://localhost:8000 \ + --prompts short,long --concurrency 1,4 --format json --output bench.json + +# Contract workload with quality checks and policy-timeout evidence +vllm-mlx bench-serve --url http://localhost:8000 \ + --workload workload.json --repetitions 5 --output workload-results.json + +# Append contract rows directly into SQLite for longitudinal comparisons +vllm-mlx bench-serve --url http://localhost:8000 \ + --workload workload.json --repetitions 5 --format sqlite --output bench.db +``` + ## `vllm-mlx-bench` Run performance benchmarks. diff --git a/examples/bench_serve_workload.json b/examples/bench_serve_workload.json new file mode 100644 index 000000000..3f093cc08 --- /dev/null +++ b/examples/bench_serve_workload.json @@ -0,0 +1,44 @@ +{ + "name": "quality-contract-smoke", + "description": "Small contract-style workload demonstrating bench-serve quality checks.", + "defaults": { + "max_tokens": 256, + "enable_thinking": false, + "policy_timeout_ms": 30000, + "checks": { + "finish_reason": "stop", + "forbidden_regex": [ + "", + "I cannot" + ], + "min_chars": 40 + } + }, + "cases": [ + { + "id": "python-palindrome", + "tags": [ + "code", + "quality" + ], + "messages": [ + { + "role": "user", + "content": "Write a Python function with type hints that checks whether a string is a palindrome. Include one short example." + } + ], + "checks": { + "finish_reason": "stop", + "required_regex": [ + "def ", + "-> bool" + ], + "forbidden_regex": [ + "", + "TODO" + ], + "min_chars": 80 + } + } + ] +} diff --git a/tests/test_bench_serve.py b/tests/test_bench_serve.py index e5b835220..04af27438 100644 --- a/tests/test_bench_serve.py +++ b/tests/test_bench_serve.py @@ -5,6 +5,7 @@ import csv import json import os +import sqlite3 from pathlib import Path import pytest @@ -13,6 +14,9 @@ RESULT_COLUMNS, BenchServeResult, SweepConfig, + Workload, + WorkloadCase, + _validate_sql_identifier, compute_request_metrics, compute_summary_stats, detect_hardware_fingerprint, @@ -21,12 +25,22 @@ format_json, format_sql, format_table, + format_workload_csv, + format_workload_payload, + format_workload_sql, load_prompt_set, + load_workload, parse_health_response, parse_metrics_text, parse_sse_line, parse_status_response, + run_bench_serve_workload, + run_workload_case, + summarize_workload_results, + validate_quality_checks, validate_response, + write_sqlite, + write_workload_sqlite, ) # --------------------------------------------------------------------------- @@ -166,6 +180,120 @@ def test_thinking_prompts_contain_reasoning_keywords(self): assert any(kw in combined for kw in keywords) +class TestWorkloadLoading: + """Tests for declarative contract workload loading.""" + + def test_load_workload_with_defaults(self, tmp_path: Path): + workload_file = tmp_path / "workload.json" + workload_file.write_text( + json.dumps( + { + "name": "writing-contract", + "defaults": { + "max_tokens": 128, + "enable_thinking": True, + "policy_timeout_ms": 180000, + "checks": {"forbidden_regex": [""]}, + }, + "cases": [ + { + "id": "case-a", + "messages": [ + {"role": "user", "content": "Write a short note."} + ], + "tags": ["quality"], + } + ], + } + ) + ) + + workload = load_workload(workload_file) + + assert workload.name == "writing-contract" + assert len(workload.cases) == 1 + case = workload.cases[0] + assert case.case_id == "case-a" + assert case.max_tokens == 128 + assert case.enable_thinking is True + assert case.policy_timeout_ms == 180000 + assert case.checks == {"forbidden_regex": [""]} + assert case.tags == ("quality",) + + def test_load_workload_null_policy_timeout_falls_back_to_default( + self, tmp_path: Path + ): + workload_file = tmp_path / "workload.json" + workload_file.write_text( + json.dumps( + { + "defaults": {"policy_timeout_ms": 180000}, + "cases": [ + { + "id": "case-a", + "messages": [{"role": "user", "content": "A"}], + "policy_timeout_ms": None, + } + ], + } + ) + ) + + workload = load_workload(workload_file) + + assert workload.cases[0].policy_timeout_ms == 180000 + + def test_load_workload_rejects_missing_messages(self, tmp_path: Path): + workload_file = tmp_path / "workload.json" + workload_file.write_text(json.dumps({"cases": [{"id": "bad"}]})) + + with pytest.raises(ValueError, match="messages"): + load_workload(workload_file) + + def test_load_workload_case_from_request_path(self, tmp_path: Path): + request_file = tmp_path / "request.json" + request_file.write_text( + json.dumps( + { + "model": "ignored-by-workload-runner", + "messages": [{"role": "user", "content": "Write the artifact."}], + "stream": True, + "max_tokens": 32768, + "enable_thinking": True, + "thinking_token_budget": 8192, + "temperature": 0.6, + } + ) + ) + workload_file = tmp_path / "workload.json" + workload_file.write_text( + json.dumps( + { + "cases": [ + { + "id": "resume", + "request_path": "request.json", + "extra_body": {"top_p": 0.95}, + } + ] + } + ) + ) + + workload = load_workload(workload_file) + case = workload.cases[0] + + assert case.messages == [{"role": "user", "content": "Write the artifact."}] + assert case.request_path == "request.json" + assert case.max_tokens == 32768 + assert case.enable_thinking is True + assert case.extra_body == { + "thinking_token_budget": 8192, + "temperature": 0.6, + "top_p": 0.95, + } + + # --------------------------------------------------------------------------- # TestExpandSweep # --------------------------------------------------------------------------- @@ -411,9 +539,9 @@ def test_parse_status_response(self): data = { "model": "mlx-community/Llama-3.2-1B-Instruct-4bit", "metal": { - "active_gb": 12.5, - "peak_gb": 14.0, - "cache_gb": 2.0, + "active_memory_gb": 12.5, + "peak_memory_gb": 14.0, + "cache_memory_gb": 2.0, }, "cache": {"type": "paged"}, } @@ -424,6 +552,20 @@ def test_parse_status_response(self): assert result["metal_cache_gb"] == pytest.approx(2.0) assert result["cache_type"] == "paged" + def test_parse_status_response_accepts_legacy_metal_keys(self): + data = { + "model": "legacy-server", + "metal": { + "active_gb": 12.5, + "peak_gb": 14.0, + "cache_gb": 2.0, + }, + } + result = parse_status_response(data) + assert result["metal_active_gb"] == pytest.approx(12.5) + assert result["metal_peak_gb"] == pytest.approx(14.0) + assert result["metal_cache_gb"] == pytest.approx(2.0) + def test_parse_status_no_metal(self): data = {"model": "some-model"} result = parse_status_response(data) @@ -627,6 +769,389 @@ def test_http_error(self): assert "500" in msg +class TestQualityChecks: + """Unit tests for workload quality checks.""" + + def test_required_and_forbidden_regex(self): + ok, issues = validate_quality_checks( + "stop", + "Dear team,\nThis is a clean response.\nSincerely", + { + "required_regex": ["Dear team", "Sincerely"], + "forbidden_regex": ["", "unsupported claim"], + "min_chars": 20, + "finish_reason": "stop", + }, + ) + + assert ok is True + assert issues == [] + + def test_forbidden_regex_failure(self): + ok, issues = validate_quality_checks( + "stop", + "Visible answer\nhidden plan", + {"forbidden_regex": [""]}, + ) + + assert ok is False + assert any("forbidden_regex matched" in issue for issue in issues) + + def test_json_check(self): + ok, issues = validate_quality_checks( + "stop", + '{"title": "Engineer", "priority": 1}', + {"json": True}, + ) + + assert ok is True + assert issues == [] + + def test_json_check_failure(self): + ok, issues = validate_quality_checks("stop", "not json", {"json": True}) + + assert ok is False + assert any("not valid JSON" in issue for issue in issues) + + +class TestWorkloadSummary: + """Unit tests for workload summary aggregation.""" + + def test_summarize_workload_results_tracks_quality_and_policy(self): + results = [ + { + "ok": True, + "case_id": "case-a", + "repetition": 0, + "policy": {"within_timeout": True}, + "quality": {"ok": True, "content_chars": 120}, + "metrics": { + "e2e_latency_ms": 100.0, + "ttft_ms": 10.0, + "gen_tps": 20.0, + }, + }, + { + "ok": True, + "case_id": "case-a", + "repetition": 1, + "policy": {"within_timeout": False}, + "quality": {"ok": True, "content_chars": 160}, + "metrics": { + "e2e_latency_ms": 200.0, + "ttft_ms": 20.0, + "gen_tps": 10.0, + }, + }, + ] + + summary = summarize_workload_results(results) + + assert summary["passed"] is True + assert summary["quality_passed"] is True + assert summary["policy_timeout_passed"] is False + assert summary["failure_rate"] == pytest.approx(0.0) + assert summary["latency_ms"]["p50"] == pytest.approx(150.0) + assert summary["unique_case_count"] == 1 + assert summary["repetition_count"] == 2 + assert summary["case_summaries"]["case-a"]["repetitions"] == [0, 1] + assert summary["case_summaries"]["case-a"]["latency_ms"]["min"] == 100.0 + assert summary["case_summaries"]["case-a"]["latency_ms"]["max"] == 200.0 + + +class TestWorkloadRunner: + """Unit tests for contract workload execution records.""" + + def test_run_workload_case_records_metrics_policy_and_quality(self, monkeypatch): + metrics_responses = iter( + [ + {"cache_hits": 10, "cache_misses": 3, "tokens_saved": 100}, + {"cache_hits": 12, "cache_misses": 4, "tokens_saved": 140}, + ] + ) + + async def fake_scrape_metrics(client, base_url): + return next(metrics_responses) + + async def fake_stream_chat_completion(**kwargs): + assert kwargs["max_tokens"] == 64 + assert kwargs["enable_thinking"] is True + assert kwargs["extra_body"] == {"temperature": 0.6} + return { + "ttft_ms": 25.0, + "tpot_ms": 3.0, + "e2e_latency_ms": 2500.0, + "gen_tps": 12.5, + "prompt_tps": 500.0, + "prompt_tokens": 200, + "completion_tokens": 100, + "finish_reason": "stop", + "content": "Dear team,\nA clean benchmark artifact.\nSincerely", + } + + class FakeResponse: + def raise_for_status(self): + return None + + def json(self): + return { + "cache": {"type": "paged"}, + "metal": { + "active_memory_gb": 42.0, + "peak_memory_gb": 45.0, + "cache_memory_gb": 4.0, + }, + } + + class FakeClient: + async def get(self, url): + return FakeResponse() + + monkeypatch.setattr("vllm_mlx.bench_serve.scrape_metrics", fake_scrape_metrics) + monkeypatch.setattr( + "vllm_mlx.bench_serve.stream_chat_completion", + fake_stream_chat_completion, + ) + + workload = Workload( + name="writing-contract", + description="", + defaults={"max_tokens": 64}, + cases=[], + ) + case = WorkloadCase( + case_id="resume-smoke", + messages=[{"role": "user", "content": "Write the artifact."}], + request_path="/tmp/request.json", + max_tokens=64, + enable_thinking=True, + extra_body={"temperature": 0.6}, + policy_timeout_ms=1800, + checks={ + "finish_reason": "stop", + "required_regex": ["Dear team", "Sincerely"], + "forbidden_regex": [""], + }, + tags=("resume",), + ) + + record = asyncio.run( + run_workload_case( + FakeClient(), + "http://server", + workload=workload, + case=case, + model="test-model", + runtime={"engine_type": "mllm"}, + hardware={"chip": "test"}, + run_id="run123", + timestamp="2026-04-23T00:00:00+00:00", + repetition=2, + scrape=True, + include_content=True, + ) + ) + + assert record["ok"] is True + assert record["quality"]["ok"] is True + assert record["quality"]["issues"] == [] + assert record["quality"]["content"].startswith("Dear team") + assert record["repetition"] == 2 + assert record["request"]["request_path"] == "/tmp/request.json" + assert record["policy"]["within_timeout"] is False + assert record["cache_reset"] == {"attempted": False} + assert record["metrics"]["cache_hits"] == 2 + assert record["metrics"]["cache_misses"] == 1 + assert record["metrics"]["tokens_saved"] == 40 + assert record["metrics"]["metal"]["metal_active_gb"] == pytest.approx(42.0) + assert record["metrics"]["metal"]["cache_type"] == "paged" + + def test_run_bench_serve_workload_repeats_each_case(self, tmp_path, monkeypatch): + workload_file = tmp_path / "workload.json" + workload_file.write_text( + json.dumps( + { + "name": "repeat-contract", + "cases": [ + { + "id": "case-a", + "messages": [{"role": "user", "content": "A"}], + }, + { + "id": "case-b", + "messages": [{"role": "user", "content": "B"}], + }, + ], + } + ) + ) + observed = [] + + async def fake_auto_detect_runtime(client, url): + return {"model_id": "test-model"} + + def fake_detect_hardware_fingerprint(): + return {"chip": "test"} + + async def fake_run_workload_case(*args, **kwargs): + observed.append((kwargs["case"].case_id, kwargs["repetition"])) + return { + "run_id": kwargs["run_id"], + "timestamp": kwargs["timestamp"], + "workload": kwargs["workload"].name, + "case_id": kwargs["case"].case_id, + "repetition": kwargs["repetition"], + "tags": [], + "model_id": kwargs["model"], + "runtime": kwargs["runtime"], + "hardware": kwargs["hardware"], + "request": {}, + "policy": {"within_timeout": None}, + "metrics": { + "e2e_latency_ms": 100.0 + kwargs["repetition"], + "ttft_ms": 10.0, + "gen_tps": 20.0, + }, + "quality": {"ok": True, "content_chars": 20}, + "ok": True, + } + + monkeypatch.setattr( + "vllm_mlx.bench_serve.auto_detect_runtime", fake_auto_detect_runtime + ) + monkeypatch.setattr( + "vllm_mlx.bench_serve.detect_hardware_fingerprint", + fake_detect_hardware_fingerprint, + ) + monkeypatch.setattr( + "vllm_mlx.bench_serve.run_workload_case", fake_run_workload_case + ) + + payload = asyncio.run( + run_bench_serve_workload( + url="http://server", + workload_path=str(workload_file), + output_path=str(tmp_path / "results.json"), + output_format="json", + scrape=False, + request_timeout_s=None, + repetitions=3, + ) + ) + + assert observed == [ + ("case-a", 0), + ("case-b", 0), + ("case-a", 1), + ("case-b", 1), + ("case-a", 2), + ("case-b", 2), + ] + assert len(payload["results"]) == 6 + assert payload["workload"]["repetitions"] == 3 + assert payload["summary"]["case_summaries"]["case-a"]["sample_count"] == 3 + + def test_run_bench_serve_workload_clears_cache_before_case( + self, tmp_path, monkeypatch + ): + workload_file = tmp_path / "workload.json" + workload_file.write_text( + json.dumps( + { + "name": "cache-contract", + "cases": [ + { + "id": "case-a", + "messages": [{"role": "user", "content": "A"}], + }, + { + "id": "case-b", + "messages": [{"role": "user", "content": "B"}], + }, + ], + } + ) + ) + clear_events = [] + observed_resets = [] + + async def fake_auto_detect_runtime(client, url): + return {"model_id": "test-model"} + + def fake_detect_hardware_fingerprint(): + return {"chip": "test"} + + async def fake_clear_runtime_cache(client, url): + event = { + "attempted": True, + "ok": True, + "status_code": 200, + "response": {"sequence": len(clear_events)}, + "error": "", + } + clear_events.append(event) + return event + + async def fake_run_workload_case(*args, **kwargs): + observed_resets.append(kwargs["cache_reset"]) + return { + "run_id": kwargs["run_id"], + "timestamp": kwargs["timestamp"], + "workload": kwargs["workload"].name, + "case_id": kwargs["case"].case_id, + "repetition": kwargs["repetition"], + "tags": [], + "model_id": kwargs["model"], + "runtime": kwargs["runtime"], + "hardware": kwargs["hardware"], + "request": {}, + "policy": {"within_timeout": None}, + "metrics": { + "e2e_latency_ms": 100.0, + "ttft_ms": 10.0, + "gen_tps": 20.0, + }, + "quality": {"ok": True, "content_chars": 20}, + "ok": True, + } + + monkeypatch.setattr( + "vllm_mlx.bench_serve.auto_detect_runtime", fake_auto_detect_runtime + ) + monkeypatch.setattr( + "vllm_mlx.bench_serve.detect_hardware_fingerprint", + fake_detect_hardware_fingerprint, + ) + monkeypatch.setattr( + "vllm_mlx.bench_serve.clear_runtime_cache", fake_clear_runtime_cache + ) + monkeypatch.setattr( + "vllm_mlx.bench_serve.run_workload_case", fake_run_workload_case + ) + + payload = asyncio.run( + run_bench_serve_workload( + url="http://server", + workload_path=str(workload_file), + output_format="json", + scrape=False, + request_timeout_s=None, + repetitions=2, + cache_policy="before-case", + ) + ) + + assert len(clear_events) == 4 + assert observed_resets == clear_events + assert payload["cache_policy"]["mode"] == "before-case" + assert [event["scope"] for event in payload["cache_policy"]["events"]] == [ + "before-case", + "before-case", + "before-case", + "before-case", + ] + + # --------------------------------------------------------------------------- # TestSummaryStats (Task 5) # --------------------------------------------------------------------------- @@ -707,6 +1232,69 @@ def _make_sample_result(**overrides) -> BenchServeResult: return BenchServeResult(**defaults) +def _make_sample_workload_payload() -> dict: + return { + "run_id": "run123", + "timestamp": "2026-04-23T00:00:00+00:00", + "summary": {"passed": True}, + "results": [ + { + "run_id": "run123", + "timestamp": "2026-04-23T00:00:00+00:00", + "workload": "writing-contract", + "case_id": "resume-smoke", + "repetition": 0, + "tags": ["resume", "quality"], + "model_id": "test-model", + "runtime": { + "engine_type": "mllm", + "model_type": "mllm", + "mtp_enabled": False, + "specprefill": False, + "kv_quant": "", + "cache_type": "paged", + }, + "hardware": { + "chip": "M4 Ultra", + "memory_gb": 256.0, + "os_version": "macOS-test", + }, + "request": { + "max_tokens": 32768, + "enable_thinking": True, + "extra_body": {"temperature": 0.6}, + }, + "policy": {"timeout_ms": 180000, "within_timeout": True}, + "metrics": { + "ttft_ms": 25.0, + "tpot_ms": 3.0, + "e2e_latency_ms": 2500.0, + "gen_tps": 12.5, + "prompt_tps": 500.0, + "prompt_tokens": 200, + "completion_tokens": 100, + "cache_hits": 2, + "cache_misses": 1, + "tokens_saved": 40, + "metal": { + "metal_active_gb": 42.0, + "metal_peak_gb": 45.0, + "metal_cache_gb": 4.0, + }, + }, + "quality": { + "ok": True, + "issues": [], + "finish_reason": "stop", + "content_chars": 512, + "content_preview": "Clean artifact", + }, + "ok": True, + } + ], + } + + # --------------------------------------------------------------------------- # TestFormatters (Task 6) # --------------------------------------------------------------------------- @@ -760,12 +1348,57 @@ def test_format_sql_handles_nan_inf(self): assert "inf" not in output.lower().split("'")[-1] assert "NULL" in output + def test_write_sqlite_creates_prompt_sweep_rows(self, tmp_path): + db_path = tmp_path / "bench.db" + write_sqlite([_make_sample_result(run_id="sqlite-run")], str(db_path)) + with sqlite3.connect(db_path) as conn: + row = conn.execute( + "SELECT run_id, model_id, repetition FROM bench_serve" + ).fetchone() + assert row == ("sqlite-run", "mlx-community/gemma-3-4b-it-4bit", 0) + def test_result_columns_match_dataclass(self): import dataclasses field_names = {f.name for f in dataclasses.fields(BenchServeResult)} assert set(RESULT_COLUMNS) == field_names + def test_format_workload_csv_parseable(self): + output = format_workload_csv(_make_sample_workload_payload()) + rows = list(csv.DictReader(output.splitlines())) + assert len(rows) == 1 + assert rows[0]["case_id"] == "resume-smoke" + assert rows[0]["repetition"] == "0" + assert rows[0]["model_id"] == "test-model" + assert rows[0]["request_extra_body"] == '{"temperature": 0.6}' + + def test_format_workload_sql_valid(self): + output = format_workload_sql(_make_sample_workload_payload()) + assert "CREATE TABLE IF NOT EXISTS bench_serve_workload" in output + assert "case_id TEXT, repetition INTEGER, tags TEXT" in output + assert "INSERT INTO bench_serve_workload" in output + assert "resume-smoke" in output + + def test_write_workload_sqlite_creates_case_rows(self, tmp_path): + db_path = tmp_path / "workload.db" + write_workload_sqlite(_make_sample_workload_payload(), str(db_path)) + with sqlite3.connect(db_path) as conn: + row = conn.execute( + "SELECT case_id, repetition, quality_ok FROM bench_serve_workload" + ).fetchone() + assert row == ("resume-smoke", 0, 1) + + def test_sqlite_identifier_validation_rejects_unsafe_names(self): + with pytest.raises(ValueError, match="invalid SQLite table identifier"): + _validate_sql_identifier("bench; DROP TABLE bench_serve", kind="table") + + with pytest.raises(ValueError, match="invalid SQLite column identifier"): + _validate_sql_identifier("case-id", kind="column") + + def test_format_workload_payload_rejects_unknown_format(self): + with pytest.raises(ValueError, match="Unsupported workload output format"): + format_workload_payload(_make_sample_workload_payload(), "xml") + # --------------------------------------------------------------------------- # TestBenchServeIntegration (Task 8) @@ -807,7 +1440,7 @@ def test_smoke_run(self): def test_sql_output_is_valid(self): """Verify SQL output contains CREATE TABLE and INSERT.""" - from vllm_mlx.bench_serve import run_bench_serve, format_sql + from vllm_mlx.bench_serve import format_sql, run_bench_serve results = asyncio.run( run_bench_serve( diff --git a/vllm_mlx/bench_serve.py b/vllm_mlx/bench_serve.py index bec5d21a5..a0c48ea81 100644 --- a/vllm_mlx/bench_serve.py +++ b/vllm_mlx/bench_serve.py @@ -25,6 +25,7 @@ import math import platform import re +import sqlite3 import statistics import sys import time @@ -32,7 +33,7 @@ from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path -from typing import Optional +from typing import Any, Optional import httpx from tabulate import tabulate as _tabulate @@ -43,6 +44,32 @@ _BUILTIN_DIR = Path(__file__).parent / "bench_serve_prompts" _BUILTIN_NAMES = {"short", "medium", "long", "thinking"} +_SQL_IDENTIFIER_RE = re.compile(r"^[a-z_][a-z0-9_]*$") + + +@dataclass +class WorkloadCase: + """One declarative benchmark case for contract-style serving tests.""" + + case_id: str + messages: list[dict] + request_path: Optional[str] = None + max_tokens: Optional[int] = None + enable_thinking: Optional[bool] = None + extra_body: Optional[dict] = None + policy_timeout_ms: Optional[int] = None + checks: Optional[dict] = None + tags: tuple[str, ...] = () + + +@dataclass +class Workload: + """Normalized bench-serve workload manifest.""" + + name: str + description: str + defaults: dict + cases: list[WorkloadCase] def load_prompt_set(name_or_path: str) -> list[list[dict]]: @@ -110,6 +137,134 @@ def load_prompt_set(name_or_path: str) -> list[list[dict]]: ) +def _require_message_list(value: Any, *, label: str) -> list[dict]: + if not isinstance(value, list) or not value: + raise ValueError(f"{label}: messages must be a non-empty list") + for idx, message in enumerate(value): + if not isinstance(message, dict): + raise ValueError(f"{label}: message {idx} must be an object") + if "role" not in message or "content" not in message: + raise ValueError(f"{label}: message {idx} must include role and content") + return value + + +def _load_case_request(path: str, *, workload_path: Path, case_id: str) -> dict: + request_path = Path(path).expanduser() + if not request_path.is_absolute(): + request_path = workload_path.parent / request_path + with request_path.open() as fh: + request = json.load(fh) + if not isinstance(request, dict): + raise ValueError(f"{case_id}: request_path must point to a JSON object") + return request + + +def _request_extra_body(request: dict) -> dict: + reserved = { + "model", + "messages", + "max_tokens", + "stream", + "stream_options", + "enable_thinking", + } + return {key: value for key, value in request.items() if key not in reserved} + + +def _first_not_none(*values: Any) -> Any: + for value in values: + if value is not None: + return value + return None + + +def load_workload(path: str | Path) -> Workload: + """Load a declarative serving benchmark workload. + + Workloads are for product-like qualification where each case can carry + request settings, comparison-only policy timeouts, and quality checks. + Timeout fields are metadata unless the runner explicitly uses them as a + transport limit; they are not treated as hardware capability claims. + """ + workload_path = Path(path).expanduser() + with workload_path.open() as fh: + raw = json.load(fh) + + if not isinstance(raw, dict): + raise ValueError("workload root must be a JSON object") + raw_cases = raw.get("cases") + if not isinstance(raw_cases, list) or not raw_cases: + raise ValueError("workload must contain a non-empty cases list") + + defaults = raw.get("defaults") or {} + if not isinstance(defaults, dict): + raise ValueError("workload defaults must be an object") + + cases: list[WorkloadCase] = [] + for idx, item in enumerate(raw_cases): + if not isinstance(item, dict): + raise ValueError(f"case {idx}: case must be an object") + case_id = str(item.get("id") or f"case_{idx + 1}") + request_path = item.get("request_path") + request_defaults: dict = {} + if request_path is not None: + request_defaults = _load_case_request( + str(request_path), workload_path=workload_path, case_id=case_id + ) + messages = _require_message_list( + item.get("messages", request_defaults.get("messages")), + label=case_id, + ) + extra_body = item.get("extra_body", defaults.get("extra_body")) + request_extra = _request_extra_body(request_defaults) + if extra_body: + request_extra.update(extra_body) + extra_body = request_extra or None + if extra_body is not None and not isinstance(extra_body, dict): + raise ValueError(f"{case_id}: extra_body must be an object") + checks = item.get("checks", defaults.get("checks")) + if checks is not None and not isinstance(checks, dict): + raise ValueError(f"{case_id}: checks must be an object") + tags = item.get("tags", []) + if isinstance(tags, str): + tags = [tags] + if not isinstance(tags, list): + raise ValueError(f"{case_id}: tags must be a list or string") + + cases.append( + WorkloadCase( + case_id=case_id, + messages=messages, + request_path=str(request_path) if request_path is not None else None, + max_tokens=_first_not_none( + item.get("max_tokens"), + request_defaults.get("max_tokens"), + defaults.get("max_tokens"), + ), + enable_thinking=_first_not_none( + item.get("enable_thinking"), + request_defaults.get("enable_thinking"), + defaults.get("enable_thinking"), + ), + extra_body=extra_body, + policy_timeout_ms=_first_not_none( + item.get("policy_timeout_ms"), + request_defaults.get("policy_timeout_ms"), + defaults.get("policy_timeout_ms"), + ), + checks=checks, + tags=tuple(str(tag) for tag in tags), + ) + ) + + return Workload( + name=str(raw.get("name") or workload_path.stem), + description=str(raw.get("description") or ""), + defaults=defaults, + cases=cases, + ) + + # --------------------------------------------------------------------------- # Result dataclass # --------------------------------------------------------------------------- @@ -258,9 +413,15 @@ def parse_status_response(data: dict) -> dict: cache = data.get("cache") or {} return { "model": data.get("model", ""), - "metal_active_gb": float(metal.get("active_gb") or 0.0), - "metal_peak_gb": float(metal.get("peak_gb") or 0.0), - "metal_cache_gb": float(metal.get("cache_gb") or 0.0), + "metal_active_gb": float( + metal.get("active_memory_gb") or metal.get("active_gb") or 0.0 + ), + "metal_peak_gb": float( + metal.get("peak_memory_gb") or metal.get("peak_gb") or 0.0 + ), + "metal_cache_gb": float( + metal.get("cache_memory_gb") or metal.get("cache_gb") or 0.0 + ), "cache_type": cache.get("type", "") or "", } @@ -430,6 +591,43 @@ async def scrape_metrics(client: httpx.AsyncClient, base_url: str) -> dict: return {} +async def clear_runtime_cache(client: httpx.AsyncClient, base_url: str) -> dict: + """Clear server-side runtime caches and return a JSON-serializable event.""" + event: dict[str, Any] = { + "attempted": True, + "ok": False, + "status_code": 0, + "response": {}, + "error": "", + } + try: + resp = await client.delete(f"{base_url}/v1/cache") + event["status_code"] = resp.status_code + try: + event["response"] = resp.json() + except ValueError: + event["response"] = {"text": resp.text} + resp.raise_for_status() + event["ok"] = True + except Exception as exc: + event["error"] = str(exc) + return event + + +def _normalize_cache_policy(value: Optional[str]) -> str: + """Normalize cache-policy spelling from CLI or workload JSON. + + CLI choices are hyphenated, but workload JSON may use underscores when it + follows common Python/YAML identifier style. + """ + policy = (value or "preserve").strip().lower().replace("_", "-") + if policy not in {"preserve", "before-run", "before-case"}: + raise ValueError( + "cache policy must be one of: preserve, before-run, before-case" + ) + return policy + + # --------------------------------------------------------------------------- # Task 4: SSE streaming core + token counting + request timing # --------------------------------------------------------------------------- @@ -701,6 +899,69 @@ def validate_response( return (True, "") +def validate_quality_checks( + finish_reason: Optional[str], + content: str, + checks: Optional[dict], + *, + status_code: int = 200, +) -> tuple[bool, list[str]]: + """Validate content against generic workload quality checks. + + Supported checks: + - ``finish_reason``: string or list of allowed finish reasons + - ``required_regex``: list of regex patterns that must match + - ``forbidden_regex``: list of regex patterns that must not match + - ``min_chars`` / ``max_chars``: length bounds + - ``json``: when true, content must parse as JSON + """ + basic_ok, basic_issue = validate_response(finish_reason, content, status_code) + issues: list[str] = [] if basic_ok else [basic_issue] + checks = checks or {} + + allowed_finish = checks.get("finish_reason") + if allowed_finish is not None: + allowed = ( + [allowed_finish] + if isinstance(allowed_finish, str) + else list(allowed_finish) + ) + if finish_reason not in allowed: + issues.append( + f"finish_reason {finish_reason!r} not in allowed set {allowed!r}" + ) + + min_chars = checks.get("min_chars") + if min_chars is not None and len(content) < int(min_chars): + issues.append(f"content shorter than min_chars={min_chars}") + + max_chars = checks.get("max_chars") + if max_chars is not None and len(content) > int(max_chars): + issues.append(f"content longer than max_chars={max_chars}") + + for pattern in checks.get("required_regex", []) or []: + try: + if not re.search(str(pattern), content, re.MULTILINE): + issues.append(f"required_regex did not match: {pattern}") + except re.error as exc: + issues.append(f"invalid required_regex {pattern!r}: {exc}") + + for pattern in checks.get("forbidden_regex", []) or []: + try: + if re.search(str(pattern), content, re.MULTILINE): + issues.append(f"forbidden_regex matched: {pattern}") + except re.error as exc: + issues.append(f"invalid forbidden_regex {pattern!r}: {exc}") + + if checks.get("json"): + try: + json.loads(content) + except json.JSONDecodeError as exc: + issues.append(f"content is not valid JSON: {exc}") + + return (not issues, issues) + + def compute_summary_stats(values: list[float]) -> dict: """Compute summary statistics over a list of floats. @@ -812,6 +1073,342 @@ async def _single(messages: list[dict]) -> dict: return list(results) +def _summary_or_empty(values: list[float]) -> dict: + return compute_summary_stats(values) if values else {} + + +async def run_workload_case( + client: httpx.AsyncClient, + base_url: str, + *, + workload: Workload, + case: WorkloadCase, + model: str, + runtime: dict, + hardware: dict, + run_id: str, + timestamp: str, + repetition: int = 0, + scrape: bool = True, + include_content: bool = False, + cache_reset: Optional[dict] = None, +) -> dict: + """Run one workload case and return a JSON-serializable result.""" + metrics_before = await scrape_metrics(client, base_url) if scrape else {} + started_wall = datetime.now(timezone.utc).isoformat() + + try: + result = await stream_chat_completion( + client=client, + base_url=base_url, + messages=case.messages, + model=model, + max_tokens=int(case.max_tokens or workload.defaults.get("max_tokens", 256)), + enable_thinking=case.enable_thinking, + extra_body=case.extra_body, + ) + error = "" + except Exception as exc: + result = { + "ttft_ms": 0.0, + "tpot_ms": 0.0, + "e2e_latency_ms": 0.0, + "gen_tps": 0.0, + "prompt_tps": 0.0, + "prompt_tokens": 0, + "completion_tokens": 0, + "finish_reason": None, + "content": "", + } + error = str(exc) + + metrics_after = await scrape_metrics(client, base_url) if scrape else {} + status_after: dict = {} + try: + resp = await client.get(f"{base_url}/v1/status") + resp.raise_for_status() + status_after = resp.json() + except Exception: + status_after = {} + + cache_hits_delta = metrics_after.get("cache_hits", 0) - metrics_before.get( + "cache_hits", 0 + ) + cache_misses_delta = metrics_after.get("cache_misses", 0) - metrics_before.get( + "cache_misses", 0 + ) + tokens_saved_delta = metrics_after.get("tokens_saved", 0) - metrics_before.get( + "tokens_saved", 0 + ) + + content = str(result.get("content") or "") + quality_ok, quality_issues = validate_quality_checks( + result.get("finish_reason"), + content, + case.checks, + status_code=500 if error else 200, + ) + if error: + quality_issues.append(f"request error: {error}") + + if case.policy_timeout_ms is None: + within_policy_timeout = None + elif error: + within_policy_timeout = False + else: + within_policy_timeout = result["e2e_latency_ms"] <= case.policy_timeout_ms + + record = { + "run_id": run_id, + "timestamp": timestamp, + "started_at": started_wall, + "workload": workload.name, + "case_id": case.case_id, + "repetition": repetition, + "tags": list(case.tags), + "model_id": model, + "runtime": runtime, + "hardware": hardware, + "request": { + "max_tokens": int( + case.max_tokens or workload.defaults.get("max_tokens", 256) + ), + "request_path": case.request_path, + "enable_thinking": case.enable_thinking, + "extra_body": case.extra_body or {}, + "message_count": len(case.messages), + }, + "policy": { + "timeout_ms": case.policy_timeout_ms, + "within_timeout": within_policy_timeout, + "note": "comparison-only unless your product contract explicitly requires it", + }, + "cache_reset": cache_reset or {"attempted": False}, + "metrics": { + "ttft_ms": result["ttft_ms"], + "tpot_ms": result["tpot_ms"], + "e2e_latency_ms": result["e2e_latency_ms"], + "gen_tps": result["gen_tps"], + "prompt_tps": result["prompt_tps"], + "prompt_tokens": result["prompt_tokens"], + "completion_tokens": result["completion_tokens"], + "cache_hits": cache_hits_delta, + "cache_misses": cache_misses_delta, + "tokens_saved": tokens_saved_delta, + "metal": parse_status_response(status_after), + }, + "quality": { + "ok": quality_ok, + "issues": quality_issues, + "finish_reason": result.get("finish_reason"), + "content_chars": len(content), + "content_preview": content[:240], + }, + "ok": quality_ok, + } + if include_content: + record["quality"]["content"] = content + return record + + +def summarize_workload_results(results: list[dict]) -> dict: + """Aggregate workload case records into stable qualification summary stats.""" + latencies = [r["metrics"]["e2e_latency_ms"] for r in results] + ttft = [r["metrics"]["ttft_ms"] for r in results] + gen_tps = [r["metrics"]["gen_tps"] for r in results] + failures = [r for r in results if not r["quality"]["ok"]] + policy_trials = [ + r for r in results if r["policy"].get("within_timeout") is not None + ] + policy_failures = [ + r for r in policy_trials if r["policy"].get("within_timeout") is False + ] + cases: dict[str, list[dict]] = {} + for result in results: + cases.setdefault(str(result.get("case_id", "")), []).append(result) + + case_summaries = {} + for case_id, case_results in sorted(cases.items()): + case_quality_failures = [r for r in case_results if not r["quality"].get("ok")] + case_policy_trials = [ + r for r in case_results if r["policy"].get("within_timeout") is not None + ] + case_policy_failures = [ + r for r in case_policy_trials if r["policy"].get("within_timeout") is False + ] + case_summaries[case_id] = { + "sample_count": len(case_results), + "repetitions": sorted( + { + int(r.get("repetition", 0)) + for r in case_results + if r.get("repetition") is not None + } + ), + "passed": not case_quality_failures, + "failure_count": len(case_quality_failures), + "failure_rate": ( + round(len(case_quality_failures) / len(case_results), 4) + if case_results + else 0.0 + ), + "policy_timeout_passed": ( + not case_policy_failures if case_policy_trials else None + ), + "policy_timeout_failure_count": ( + len(case_policy_failures) if case_policy_trials else None + ), + "latency_ms": _summary_or_empty( + [r["metrics"]["e2e_latency_ms"] for r in case_results] + ), + "ttft_ms": _summary_or_empty( + [r["metrics"]["ttft_ms"] for r in case_results] + ), + "gen_tps": _summary_or_empty( + [r["metrics"]["gen_tps"] for r in case_results] + ), + "content_chars": _summary_or_empty( + [r["quality"].get("content_chars", 0) for r in case_results] + ), + } + + return { + "case_count": len(results), + "unique_case_count": len(cases), + "repetition_count": max( + (len(summary["repetitions"]) for summary in case_summaries.values()), + default=0, + ), + "passed": not failures, + "failure_count": len(failures), + "failure_rate": round(len(failures) / len(results), 4) if results else 0.0, + "quality_passed": not failures, + "quality_failure_count": len(failures), + "policy_timeout_passed": not policy_failures if policy_trials else None, + "policy_timeout_failure_count": ( + len(policy_failures) if policy_trials else None + ), + "latency_ms": _summary_or_empty(latencies), + "ttft_ms": _summary_or_empty(ttft), + "gen_tps": _summary_or_empty(gen_tps), + "case_summaries": case_summaries, + } + + +async def run_bench_serve_workload( + *, + url: str, + workload_path: str, + model: Optional[str] = None, + output_path: Optional[str] = None, + output_format: str = "json", + scrape: bool = True, + include_content: bool = False, + request_timeout_s: Optional[float] = 300.0, + repetitions: int = 1, + cache_policy: Optional[str] = None, +) -> dict: + """Run a declarative workload against a running server. + + This is the contract-style counterpart to prompt sweeps: it keeps product + policy knobs in the manifest, records them as evidence, and measures what + the server actually does before anyone promotes a model or feature stack. + """ + if repetitions < 1: + raise ValueError("repetitions must be at least 1") + + workload = load_workload(workload_path) + resolved_cache_policy = _normalize_cache_policy( + cache_policy or workload.defaults.get("cache_policy") + ) + run_id = str(uuid.uuid4())[:8] + timestamp = datetime.now(timezone.utc).isoformat() + timeout = httpx.Timeout(request_timeout_s) if request_timeout_s else None + + async with httpx.AsyncClient(timeout=timeout) as client: + runtime = await auto_detect_runtime(client, url) + hardware = detect_hardware_fingerprint() + model_id = model or runtime.get("model_id", "") + if not model_id: + raise ValueError("could not determine model ID; pass --model") + + records = [] + cache_events = [] + if resolved_cache_policy == "before-run": + cache_events.append( + { + "scope": "before-run", + "event": await clear_runtime_cache(client, url), + } + ) + for repetition in range(repetitions): + for case in workload.cases: + cache_reset = None + if resolved_cache_policy == "before-case": + cache_reset = await clear_runtime_cache(client, url) + cache_events.append( + { + "scope": "before-case", + "case_id": case.case_id, + "repetition": repetition, + "event": cache_reset, + } + ) + record = await run_workload_case( + client, + url, + workload=workload, + case=case, + model=model_id, + runtime=runtime, + hardware=hardware, + run_id=run_id, + timestamp=timestamp, + repetition=repetition, + scrape=scrape, + include_content=include_content, + cache_reset=cache_reset, + ) + records.append(record) + + payload = { + "run_id": run_id, + "timestamp": timestamp, + "workload": { + "name": workload.name, + "description": workload.description, + "path": str(Path(workload_path).expanduser()), + "defaults": workload.defaults, + "repetitions": repetitions, + }, + "transport": { + "request_timeout_s": request_timeout_s, + "note": "transport safety only; product policy timeouts live in workload cases", + }, + "cache_policy": { + "mode": resolved_cache_policy, + "events": cache_events, + }, + "summary": summarize_workload_results(records), + "results": records, + } + + if output_format == "sqlite": + if not output_path: + raise ValueError("--output is required when --format sqlite") + write_workload_sqlite(payload, output_path) + print(f"Workload SQLite results written to {output_path}") + return payload + + rendered = format_workload_payload(payload, output_format) + if output_path: + Path(output_path).expanduser().write_text(rendered) + print(f"Workload results written to {output_path}") + else: + print(rendered) + return payload + + # --------------------------------------------------------------------------- # Task 6: Output formatters # --------------------------------------------------------------------------- @@ -958,6 +1555,235 @@ def format_sql(results: list[BenchServeResult]) -> str: return "\n".join(lines) +def _write_sqlite_rows( + output_path: str, + *, + table: str, + schema: str, + columns: list[str], + rows: list[dict], +) -> None: + """Append benchmark rows to a SQLite database.""" + db_path = Path(output_path).expanduser() + _validate_sql_identifier(table, kind="table") + for column in columns: + _validate_sql_identifier(column, kind="column") + placeholders = ", ".join("?" for _ in columns) + column_list = ", ".join(columns) + values = [[row.get(col) for col in columns] for row in rows] + with sqlite3.connect(db_path) as conn: + conn.execute(f"CREATE TABLE IF NOT EXISTS {table} ({schema})") + if values: + conn.executemany( + f"INSERT INTO {table} ({column_list}) VALUES ({placeholders})", + values, + ) + conn.commit() + + +def _validate_sql_identifier(identifier: str, *, kind: str) -> None: + """Reject unsafe SQL identifiers before string interpolation.""" + if not _SQL_IDENTIFIER_RE.fullmatch(identifier): + raise ValueError(f"invalid SQLite {kind} identifier: {identifier!r}") + + +def write_sqlite(results: list[BenchServeResult], output_path: str) -> None: + rows = [_result_to_dict(r) for r in results] + _write_sqlite_rows( + output_path, + table="bench_serve", + schema=_SQL_SCHEMA, + columns=RESULT_COLUMNS, + rows=rows, + ) + + +WORKLOAD_RESULT_COLUMNS = [ + "run_id", + "timestamp", + "workload", + "case_id", + "repetition", + "tags", + "model_id", + "chip", + "memory_gb", + "os_version", + "engine_type", + "model_type", + "mtp_enabled", + "specprefill", + "kv_quant", + "cache_type", + "request_max_tokens", + "request_enable_thinking", + "request_extra_body", + "policy_timeout_ms", + "within_policy_timeout", + "ttft_ms", + "tpot_ms", + "e2e_latency_ms", + "gen_tps", + "prompt_tps", + "prompt_tokens", + "completion_tokens", + "cache_hits", + "cache_misses", + "tokens_saved", + "metal_active_gb", + "metal_peak_gb", + "metal_cache_gb", + "quality_ok", + "quality_issues", + "finish_reason", + "content_chars", + "content_preview", +] + +_WORKLOAD_TABLE_COLUMNS = [ + "case_id", + "repetition", + "tags", + "quality_ok", + "within_policy_timeout", + "ttft_ms", + "gen_tps", + "e2e_latency_ms", + "cache_hits", + "tokens_saved", + "finish_reason", +] + + +def _workload_record_to_row(record: dict) -> dict: + runtime = record.get("runtime") or {} + hardware = record.get("hardware") or {} + request = record.get("request") or {} + policy = record.get("policy") or {} + metrics = record.get("metrics") or {} + metal = metrics.get("metal") or {} + quality = record.get("quality") or {} + return { + "run_id": record.get("run_id", ""), + "timestamp": record.get("timestamp", ""), + "workload": record.get("workload", ""), + "case_id": record.get("case_id", ""), + "repetition": record.get("repetition", 0), + "tags": ",".join(record.get("tags") or []), + "model_id": record.get("model_id", ""), + "chip": hardware.get("chip", ""), + "memory_gb": hardware.get("memory_gb", 0.0), + "os_version": hardware.get("os_version", ""), + "engine_type": runtime.get("engine_type", ""), + "model_type": runtime.get("model_type", ""), + "mtp_enabled": runtime.get("mtp_enabled", False), + "specprefill": runtime.get("specprefill", False), + "kv_quant": runtime.get("kv_quant", ""), + "cache_type": runtime.get("cache_type", ""), + "request_max_tokens": request.get("max_tokens"), + "request_enable_thinking": request.get("enable_thinking"), + "request_extra_body": json.dumps( + request.get("extra_body") or {}, sort_keys=True + ), + "policy_timeout_ms": policy.get("timeout_ms"), + "within_policy_timeout": policy.get("within_timeout"), + "ttft_ms": metrics.get("ttft_ms", 0.0), + "tpot_ms": metrics.get("tpot_ms", 0.0), + "e2e_latency_ms": metrics.get("e2e_latency_ms", 0.0), + "gen_tps": metrics.get("gen_tps", 0.0), + "prompt_tps": metrics.get("prompt_tps", 0.0), + "prompt_tokens": metrics.get("prompt_tokens", 0), + "completion_tokens": metrics.get("completion_tokens", 0), + "cache_hits": metrics.get("cache_hits", 0), + "cache_misses": metrics.get("cache_misses", 0), + "tokens_saved": metrics.get("tokens_saved", 0), + "metal_active_gb": metal.get("metal_active_gb", 0.0), + "metal_peak_gb": metal.get("metal_peak_gb", 0.0), + "metal_cache_gb": metal.get("metal_cache_gb", 0.0), + "quality_ok": quality.get("ok", False), + "quality_issues": json.dumps(quality.get("issues") or []), + "finish_reason": quality.get("finish_reason"), + "content_chars": quality.get("content_chars", 0), + "content_preview": quality.get("content_preview", ""), + } + + +def format_workload_table(payload: dict) -> str: + rows = [] + for record in payload.get("results") or []: + row = _workload_record_to_row(record) + rows.append( + [ + round(value, 1) if isinstance(value, float) else value + for value in (row[col] for col in _WORKLOAD_TABLE_COLUMNS) + ] + ) + return _tabulate(rows, headers=_WORKLOAD_TABLE_COLUMNS, tablefmt="simple") + + +def format_workload_json(payload: dict) -> str: + return json.dumps(payload, indent=2) + + +def format_workload_csv(payload: dict) -> str: + buf = io.StringIO() + writer = csv_mod.DictWriter(buf, fieldnames=WORKLOAD_RESULT_COLUMNS) + writer.writeheader() + for record in payload.get("results") or []: + writer.writerow(_workload_record_to_row(record)) + return buf.getvalue() + + +_WORKLOAD_SQL_SCHEMA = ( + "run_id TEXT, timestamp TEXT, workload TEXT, case_id TEXT, repetition INTEGER, " + "tags TEXT, model_id TEXT, chip TEXT, memory_gb REAL, os_version TEXT, " + "engine_type TEXT, model_type TEXT, mtp_enabled BOOLEAN, specprefill BOOLEAN, " + "kv_quant TEXT, cache_type TEXT, request_max_tokens INTEGER, " + "request_enable_thinking BOOLEAN, request_extra_body TEXT, " + "policy_timeout_ms INTEGER, within_policy_timeout BOOLEAN, " + "ttft_ms REAL, tpot_ms REAL, e2e_latency_ms REAL, gen_tps REAL, " + "prompt_tps REAL, prompt_tokens INTEGER, completion_tokens INTEGER, " + "cache_hits INTEGER, cache_misses INTEGER, tokens_saved INTEGER, " + "metal_active_gb REAL, metal_peak_gb REAL, metal_cache_gb REAL, " + "quality_ok BOOLEAN, quality_issues TEXT, finish_reason TEXT, " + "content_chars INTEGER, content_preview TEXT" +) + + +def format_workload_sql(payload: dict) -> str: + lines = [ + f"CREATE TABLE IF NOT EXISTS bench_serve_workload ({_WORKLOAD_SQL_SCHEMA});", + ] + for record in payload.get("results") or []: + row = _workload_record_to_row(record) + values = ", ".join(_sql_escape(row[col]) for col in WORKLOAD_RESULT_COLUMNS) + lines.append(f"INSERT INTO bench_serve_workload VALUES ({values});") + return "\n".join(lines) + + +def write_workload_sqlite(payload: dict, output_path: str) -> None: + rows = [_workload_record_to_row(record) for record in payload.get("results") or []] + _write_sqlite_rows( + output_path, + table="bench_serve_workload", + schema=_WORKLOAD_SQL_SCHEMA, + columns=WORKLOAD_RESULT_COLUMNS, + rows=rows, + ) + + +def format_workload_payload(payload: dict, fmt: str = "json") -> str: + if fmt == "json": + return format_workload_json(payload) + if fmt == "csv": + return format_workload_csv(payload) + if fmt == "sql": + return format_workload_sql(payload) + if fmt == "table": + return format_workload_table(payload) + raise ValueError(f"Unsupported workload output format: {fmt}") + + # --------------------------------------------------------------------------- # Task 7: Main async orchestrator # --------------------------------------------------------------------------- @@ -1004,7 +1830,7 @@ async def run_bench_serve( output_path: File path to write results to. If ``None``, prints to stdout. fmt: Output format — one of ``"table"``, ``"json"``, ``"csv"``, - ``"sql"``. + ``"sql"``, or ``"sqlite"``. do_validate: Whether to validate each response. scrape: Whether to scrape ``/metrics`` before and after each run. tag: Optional tag string stored in every result row. @@ -1358,6 +2184,13 @@ def _mean(key: str) -> float: results.append(result_obj) # 12. Format output + if fmt == "sqlite": + if not output_path: + raise ValueError("--output is required when --format sqlite") + write_sqlite(results, output_path) + print(f"\nSQLite results written to {output_path}") + return results + formatters = { "table": format_table, "json": format_json, diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index d7b6fd6e9..af1a94e9a 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -696,7 +696,29 @@ def bench_kv_cache_command(args): def bench_serve_command(args): """Run serving benchmark.""" import asyncio - from .bench_serve import run_bench_serve + + from .bench_serve import run_bench_serve, run_bench_serve_workload + + if args.workload: + request_timeout_s = ( + None if args.request_timeout_s <= 0 else args.request_timeout_s + ) + output_format = "json" if args.format == "auto" else args.format + asyncio.run( + run_bench_serve_workload( + url=args.url, + workload_path=args.workload, + model=args.model, + output_path=args.output, + output_format=output_format, + scrape=args.scrape_metrics == "true", + include_content=args.include_content, + request_timeout_s=request_timeout_s, + repetitions=args.repetitions, + cache_policy=args.cache_policy, + ) + ) + return prompt_sets = args.prompts.split(",") concurrencies = [int(c) for c in args.concurrency.split(",")] @@ -730,6 +752,7 @@ def bench_serve_command(args): k, v = kv.split("=", 1) overrides[k] = v + output_format = "table" if args.format == "auto" else args.format asyncio.run( run_bench_serve( url=args.url, @@ -743,7 +766,7 @@ def bench_serve_command(args): thinking_values=thinking_values, extra_bodies=extra_bodies, output_path=args.output, - fmt=args.format, + fmt=output_format, do_validate=args.validate == "true", scrape=args.scrape_metrics == "true", tag=args.tag, @@ -1342,6 +1365,16 @@ def create_parser() -> argparse.ArgumentParser: default=None, help="Model ID to benchmark (default: auto-detected from server)", ) + bench_serve_parser.add_argument( + "--workload", + type=str, + default=None, + help=( + "Path to a declarative workload JSON file. When set, bench-serve " + "runs contract-style cases with per-case quality checks and " + "comparison-only policy timeouts instead of the prompt sweep." + ), + ) bench_serve_parser.add_argument( "--prompts", type=str, @@ -1393,7 +1426,7 @@ def create_parser() -> argparse.ArgumentParser: "--repetitions", type=int, default=3, - help="Number of repetitions per sweep configuration (default: 3)", + help="Number of repetitions per sweep configuration or workload case (default: 3)", ) bench_serve_parser.add_argument( "--warmup", @@ -1422,9 +1455,12 @@ def create_parser() -> argparse.ArgumentParser: bench_serve_parser.add_argument( "--format", type=str, - default="table", - choices=["table", "json", "csv", "sql"], - help="Output format (default: table)", + default="auto", + choices=["auto", "table", "json", "csv", "sql", "sqlite"], + help=( + "Output format (auto = table for prompt sweeps, json for workloads; " + "sqlite requires --output)" + ), ) bench_serve_parser.add_argument( "--validate", @@ -1440,6 +1476,31 @@ def create_parser() -> argparse.ArgumentParser: choices=["true", "false"], help="Scrape /metrics before and after each run (default: true)", ) + bench_serve_parser.add_argument( + "--include-content", + action="store_true", + help="Include full generated content in workload JSON output", + ) + bench_serve_parser.add_argument( + "--request-timeout-s", + type=float, + default=300.0, + help=( + "HTTP transport timeout for workload mode in seconds (default: 300). " + "Use 0 to disable; product policy timeouts belong in the workload." + ), + ) + bench_serve_parser.add_argument( + "--cache-policy", + type=str, + default=None, + choices=["preserve", "before-run", "before-case"], + help=( + "Workload cache handling (default: workload defaults or preserve). " + "Use before-case for cold, uncontaminated per-case qualification. " + "Workload JSON may also spell these with underscores." + ), + ) bench_serve_parser.add_argument( "--tag", type=str,