Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from types import SimpleNamespace


def _serve_args(**overrides):
args = {
"api_key": None,
"auto_unload_idle_seconds": 0.0,
"cache_memory_mb": None,
"cache_memory_percent": 0.2,
"chunked_prefill_tokens": 0,
"continuous_batching": False,
"default_min_p": None,
"default_presence_penalty": None,
"default_repetition_penalty": None,
"default_temperature": None,
"default_top_k": None,
"default_top_p": None,
"disable_prefix_cache": False,
"download_retries": 0,
"download_timeout": 1,
"embedding_model": None,
"enable_auto_tool_choice": False,
"enable_metrics": False,
"enable_mtp": False,
"enable_prefix_cache": True,
"gpu_memory_utilization": 0.9,
"host": "127.0.0.1",
"kv_cache_min_quantize_tokens": 256,
"kv_cache_quantization": False,
"kv_cache_quantization_bits": 8,
"kv_cache_quantization_group_size": 64,
"max_cache_blocks": 1000,
"max_num_seqs": 32,
"max_tokens": 16,
"mcp_config": None,
"mllm_prefill_step_size": None,
"mllm": False,
"model": "local-test-model",
"mtp_num_draft_tokens": 1,
"mtp_optimistic": False,
"no_memory_aware_cache": False,
"offline": True,
"paged_cache_block_size": 64,
"port": 8000,
"prefill_batch_size": 8,
"prefill_step_size": 512,
"prefix_cache_size": 100,
"rate_limit": 0,
"reasoning_parser": None,
"served_model_name": None,
"specprefill": False,
"specprefill_backbone_pct": 0.0,
"specprefill_draft_model": None,
"specprefill_keep_pct": 0.3,
"specprefill_threshold": 8192,
"stream_interval": 1,
"tool_call_parser": None,
"timeout": 300,
"lazy_load_model": False,
"use_paged_cache": False,
}
args.update(overrides)
return SimpleNamespace(**args)


def test_serve_command_propagates_all_sampling_defaults(monkeypatch):
from vllm_mlx import cli, server
from vllm_mlx.utils import download

monkeypatch.setattr(
download, "ensure_model_downloaded", lambda *args, **kwargs: "local-test-model"
)
monkeypatch.setattr(server, "load_model", lambda *args, **kwargs: None)
monkeypatch.setattr("uvicorn.run", lambda *args, **kwargs: None)

for attr in (
"_default_temperature",
"_default_top_p",
"_default_top_k",
"_default_min_p",
"_default_presence_penalty",
"_default_repetition_penalty",
):
monkeypatch.setattr(server, attr, None)

cli.serve_command(
_serve_args(
default_temperature=0.6,
default_top_p=0.95,
default_top_k=20,
default_min_p=0.0,
default_presence_penalty=0.0,
default_repetition_penalty=1.0,
)
)

assert server._default_temperature == 0.6
assert server._default_top_p == 0.95
assert server._default_top_k == 20
assert server._default_min_p == 0.0
assert server._default_presence_penalty == 0.0
assert server._default_repetition_penalty == 1.0
14 changes: 12 additions & 2 deletions tests/test_kv_cache_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,11 @@ class FakeModel:

def test_store_fetch_without_quantization(self):
model = self._make_cache_and_model()
config = MemoryCacheConfig(kv_quantize=False, max_memory_mb=500)
config = MemoryCacheConfig(
kv_quantize=False,
max_memory_mb=500,
min_prefix_tokens=1,
)
pc = MemoryAwarePrefixCache(model, config)

cache = _make_kv_cache(n_layers=2, seq_len=50)
Expand All @@ -187,6 +191,7 @@ def test_store_fetch_with_quantization(self):
kv_bits=8,
kv_min_quantize_tokens=0,
max_memory_mb=500,
min_prefix_tokens=1,
)
pc = MemoryAwarePrefixCache(model, config)

Expand Down Expand Up @@ -282,6 +287,7 @@ def test_store_skips_quantization_below_threshold(self):
kv_bits=8,
kv_min_quantize_tokens=256,
max_memory_mb=500,
min_prefix_tokens=1,
)
pc = MemoryAwarePrefixCache(model, config)

Expand Down Expand Up @@ -319,7 +325,11 @@ def test_store_quantizes_above_threshold(self):
def test_trim_applied_without_quantization(self):
"""Oversized arrays should be trimmed even without quantization."""
model = self._make_model()
config = MemoryCacheConfig(kv_quantize=False, max_memory_mb=500)
config = MemoryCacheConfig(
kv_quantize=False,
max_memory_mb=500,
min_prefix_tokens=1,
)
pc = MemoryAwarePrefixCache(model, config)

# Create oversized cache: arrays have 4096 but offset is 100
Expand Down
37 changes: 35 additions & 2 deletions tests/test_memory_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_default_config(self):
assert config.max_memory_percent == 0.20
assert config.max_entries == 1000
assert config.enable_memory_tracking is True
assert config.min_prefix_tokens == 128

def test_custom_config(self):
config = MemoryCacheConfig(
Expand Down Expand Up @@ -52,6 +53,10 @@ def test_invalid_max_entries(self):
with pytest.raises(ValueError, match="max_entries"):
MemoryCacheConfig(max_entries=0)

def test_invalid_min_prefix_tokens(self):
with pytest.raises(ValueError, match="min_prefix_tokens"):
MemoryCacheConfig(min_prefix_tokens=0)

def test_compute_memory_limit_explicit(self):
config = MemoryCacheConfig(max_memory_mb=1024)
assert config.compute_memory_limit() == 1024 * 1024 * 1024
Expand Down Expand Up @@ -239,7 +244,11 @@ def model(self):
@pytest.fixture
def small_cache(self, model):
"""Cache with 1MB limit."""
config = MemoryCacheConfig(max_memory_mb=1, max_entries=10)
config = MemoryCacheConfig(
max_memory_mb=1,
max_entries=10,
min_prefix_tokens=1,
)
return MemoryAwarePrefixCache(model, config)

@pytest.fixture
Expand Down Expand Up @@ -270,6 +279,26 @@ def test_store_and_fetch_exact_match(self, small_cache, mock_kv_cache):
assert result is kv # Same reference, no copy
assert remaining == []

def test_short_prefix_reuse_is_rejected(self, model, mock_kv_cache):
cache = MemoryAwarePrefixCache(
model,
MemoryCacheConfig(
max_memory_mb=10,
max_entries=10,
min_prefix_tokens=8,
),
)
short_tokens = [1, 2, 3, 4, 5]
kv = mock_kv_cache(1000)

assert cache.store(short_tokens, kv) is False
assert len(cache) == 0

result, remaining = cache.fetch(short_tokens)
assert result is None
assert remaining == short_tokens
assert cache.get_stats()["misses"] == 1

def test_fetch_prefix_match(self, small_cache, mock_kv_cache):
# Store shorter sequence
short_tokens = [1, 2, 3]
Expand All @@ -295,7 +324,11 @@ def test_fetch_miss(self, small_cache, mock_kv_cache):

def test_lru_eviction_on_memory_pressure(self, model, mock_kv_cache):
# Create cache with 500KB limit
config = MemoryCacheConfig(max_memory_mb=0.5, max_entries=100)
config = MemoryCacheConfig(
max_memory_mb=0.5,
max_entries=100,
min_prefix_tokens=1,
)
cache = MemoryAwarePrefixCache(model, config)

# Store entries that together exceed limit
Expand Down
4 changes: 3 additions & 1 deletion tests/test_memory_cache_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ def test_fetch_returns_sliced_cache_on_lcp_match(self):

model = MagicMock()
cache = MemoryAwarePrefixCache(
model, MemoryCacheConfig(max_memory_mb=64, max_entries=10)
model,
MemoryCacheConfig(max_memory_mb=64, max_entries=10, min_prefix_tokens=1),
)

# Stored: tokens [1..120] with 120 positions of KV data, the first 60
Expand Down Expand Up @@ -424,6 +425,7 @@ def test_dequantize_end_to_end_fetch_with_quantization(self):
kv_bits=8,
kv_group_size=64,
kv_min_quantize_tokens=0,
min_prefix_tokens=1,
),
)

Expand Down
58 changes: 58 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,64 @@ def test_max_tokens_must_be_positive(self):
)


class TestSamplingDefaults:
"""Test server-wide sampling default resolution."""

def test_extended_sampling_defaults_resolve_from_server_globals(self):
from vllm_mlx import server

old_values = (
server._default_top_k,
server._default_min_p,
server._default_presence_penalty,
server._default_repetition_penalty,
)
try:
server._default_top_k = 20
server._default_min_p = 0.05
server._default_presence_penalty = 1.5
server._default_repetition_penalty = 1.1

assert server._resolve_top_k(None) == 20
assert server._resolve_min_p(None) == 0.05
assert server._resolve_presence_penalty(None) == 1.5
assert server._resolve_repetition_penalty(None) == 1.1
finally:
(
server._default_top_k,
server._default_min_p,
server._default_presence_penalty,
server._default_repetition_penalty,
) = old_values

def test_request_values_override_extended_sampling_defaults(self):
from vllm_mlx import server

old_values = (
server._default_top_k,
server._default_min_p,
server._default_presence_penalty,
server._default_repetition_penalty,
)
try:
server._default_top_k = 20
server._default_min_p = 0.05
server._default_presence_penalty = 1.5
server._default_repetition_penalty = 1.1

assert server._resolve_top_k(0) == 0
assert server._resolve_min_p(0.0) == 0.0
assert server._resolve_presence_penalty(0.0) == 0.0
assert server._resolve_repetition_penalty(1.0) == 1.0
finally:
(
server._default_top_k,
server._default_min_p,
server._default_presence_penalty,
server._default_repetition_penalty,
) = old_values


class TestAnthropicRequest:
"""Test Anthropic request model."""

Expand Down
42 changes: 41 additions & 1 deletion vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,17 @@ def serve_command(args):
server._default_temperature = args.default_temperature
if args.default_top_p is not None:
server._default_top_p = args.default_top_p
server._default_chat_template_kwargs = args.default_chat_template_kwargs
server._default_chat_template_kwargs = getattr(
args, "default_chat_template_kwargs", None
)
if args.default_top_k is not None:
server._default_top_k = args.default_top_k
if args.default_min_p is not None:
server._default_min_p = args.default_min_p
if args.default_presence_penalty is not None:
server._default_presence_penalty = args.default_presence_penalty
if args.default_repetition_penalty is not None:
server._default_repetition_penalty = args.default_repetition_penalty
max_audio_upload_mb = getattr(args, "max_audio_upload_mb", 25)
max_tts_input_chars = getattr(args, "max_tts_input_chars", 4096)
server._max_audio_upload_bytes = max_audio_upload_mb * 1024 * 1024
Expand Down Expand Up @@ -1149,6 +1159,36 @@ def create_parser() -> argparse.ArgumentParser:
'existing server defaults (JSON object, e.g. {"enable_thinking": true})'
),
)
serve_parser.add_argument(
"--default-top-k",
type=int,
default=None,
help="Override default top_k for all requests (default: use model default)",
)
serve_parser.add_argument(
"--default-min-p",
type=float,
default=None,
help="Override default min_p for all requests (default: use model default)",
)
serve_parser.add_argument(
"--default-presence-penalty",
type=float,
default=None,
help=(
"Override default presence_penalty for all requests "
"(default: use model default)"
),
)
serve_parser.add_argument(
"--default-repetition-penalty",
type=float,
default=None,
help=(
"Override default repetition_penalty for all requests "
"(default: use model default)"
),
)
# Embedding model option
serve_parser.add_argument(
"--embedding-model",
Expand Down
Loading
Loading