Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/reference/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ vllm-mlx serve <model> [options]
| `--embedding-model` | Pre-load an embedding model at startup | None |
| `--enable-auto-tool-choice` | Enable automatic tool calling | False |
| `--tool-call-parser` | Tool call parser (`auto`, `mistral`, `qwen`, `llama`, `hermes`, `deepseek`, `kimi`, `granite`, `nemotron`, `xlam`, `functionary`, `glm47`) | None |
| `--compile` | Compile model forward pass with mx.compile for fused Metal kernels (experimental) | False |

### Examples

Expand Down Expand Up @@ -96,6 +97,16 @@ vllm-mlx serve mlx-community/Qwen3-4B-4bit \
--continuous-batching
```

### Performance: mx.compile

```bash
# Serve with compiled model (may improve throughput 5-30%)
vllm-mlx serve mlx-community/Qwen3-8B-4bit --compile

# A/B benchmark to measure the real impact
vllm-mlx bench-compile mlx-community/Qwen3-8B-4bit --prompts 5 --max-tokens 128
```

### Security

When `--api-key` is set, all API requests require the `Authorization: Bearer <api-key>` header:
Expand Down
57 changes: 57 additions & 0 deletions tests/test_bench_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
"""Integration test for mx.compile wrapper with mlx_lm models."""

import os

import pytest

try:
import mlx.core as mx
from mlx_lm import load

HAS_MLX_LM = True
except ImportError:
HAS_MLX_LM = False

from vllm_mlx.compile import apply_compile, is_compiled


@pytest.mark.skipif(not HAS_MLX_LM, reason="mlx_lm not available")
class TestCompileWithRealModel:
@pytest.fixture
def model_and_tokenizer(self):
model_name = os.environ.get(
"VLLM_MLX_TEST_MODEL", "mlx-community/Qwen3-0.6B-8bit"
)
try:
model, tokenizer = load(model_name)
return model, tokenizer
except Exception:
pytest.skip(f"Model {model_name} not available")

def test_compile_real_model_produces_output(self, model_and_tokenizer):
from mlx_lm import stream_generate

model, tokenizer = model_and_tokenizer

baseline_tokens = []
for resp in stream_generate(
model, tokenizer, "Hello", max_tokens=5, temp=0.0
):
baseline_tokens.append(resp.token)
if resp.finish_reason:
break

apply_compile(model)
assert is_compiled(model)

compiled_tokens = []
for resp in stream_generate(
model, tokenizer, "Hello", max_tokens=5, temp=0.0
):
compiled_tokens.append(resp.token)
if resp.finish_reason:
break

assert len(baseline_tokens) > 0
assert len(compiled_tokens) > 0
60 changes: 60 additions & 0 deletions tests/test_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for mx.compile wrapper utility."""

import mlx.core as mx
import mlx.nn as nn
import pytest

from vllm_mlx.compile import apply_compile, is_compiled


class SimpleModel(nn.Module):
"""Tiny model for testing compilation."""

def __init__(self):
super().__init__()
self.linear = nn.Linear(8, 8)

def __call__(self, x):
return self.linear(x)


class TestApplyCompile:
def test_compile_wraps_model(self):
"""apply_compile returns a model whose __call__ is compiled."""
model = SimpleModel()
compiled_model = apply_compile(model)
x = mx.ones((1, 8))
original_out = model(x)
compiled_out = compiled_model(x)
assert mx.allclose(original_out, compiled_out).item()

def test_compile_is_idempotent(self):
"""Applying compile twice doesn't double-wrap."""
model = SimpleModel()
compiled = apply_compile(model)
double_compiled = apply_compile(compiled)
assert compiled is double_compiled

def test_is_compiled_flag(self):
"""is_compiled returns correct state."""
model = SimpleModel()
assert is_compiled(model) is False
compiled = apply_compile(model)
assert is_compiled(compiled) is True

def test_no_compile_returns_original(self):
"""When compile=False, return model unchanged."""
model = SimpleModel()
result = apply_compile(model, enabled=False)
assert result is model
assert is_compiled(result) is False

def test_compiled_model_handles_different_shapes(self):
"""shapeless=True means different input shapes don't crash."""
model = SimpleModel()
compiled = apply_compile(model)
out1 = compiled(mx.ones((1, 8)))
out2 = compiled(mx.ones((4, 8)))
assert out1.shape == (1, 8)
assert out2.shape == (4, 8)
189 changes: 189 additions & 0 deletions vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def serve_command(args):
f"keep={args.specprefill_keep_pct*100:.0f}%)"
)

if args.compile:
print(" Compile: ENABLED (mx.compile with fused kernels)")

# Load model with unified server
load_model(
args.model,
Expand All @@ -204,6 +207,7 @@ def serve_command(args):
specprefill_threshold=args.specprefill_threshold,
specprefill_keep_pct=args.specprefill_keep_pct,
specprefill_draft_model=args.specprefill_draft_model,
compile=args.compile,
)

# Start server
Expand Down Expand Up @@ -332,6 +336,155 @@ async def get_output(rid):
asyncio.run(run_benchmark())


def bench_compile_command(args):
"""A/B benchmark: measure tok/s with and without mx.compile."""
import statistics

import mlx.core as mx
from mlx_lm import load, stream_generate

from .compile import apply_compile

print(f"Model: {args.model}")
print(f"Runs: {args.prompts}, Max tokens: {args.max_tokens}")
print()

# Build a prompt of approximately the right length
base_prompt = "Explain the theory of " + " ".join(
["quantum"] * (args.prompt_tokens // 2)
)

def run_measurement(model, tokenizer, label, num_runs):
results = []
for i in range(num_runs):
mx.clear_cache()

prompt_tps = 0.0
gen_tps = 0.0
gen_tokens = 0

for response in stream_generate(
model,
tokenizer,
base_prompt,
max_tokens=args.max_tokens,
sampler=lambda logits: mx.argmax(logits, axis=-1), # Greedy for determinism
):
if response.finish_reason is not None:
prompt_tps = response.prompt_tps
gen_tps = response.generation_tps
gen_tokens = response.generation_tokens
break

results.append(
{
"prompt_tps": prompt_tps,
"gen_tps": gen_tps,
"gen_tokens": gen_tokens,
}
)
print(
f" {label} run {i+1}/{num_runs}: "
f"prefill={prompt_tps:.1f} tok/s, decode={gen_tps:.1f} tok/s"
)

return results

# Phase 1: Without compile
print("Loading model (without compile)...")
model, tokenizer = load(args.model)

print("Warmup run (discarded)...")
for _ in stream_generate(model, tokenizer, "Hello", max_tokens=10):
pass
mx.clear_cache()

print(f"\nMeasuring WITHOUT compile ({args.prompts} runs)...")
baseline_results = run_measurement(model, tokenizer, "baseline", args.prompts)

# Phase 2: With compile
print("\nApplying mx.compile to model...")
apply_compile(model)

print("Warmup run (triggers compilation, discarded)...")
for _ in stream_generate(model, tokenizer, "Hello", max_tokens=10):
pass
mx.clear_cache()

print(f"\nMeasuring WITH compile ({args.prompts} runs)...")
compiled_results = run_measurement(model, tokenizer, "compiled", args.prompts)

# Report
def stats(results, key):
values = [r[key] for r in results]
return {
"mean": statistics.mean(values),
"std": statistics.stdev(values) if len(values) > 1 else 0,
"values": values,
}

b_prefill = stats(baseline_results, "prompt_tps")
c_prefill = stats(compiled_results, "prompt_tps")
b_decode = stats(baseline_results, "gen_tps")
c_decode = stats(compiled_results, "gen_tps")

prefill_pct = (
((c_prefill["mean"] - b_prefill["mean"]) / b_prefill["mean"]) * 100
if b_prefill["mean"] > 0
else 0
)
decode_pct = (
((c_decode["mean"] - b_decode["mean"]) / b_decode["mean"]) * 100
if b_decode["mean"] > 0
else 0
)

print("\n" + "=" * 60)
print(f"Model: {args.model}")
print(f"Runs: {args.prompts}, Max tokens: {args.max_tokens}")
print("=" * 60)
print(
f"{'':20s} {'Without compile':>18s} {'With compile':>18s} {'Change':>10s}"
)
print("-" * 68)
print(
f"{'Prefill (tok/s)':20s} {b_prefill['mean']:>14.1f} "
f"{c_prefill['mean']:>14.1f} {prefill_pct:>+7.1f}%"
)
print(
f"{'Decode (tok/s)':20s} {b_decode['mean']:>14.1f} "
f"{c_decode['mean']:>14.1f} {decode_pct:>+7.1f}%"
)
print()
print(
f"Decode per-run (without): "
f"{', '.join(f'{v:.1f}' for v in b_decode['values'])} "
f"(std={b_decode['std']:.2f})"
)
print(
f"Decode per-run (with): "
f"{', '.join(f'{v:.1f}' for v in c_decode['values'])} "
f"(std={c_decode['std']:.2f})"
)
print()

if decode_pct > 2:
print(
f"VERDICT: mx.compile gives +{decode_pct:.1f}% decode speed. "
f"Use --compile."
)
elif decode_pct < -2:
print(
f"VERDICT: mx.compile is {decode_pct:.1f}% SLOWER. "
f"Do NOT use --compile for this model."
)
else:
print(
f"VERDICT: mx.compile has no significant effect ({decode_pct:+.1f}%). "
f"Skip --compile for this model."
)


def bench_detok_command(args):
"""Benchmark streaming detokenizer optimization."""
import statistics
Expand Down Expand Up @@ -730,6 +883,14 @@ def main():
help="Max prefill tokens per scheduler step (0=disabled). "
"Prevents starvation of active requests during long prefills.",
)
# Compile (mx.compile)
serve_parser.add_argument(
"--compile",
action="store_true",
default=False,
help="Compile model forward pass with mx.compile for fused Metal kernels. "
"May improve throughput 5-30%%. Experimental.",
)
# MTP (Multi-Token Prediction)
serve_parser.add_argument(
"--enable-mtp",
Expand Down Expand Up @@ -1023,6 +1184,32 @@ def main():
help="Quantization group size (default: 64)",
)

# Compile A/B benchmark
bench_compile = subparsers.add_parser(
"bench-compile",
help="A/B benchmark: measure tok/s with and without mx.compile",
)
bench_compile.add_argument("model", help="Model name or path")
bench_compile.add_argument(
"--prompts",
type=int,
default=5,
help="Number of measurement runs (default: 5)",
)
bench_compile.add_argument(
"--max-tokens",
type=int,
default=128,
help="Tokens to generate per run (default: 128)",
)
bench_compile.add_argument(
"--prompt-tokens",
type=int,
default=256,
help="Approximate prompt length (default: 256)",
)
bench_compile.set_defaults(func=bench_compile_command)

args = parser.parse_args()

if args.command == "serve":
Expand All @@ -1033,6 +1220,8 @@ def main():
bench_detok_command(args)
elif args.command == "bench-kv-cache":
bench_kv_cache_command(args)
elif args.command == "bench-compile":
bench_compile_command(args)
else:
parser.print_help()
sys.exit(1)
Expand Down
Loading