Skip to content

feat: add --compile flag for mx.compile model optimization#270

Open
jackneil wants to merge 4 commits intowaybarrios:mainfrom
jackneil:pr/mx-compile
Open

feat: add --compile flag for mx.compile model optimization#270
jackneil wants to merge 4 commits intowaybarrios:mainfrom
jackneil:pr/mx-compile

Conversation

@jackneil
Copy link
Copy Markdown
Contributor

@jackneil jackneil commented Apr 9, 2026

Summary

  • Adds --compile flag to vllm-mlx serve that wraps model forward pass with mx.compile(shapeless=True)
  • Fuses elementwise Metal kernels, reducing kernel launch overhead
  • Off by default, opt-in, zero impact when not used
  • Adds vllm-mlx bench-compile <model> A/B benchmark command with warmup, cache clearing, and statistical reporting
  • Benchmarked: +33% prefill speed, no decode change (decode is memory-bandwidth-bound on Apple Silicon)

Usage

# Serve with compiled model
vllm-mlx serve mlx-community/Qwen3-8B-4bit --compile

# A/B benchmark
vllm-mlx bench-compile mlx-community/Qwen3-8B-4bit --prompts 5

Test plan

  • 5 unit tests for compile wrapper
  • Integration test with real model
  • Benchmarked on Qwen3-0.6B and Qwen3.5-35B

🤖 Generated with Claude Code

Jack Neil and others added 4 commits April 9, 2026 14:26
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Replace temp/temperature kwarg with sampler for greedy decoding
- Use mx.clear_cache() instead of deprecated mx.metal.clear_cache()

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Thump604
Copy link
Copy Markdown
Collaborator

Blocking issue: the current compile wrapper mutates model.__call__ on the instance, but Python does not use an instance attribute for obj(...) special-method dispatch.

Minimal repro:

class C:
    def __call__(self, x):
        return ('orig', x)

c = C()
def alt(x):
    return ('alt', x)

c.__call__ = alt
assert c.__call__(2) == ('alt', 2)
assert c(3) == ('orig', 3)  # still uses type(c).__call__

So apply_compile() can set the compiled flag and log success while the normal inference path still executes the original uncompiled __call__.

I think this needs a different implementation before merge: either wrap the model in a callable proxy used by the engine, or compile the actual function object the engine invokes at call sites rather than writing to the instance __call__ attribute.

Copy link
Copy Markdown
Collaborator

@Thump604 Thump604 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation is not safe to merge yet.

apply_compile() mutates model.__call__ on the instance, but Python does not use an instance attribute for obj(...) special-method dispatch. In other words, model.__call__ = compiled_call does not make model(...) use the compiled function.

Minimal repro:

class C:
    def __call__(self, x):
        return ('orig', x)

c = C()
def alt(x):
    return ('alt', x)

c.__call__ = alt
assert c.__call__(2) == ('alt', 2)
assert c(3) == ('orig', 3)

So this PR can set the compiled flag and log success while the normal inference path still executes the original uncompiled __call__.

I think this needs a different implementation before merge: either wrap the model in a callable proxy used by the engine, or compile the actual function object the engine invokes at call sites rather than writing to the instance __call__ attribute.

Copy link
Copy Markdown
Collaborator

@janhilgard janhilgard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreeing with Thump604's blocking review -- the __call__ dispatch issue is fundamental. To elaborate:

The core bug Thump604 identified is real and critical. Python resolves obj(args) via type(obj).__call__, not obj.__dict__["__call__"]. So model.__call__ = mx.compile(original_call, shapeless=True) sets an instance attribute that is only reached via model.__call__(x) (explicit attribute access), NOT via model(x) (which is how MLX generation code invokes the model). The compiled function is never actually called during inference.

The unit tests pass because apply_compile returns the same model object, and is_compiled checks the flag (which IS set), but the actual forward pass is unchanged. The test_compile_wraps_model test calls compiled_model(x) which goes through the unmodified class __call__, so mx.allclose passes trivially (same function, same output).

To fix this properly, the standard pattern in MLX is:

model.__class__.__call__ = mx.compile(model.__class__.__call__, shapeless=True)

But this mutates the class itself, affecting all instances. The safer approach would be to create a wrapper class or use mx.compile at the nn.Module level if MLX supports it.

Additional concerns beyond Thump604's review:

  1. bench-compile uses mlx_lm.stream_generate directly rather than going through the vllm-mlx engine. This means the benchmark does not measure the actual serving path (batched engine, continuous batching, etc.), so the numbers may not reflect real-world impact.

  2. server.py simple engine path does apply_compile(inner_model) where inner_model = getattr(_engine._model, "model", None). This reaches into private internals of MLXLanguageModel and will break if the attribute name changes. Better to expose a public method on the engine.

  3. No interaction testing with MTP, KV cache quantization, or MLLM. mx.compile(shapeless=True) may not be compatible with the dynamic shapes in BatchKVCache, MTP speculation, or vision preprocessing. This needs validation before marking as experimental.

  4. Branch has merge conflicts with current main.

@Thump604
Copy link
Copy Markdown
Collaborator

Hi @jackneil -- same question on this one. Review feedback is open and it has merge conflicts. Are you still planning to address the feedback and rebase? Will check back in two weeks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants