feat: add --compile flag for mx.compile model optimization#270
feat: add --compile flag for mx.compile model optimization#270jackneil wants to merge 4 commits intowaybarrios:mainfrom
Conversation
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Replace temp/temperature kwarg with sampler for greedy decoding - Use mx.clear_cache() instead of deprecated mx.metal.clear_cache() Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Blocking issue: the current compile wrapper mutates Minimal repro: class C:
def __call__(self, x):
return ('orig', x)
c = C()
def alt(x):
return ('alt', x)
c.__call__ = alt
assert c.__call__(2) == ('alt', 2)
assert c(3) == ('orig', 3) # still uses type(c).__call__So I think this needs a different implementation before merge: either wrap the model in a callable proxy used by the engine, or compile the actual function object the engine invokes at call sites rather than writing to the instance |
Thump604
left a comment
There was a problem hiding this comment.
The current implementation is not safe to merge yet.
apply_compile() mutates model.__call__ on the instance, but Python does not use an instance attribute for obj(...) special-method dispatch. In other words, model.__call__ = compiled_call does not make model(...) use the compiled function.
Minimal repro:
class C:
def __call__(self, x):
return ('orig', x)
c = C()
def alt(x):
return ('alt', x)
c.__call__ = alt
assert c.__call__(2) == ('alt', 2)
assert c(3) == ('orig', 3)So this PR can set the compiled flag and log success while the normal inference path still executes the original uncompiled __call__.
I think this needs a different implementation before merge: either wrap the model in a callable proxy used by the engine, or compile the actual function object the engine invokes at call sites rather than writing to the instance __call__ attribute.
janhilgard
left a comment
There was a problem hiding this comment.
Agreeing with Thump604's blocking review -- the __call__ dispatch issue is fundamental. To elaborate:
The core bug Thump604 identified is real and critical. Python resolves obj(args) via type(obj).__call__, not obj.__dict__["__call__"]. So model.__call__ = mx.compile(original_call, shapeless=True) sets an instance attribute that is only reached via model.__call__(x) (explicit attribute access), NOT via model(x) (which is how MLX generation code invokes the model). The compiled function is never actually called during inference.
The unit tests pass because apply_compile returns the same model object, and is_compiled checks the flag (which IS set), but the actual forward pass is unchanged. The test_compile_wraps_model test calls compiled_model(x) which goes through the unmodified class __call__, so mx.allclose passes trivially (same function, same output).
To fix this properly, the standard pattern in MLX is:
model.__class__.__call__ = mx.compile(model.__class__.__call__, shapeless=True)But this mutates the class itself, affecting all instances. The safer approach would be to create a wrapper class or use mx.compile at the nn.Module level if MLX supports it.
Additional concerns beyond Thump604's review:
-
bench-compileusesmlx_lm.stream_generatedirectly rather than going through the vllm-mlx engine. This means the benchmark does not measure the actual serving path (batched engine, continuous batching, etc.), so the numbers may not reflect real-world impact. -
server.pysimple engine path doesapply_compile(inner_model)whereinner_model = getattr(_engine._model, "model", None). This reaches into private internals ofMLXLanguageModeland will break if the attribute name changes. Better to expose a public method on the engine. -
No interaction testing with MTP, KV cache quantization, or MLLM.
mx.compile(shapeless=True)may not be compatible with the dynamic shapes in BatchKVCache, MTP speculation, or vision preprocessing. This needs validation before marking as experimental. -
Branch has merge conflicts with current main.
|
Hi @jackneil -- same question on this one. Review feedback is open and it has merge conflicts. Are you still planning to address the feedback and rebase? Will check back in two weeks. |
Summary
--compileflag tovllm-mlx servethat wraps model forward pass withmx.compile(shapeless=True)vllm-mlx bench-compile <model>A/B benchmark command with warmup, cache clearing, and statistical reportingUsage
Test plan
🤖 Generated with Claude Code