diff --git a/docs/reference/cli.md b/docs/reference/cli.md index 2ba8b75e1..8dba28539 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -44,6 +44,7 @@ vllm-mlx serve [options] | `--embedding-model` | Pre-load an embedding model at startup | None | | `--enable-auto-tool-choice` | Enable automatic tool calling | False | | `--tool-call-parser` | Tool call parser (`auto`, `mistral`, `qwen`, `llama`, `hermes`, `deepseek`, `kimi`, `granite`, `nemotron`, `xlam`, `functionary`, `glm47`) | None | +| `--compile` | Compile model forward pass with mx.compile for fused Metal kernels (experimental) | False | ### Examples @@ -96,6 +97,16 @@ vllm-mlx serve mlx-community/Qwen3-4B-4bit \ --continuous-batching ``` +### Performance: mx.compile + +```bash +# Serve with compiled model (may improve throughput 5-30%) +vllm-mlx serve mlx-community/Qwen3-8B-4bit --compile + +# A/B benchmark to measure the real impact +vllm-mlx bench-compile mlx-community/Qwen3-8B-4bit --prompts 5 --max-tokens 128 +``` + ### Security When `--api-key` is set, all API requests require the `Authorization: Bearer ` header: diff --git a/tests/test_bench_compile.py b/tests/test_bench_compile.py new file mode 100644 index 000000000..3c7db05af --- /dev/null +++ b/tests/test_bench_compile.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Integration test for mx.compile wrapper with mlx_lm models.""" + +import os + +import pytest + +try: + import mlx.core as mx + from mlx_lm import load + + HAS_MLX_LM = True +except ImportError: + HAS_MLX_LM = False + +from vllm_mlx.compile import apply_compile, is_compiled + + +@pytest.mark.skipif(not HAS_MLX_LM, reason="mlx_lm not available") +class TestCompileWithRealModel: + @pytest.fixture + def model_and_tokenizer(self): + model_name = os.environ.get( + "VLLM_MLX_TEST_MODEL", "mlx-community/Qwen3-0.6B-8bit" + ) + try: + model, tokenizer = load(model_name) + return model, tokenizer + except Exception: + pytest.skip(f"Model {model_name} not available") + + def test_compile_real_model_produces_output(self, model_and_tokenizer): + from mlx_lm import stream_generate + + model, tokenizer = model_and_tokenizer + + baseline_tokens = [] + for resp in stream_generate( + model, tokenizer, "Hello", max_tokens=5, temp=0.0 + ): + baseline_tokens.append(resp.token) + if resp.finish_reason: + break + + apply_compile(model) + assert is_compiled(model) + + compiled_tokens = [] + for resp in stream_generate( + model, tokenizer, "Hello", max_tokens=5, temp=0.0 + ): + compiled_tokens.append(resp.token) + if resp.finish_reason: + break + + assert len(baseline_tokens) > 0 + assert len(compiled_tokens) > 0 diff --git a/tests/test_compile.py b/tests/test_compile.py new file mode 100644 index 000000000..4596da436 --- /dev/null +++ b/tests/test_compile.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for mx.compile wrapper utility.""" + +import mlx.core as mx +import mlx.nn as nn +import pytest + +from vllm_mlx.compile import apply_compile, is_compiled + + +class SimpleModel(nn.Module): + """Tiny model for testing compilation.""" + + def __init__(self): + super().__init__() + self.linear = nn.Linear(8, 8) + + def __call__(self, x): + return self.linear(x) + + +class TestApplyCompile: + def test_compile_wraps_model(self): + """apply_compile returns a model whose __call__ is compiled.""" + model = SimpleModel() + compiled_model = apply_compile(model) + x = mx.ones((1, 8)) + original_out = model(x) + compiled_out = compiled_model(x) + assert mx.allclose(original_out, compiled_out).item() + + def test_compile_is_idempotent(self): + """Applying compile twice doesn't double-wrap.""" + model = SimpleModel() + compiled = apply_compile(model) + double_compiled = apply_compile(compiled) + assert compiled is double_compiled + + def test_is_compiled_flag(self): + """is_compiled returns correct state.""" + model = SimpleModel() + assert is_compiled(model) is False + compiled = apply_compile(model) + assert is_compiled(compiled) is True + + def test_no_compile_returns_original(self): + """When compile=False, return model unchanged.""" + model = SimpleModel() + result = apply_compile(model, enabled=False) + assert result is model + assert is_compiled(result) is False + + def test_compiled_model_handles_different_shapes(self): + """shapeless=True means different input shapes don't crash.""" + model = SimpleModel() + compiled = apply_compile(model) + out1 = compiled(mx.ones((1, 8))) + out2 = compiled(mx.ones((4, 8))) + assert out1.shape == (1, 8) + assert out2.shape == (4, 8) diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 8a90bc9be..81cad4d65 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -189,6 +189,9 @@ def serve_command(args): f"keep={args.specprefill_keep_pct*100:.0f}%)" ) + if args.compile: + print(" Compile: ENABLED (mx.compile with fused kernels)") + # Load model with unified server load_model( args.model, @@ -204,6 +207,7 @@ def serve_command(args): specprefill_threshold=args.specprefill_threshold, specprefill_keep_pct=args.specprefill_keep_pct, specprefill_draft_model=args.specprefill_draft_model, + compile=args.compile, ) # Start server @@ -332,6 +336,155 @@ async def get_output(rid): asyncio.run(run_benchmark()) +def bench_compile_command(args): + """A/B benchmark: measure tok/s with and without mx.compile.""" + import statistics + + import mlx.core as mx + from mlx_lm import load, stream_generate + + from .compile import apply_compile + + print(f"Model: {args.model}") + print(f"Runs: {args.prompts}, Max tokens: {args.max_tokens}") + print() + + # Build a prompt of approximately the right length + base_prompt = "Explain the theory of " + " ".join( + ["quantum"] * (args.prompt_tokens // 2) + ) + + def run_measurement(model, tokenizer, label, num_runs): + results = [] + for i in range(num_runs): + mx.clear_cache() + + prompt_tps = 0.0 + gen_tps = 0.0 + gen_tokens = 0 + + for response in stream_generate( + model, + tokenizer, + base_prompt, + max_tokens=args.max_tokens, + sampler=lambda logits: mx.argmax(logits, axis=-1), # Greedy for determinism + ): + if response.finish_reason is not None: + prompt_tps = response.prompt_tps + gen_tps = response.generation_tps + gen_tokens = response.generation_tokens + break + + results.append( + { + "prompt_tps": prompt_tps, + "gen_tps": gen_tps, + "gen_tokens": gen_tokens, + } + ) + print( + f" {label} run {i+1}/{num_runs}: " + f"prefill={prompt_tps:.1f} tok/s, decode={gen_tps:.1f} tok/s" + ) + + return results + + # Phase 1: Without compile + print("Loading model (without compile)...") + model, tokenizer = load(args.model) + + print("Warmup run (discarded)...") + for _ in stream_generate(model, tokenizer, "Hello", max_tokens=10): + pass + mx.clear_cache() + + print(f"\nMeasuring WITHOUT compile ({args.prompts} runs)...") + baseline_results = run_measurement(model, tokenizer, "baseline", args.prompts) + + # Phase 2: With compile + print("\nApplying mx.compile to model...") + apply_compile(model) + + print("Warmup run (triggers compilation, discarded)...") + for _ in stream_generate(model, tokenizer, "Hello", max_tokens=10): + pass + mx.clear_cache() + + print(f"\nMeasuring WITH compile ({args.prompts} runs)...") + compiled_results = run_measurement(model, tokenizer, "compiled", args.prompts) + + # Report + def stats(results, key): + values = [r[key] for r in results] + return { + "mean": statistics.mean(values), + "std": statistics.stdev(values) if len(values) > 1 else 0, + "values": values, + } + + b_prefill = stats(baseline_results, "prompt_tps") + c_prefill = stats(compiled_results, "prompt_tps") + b_decode = stats(baseline_results, "gen_tps") + c_decode = stats(compiled_results, "gen_tps") + + prefill_pct = ( + ((c_prefill["mean"] - b_prefill["mean"]) / b_prefill["mean"]) * 100 + if b_prefill["mean"] > 0 + else 0 + ) + decode_pct = ( + ((c_decode["mean"] - b_decode["mean"]) / b_decode["mean"]) * 100 + if b_decode["mean"] > 0 + else 0 + ) + + print("\n" + "=" * 60) + print(f"Model: {args.model}") + print(f"Runs: {args.prompts}, Max tokens: {args.max_tokens}") + print("=" * 60) + print( + f"{'':20s} {'Without compile':>18s} {'With compile':>18s} {'Change':>10s}" + ) + print("-" * 68) + print( + f"{'Prefill (tok/s)':20s} {b_prefill['mean']:>14.1f} " + f"{c_prefill['mean']:>14.1f} {prefill_pct:>+7.1f}%" + ) + print( + f"{'Decode (tok/s)':20s} {b_decode['mean']:>14.1f} " + f"{c_decode['mean']:>14.1f} {decode_pct:>+7.1f}%" + ) + print() + print( + f"Decode per-run (without): " + f"{', '.join(f'{v:.1f}' for v in b_decode['values'])} " + f"(std={b_decode['std']:.2f})" + ) + print( + f"Decode per-run (with): " + f"{', '.join(f'{v:.1f}' for v in c_decode['values'])} " + f"(std={c_decode['std']:.2f})" + ) + print() + + if decode_pct > 2: + print( + f"VERDICT: mx.compile gives +{decode_pct:.1f}% decode speed. " + f"Use --compile." + ) + elif decode_pct < -2: + print( + f"VERDICT: mx.compile is {decode_pct:.1f}% SLOWER. " + f"Do NOT use --compile for this model." + ) + else: + print( + f"VERDICT: mx.compile has no significant effect ({decode_pct:+.1f}%). " + f"Skip --compile for this model." + ) + + def bench_detok_command(args): """Benchmark streaming detokenizer optimization.""" import statistics @@ -730,6 +883,14 @@ def main(): help="Max prefill tokens per scheduler step (0=disabled). " "Prevents starvation of active requests during long prefills.", ) + # Compile (mx.compile) + serve_parser.add_argument( + "--compile", + action="store_true", + default=False, + help="Compile model forward pass with mx.compile for fused Metal kernels. " + "May improve throughput 5-30%%. Experimental.", + ) # MTP (Multi-Token Prediction) serve_parser.add_argument( "--enable-mtp", @@ -1023,6 +1184,32 @@ def main(): help="Quantization group size (default: 64)", ) + # Compile A/B benchmark + bench_compile = subparsers.add_parser( + "bench-compile", + help="A/B benchmark: measure tok/s with and without mx.compile", + ) + bench_compile.add_argument("model", help="Model name or path") + bench_compile.add_argument( + "--prompts", + type=int, + default=5, + help="Number of measurement runs (default: 5)", + ) + bench_compile.add_argument( + "--max-tokens", + type=int, + default=128, + help="Tokens to generate per run (default: 128)", + ) + bench_compile.add_argument( + "--prompt-tokens", + type=int, + default=256, + help="Approximate prompt length (default: 256)", + ) + bench_compile.set_defaults(func=bench_compile_command) + args = parser.parse_args() if args.command == "serve": @@ -1033,6 +1220,8 @@ def main(): bench_detok_command(args) elif args.command == "bench-kv-cache": bench_kv_cache_command(args) + elif args.command == "bench-compile": + bench_compile_command(args) else: parser.print_help() sys.exit(1) diff --git a/vllm_mlx/compile.py b/vllm_mlx/compile.py new file mode 100644 index 000000000..7b5fce2f7 --- /dev/null +++ b/vllm_mlx/compile.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Compile wrapper for MLX model forward passes. + +Wraps a model's __call__ with mx.compile() for fused Metal kernels. +Used when --compile flag is passed to vllm-mlx serve. +""" + +import logging + +import mlx.core as mx +import mlx.nn as nn + +logger = logging.getLogger(__name__) + +_COMPILED_ATTR = "_vllm_mlx_compiled" + + +def apply_compile(model: nn.Module, enabled: bool = True) -> nn.Module: + """Wrap model's forward pass with mx.compile for fused kernels. + + Args: + model: The MLX model to compile + enabled: If False, return model unchanged (for easy toggling) + + Returns: + The model with compiled __call__, or original if disabled/already compiled + """ + if not enabled: + return model + + if is_compiled(model): + logger.debug("Model already compiled, skipping") + return model + + try: + original_call = model.__call__ + compiled_call = mx.compile(original_call, shapeless=True) + model.__call__ = compiled_call + setattr(model, _COMPILED_ATTR, True) + logger.info("Model forward pass compiled with mx.compile(shapeless=True)") + return model + + except Exception as e: + logger.warning(f"mx.compile failed, using uncompiled model: {e}") + return model + + +def is_compiled(model: nn.Module) -> bool: + """Check if a model has been compiled.""" + return getattr(model, _COMPILED_ATTR, False) diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 3ac52b4b0..613d9cc28 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -207,6 +207,12 @@ async def _start_mllm(self) -> None: self._model = self._mllm_instance.model self._processor = self._mllm_instance.processor + if getattr(self, "_compile_on_start", False): + from ..compile import apply_compile + + self._model = apply_compile(self._model) + logger.info("Compile: model forward pass compiled (batched MLLM)") + # Create MLLM scheduler config with batch generator support if self._scheduler_config and hasattr(self._scheduler_config, "max_num_seqs"): max_num_seqs = self._scheduler_config.max_num_seqs @@ -259,6 +265,12 @@ async def _start_llm(self) -> None: tokenizer_config=tokenizer_config, ) + if getattr(self, "_compile_on_start", False): + from ..compile import apply_compile + + self._model = apply_compile(self._model) + logger.info("Compile: model forward pass compiled (batched LLM)") + # Validate MTP support if enabled if self._scheduler_config and self._scheduler_config.enable_mtp: from ..patches.qwen3_next_mtp import validate_mtp_support diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index af10e7341..f4edc218e 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -492,6 +492,7 @@ def load_model( specprefill_threshold: int = 8192, specprefill_keep_pct: float = 0.3, specprefill_draft_model: str = None, + compile: bool = False, ): """ Load a model (auto-detects MLLM vs LLM). @@ -529,6 +530,9 @@ def load_model( stream_interval=stream_interval, force_mllm=force_mllm, ) + if compile: + _engine._compile_on_start = True + # BatchedEngine will be started in lifespan (uvicorn's event loop) # Just log for now logger.info(f"Model loaded (batched mode): {model_name}") @@ -549,6 +553,15 @@ def load_model( loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(_engine.start()) + + if compile: + from .compile import apply_compile + + inner_model = getattr(_engine._model, "model", None) + if inner_model is not None: + apply_compile(inner_model) + logger.info("Compile: model forward pass compiled") + model_type = "MLLM" if _engine.is_mllm else "LLM" logger.info(f"{model_type} model loaded (simple mode): {model_name}")